\", self.on_canvas_click)\n\n self.rg_data.clear(), self.current_box.clear()\n\n def on_canvas_click(self, event) -> None:\n \"\"\"Handle mouse clicks to add points for bounding boxes on the canvas.\"\"\"\n self.current_box.append((event.x, event.y))\n self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill=\"red\")\n if len(self.current_box) == 4:\n self.rg_data.append(self.current_box.copy())\n self.draw_box(self.current_box)\n self.current_box.clear()\n\n def draw_box(self, box: List[Tuple[int, int]]) -> None:\n \"\"\"Draw a bounding box on the canvas using the provided coordinates.\"\"\"\n for i in range(4):\n self.canvas.create_line(box[i], box[(i + 1) % 4], fill=\"blue\", width=2)\n\n def remove_last_bounding_box(self) -> None:\n \"\"\"Remove the last bounding box from the list and redraw the canvas.\"\"\"\n if not self.rg_data:\n self.messagebox.showwarning(\"Warning\", \"No bounding boxes to remove.\")\n return\n self.rg_data.pop()\n self.redraw_canvas()\n\n def redraw_canvas(self) -> None:\n \"\"\"Redraw the canvas with the image and all bounding boxes.\"\"\"\n self.canvas.delete(\"all\")\n self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image)\n for box in self.rg_data:\n self.draw_box(box)\n\n def save_to_json(self) -> None:\n \"\"\"Save the selected parking zone points to a JSON file with scaled coordinates.\"\"\"\n scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height()\n data = [{\"points\": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data]\n\n from io import StringIO # Function level import, as it's only required to store coordinates\n\n write_buffer = StringIO()\n json.dump(data, write_buffer, indent=4)\n with open(\"bounding_boxes.json\", \"w\", encoding=\"utf-8\") as f:\n f.write(write_buffer.getvalue())\n self.messagebox.showinfo(\"Success\", \"Bounding boxes saved to bounding_boxes.json\")",
"chunk_type": "class",
"name": "ParkingPtsSelection",
"file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py",
"start_line": 14,
"end_line": 175,
"start_col": 0,
"end_col": 90,
"parent_name": null,
"docstring": "A class for selecting and managing parking zone points on images using a Tkinter-based UI.\n\nThis class provides functionality to upload an image, select points to define parking zones, and save the\nselected points to a JSON file. It uses Tkinter for the graphical user interface.\n\nAttributes:\n tk (module): The Tkinter module for GUI operations.\n filedialog (module): Tkinter's filedialog module for file selection operations.\n messagebox (module): Tkinter's messagebox module for displaying message boxes.\n master (tk.Tk): The main Tkinter window.\n canvas (tk.Canvas): The canvas widget for displaying the image and drawing bounding boxes.\n image (PIL.Image.Image): The uploaded image.\n canvas_image (ImageTk.PhotoImage): The image displayed on the canvas.\n rg_data (List[List[Tuple[int, int]]]): List of bounding boxes, each defined by 4 points.\n current_box (List[Tuple[int, int]]): Temporary storage for the points of the current bounding box.\n imgw (int): Original width of the uploaded image.\n imgh (int): Original height of the uploaded image.\n canvas_max_width (int): Maximum width of the canvas.\n canvas_max_height (int): Maximum height of the canvas.\n\nMethods:\n initialize_properties: Initialize properties for image, canvas, bounding boxes, and dimensions.\n upload_image: Upload and display an image on the canvas, resizing it to fit within specified dimensions.\n on_canvas_click: Handle mouse clicks to add points for bounding boxes on the canvas.\n draw_box: Draw a bounding box on the canvas using the provided coordinates.\n remove_last_bounding_box: Remove the last bounding box from the list and redraw the canvas.\n redraw_canvas: Redraw the canvas with the image and all bounding boxes.\n save_to_json: Save the selected parking zone points to a JSON file with scaled coordinates.\n\nExamples:\n >>> parking_selector = ParkingPtsSelection()\n >>> # Use the GUI to upload an image, select parking zones, and save the data",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"json",
"typing.Any",
"typing.List",
"typing.Tuple",
"cv2",
"numpy",
"ultralytics.solutions.solutions.BaseSolution",
"ultralytics.solutions.solutions.SolutionAnnotator",
"ultralytics.solutions.solutions.SolutionResults",
"ultralytics.utils.LOGGER",
"ultralytics.utils.checks.check_imshow",
"PIL.Image",
"PIL.ImageTk",
"io.StringIO",
"tkinter",
"tkinter.filedialog",
"tkinter.messagebox",
"platform"
],
"chunk_id": "class_ParkingPtsSelection_a14389a1"
},
{
"content": "class ParkingManagement(BaseSolution):\n \"\"\"\n Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization.\n\n This class extends BaseSolution to provide functionality for parking lot management, including detection of\n occupied spaces, visualization of parking regions, and display of occupancy statistics.\n\n Attributes:\n json_file (str): Path to the JSON file containing parking region details.\n json (List[Dict]): Loaded JSON data containing parking region information.\n pr_info (Dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces).\n arc (Tuple[int, int, int]): RGB color tuple for available region visualization.\n occ (Tuple[int, int, int]): RGB color tuple for occupied region visualization.\n dc (Tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects.\n\n Methods:\n process: Process the input image for parking lot management and visualization.\n\n Examples:\n >>> from ultralytics.solutions import ParkingManagement\n >>> parking_manager = ParkingManagement(model=\"yolo11n.pt\", json_file=\"parking_regions.json\")\n >>> print(f\"Occupied spaces: {parking_manager.pr_info['Occupancy']}\")\n >>> print(f\"Available spaces: {parking_manager.pr_info['Available']}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the parking management system with a YOLO model and visualization settings.\"\"\"\n super().__init__(**kwargs)\n\n self.json_file = self.CFG[\"json_file\"] # Load parking regions JSON data\n if self.json_file is None:\n LOGGER.warning(\"json_file argument missing. Parking region details required.\")\n raise ValueError(\"❌ Json file path can not be empty\")\n\n with open(self.json_file) as f:\n self.json = json.load(f)\n\n self.pr_info = {\"Occupancy\": 0, \"Available\": 0} # Dictionary for parking information\n\n self.arc = (0, 0, 255) # Available region color\n self.occ = (0, 255, 0) # Occupied region color\n self.dc = (255, 0, 189) # Centroid color for each box\n\n def process(self, im0: np.ndarray) -> SolutionResults:\n \"\"\"\n Process the input image for parking lot management and visualization.\n\n This function analyzes the input image, extracts tracks, and determines the occupancy status of parking\n regions defined in the JSON file. It annotates the image with occupied and available parking spots,\n and updates the parking information.\n\n Args:\n im0 (np.ndarray): The input inference image.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`, 'filled_slots' (number of occupied parking slots),\n 'available_slots' (number of available parking slots), and 'total_tracks' (total number of tracked objects).\n\n Examples:\n >>> parking_manager = ParkingManagement(json_file=\"parking_regions.json\")\n >>> image = cv2.imread(\"parking_lot.jpg\")\n >>> results = parking_manager.process(image)\n \"\"\"\n self.extract_tracks(im0) # Extract tracks from im0\n es, fs = len(self.json), 0 # Empty slots, filled slots\n annotator = SolutionAnnotator(im0, self.line_width) # Initialize annotator\n\n for region in self.json:\n # Convert points to a NumPy array with the correct dtype and reshape properly\n pts_array = np.array(region[\"points\"], dtype=np.int32).reshape((-1, 1, 2))\n rg_occupied = False # Occupied region initialization\n for box, cls in zip(self.boxes, self.clss):\n xc, yc = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)\n dist = cv2.pointPolygonTest(pts_array, (xc, yc), False)\n if dist >= 0:\n # cv2.circle(im0, (xc, yc), radius=self.line_width * 4, color=self.dc, thickness=-1)\n annotator.display_objects_labels(\n im0, self.model.names[int(cls)], (104, 31, 17), (255, 255, 255), xc, yc, 10\n )\n rg_occupied = True\n break\n fs, es = (fs + 1, es - 1) if rg_occupied else (fs, es)\n # Plot regions\n cv2.polylines(im0, [pts_array], isClosed=True, color=self.occ if rg_occupied else self.arc, thickness=2)\n\n self.pr_info[\"Occupancy\"], self.pr_info[\"Available\"] = fs, es\n\n annotator.display_analytics(im0, self.pr_info, (104, 31, 17), (255, 255, 255), 10)\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return SolutionResults\n return SolutionResults(\n plot_im=plot_im,\n filled_slots=self.pr_info[\"Occupancy\"],\n available_slots=self.pr_info[\"Available\"],\n total_tracks=len(self.track_ids),\n )",
"chunk_type": "class",
"name": "ParkingManagement",
"file_path": "ultralytics\\ultralytics\\solutions\\parking_management.py",
"start_line": 178,
"end_line": 276,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization.\n\nThis class extends BaseSolution to provide functionality for parking lot management, including detection of\noccupied spaces, visualization of parking regions, and display of occupancy statistics.\n\nAttributes:\n json_file (str): Path to the JSON file containing parking region details.\n json (List[Dict]): Loaded JSON data containing parking region information.\n pr_info (Dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces).\n arc (Tuple[int, int, int]): RGB color tuple for available region visualization.\n occ (Tuple[int, int, int]): RGB color tuple for occupied region visualization.\n dc (Tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects.\n\nMethods:\n process: Process the input image for parking lot management and visualization.\n\nExamples:\n >>> from ultralytics.solutions import ParkingManagement\n >>> parking_manager = ParkingManagement(model=\"yolo11n.pt\", json_file=\"parking_regions.json\")\n >>> print(f\"Occupied spaces: {parking_manager.pr_info['Occupancy']}\")\n >>> print(f\"Available spaces: {parking_manager.pr_info['Available']}\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"json",
"typing.Any",
"typing.List",
"typing.Tuple",
"cv2",
"numpy",
"ultralytics.solutions.solutions.BaseSolution",
"ultralytics.solutions.solutions.SolutionAnnotator",
"ultralytics.solutions.solutions.SolutionResults",
"ultralytics.utils.LOGGER",
"ultralytics.utils.checks.check_imshow",
"PIL.Image",
"PIL.ImageTk",
"io.StringIO",
"tkinter",
"tkinter.filedialog",
"tkinter.messagebox",
"platform",
"BaseSolution"
],
"chunk_id": "class_ParkingManagement_08332990"
},
{
"content": "from typing import Any",
"chunk_type": "import",
"name": "Any",
"file_path": "ultralytics\\ultralytics\\solutions\\queue_management.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any_9aed8648"
},
{
"content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults",
"chunk_type": "import",
"name": "BaseSolution, SolutionAnnotator, SolutionResults",
"file_path": "ultralytics\\ultralytics\\solutions\\queue_management.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 92,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_78f87b94"
},
{
"content": "from ultralytics.utils.plotting import colors",
"chunk_type": "import",
"name": "colors",
"file_path": "ultralytics\\ultralytics\\solutions\\queue_management.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_colors_46e65bda"
},
{
"content": "class QueueManager(BaseSolution):\n \"\"\"\n Manages queue counting in real-time video streams based on object tracks.\n\n This class extends BaseSolution to provide functionality for tracking and counting objects within a specified\n region in video frames.\n\n Attributes:\n counts (int): The current count of objects in the queue.\n rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle.\n region_length (int): The number of points defining the queue region.\n track_line (List[Tuple[int, int]]): List of track line coordinates.\n track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object.\n\n Methods:\n initialize_region: Initialize the queue region.\n process: Process a single frame for queue management.\n extract_tracks: Extract object tracks from the current frame.\n store_tracking_history: Store the tracking history for an object.\n display_output: Display the processed output.\n\n Examples:\n >>> cap = cv2.VideoCapture(\"path/to/video.mp4\")\n >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300])\n >>> while cap.isOpened():\n >>> success, im0 = cap.read()\n >>> if not success:\n >>> break\n >>> results = queue_manager.process(im0)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the QueueManager with parameters for tracking and counting objects in a video stream.\"\"\"\n super().__init__(**kwargs)\n self.initialize_region()\n self.counts = 0 # Queue counts information\n self.rect_color = (255, 255, 255) # Rectangle color for visualization\n self.region_length = len(self.region) # Store region length for further usage\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Process queue management for a single frame of video.\n\n Args:\n im0 (np.ndarray): Input image for processing, typically a frame from a video stream.\n\n Returns:\n (SolutionResults): Contains processed image `im0`, 'queue_count' (int, number of objects in the queue) and\n 'total_tracks' (int, total number of tracked objects).\n\n Examples:\n >>> queue_manager = QueueManager()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = queue_manager.process(frame)\n \"\"\"\n self.counts = 0 # Reset counts every frame\n self.extract_tracks(im0) # Extract tracks from the current frame\n annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n annotator.draw_region(reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2) # Draw region\n\n for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):\n # Draw bounding box and counting region\n annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(track_id, True))\n self.store_tracking_history(track_id, box) # Store track history\n\n # Cache frequently accessed attributes\n track_history = self.track_history.get(track_id, [])\n\n # Store previous position of track and check if the object is inside the counting region\n prev_position = None\n if len(track_history) > 1:\n prev_position = track_history[-2]\n if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])):\n self.counts += 1\n\n # Display queue counts\n annotator.queue_counts_display(\n f\"Queue Counts : {str(self.counts)}\",\n points=self.region,\n region_color=self.rect_color,\n txt_color=(104, 31, 17),\n )\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return a SolutionResults object with processed data\n return SolutionResults(plot_im=plot_im, queue_count=self.counts, total_tracks=len(self.track_ids))",
"chunk_type": "class",
"name": "QueueManager",
"file_path": "ultralytics\\ultralytics\\solutions\\queue_management.py",
"start_line": 9,
"end_line": 95,
"start_col": 0,
"end_col": 106,
"parent_name": null,
"docstring": "Manages queue counting in real-time video streams based on object tracks.\n\nThis class extends BaseSolution to provide functionality for tracking and counting objects within a specified\nregion in video frames.\n\nAttributes:\n counts (int): The current count of objects in the queue.\n rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle.\n region_length (int): The number of points defining the queue region.\n track_line (List[Tuple[int, int]]): List of track line coordinates.\n track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object.\n\nMethods:\n initialize_region: Initialize the queue region.\n process: Process a single frame for queue management.\n extract_tracks: Extract object tracks from the current frame.\n store_tracking_history: Store the tracking history for an object.\n display_output: Display the processed output.\n\nExamples:\n >>> cap = cv2.VideoCapture(\"path/to/video.mp4\")\n >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300])\n >>> while cap.isOpened():\n >>> success, im0 = cap.read()\n >>> if not success:\n >>> break\n >>> results = queue_manager.process(im0)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"ultralytics.solutions.solutions.BaseSolution",
"ultralytics.solutions.solutions.SolutionAnnotator",
"ultralytics.solutions.solutions.SolutionResults",
"ultralytics.utils.plotting.colors",
"BaseSolution"
],
"chunk_id": "class_QueueManager_74fd260d"
},
{
"content": "from typing import Any, Dict, List, Tuple",
"chunk_type": "import",
"name": "Any, Dict, List, Tuple",
"file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Tuple_388808c1"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_17796638"
},
{
"content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults",
"chunk_type": "import",
"name": "BaseSolution, SolutionAnnotator, SolutionResults",
"file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 92,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_0b029d21"
},
{
"content": "from ultralytics.utils.plotting import colors",
"chunk_type": "import",
"name": "colors",
"file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_colors_c1b76362"
},
{
"content": "class RegionCounter(BaseSolution):\n \"\"\"\n A class for real-time counting of objects within user-defined regions in a video stream.\n\n This class inherits from `BaseSolution` and provides functionality to define polygonal regions in a video frame,\n track objects, and count those objects that pass through each defined region. Useful for applications requiring\n counting in specified areas, such as monitoring zones or segmented sections.\n\n Attributes:\n region_template (dict): Template for creating new counting regions with default attributes including name,\n polygon coordinates, and display colors.\n counting_regions (list): List storing all defined regions, where each entry is based on `region_template`\n and includes specific region settings like name, coordinates, and color.\n region_counts (dict): Dictionary storing the count of objects for each named region.\n\n Methods:\n add_region: Add a new counting region with specified attributes.\n process: Process video frames to count objects in each region.\n initialize_regions: Initialize zones to count the objects in each one. Zones could be multiple as well.\n\n Examples:\n Initialize a RegionCounter and add a counting region\n >>> counter = RegionCounter()\n >>> counter.add_region(\"Zone1\", [(100, 100), (200, 100), (200, 200), (100, 200)], (255, 0, 0), (255, 255, 255))\n >>> results = counter.process(frame)\n >>> print(f\"Total tracks: {results.total_tracks}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the RegionCounter for real-time object counting in user-defined regions.\"\"\"\n super().__init__(**kwargs)\n self.region_template = {\n \"name\": \"Default Region\",\n \"polygon\": None,\n \"counts\": 0,\n \"region_color\": (255, 255, 255),\n \"text_color\": (0, 0, 0),\n }\n self.region_counts = {}\n self.counting_regions = []\n self.initialize_regions()\n\n def add_region(\n self,\n name: str,\n polygon_points: List[Tuple],\n region_color: Tuple[int, int, int],\n text_color: Tuple[int, int, int],\n ) -> Dict[str, Any]:\n \"\"\"\n Add a new region to the counting list based on the provided template with specific attributes.\n\n Args:\n name (str): Name assigned to the new region.\n polygon_points (List[Tuple]): List of (x, y) coordinates defining the region's polygon.\n region_color (Tuple[int, int, int]): BGR color for region visualization.\n text_color (Tuple[int, int, int]): BGR color for the text within the region.\n\n Returns:\n (Dict[str, any]): Returns a dictionary including the region information i.e. name, region_color etc.\n \"\"\"\n region = self.region_template.copy()\n region.update(\n {\n \"name\": name,\n \"polygon\": self.Polygon(polygon_points),\n \"region_color\": region_color,\n \"text_color\": text_color,\n }\n )\n self.counting_regions.append(region)\n return region\n\n def initialize_regions(self):\n \"\"\"Initialize regions only once.\"\"\"\n if self.region is None:\n self.initialize_region()\n if not isinstance(self.region, dict): # Ensure self.region is initialized and structured as a dictionary\n self.region = {\"Region#01\": self.region}\n for i, (name, pts) in enumerate(self.region.items()):\n region = self.add_region(name, pts, colors(i, True), (255, 255, 255))\n region[\"prepared_polygon\"] = self.prep(region[\"polygon\"])\n\n def process(self, im0: np.ndarray) -> SolutionResults:\n \"\"\"\n Process the input frame to detect and count objects within each defined region.\n\n Args:\n im0 (np.ndarray): Input image frame where objects and regions are annotated.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (int, total number of tracked objects),\n and 'region_counts' (dict, counts of objects per region).\n \"\"\"\n self.extract_tracks(im0)\n annotator = SolutionAnnotator(im0, line_width=self.line_width)\n\n for box, cls, track_id, conf in zip(self.boxes, self.clss, self.track_ids, self.confs):\n annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(track_id, True))\n center = self.Point(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))\n for region in self.counting_regions:\n if region[\"prepared_polygon\"].contains(center):\n region[\"counts\"] += 1\n self.region_counts[region[\"name\"]] = region[\"counts\"]\n\n # Display region counts\n for region in self.counting_regions:\n x1, y1, x2, y2 = map(int, region[\"polygon\"].bounds)\n pts = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]\n annotator.draw_region(pts, region[\"region_color\"], self.line_width * 2)\n annotator.text_label(\n [x1, y1, x2, y2],\n label=str(region[\"counts\"]),\n color=region[\"region_color\"],\n txt_color=region[\"text_color\"],\n margin=self.line_width * 4,\n )\n region[\"counts\"] = 0 # Reset for next frame\n plot_im = annotator.result()\n self.display_output(plot_im)\n\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), region_counts=self.region_counts)",
"chunk_type": "class",
"name": "RegionCounter",
"file_path": "ultralytics\\ultralytics\\solutions\\region_counter.py",
"start_line": 11,
"end_line": 132,
"start_col": 0,
"end_col": 115,
"parent_name": null,
"docstring": "A class for real-time counting of objects within user-defined regions in a video stream.\n\nThis class inherits from `BaseSolution` and provides functionality to define polygonal regions in a video frame,\ntrack objects, and count those objects that pass through each defined region. Useful for applications requiring\ncounting in specified areas, such as monitoring zones or segmented sections.\n\nAttributes:\n region_template (dict): Template for creating new counting regions with default attributes including name,\n polygon coordinates, and display colors.\n counting_regions (list): List storing all defined regions, where each entry is based on `region_template`\n and includes specific region settings like name, coordinates, and color.\n region_counts (dict): Dictionary storing the count of objects for each named region.\n\nMethods:\n add_region: Add a new counting region with specified attributes.\n process: Process video frames to count objects in each region.\n initialize_regions: Initialize zones to count the objects in each one. Zones could be multiple as well.\n\nExamples:\n Initialize a RegionCounter and add a counting region\n >>> counter = RegionCounter()\n >>> counter.add_region(\"Zone1\", [(100, 100), (200, 100), (200, 200), (100, 200)], (255, 0, 0), (255, 255, 255))\n >>> results = counter.process(frame)\n >>> print(f\"Total tracks: {results.total_tracks}\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"numpy",
"ultralytics.solutions.solutions.BaseSolution",
"ultralytics.solutions.solutions.SolutionAnnotator",
"ultralytics.solutions.solutions.SolutionResults",
"ultralytics.utils.plotting.colors",
"BaseSolution"
],
"chunk_id": "class_RegionCounter_9cb98056"
},
{
"content": "from typing import Any",
"chunk_type": "import",
"name": "Any",
"file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any_d5fd2f85"
},
{
"content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults",
"chunk_type": "import",
"name": "BaseSolution, SolutionAnnotator, SolutionResults",
"file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 92,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_a1445258"
},
{
"content": "from ultralytics.utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_934d92dd"
},
{
"content": "from ultralytics.utils.plotting import colors",
"chunk_type": "import",
"name": "colors",
"file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_colors_0fe10d62"
},
{
"content": "class SecurityAlarm(BaseSolution):\n \"\"\"\n A class to manage security alarm functionalities for real-time monitoring.\n\n This class extends the BaseSolution class and provides features to monitor objects in a frame, send email\n notifications when specific thresholds are exceeded for total detections, and annotate the output frame for\n visualization.\n\n Attributes:\n email_sent (bool): Flag to track if an email has already been sent for the current event.\n records (int): Threshold for the number of detected objects to trigger an alert.\n server (smtplib.SMTP): SMTP server connection for sending email alerts.\n to_email (str): Recipient's email address for alerts.\n from_email (str): Sender's email address for alerts.\n\n Methods:\n authenticate: Set up email server authentication for sending alerts.\n send_email: Send an email notification with details and an image attachment.\n process: Monitor the frame, process detections, and trigger alerts if thresholds are crossed.\n\n Examples:\n >>> security = SecurityAlarm()\n >>> security.authenticate(\"abc@gmail.com\", \"1111222233334444\", \"xyz@gmail.com\")\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = security.process(frame)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the SecurityAlarm class with parameters for real-time object monitoring.\n\n Args:\n **kwargs (Any): Additional keyword arguments passed to the parent class.\n \"\"\"\n super().__init__(**kwargs)\n self.email_sent = False\n self.records = self.CFG[\"records\"]\n self.server = None\n self.to_email = \"\"\n self.from_email = \"\"\n\n def authenticate(self, from_email: str, password: str, to_email: str) -> None:\n \"\"\"\n Authenticate the email server for sending alert notifications.\n\n Args:\n from_email (str): Sender's email address.\n password (str): Password for the sender's email account.\n to_email (str): Recipient's email address.\n\n This method initializes a secure connection with the SMTP server and logs in using the provided credentials.\n\n Examples:\n >>> alarm = SecurityAlarm()\n >>> alarm.authenticate(\"sender@example.com\", \"password123\", \"recipient@example.com\")\n \"\"\"\n import smtplib\n\n self.server = smtplib.SMTP(\"smtp.gmail.com: 587\")\n self.server.starttls()\n self.server.login(from_email, password)\n self.to_email = to_email\n self.from_email = from_email\n\n def send_email(self, im0, records: int = 5) -> None:\n \"\"\"\n Send an email notification with an image attachment indicating the number of objects detected.\n\n Args:\n im0 (np.ndarray): The input image or frame to be attached to the email.\n records (int, optional): The number of detected objects to be included in the email message.\n\n This method encodes the input image, composes the email message with details about the detection, and sends it\n to the specified recipient.\n\n Examples:\n >>> alarm = SecurityAlarm()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> alarm.send_email(frame, records=10)\n \"\"\"\n from email.mime.image import MIMEImage\n from email.mime.multipart import MIMEMultipart\n from email.mime.text import MIMEText\n\n import cv2\n\n img_bytes = cv2.imencode(\".jpg\", im0)[1].tobytes() # Encode the image as JPEG\n\n # Create the email\n message = MIMEMultipart()\n message[\"From\"] = self.from_email\n message[\"To\"] = self.to_email\n message[\"Subject\"] = \"Security Alert\"\n\n # Add the text message body\n message_body = f\"Ultralytics ALERT!!! {records} objects have been detected!!\"\n message.attach(MIMEText(message_body))\n\n # Attach the image\n image_attachment = MIMEImage(img_bytes, name=\"ultralytics.jpg\")\n message.attach(image_attachment)\n\n # Send the email\n try:\n self.server.send_message(message)\n LOGGER.info(\"Email sent successfully!\")\n except Exception as e:\n LOGGER.error(f\"Failed to send email: {e}\")\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Monitor the frame, process object detections, and trigger alerts if thresholds are exceeded.\n\n Args:\n im0 (np.ndarray): The input image or frame to be processed and annotated.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (total number of tracked objects) and\n 'email_sent' (whether an email alert was triggered).\n\n This method processes the input frame, extracts detections, annotates the frame with bounding boxes, and sends\n an email notification if the number of detected objects surpasses the specified threshold and an alert has not\n already been sent.\n\n Examples:\n >>> alarm = SecurityAlarm()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> results = alarm.process(frame)\n \"\"\"\n self.extract_tracks(im0) # Extract tracks\n annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n\n # Iterate over bounding boxes and classes index\n for box, cls in zip(self.boxes, self.clss):\n # Draw bounding box\n annotator.box_label(box, label=self.names[cls], color=colors(cls, True))\n\n total_det = len(self.clss)\n if total_det >= self.records and not self.email_sent: # Only send email if not sent before\n self.send_email(im0, total_det)\n self.email_sent = True\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return a SolutionResults\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), email_sent=self.email_sent)",
"chunk_type": "class",
"name": "SecurityAlarm",
"file_path": "ultralytics\\ultralytics\\solutions\\security_alarm.py",
"start_line": 10,
"end_line": 156,
"start_col": 0,
"end_col": 109,
"parent_name": null,
"docstring": "A class to manage security alarm functionalities for real-time monitoring.\n\nThis class extends the BaseSolution class and provides features to monitor objects in a frame, send email\nnotifications when specific thresholds are exceeded for total detections, and annotate the output frame for\nvisualization.\n\nAttributes:\n email_sent (bool): Flag to track if an email has already been sent for the current event.\n records (int): Threshold for the number of detected objects to trigger an alert.\n server (smtplib.SMTP): SMTP server connection for sending email alerts.\n to_email (str): Recipient's email address for alerts.\n from_email (str): Sender's email address for alerts.\n\nMethods:\n authenticate: Set up email server authentication for sending alerts.\n send_email: Send an email notification with details and an image attachment.\n process: Monitor the frame, process detections, and trigger alerts if thresholds are crossed.\n\nExamples:\n >>> security = SecurityAlarm()\n >>> security.authenticate(\"abc@gmail.com\", \"1111222233334444\", \"xyz@gmail.com\")\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = security.process(frame)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"ultralytics.solutions.solutions.BaseSolution",
"ultralytics.solutions.solutions.SolutionAnnotator",
"ultralytics.solutions.solutions.SolutionResults",
"ultralytics.utils.LOGGER",
"ultralytics.utils.plotting.colors",
"smtplib",
"email.mime.image.MIMEImage",
"email.mime.multipart.MIMEMultipart",
"email.mime.text.MIMEText",
"cv2",
"BaseSolution"
],
"chunk_id": "class_SecurityAlarm_c056dd1c"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_031fa0f7"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_b1440c31"
},
{
"content": "from typing import Any, List",
"chunk_type": "import",
"name": "Any, List",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, List_8a82c947"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_4dc09686"
},
{
"content": "from PIL import Image",
"chunk_type": "import",
"name": "Image",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Image_46aafba3"
},
{
"content": "from ultralytics.data.utils import IMG_FORMATS",
"chunk_type": "import",
"name": "IMG_FORMATS",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_IMG_FORMATS_b9b5e2e6"
},
{
"content": "from ultralytics.nn.text_model import build_text_model",
"chunk_type": "import",
"name": "build_text_model",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_build_text_model_9485a61e"
},
{
"content": "from ultralytics.utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_9e0460cc"
},
{
"content": "from ultralytics.utils.checks import check_requirements",
"chunk_type": "import",
"name": "check_requirements",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_requirements_ed830f39"
},
{
"content": "from ultralytics.utils.torch_utils import select_device",
"chunk_type": "import",
"name": "select_device",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_select_device_ecc55285"
},
{
"content": "class VisualAISearch:\n \"\"\"\n A semantic image search system that leverages OpenCLIP for generating high-quality image and text embeddings and\n FAISS for fast similarity-based retrieval.\n\n This class aligns image and text embeddings in a shared semantic space, enabling users to search large collections\n of images using natural language queries with high accuracy and speed.\n\n Attributes:\n data (str): Directory containing images.\n device (str): Computation device, e.g., 'cpu' or 'cuda'.\n faiss_index (str): Path to the FAISS index file.\n data_path_npy (str): Path to the numpy file storing image paths.\n data_dir (Path): Path object for the data directory.\n model: Loaded CLIP model.\n index: FAISS index for similarity search.\n image_paths (List[str]): List of image file paths.\n\n Methods:\n extract_image_feature: Extract CLIP embedding from an image.\n extract_text_feature: Extract CLIP embedding from text.\n load_or_build_index: Load existing FAISS index or build new one.\n search: Perform semantic search for similar images.\n\n Examples:\n Initialize and search for images\n >>> searcher = VisualAISearch(data=\"path/to/images\", device=\"cuda\")\n >>> results = searcher.search(\"a cat sitting on a chair\", k=10)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"Initialize the VisualAISearch class with FAISS index and CLIP model.\"\"\"\n check_requirements(\"faiss-cpu\")\n\n self.faiss = __import__(\"faiss\")\n self.faiss_index = \"faiss.index\"\n self.data_path_npy = \"paths.npy\"\n self.data_dir = Path(kwargs.get(\"data\", \"images\"))\n self.device = select_device(kwargs.get(\"device\", \"cpu\"))\n\n if not self.data_dir.exists():\n from ultralytics.utils import ASSETS_URL\n\n LOGGER.warning(f\"{self.data_dir} not found. Downloading images.zip from {ASSETS_URL}/images.zip\")\n from ultralytics.utils.downloads import safe_download\n\n safe_download(url=f\"{ASSETS_URL}/images.zip\", unzip=True, retry=3)\n self.data_dir = Path(\"images\")\n\n self.model = build_text_model(\"clip:ViT-B/32\", device=self.device)\n\n self.index = None\n self.image_paths = []\n\n self.load_or_build_index()\n\n def extract_image_feature(self, path: Path) -> np.ndarray:\n \"\"\"Extract CLIP image embedding from the given image path.\"\"\"\n return self.model.encode_image(Image.open(path)).cpu().numpy()\n\n def extract_text_feature(self, text: str) -> np.ndarray:\n \"\"\"Extract CLIP text embedding from the given text query.\"\"\"\n return self.model.encode_text(self.model.tokenize([text])).cpu().numpy()\n\n def load_or_build_index(self) -> None:\n \"\"\"\n Load existing FAISS index or build a new one from image features.\n\n Checks if FAISS index and image paths exist on disk. If found, loads them directly. Otherwise, builds a new\n index by extracting features from all images in the data directory, normalizes the features, and saves both the\n index and image paths for future use.\n \"\"\"\n # Check if the FAISS index and corresponding image paths already exist\n if Path(self.faiss_index).exists() and Path(self.data_path_npy).exists():\n LOGGER.info(\"Loading existing FAISS index...\")\n self.index = self.faiss.read_index(self.faiss_index) # Load the FAISS index from disk\n self.image_paths = np.load(self.data_path_npy) # Load the saved image path list\n return # Exit the function as the index is successfully loaded\n\n # If the index doesn't exist, start building it from scratch\n LOGGER.info(\"Building FAISS index from images...\")\n vectors = [] # List to store feature vectors of images\n\n # Iterate over all image files in the data directory\n for file in self.data_dir.iterdir():\n # Skip files that are not valid image formats\n if file.suffix.lower().lstrip(\".\") not in IMG_FORMATS:\n continue\n try:\n # Extract feature vector for the image and add to the list\n vectors.append(self.extract_image_feature(file))\n self.image_paths.append(file.name) # Store the corresponding image name\n except Exception as e:\n LOGGER.warning(f\"Skipping {file.name}: {e}\")\n\n # If no vectors were successfully created, raise an error\n if not vectors:\n raise RuntimeError(\"No image embeddings could be generated.\")\n\n vectors = np.vstack(vectors).astype(\"float32\") # Stack all vectors into a NumPy array and convert to float32\n self.faiss.normalize_L2(vectors) # Normalize vectors to unit length for cosine similarity\n\n self.index = self.faiss.IndexFlatIP(vectors.shape[1]) # Create a new FAISS index using inner product\n self.index.add(vectors) # Add the normalized vectors to the FAISS index\n self.faiss.write_index(self.index, self.faiss_index) # Save the newly built FAISS index to disk\n np.save(self.data_path_npy, np.array(self.image_paths)) # Save the list of image paths to disk\n\n LOGGER.info(f\"Indexed {len(self.image_paths)} images.\")\n\n def search(self, query: str, k: int = 30, similarity_thresh: float = 0.1) -> List[str]:\n \"\"\"\n Return top-k semantically similar images to the given query.\n\n Args:\n query (str): Natural language text query to search for.\n k (int, optional): Maximum number of results to return.\n similarity_thresh (float, optional): Minimum similarity threshold for filtering results.\n\n Returns:\n (List[str]): List of image filenames ranked by similarity score.\n\n Examples:\n Search for images matching a query\n >>> searcher = VisualAISearch(data=\"images\")\n >>> results = searcher.search(\"red car\", k=5, similarity_thresh=0.2)\n \"\"\"\n text_feat = self.extract_text_feature(query).astype(\"float32\")\n self.faiss.normalize_L2(text_feat)\n\n D, index = self.index.search(text_feat, k)\n results = [\n (self.image_paths[i], float(D[0][idx])) for idx, i in enumerate(index[0]) if D[0][idx] >= similarity_thresh\n ]\n results.sort(key=lambda x: x[1], reverse=True)\n\n LOGGER.info(\"\\nRanked Results:\")\n for name, score in results:\n LOGGER.info(f\" - {name} | Similarity: {score:.4f}\")\n\n return [r[0] for r in results]\n\n def __call__(self, query: str) -> List[str]:\n \"\"\"Direct call interface for the search function.\"\"\"\n return self.search(query)",
"chunk_type": "class",
"name": "VisualAISearch",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 19,
"end_line": 162,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "A semantic image search system that leverages OpenCLIP for generating high-quality image and text embeddings and\nFAISS for fast similarity-based retrieval.\n\nThis class aligns image and text embeddings in a shared semantic space, enabling users to search large collections\nof images using natural language queries with high accuracy and speed.\n\nAttributes:\n data (str): Directory containing images.\n device (str): Computation device, e.g., 'cpu' or 'cuda'.\n faiss_index (str): Path to the FAISS index file.\n data_path_npy (str): Path to the numpy file storing image paths.\n data_dir (Path): Path object for the data directory.\n model: Loaded CLIP model.\n index: FAISS index for similarity search.\n image_paths (List[str]): List of image file paths.\n\nMethods:\n extract_image_feature: Extract CLIP embedding from an image.\n extract_text_feature: Extract CLIP embedding from text.\n load_or_build_index: Load existing FAISS index or build new one.\n search: Perform semantic search for similar images.\n\nExamples:\n Initialize and search for images\n >>> searcher = VisualAISearch(data=\"path/to/images\", device=\"cuda\")\n >>> results = searcher.search(\"a cat sitting on a chair\", k=10)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"os",
"pathlib.Path",
"typing.Any",
"typing.List",
"numpy",
"PIL.Image",
"ultralytics.data.utils.IMG_FORMATS",
"ultralytics.nn.text_model.build_text_model",
"ultralytics.utils.LOGGER",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.torch_utils.select_device",
"flask.Flask",
"flask.render_template",
"flask.request",
"ultralytics.utils.ASSETS_URL",
"ultralytics.utils.downloads.safe_download"
],
"chunk_id": "class_VisualAISearch_06b9dda1"
},
{
"content": "class SearchApp:\n \"\"\"\n A Flask-based web interface for semantic image search with natural language queries.\n\n This class provides a clean, responsive frontend that enables users to input natural language queries and\n instantly view the most relevant images retrieved from the indexed database.\n\n Attributes:\n render_template: Flask template rendering function.\n request: Flask request object.\n searcher (VisualAISearch): Instance of the VisualAISearch class.\n app (Flask): Flask application instance.\n\n Methods:\n index: Process user queries and display search results.\n run: Start the Flask web application.\n\n Examples:\n Start a search application\n >>> app = SearchApp(data=\"path/to/images\", device=\"cuda\")\n >>> app.run(debug=True)\n \"\"\"\n\n def __init__(self, data: str = \"images\", device: str = None) -> None:\n \"\"\"\n Initialize the SearchApp with VisualAISearch backend.\n\n Args:\n data (str, optional): Path to directory containing images to index and search.\n device (str, optional): Device to run inference on (e.g. 'cpu', 'cuda').\n \"\"\"\n check_requirements(\"flask>=3.0.1\")\n from flask import Flask, render_template, request\n\n self.render_template = render_template\n self.request = request\n self.searcher = VisualAISearch(data=data, device=device)\n self.app = Flask(\n __name__,\n template_folder=\"templates\",\n static_folder=Path(data).resolve(), # Absolute path to serve images\n static_url_path=\"/images\", # URL prefix for images\n )\n self.app.add_url_rule(\"/\", view_func=self.index, methods=[\"GET\", \"POST\"])\n\n def index(self) -> str:\n \"\"\"Process user query and display search results in the web interface.\"\"\"\n results = []\n if self.request.method == \"POST\":\n query = self.request.form.get(\"query\", \"\").strip()\n results = self.searcher(query)\n return self.render_template(\"similarity-search.html\", results=results)\n\n def run(self, debug: bool = False) -> None:\n \"\"\"Start the Flask web application server.\"\"\"\n self.app.run(debug=debug)",
"chunk_type": "class",
"name": "SearchApp",
"file_path": "ultralytics\\ultralytics\\solutions\\similarity_search.py",
"start_line": 165,
"end_line": 220,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "A Flask-based web interface for semantic image search with natural language queries.\n\nThis class provides a clean, responsive frontend that enables users to input natural language queries and\ninstantly view the most relevant images retrieved from the indexed database.\n\nAttributes:\n render_template: Flask template rendering function.\n request: Flask request object.\n searcher (VisualAISearch): Instance of the VisualAISearch class.\n app (Flask): Flask application instance.\n\nMethods:\n index: Process user queries and display search results.\n run: Start the Flask web application.\n\nExamples:\n Start a search application\n >>> app = SearchApp(data=\"path/to/images\", device=\"cuda\")\n >>> app.run(debug=True)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"os",
"pathlib.Path",
"typing.Any",
"typing.List",
"numpy",
"PIL.Image",
"ultralytics.data.utils.IMG_FORMATS",
"ultralytics.nn.text_model.build_text_model",
"ultralytics.utils.LOGGER",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.torch_utils.select_device",
"flask.Flask",
"flask.render_template",
"flask.request",
"ultralytics.utils.ASSETS_URL",
"ultralytics.utils.downloads.safe_download"
],
"chunk_id": "class_SearchApp_25d60119"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_36e75246"
},
{
"content": "from collections import defaultdict",
"chunk_type": "import",
"name": "defaultdict",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_defaultdict_8a57e2e2"
},
{
"content": "from functools import lru_cache",
"chunk_type": "import",
"name": "lru_cache",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_lru_cache_73213f9b"
},
{
"content": "from typing import Any, Dict, List, Optional, Tuple",
"chunk_type": "import",
"name": "Any, Dict, List, Optional, Tuple",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional, Tuple_ed8c10c2"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_f1bc611f"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_13d88f0c"
},
{
"content": "from ultralytics import YOLO",
"chunk_type": "import",
"name": "YOLO",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLO_6b9996f5"
},
{
"content": "from ultralytics.solutions.config import SolutionConfig",
"chunk_type": "import",
"name": "SolutionConfig",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SolutionConfig_51782812"
},
{
"content": "from ultralytics.utils import ASSETS_URL, LOGGER, ops",
"chunk_type": "import",
"name": "ASSETS_URL, LOGGER, ops",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ASSETS_URL, LOGGER, ops_c02d76db"
},
{
"content": "from ultralytics.utils.checks import check_imshow, check_requirements",
"chunk_type": "import",
"name": "check_imshow, check_requirements",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 69,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_imshow, check_requirements_58a3f4ea"
},
{
"content": "from ultralytics.utils.plotting import Annotator",
"chunk_type": "import",
"name": "Annotator",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Annotator_ee834069"
},
{
"content": "class BaseSolution:\n \"\"\"\n A base class for managing Ultralytics Solutions.\n\n This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,\n and region initialization. It serves as the foundation for implementing specific computer vision solutions such as\n object counting, pose estimation, and analytics.\n\n Attributes:\n LineString: Class for creating line string geometries from shapely.\n Polygon: Class for creating polygon geometries from shapely.\n Point: Class for creating point geometries from shapely.\n prep: Prepared geometry function from shapely for optimized spatial operations.\n CFG (Dict[str, Any]): Configuration dictionary loaded from YAML file and updated with kwargs.\n LOGGER: Logger instance for solution-specific logging.\n annotator: Annotator instance for drawing on images.\n tracks: YOLO tracking results from the latest inference.\n track_data: Extracted tracking data (boxes or OBB) from tracks.\n boxes (List): Bounding box coordinates from tracking results.\n clss (List[int]): Class indices from tracking results.\n track_ids (List[int]): Track IDs from tracking results.\n confs (List[float]): Confidence scores from tracking results.\n track_line: Current track line for storing tracking history.\n masks: Segmentation masks from tracking results.\n r_s: Region or line geometry object for spatial operations.\n frame_no (int): Current frame number for logging purposes.\n region (List[Tuple[int, int]]): List of coordinate tuples defining region of interest.\n line_width (int): Width of lines used in visualizations.\n model (YOLO): Loaded YOLO model instance.\n names (Dict[int, str]): Dictionary mapping class indices to class names.\n classes (List[int]): List of class indices to track.\n show_conf (bool): Flag to show confidence scores in annotations.\n show_labels (bool): Flag to show class labels in annotations.\n device (str): Device for model inference.\n track_add_args (Dict[str, Any]): Additional arguments for tracking configuration.\n env_check (bool): Flag indicating whether environment supports image display.\n track_history (defaultdict): Dictionary storing tracking history for each object.\n profilers (Tuple): Profiler instances for performance monitoring.\n\n Methods:\n adjust_box_label: Generate formatted label for bounding box.\n extract_tracks: Apply object tracking and extract tracks from input image.\n store_tracking_history: Store object tracking history for given track ID and bounding box.\n initialize_region: Initialize counting region and line segment based on configuration.\n display_output: Display processing results including frames or saved results.\n process: Process method to be implemented by each Solution subclass.\n\n Examples:\n >>> solution = BaseSolution(model=\"yolo11n.pt\", region=[(0, 0), (100, 0), (100, 100), (0, 100)])\n >>> solution.initialize_region()\n >>> image = cv2.imread(\"image.jpg\")\n >>> solution.extract_tracks(image)\n >>> solution.display_output(image)\n \"\"\"\n\n def __init__(self, is_cli: bool = False, **kwargs: Any) -> None:\n \"\"\"\n Initialize the BaseSolution class with configuration settings and YOLO model.\n\n Args:\n is_cli (bool): Enable CLI mode if set to True.\n **kwargs (Any): Additional configuration parameters that override defaults.\n \"\"\"\n self.CFG = vars(SolutionConfig().update(**kwargs))\n self.LOGGER = LOGGER # Store logger object to be used in multiple solution classes\n\n check_requirements(\"shapely>=2.0.0\")\n from shapely.geometry import LineString, Point, Polygon\n from shapely.prepared import prep\n\n self.LineString = LineString\n self.Polygon = Polygon\n self.Point = Point\n self.prep = prep\n self.annotator = None # Initialize annotator\n self.tracks = None\n self.track_data = None\n self.boxes = []\n self.clss = []\n self.track_ids = []\n self.track_line = None\n self.masks = None\n self.r_s = None\n self.frame_no = -1 # Only for logging\n\n self.LOGGER.info(f\"Ultralytics Solutions: ✅ {self.CFG}\")\n self.region = self.CFG[\"region\"] # Store region data for other classes usage\n self.line_width = self.CFG[\"line_width\"]\n\n # Load Model and store additional information (classes, show_conf, show_label)\n if self.CFG[\"model\"] is None:\n self.CFG[\"model\"] = \"yolo11n.pt\"\n self.model = YOLO(self.CFG[\"model\"])\n self.names = self.model.names\n self.classes = self.CFG[\"classes\"]\n self.show_conf = self.CFG[\"show_conf\"]\n self.show_labels = self.CFG[\"show_labels\"]\n self.device = self.CFG[\"device\"]\n\n self.track_add_args = { # Tracker additional arguments for advance configuration\n k: self.CFG[k] for k in {\"iou\", \"conf\", \"device\", \"max_det\", \"half\", \"tracker\"}\n } # verbose must be passed to track method; setting it False in YOLO still logs the track information.\n\n if is_cli and self.CFG[\"source\"] is None:\n d_s = \"solutions_ci_demo.mp4\" if \"-pose\" not in self.CFG[\"model\"] else \"solution_ci_pose_demo.mp4\"\n self.LOGGER.warning(f\"source not provided. using default source {ASSETS_URL}/{d_s}\")\n from ultralytics.utils.downloads import safe_download\n\n safe_download(f\"{ASSETS_URL}/{d_s}\") # download source from ultralytics assets\n self.CFG[\"source\"] = d_s # set default source\n\n # Initialize environment and region setup\n self.env_check = check_imshow(warn=True)\n self.track_history = defaultdict(list)\n\n self.profilers = (\n ops.Profile(device=self.device), # track\n ops.Profile(device=self.device), # solution\n )\n\n def adjust_box_label(self, cls: int, conf: float, track_id: Optional[int] = None) -> Optional[str]:\n \"\"\"\n Generate a formatted label for a bounding box.\n\n This method constructs a label string for a bounding box using the class index and confidence score.\n Optionally includes the track ID if provided. The label format adapts based on the display settings\n defined in `self.show_conf` and `self.show_labels`.\n\n Args:\n cls (int): The class index of the detected object.\n conf (float): The confidence score of the detection.\n track_id (int, optional): The unique identifier for the tracked object.\n\n Returns:\n (str | None): The formatted label string if `self.show_labels` is True; otherwise, None.\n \"\"\"\n name = (\"\" if track_id is None else f\"{track_id} \") + self.names[cls]\n return (f\"{name} {conf:.2f}\" if self.show_conf else name) if self.show_labels else None\n\n def extract_tracks(self, im0: np.ndarray) -> None:\n \"\"\"\n Apply object tracking and extract tracks from an input image or frame.\n\n Args:\n im0 (np.ndarray): The input image or frame.\n\n Examples:\n >>> solution = BaseSolution()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> solution.extract_tracks(frame)\n \"\"\"\n with self.profilers[0]:\n self.tracks = self.model.track(\n source=im0, persist=True, classes=self.classes, verbose=False, **self.track_add_args\n )[0]\n is_obb = self.tracks.obb is not None\n self.track_data = self.tracks.obb if is_obb else self.tracks.boxes # Extract tracks for OBB or object detection\n\n if self.track_data and self.track_data.is_track:\n self.boxes = (self.track_data.xyxyxyxy if is_obb else self.track_data.xyxy).cpu()\n self.clss = self.track_data.cls.cpu().tolist()\n self.track_ids = self.track_data.id.int().cpu().tolist()\n self.confs = self.track_data.conf.cpu().tolist()\n else:\n self.LOGGER.warning(\"no tracks found!\")\n self.boxes, self.clss, self.track_ids, self.confs = [], [], [], []\n\n def store_tracking_history(self, track_id: int, box) -> None:\n \"\"\"\n Store the tracking history of an object.\n\n This method updates the tracking history for a given object by appending the center point of its\n bounding box to the track line. It maintains a maximum of 30 points in the tracking history.\n\n Args:\n track_id (int): The unique identifier for the tracked object.\n box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].\n\n Examples:\n >>> solution = BaseSolution()\n >>> solution.store_tracking_history(1, [100, 200, 300, 400])\n \"\"\"\n # Store tracking history\n self.track_line = self.track_history[track_id]\n self.track_line.append(tuple(box.mean(dim=0)) if box.numel() > 4 else (box[:4:2].mean(), box[1:4:2].mean()))\n if len(self.track_line) > 30:\n self.track_line.pop(0)\n\n def initialize_region(self) -> None:\n \"\"\"Initialize the counting region and line segment based on configuration settings.\"\"\"\n if self.region is None:\n self.region = [(10, 200), (540, 200), (540, 180), (10, 180)]\n self.r_s = (\n self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)\n ) # region or line\n\n def display_output(self, plot_im: np.ndarray) -> None:\n \"\"\"\n Display the results of the processing, which could involve showing frames, printing counts, or saving results.\n\n This method is responsible for visualizing the output of the object detection and tracking process. It displays\n the processed frame with annotations, and allows for user interaction to close the display.\n\n Args:\n plot_im (np.ndarray): The image or frame that has been processed and annotated.\n\n Examples:\n >>> solution = BaseSolution()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> solution.display_output(frame)\n\n Notes:\n - This method will only display output if the 'show' configuration is set to True and the environment\n supports image display.\n - The display can be closed by pressing the 'q' key.\n \"\"\"\n if self.CFG.get(\"show\") and self.env_check:\n cv2.imshow(\"Ultralytics Solutions\", plot_im)\n if cv2.waitKey(1) & 0xFF == ord(\"q\"):\n cv2.destroyAllWindows() # Closes current frame window\n return\n\n def process(self, *args: Any, **kwargs: Any):\n \"\"\"Process method should be implemented by each Solution subclass.\"\"\"\n\n def __call__(self, *args: Any, **kwargs: Any):\n \"\"\"Allow instances to be called like a function with flexible arguments.\"\"\"\n with self.profilers[1]:\n result = self.process(*args, **kwargs) # Call the subclass-specific process method\n track_or_predict = \"predict\" if type(self).__name__ == \"ObjectCropper\" else \"track\"\n track_or_predict_speed = self.profilers[0].dt * 1e3\n solution_speed = (self.profilers[1].dt - self.profilers[0].dt) * 1e3 # solution time = process - track\n result.speed = {track_or_predict: track_or_predict_speed, \"solution\": solution_speed}\n if self.CFG[\"verbose\"]:\n self.frame_no += 1\n LOGGER.info(\n f\"{self.frame_no}: {result.plot_im.shape[0]}x{result.plot_im.shape[1]} {solution_speed:.1f}ms\\n\"\n f\"Speed: {track_or_predict_speed:.1f}ms {track_or_predict}, \"\n f\"{solution_speed:.1f}ms solution per image at shape \"\n f\"(1, {getattr(self.model, 'ch', 3)}, {result.plot_im.shape[0]}, {result.plot_im.shape[1]})\\n\"\n )\n return result",
"chunk_type": "class",
"name": "BaseSolution",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 18,
"end_line": 259,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": "A base class for managing Ultralytics Solutions.\n\nThis class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,\nand region initialization. It serves as the foundation for implementing specific computer vision solutions such as\nobject counting, pose estimation, and analytics.\n\nAttributes:\n LineString: Class for creating line string geometries from shapely.\n Polygon: Class for creating polygon geometries from shapely.\n Point: Class for creating point geometries from shapely.\n prep: Prepared geometry function from shapely for optimized spatial operations.\n CFG (Dict[str, Any]): Configuration dictionary loaded from YAML file and updated with kwargs.\n LOGGER: Logger instance for solution-specific logging.\n annotator: Annotator instance for drawing on images.\n tracks: YOLO tracking results from the latest inference.\n track_data: Extracted tracking data (boxes or OBB) from tracks.\n boxes (List): Bounding box coordinates from tracking results.\n clss (List[int]): Class indices from tracking results.\n track_ids (List[int]): Track IDs from tracking results.\n confs (List[float]): Confidence scores from tracking results.\n track_line: Current track line for storing tracking history.\n masks: Segmentation masks from tracking results.\n r_s: Region or line geometry object for spatial operations.\n frame_no (int): Current frame number for logging purposes.\n region (List[Tuple[int, int]]): List of coordinate tuples defining region of interest.\n line_width (int): Width of lines used in visualizations.\n model (YOLO): Loaded YOLO model instance.\n names (Dict[int, str]): Dictionary mapping class indices to class names.\n classes (List[int]): List of class indices to track.\n show_conf (bool): Flag to show confidence scores in annotations.\n show_labels (bool): Flag to show class labels in annotations.\n device (str): Device for model inference.\n track_add_args (Dict[str, Any]): Additional arguments for tracking configuration.\n env_check (bool): Flag indicating whether environment supports image display.\n track_history (defaultdict): Dictionary storing tracking history for each object.\n profilers (Tuple): Profiler instances for performance monitoring.\n\nMethods:\n adjust_box_label: Generate formatted label for bounding box.\n extract_tracks: Apply object tracking and extract tracks from input image.\n store_tracking_history: Store object tracking history for given track ID and bounding box.\n initialize_region: Initialize counting region and line segment based on configuration.\n display_output: Display processing results including frames or saved results.\n process: Process method to be implemented by each Solution subclass.\n\nExamples:\n >>> solution = BaseSolution(model=\"yolo11n.pt\", region=[(0, 0), (100, 0), (100, 100), (0, 100)])\n >>> solution.initialize_region()\n >>> image = cv2.imread(\"image.jpg\")\n >>> solution.extract_tracks(image)\n >>> solution.display_output(image)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"collections.defaultdict",
"functools.lru_cache",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"cv2",
"numpy",
"ultralytics.YOLO",
"ultralytics.solutions.config.SolutionConfig",
"ultralytics.utils.ASSETS_URL",
"ultralytics.utils.LOGGER",
"ultralytics.utils.ops",
"ultralytics.utils.checks.check_imshow",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.plotting.Annotator",
"shapely.geometry.LineString",
"shapely.geometry.Point",
"shapely.geometry.Polygon",
"shapely.prepared.prep",
"ultralytics.utils.downloads.safe_download"
],
"chunk_id": "class_BaseSolution_d8c7373e"
},
{
"content": "class SolutionAnnotator(Annotator):\n \"\"\"\n A specialized annotator class for visualizing and analyzing computer vision tasks.\n\n This class extends the base Annotator class, providing additional methods for drawing regions, centroids, tracking\n trails, and visual annotations for Ultralytics Solutions. It offers comprehensive visualization capabilities for\n various computer vision applications including object detection, tracking, pose estimation, and analytics.\n\n Attributes:\n im (np.ndarray): The image being annotated.\n line_width (int): Thickness of lines used in annotations.\n font_size (int): Size of the font used for text annotations.\n font (str): Path to the font file used for text rendering.\n pil (bool): Whether to use PIL for text rendering.\n example (str): An example attribute for demonstration purposes.\n\n Methods:\n draw_region: Draw a region using specified points, colors, and thickness.\n queue_counts_display: Display queue counts in the specified region.\n display_analytics: Display overall statistics for parking lot management.\n estimate_pose_angle: Calculate the angle between three points in an object pose.\n draw_specific_kpts: Draw specific keypoints on the image.\n plot_workout_information: Draw a labeled text box on the image.\n plot_angle_and_count_and_stage: Visualize angle, step count, and stage for workout monitoring.\n plot_distance_and_line: Display the distance between centroids and connect them with a line.\n display_objects_labels: Annotate bounding boxes with object class labels.\n sweep_annotator: Visualize a vertical sweep line and optional label.\n visioneye: Map and connect object centroids to a visual \"eye\" point.\n circle_label: Draw a circular label within a bounding box.\n text_label: Draw a rectangular label within a bounding box.\n\n Examples:\n >>> annotator = SolutionAnnotator(image)\n >>> annotator.draw_region([(0, 0), (100, 100)], color=(0, 255, 0), thickness=5)\n >>> annotator.display_analytics(\n ... image, text={\"Available Spots\": 5}, txt_color=(0, 0, 0), bg_color=(255, 255, 255), margin=10\n ... )\n \"\"\"\n\n def __init__(\n self,\n im: np.ndarray,\n line_width: Optional[int] = None,\n font_size: Optional[int] = None,\n font: str = \"Arial.ttf\",\n pil: bool = False,\n example: str = \"abc\",\n ):\n \"\"\"\n Initialize the SolutionAnnotator class with an image for annotation.\n\n Args:\n im (np.ndarray): The image to be annotated.\n line_width (int, optional): Line thickness for drawing on the image.\n font_size (int, optional): Font size for text annotations.\n font (str): Path to the font file.\n pil (bool): Indicates whether to use PIL for rendering text.\n example (str): An example parameter for demonstration purposes.\n \"\"\"\n super().__init__(im, line_width, font_size, font, pil, example)\n\n def draw_region(\n self,\n reg_pts: Optional[List[Tuple[int, int]]] = None,\n color: Tuple[int, int, int] = (0, 255, 0),\n thickness: int = 5,\n ):\n \"\"\"\n Draw a region or line on the image.\n\n Args:\n reg_pts (List[Tuple[int, int]], optional): Region points (for line 2 points, for region 4+ points).\n color (Tuple[int, int, int]): RGB color value for the region.\n thickness (int): Line thickness for drawing the region.\n \"\"\"\n cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)\n\n # Draw small circles at the corner points\n for point in reg_pts:\n cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle\n\n def queue_counts_display(\n self,\n label: str,\n points: Optional[List[Tuple[int, int]]] = None,\n region_color: Tuple[int, int, int] = (255, 255, 255),\n txt_color: Tuple[int, int, int] = (0, 0, 0),\n ):\n \"\"\"\n Display queue counts on an image centered at the points with customizable font size and colors.\n\n Args:\n label (str): Queue counts label.\n points (List[Tuple[int, int]], optional): Region points for center point calculation to display text.\n region_color (Tuple[int, int, int]): RGB queue region color.\n txt_color (Tuple[int, int, int]): RGB text display color.\n \"\"\"\n x_values = [point[0] for point in points]\n y_values = [point[1] for point in points]\n center_x = sum(x_values) // len(points)\n center_y = sum(y_values) // len(points)\n\n text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]\n text_width = text_size[0]\n text_height = text_size[1]\n\n rect_width = text_width + 20\n rect_height = text_height + 20\n rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)\n rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)\n cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)\n\n text_x = center_x - text_width // 2\n text_y = center_y + text_height // 2\n\n # Draw text\n cv2.putText(\n self.im,\n label,\n (text_x, text_y),\n 0,\n fontScale=self.sf,\n color=txt_color,\n thickness=self.tf,\n lineType=cv2.LINE_AA,\n )\n\n def display_analytics(\n self,\n im0: np.ndarray,\n text: Dict[str, Any],\n txt_color: Tuple[int, int, int],\n bg_color: Tuple[int, int, int],\n margin: int,\n ):\n \"\"\"\n Display the overall statistics for parking lots, object counter etc.\n\n Args:\n im0 (np.ndarray): Inference image.\n text (Dict[str, Any]): Labels dictionary.\n txt_color (Tuple[int, int, int]): Display color for text foreground.\n bg_color (Tuple[int, int, int]): Display color for text background.\n margin (int): Gap between text and rectangle for better display.\n \"\"\"\n horizontal_gap = int(im0.shape[1] * 0.02)\n vertical_gap = int(im0.shape[0] * 0.01)\n text_y_offset = 0\n for label, value in text.items():\n txt = f\"{label}: {value}\"\n text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]\n if text_size[0] < 5 or text_size[1] < 5:\n text_size = (5, 5)\n text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap\n text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap\n rect_x1 = text_x - margin * 2\n rect_y1 = text_y - text_size[1] - margin * 2\n rect_x2 = text_x + text_size[0] + margin * 2\n rect_y2 = text_y + margin * 2\n cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)\n cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)\n text_y_offset = rect_y2\n\n @staticmethod\n @lru_cache(maxsize=256)\n def estimate_pose_angle(a: List[float], b: List[float], c: List[float]) -> float:\n \"\"\"\n Calculate the angle between three points for workout monitoring.\n\n Args:\n a (List[float]): The coordinates of the first point.\n b (List[float]): The coordinates of the second point (vertex).\n c (List[float]): The coordinates of the third point.\n\n Returns:\n (float): The angle in degrees between the three points.\n \"\"\"\n radians = math.atan2(c[1] - b[1], c[0] - b[0]) - math.atan2(a[1] - b[1], a[0] - b[0])\n angle = abs(radians * 180.0 / math.pi)\n return angle if angle <= 180.0 else (360 - angle)\n\n def draw_specific_kpts(\n self,\n keypoints: List[List[float]],\n indices: Optional[List[int]] = None,\n radius: int = 2,\n conf_thresh: float = 0.25,\n ) -> np.ndarray:\n \"\"\"\n Draw specific keypoints for gym steps counting.\n\n Args:\n keypoints (List[List[float]]): Keypoints data to be plotted, each in format [x, y, confidence].\n indices (List[int], optional): Keypoint indices to be plotted.\n radius (int): Keypoint radius.\n conf_thresh (float): Confidence threshold for keypoints.\n\n Returns:\n (np.ndarray): Image with drawn keypoints.\n\n Notes:\n Keypoint format: [x, y] or [x, y, confidence].\n Modifies self.im in-place.\n \"\"\"\n indices = indices or [2, 5, 7]\n points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thresh]\n\n # Draw lines between consecutive points\n for start, end in zip(points[:-1], points[1:]):\n cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA)\n\n # Draw circles for keypoints\n for pt in points:\n cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA)\n\n return self.im\n\n def plot_workout_information(\n self,\n display_text: str,\n position: Tuple[int, int],\n color: Tuple[int, int, int] = (104, 31, 17),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n ) -> int:\n \"\"\"\n Draw workout text with a background on the image.\n\n Args:\n display_text (str): The text to be displayed.\n position (Tuple[int, int]): Coordinates (x, y) on the image where the text will be placed.\n color (Tuple[int, int, int]): Text background color.\n txt_color (Tuple[int, int, int]): Text foreground color.\n\n Returns:\n (int): The height of the text.\n \"\"\"\n (text_width, text_height), _ = cv2.getTextSize(display_text, 0, fontScale=self.sf, thickness=self.tf)\n\n # Draw background rectangle\n cv2.rectangle(\n self.im,\n (position[0], position[1] - text_height - 5),\n (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf),\n color,\n -1,\n )\n # Draw text\n cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf)\n\n return text_height\n\n def plot_angle_and_count_and_stage(\n self,\n angle_text: str,\n count_text: str,\n stage_text: str,\n center_kpt: List[int],\n color: Tuple[int, int, int] = (104, 31, 17),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n ):\n \"\"\"\n Plot the pose angle, count value, and step stage for workout monitoring.\n\n Args:\n angle_text (str): Angle value for workout monitoring.\n count_text (str): Counts value for workout monitoring.\n stage_text (str): Stage decision for workout monitoring.\n center_kpt (List[int]): Centroid pose index for workout monitoring.\n color (Tuple[int, int, int]): Text background color.\n txt_color (Tuple[int, int, int]): Text foreground color.\n \"\"\"\n # Format text\n angle_text, count_text, stage_text = f\" {angle_text:.2f}\", f\"Steps : {count_text}\", f\" {stage_text}\"\n\n # Draw angle, count and stage text\n angle_height = self.plot_workout_information(\n angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color\n )\n count_height = self.plot_workout_information(\n count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color\n )\n self.plot_workout_information(\n stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color\n )\n\n def plot_distance_and_line(\n self,\n pixels_distance: float,\n centroids: List[Tuple[int, int]],\n line_color: Tuple[int, int, int] = (104, 31, 17),\n centroid_color: Tuple[int, int, int] = (255, 0, 255),\n ):\n \"\"\"\n Plot the distance and line between two centroids on the frame.\n\n Args:\n pixels_distance (float): Pixels distance between two bbox centroids.\n centroids (List[Tuple[int, int]]): Bounding box centroids data.\n line_color (Tuple[int, int, int]): Distance line color.\n centroid_color (Tuple[int, int, int]): Bounding box centroid color.\n \"\"\"\n # Get the text size\n text = f\"Pixels Distance: {pixels_distance:.2f}\"\n (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)\n\n # Define corners with 10-pixel margin and draw rectangle\n cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)\n\n # Calculate the position for the text with a 10-pixel margin and draw text\n text_position = (25, 25 + text_height_m + 10)\n cv2.putText(\n self.im,\n text,\n text_position,\n 0,\n self.sf,\n (255, 255, 255),\n self.tf,\n cv2.LINE_AA,\n )\n\n cv2.line(self.im, centroids[0], centroids[1], line_color, 3)\n cv2.circle(self.im, centroids[0], 6, centroid_color, -1)\n cv2.circle(self.im, centroids[1], 6, centroid_color, -1)\n\n def display_objects_labels(\n self,\n im0: np.ndarray,\n text: str,\n txt_color: Tuple[int, int, int],\n bg_color: Tuple[int, int, int],\n x_center: float,\n y_center: float,\n margin: int,\n ):\n \"\"\"\n Display the bounding boxes labels in parking management app.\n\n Args:\n im0 (np.ndarray): Inference image.\n text (str): Object/class name.\n txt_color (Tuple[int, int, int]): Display color for text foreground.\n bg_color (Tuple[int, int, int]): Display color for text background.\n x_center (float): The x position center point for bounding box.\n y_center (float): The y position center point for bounding box.\n margin (int): The gap between text and rectangle for better display.\n \"\"\"\n text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]\n text_x = x_center - text_size[0] // 2\n text_y = y_center + text_size[1] // 2\n\n rect_x1 = text_x - margin\n rect_y1 = text_y - text_size[1] - margin\n rect_x2 = text_x + text_size[0] + margin\n rect_y2 = text_y + margin\n cv2.rectangle(\n im0,\n (int(rect_x1), int(rect_y1)),\n (int(rect_x2), int(rect_y2)),\n tuple(map(int, bg_color)), # Ensure color values are int\n -1,\n )\n\n cv2.putText(\n im0,\n text,\n (int(text_x), int(text_y)),\n 0,\n self.sf,\n tuple(map(int, txt_color)), # Ensure color values are int\n self.tf,\n lineType=cv2.LINE_AA,\n )\n\n def sweep_annotator(\n self,\n line_x: int = 0,\n line_y: int = 0,\n label: Optional[str] = None,\n color: Tuple[int, int, int] = (221, 0, 186),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n ):\n \"\"\"\n Draw a sweep annotation line and an optional label.\n\n Args:\n line_x (int): The x-coordinate of the sweep line.\n line_y (int): The y-coordinate limit of the sweep line.\n label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn.\n color (Tuple[int, int, int]): RGB color for the line and label background.\n txt_color (Tuple[int, int, int]): RGB color for the label text.\n \"\"\"\n # Draw the sweep line\n cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2)\n\n # Draw label, if provided\n if label:\n (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf)\n cv2.rectangle(\n self.im,\n (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10),\n (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10),\n color,\n -1,\n )\n cv2.putText(\n self.im,\n label,\n (line_x - text_width // 2, line_y // 2 + text_height // 2),\n cv2.FONT_HERSHEY_SIMPLEX,\n self.sf,\n txt_color,\n self.tf,\n )\n\n def visioneye(\n self,\n box: List[float],\n center_point: Tuple[int, int],\n color: Tuple[int, int, int] = (235, 219, 11),\n pin_color: Tuple[int, int, int] = (255, 0, 255),\n ):\n \"\"\"\n Perform pinpoint human-vision eye mapping and plotting.\n\n Args:\n box (List[float]): Bounding box coordinates in format [x1, y1, x2, y2].\n center_point (Tuple[int, int]): Center point for vision eye view.\n color (Tuple[int, int, int]): Object centroid and line color.\n pin_color (Tuple[int, int, int]): Visioneye point color.\n \"\"\"\n center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)\n cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)\n cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)\n cv2.line(self.im, center_point, center_bbox, color, self.tf)\n\n def circle_label(\n self,\n box: Tuple[float, float, float, float],\n label: str = \"\",\n color: Tuple[int, int, int] = (128, 128, 128),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n margin: int = 2,\n ):\n \"\"\"\n Draw a label with a background circle centered within a given bounding box.\n\n Args:\n box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).\n label (str): The text label to be displayed.\n color (Tuple[int, int, int]): The background color of the circle (B, G, R).\n txt_color (Tuple[int, int, int]): The color of the text (R, G, B).\n margin (int): The margin between the text and the circle border.\n \"\"\"\n if len(label) > 3:\n LOGGER.warning(f\"Length of label is {len(label)}, only first 3 letters will be used for circle annotation.\")\n label = label[:3]\n\n # Calculate the center of the box\n x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)\n # Get the text size\n text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]\n # Calculate the required radius to fit the text with the margin\n required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin\n # Draw the circle with the required radius\n cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)\n # Calculate the position for the text\n text_x = x_center - text_size[0] // 2\n text_y = y_center + text_size[1] // 2\n # Draw the text\n cv2.putText(\n self.im,\n str(label),\n (text_x, text_y),\n cv2.FONT_HERSHEY_SIMPLEX,\n self.sf - 0.15,\n self.get_txt_color(color, txt_color),\n self.tf,\n lineType=cv2.LINE_AA,\n )\n\n def text_label(\n self,\n box: Tuple[float, float, float, float],\n label: str = \"\",\n color: Tuple[int, int, int] = (128, 128, 128),\n txt_color: Tuple[int, int, int] = (255, 255, 255),\n margin: int = 5,\n ):\n \"\"\"\n Draw a label with a background rectangle centered within a given bounding box.\n\n Args:\n box (Tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).\n label (str): The text label to be displayed.\n color (Tuple[int, int, int]): The background color of the rectangle (B, G, R).\n txt_color (Tuple[int, int, int]): The color of the text (R, G, B).\n margin (int): The margin between the text and the rectangle border.\n \"\"\"\n # Calculate the center of the bounding box\n x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)\n # Get the size of the text\n text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]\n # Calculate the top-left corner of the text (to center it)\n text_x = x_center - text_size[0] // 2\n text_y = y_center + text_size[1] // 2\n # Calculate the coordinates of the background rectangle\n rect_x1 = text_x - margin\n rect_y1 = text_y - text_size[1] - margin\n rect_x2 = text_x + text_size[0] + margin\n rect_y2 = text_y + margin\n # Draw the background rectangle\n cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)\n # Draw the text on top of the rectangle\n cv2.putText(\n self.im,\n label,\n (text_x, text_y),\n cv2.FONT_HERSHEY_SIMPLEX,\n self.sf - 0.1,\n self.get_txt_color(color, txt_color),\n self.tf,\n lineType=cv2.LINE_AA,\n )",
"chunk_type": "class",
"name": "SolutionAnnotator",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 262,
"end_line": 785,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "A specialized annotator class for visualizing and analyzing computer vision tasks.\n\nThis class extends the base Annotator class, providing additional methods for drawing regions, centroids, tracking\ntrails, and visual annotations for Ultralytics Solutions. It offers comprehensive visualization capabilities for\nvarious computer vision applications including object detection, tracking, pose estimation, and analytics.\n\nAttributes:\n im (np.ndarray): The image being annotated.\n line_width (int): Thickness of lines used in annotations.\n font_size (int): Size of the font used for text annotations.\n font (str): Path to the font file used for text rendering.\n pil (bool): Whether to use PIL for text rendering.\n example (str): An example attribute for demonstration purposes.\n\nMethods:\n draw_region: Draw a region using specified points, colors, and thickness.\n queue_counts_display: Display queue counts in the specified region.\n display_analytics: Display overall statistics for parking lot management.\n estimate_pose_angle: Calculate the angle between three points in an object pose.\n draw_specific_kpts: Draw specific keypoints on the image.\n plot_workout_information: Draw a labeled text box on the image.\n plot_angle_and_count_and_stage: Visualize angle, step count, and stage for workout monitoring.\n plot_distance_and_line: Display the distance between centroids and connect them with a line.\n display_objects_labels: Annotate bounding boxes with object class labels.\n sweep_annotator: Visualize a vertical sweep line and optional label.\n visioneye: Map and connect object centroids to a visual \"eye\" point.\n circle_label: Draw a circular label within a bounding box.\n text_label: Draw a rectangular label within a bounding box.\n\nExamples:\n >>> annotator = SolutionAnnotator(image)\n >>> annotator.draw_region([(0, 0), (100, 100)], color=(0, 255, 0), thickness=5)\n >>> annotator.display_analytics(\n ... image, text={\"Available Spots\": 5}, txt_color=(0, 0, 0), bg_color=(255, 255, 255), margin=10\n ... )",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"collections.defaultdict",
"functools.lru_cache",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"cv2",
"numpy",
"ultralytics.YOLO",
"ultralytics.solutions.config.SolutionConfig",
"ultralytics.utils.ASSETS_URL",
"ultralytics.utils.LOGGER",
"ultralytics.utils.ops",
"ultralytics.utils.checks.check_imshow",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.plotting.Annotator",
"shapely.geometry.LineString",
"shapely.geometry.Point",
"shapely.geometry.Polygon",
"shapely.prepared.prep",
"ultralytics.utils.downloads.safe_download",
"Annotator"
],
"chunk_id": "class_SolutionAnnotator_faab5a5d"
},
{
"content": "class SolutionResults:\n \"\"\"\n A class to encapsulate the results of Ultralytics Solutions.\n\n This class is designed to store and manage various outputs generated by the solution pipeline, including counts,\n angles, workout stages, and other analytics data. It provides a structured way to access and manipulate results\n from different computer vision solutions such as object counting, pose estimation, and tracking analytics.\n\n Attributes:\n plot_im (np.ndarray): Processed image with counts, blurred, or other effects from solutions.\n in_count (int): The total number of \"in\" counts in a video stream.\n out_count (int): The total number of \"out\" counts in a video stream.\n classwise_count (Dict[str, int]): A dictionary containing counts of objects categorized by class.\n queue_count (int): The count of objects in a queue or waiting area.\n workout_count (int): The count of workout repetitions.\n workout_angle (float): The angle calculated during a workout exercise.\n workout_stage (str): The current stage of the workout.\n pixels_distance (float): The calculated distance in pixels between two points or objects.\n available_slots (int): The number of available slots in a monitored area.\n filled_slots (int): The number of filled slots in a monitored area.\n email_sent (bool): A flag indicating whether an email notification was sent.\n total_tracks (int): The total number of tracked objects.\n region_counts (Dict[str, int]): The count of objects within a specific region.\n speed_dict (Dict[str, float]): A dictionary containing speed information for tracked objects.\n total_crop_objects (int): Total number of cropped objects using ObjectCropper class.\n speed (Dict[str, float]): Performance timing information for tracking and solution processing.\n \"\"\"\n\n def __init__(self, **kwargs):\n \"\"\"\n Initialize a SolutionResults object with default or user-specified values.\n\n Args:\n **kwargs (Any): Optional arguments to override default attribute values.\n \"\"\"\n self.plot_im = None\n self.in_count = 0\n self.out_count = 0\n self.classwise_count = {}\n self.queue_count = 0\n self.workout_count = 0\n self.workout_angle = 0.0\n self.workout_stage = None\n self.pixels_distance = 0.0\n self.available_slots = 0\n self.filled_slots = 0\n self.email_sent = False\n self.total_tracks = 0\n self.region_counts = {}\n self.speed_dict = {} # for speed estimation\n self.total_crop_objects = 0\n self.speed = {}\n\n # Override with user-defined values\n self.__dict__.update(kwargs)\n\n def __str__(self) -> str:\n \"\"\"\n Return a formatted string representation of the SolutionResults object.\n\n Returns:\n (str): A string representation listing non-null attributes.\n \"\"\"\n attrs = {\n k: v\n for k, v in self.__dict__.items()\n if k != \"plot_im\" and v not in [None, {}, 0, 0.0, False] # Exclude `plot_im` explicitly\n }\n return \", \".join(f\"{k}={v}\" for k, v in attrs.items())",
"chunk_type": "class",
"name": "SolutionResults",
"file_path": "ultralytics\\ultralytics\\solutions\\solutions.py",
"start_line": 788,
"end_line": 856,
"start_col": 0,
"end_col": 62,
"parent_name": null,
"docstring": "A class to encapsulate the results of Ultralytics Solutions.\n\nThis class is designed to store and manage various outputs generated by the solution pipeline, including counts,\nangles, workout stages, and other analytics data. It provides a structured way to access and manipulate results\nfrom different computer vision solutions such as object counting, pose estimation, and tracking analytics.\n\nAttributes:\n plot_im (np.ndarray): Processed image with counts, blurred, or other effects from solutions.\n in_count (int): The total number of \"in\" counts in a video stream.\n out_count (int): The total number of \"out\" counts in a video stream.\n classwise_count (Dict[str, int]): A dictionary containing counts of objects categorized by class.\n queue_count (int): The count of objects in a queue or waiting area.\n workout_count (int): The count of workout repetitions.\n workout_angle (float): The angle calculated during a workout exercise.\n workout_stage (str): The current stage of the workout.\n pixels_distance (float): The calculated distance in pixels between two points or objects.\n available_slots (int): The number of available slots in a monitored area.\n filled_slots (int): The number of filled slots in a monitored area.\n email_sent (bool): A flag indicating whether an email notification was sent.\n total_tracks (int): The total number of tracked objects.\n region_counts (Dict[str, int]): The count of objects within a specific region.\n speed_dict (Dict[str, float]): A dictionary containing speed information for tracked objects.\n total_crop_objects (int): Total number of cropped objects using ObjectCropper class.\n speed (Dict[str, float]): Performance timing information for tracking and solution processing.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"collections.defaultdict",
"functools.lru_cache",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"cv2",
"numpy",
"ultralytics.YOLO",
"ultralytics.solutions.config.SolutionConfig",
"ultralytics.utils.ASSETS_URL",
"ultralytics.utils.LOGGER",
"ultralytics.utils.ops",
"ultralytics.utils.checks.check_imshow",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.plotting.Annotator",
"shapely.geometry.LineString",
"shapely.geometry.Point",
"shapely.geometry.Polygon",
"shapely.prepared.prep",
"ultralytics.utils.downloads.safe_download"
],
"chunk_id": "class_SolutionResults_66279a46"
},
{
"content": "from collections import deque",
"chunk_type": "import",
"name": "deque",
"file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_deque_a826132f"
},
{
"content": "from math import sqrt",
"chunk_type": "import",
"name": "sqrt",
"file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_sqrt_d4a1e012"
},
{
"content": "from typing import Any",
"chunk_type": "import",
"name": "Any",
"file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any_b82552cb"
},
{
"content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults",
"chunk_type": "import",
"name": "BaseSolution, SolutionAnnotator, SolutionResults",
"file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 92,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_a1387d74"
},
{
"content": "from ultralytics.utils.plotting import colors",
"chunk_type": "import",
"name": "colors",
"file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_colors_7dec7666"
},
{
"content": "class SpeedEstimator(BaseSolution):\n \"\"\"\n A class to estimate the speed of objects in a real-time video stream based on their tracks.\n\n This class extends the BaseSolution class and provides functionality for estimating object speeds using\n tracking data in video streams. Speed is calculated based on pixel displacement over time and converted\n to real-world units using a configurable meters-per-pixel scale factor.\n\n Attributes:\n fps (float): Video frame rate for time calculations.\n frame_count (int): Global frame counter for tracking temporal information.\n trk_frame_ids (dict): Maps track IDs to their first frame index.\n spd (dict): Final speed per object in km/h once locked.\n trk_hist (dict): Maps track IDs to deque of position history.\n locked_ids (set): Track IDs whose speed has been finalized.\n max_hist (int): Required frame history before computing speed.\n meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.\n max_speed (int): Maximum allowed object speed; values above this will be capped.\n\n Methods:\n process: Process input frames to estimate object speeds based on tracking data.\n store_tracking_history: Store the tracking history for an object.\n extract_tracks: Extract tracks from the current frame.\n display_output: Display the output with annotations.\n\n Examples:\n Initialize speed estimator and process a frame\n >>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = estimator.process(frame)\n >>> cv2.imshow(\"Speed Estimation\", results.plot_im)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the SpeedEstimator object with speed estimation parameters and data structures.\n\n Args:\n **kwargs (Any): Additional keyword arguments passed to the parent class.\n \"\"\"\n super().__init__(**kwargs)\n\n self.fps = self.CFG[\"fps\"] # Video frame rate for time calculations\n self.frame_count = 0 # Global frame counter\n self.trk_frame_ids = {} # Track ID → first frame index\n self.spd = {} # Final speed per object (km/h), once locked\n self.trk_hist = {} # Track ID → deque of (time, position)\n self.locked_ids = set() # Track IDs whose speed has been finalized\n self.max_hist = self.CFG[\"max_hist\"] # Required frame history before computing speed\n self.meter_per_pixel = self.CFG[\"meter_per_pixel\"] # Scene scale, depends on camera details\n self.max_speed = self.CFG[\"max_speed\"] # Maximum speed adjustment\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Process an input frame to estimate object speeds based on tracking data.\n\n Args:\n im0 (np.ndarray): Input image for processing with shape (H, W, C) for RGB images.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).\n\n Examples:\n Process a frame for speed estimation\n >>> estimator = SpeedEstimator()\n >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)\n >>> results = estimator.process(image)\n \"\"\"\n self.frame_count += 1\n self.extract_tracks(im0)\n annotator = SolutionAnnotator(im0, line_width=self.line_width)\n\n for box, track_id, _, _ in zip(self.boxes, self.track_ids, self.clss, self.confs):\n self.store_tracking_history(track_id, box)\n\n if track_id not in self.trk_hist: # Initialize history if new track found\n self.trk_hist[track_id] = deque(maxlen=self.max_hist)\n self.trk_frame_ids[track_id] = self.frame_count\n\n if track_id not in self.locked_ids: # Update history until speed is locked\n trk_hist = self.trk_hist[track_id]\n trk_hist.append(self.track_line[-1])\n\n # Compute and lock speed once enough history is collected\n if len(trk_hist) == self.max_hist:\n p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track\n dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds\n if dt > 0:\n dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement\n pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance\n meters = pixel_distance * self.meter_per_pixel # Convert to meters\n self.spd[track_id] = int(\n min((meters / dt) * 3.6, self.max_speed)\n ) # Convert to km/h and store final speed\n self.locked_ids.add(track_id) # Prevent further updates\n self.trk_hist.pop(track_id, None) # Free memory\n self.trk_frame_ids.pop(track_id, None) # Remove frame start reference\n\n if track_id in self.spd:\n speed_label = f\"{self.spd[track_id]} km/h\"\n annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return results with processed image and tracking summary\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))",
"chunk_type": "class",
"name": "SpeedEstimator",
"file_path": "ultralytics\\ultralytics\\solutions\\speed_estimation.py",
"start_line": 11,
"end_line": 117,
"start_col": 0,
"end_col": 81,
"parent_name": null,
"docstring": "A class to estimate the speed of objects in a real-time video stream based on their tracks.\n\nThis class extends the BaseSolution class and provides functionality for estimating object speeds using\ntracking data in video streams. Speed is calculated based on pixel displacement over time and converted\nto real-world units using a configurable meters-per-pixel scale factor.\n\nAttributes:\n fps (float): Video frame rate for time calculations.\n frame_count (int): Global frame counter for tracking temporal information.\n trk_frame_ids (dict): Maps track IDs to their first frame index.\n spd (dict): Final speed per object in km/h once locked.\n trk_hist (dict): Maps track IDs to deque of position history.\n locked_ids (set): Track IDs whose speed has been finalized.\n max_hist (int): Required frame history before computing speed.\n meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.\n max_speed (int): Maximum allowed object speed; values above this will be capped.\n\nMethods:\n process: Process input frames to estimate object speeds based on tracking data.\n store_tracking_history: Store the tracking history for an object.\n extract_tracks: Extract tracks from the current frame.\n display_output: Display the output with annotations.\n\nExamples:\n Initialize speed estimator and process a frame\n >>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = estimator.process(frame)\n >>> cv2.imshow(\"Speed Estimation\", results.plot_im)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.deque",
"math.sqrt",
"typing.Any",
"ultralytics.solutions.solutions.BaseSolution",
"ultralytics.solutions.solutions.SolutionAnnotator",
"ultralytics.solutions.solutions.SolutionResults",
"ultralytics.utils.plotting.colors",
"BaseSolution"
],
"chunk_id": "class_SpeedEstimator_c026ecda"
},
{
"content": "import io",
"chunk_type": "import",
"name": "io",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_io_58d2a448"
},
{
"content": "from typing import Any, List",
"chunk_type": "import",
"name": "Any, List",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, List_035f5a5b"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_4567a341"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_b8239f7b"
},
{
"content": "from ultralytics import YOLO",
"chunk_type": "import",
"name": "YOLO",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLO_2bd34118"
},
{
"content": "from ultralytics.utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_47f3cb60"
},
{
"content": "from ultralytics.utils.checks import check_requirements",
"chunk_type": "import",
"name": "check_requirements",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_requirements_e1938d5b"
},
{
"content": "from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS",
"chunk_type": "import",
"name": "GITHUB_ASSETS_STEMS",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_GITHUB_ASSETS_STEMS_899663f0"
},
{
"content": "class Inference:\n \"\"\"\n A class to perform object detection, image classification, image segmentation and pose estimation inference.\n\n This class provides functionalities for loading models, configuring settings, uploading video files, and performing\n real-time inference using Streamlit and Ultralytics YOLO models.\n\n Attributes:\n st (module): Streamlit module for UI creation.\n temp_dict (dict): Temporary dictionary to store the model path and other configuration.\n model_path (str): Path to the loaded model.\n model (YOLO): The YOLO model instance.\n source (str): Selected video source (webcam or video file).\n enable_trk (bool): Enable tracking option.\n conf (float): Confidence threshold for detection.\n iou (float): IoU threshold for non-maximum suppression.\n org_frame (Any): Container for the original frame to be displayed.\n ann_frame (Any): Container for the annotated frame to be displayed.\n vid_file_name (str | int): Name of the uploaded video file or webcam index.\n selected_ind (List[int]): List of selected class indices for detection.\n\n Methods:\n web_ui: Set up the Streamlit web interface with custom HTML elements.\n sidebar: Configure the Streamlit sidebar for model and inference settings.\n source_upload: Handle video file uploads through the Streamlit interface.\n configure: Configure the model and load selected classes for inference.\n inference: Perform real-time object detection inference.\n\n Examples:\n Create an Inference instance with a custom model\n >>> inf = Inference(model=\"path/to/model.pt\")\n >>> inf.inference()\n\n Create an Inference instance with default settings\n >>> inf = Inference()\n >>> inf.inference()\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the Inference class, checking Streamlit requirements and setting up the model path.\n\n Args:\n **kwargs (Any): Additional keyword arguments for model configuration.\n \"\"\"\n check_requirements(\"streamlit>=1.29.0\") # scope imports for faster ultralytics package load speeds\n import streamlit as st\n\n self.st = st # Reference to the Streamlit module\n self.source = None # Video source selection (webcam or video file)\n self.enable_trk = False # Flag to toggle object tracking\n self.conf = 0.25 # Confidence threshold for detection\n self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression\n self.org_frame = None # Container for the original frame display\n self.ann_frame = None # Container for the annotated frame display\n self.vid_file_name = None # Video file name or webcam index\n self.selected_ind: List[int] = [] # List of selected class indices for detection\n self.model = None # YOLO model instance\n\n self.temp_dict = {\"model\": None, **kwargs}\n self.model_path = None # Model file path\n if self.temp_dict[\"model\"] is not None:\n self.model_path = self.temp_dict[\"model\"]\n\n LOGGER.info(f\"Ultralytics Solutions: ✅ {self.temp_dict}\")\n\n def web_ui(self) -> None:\n \"\"\"Set up the Streamlit web interface with custom HTML elements.\"\"\"\n menu_style_cfg = \"\"\"\"\"\" # Hide main menu style\n\n # Main title of streamlit application\n main_title_cfg = \"\"\"Ultralytics YOLO Streamlit Application
\"\"\"\n\n # Subtitle of streamlit application\n sub_title_cfg = \"\"\"Experience real-time object detection on your webcam with the power \n of Ultralytics YOLO! 🚀
\"\"\"\n\n # Set html page configuration and append custom HTML\n self.st.set_page_config(page_title=\"Ultralytics Streamlit App\", layout=\"wide\")\n self.st.markdown(menu_style_cfg, unsafe_allow_html=True)\n self.st.markdown(main_title_cfg, unsafe_allow_html=True)\n self.st.markdown(sub_title_cfg, unsafe_allow_html=True)\n\n def sidebar(self) -> None:\n \"\"\"Configure the Streamlit sidebar for model and inference settings.\"\"\"\n with self.st.sidebar: # Add Ultralytics LOGO\n logo = \"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg\"\n self.st.image(logo, width=250)\n\n self.st.sidebar.title(\"User Configuration\") # Add elements to vertical setting menu\n self.source = self.st.sidebar.selectbox(\n \"Video\",\n (\"webcam\", \"video\"),\n ) # Add source selection dropdown\n self.enable_trk = self.st.sidebar.radio(\"Enable Tracking\", (\"Yes\", \"No\")) == \"Yes\" # Enable object tracking\n self.conf = float(\n self.st.sidebar.slider(\"Confidence Threshold\", 0.0, 1.0, self.conf, 0.01)\n ) # Slider for confidence\n self.iou = float(self.st.sidebar.slider(\"IoU Threshold\", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold\n\n col1, col2 = self.st.columns(2) # Create two columns for displaying frames\n self.org_frame = col1.empty() # Container for original frame\n self.ann_frame = col2.empty() # Container for annotated frame\n\n def source_upload(self) -> None:\n \"\"\"Handle video file uploads through the Streamlit interface.\"\"\"\n self.vid_file_name = \"\"\n if self.source == \"video\":\n vid_file = self.st.sidebar.file_uploader(\"Upload Video File\", type=[\"mp4\", \"mov\", \"avi\", \"mkv\"])\n if vid_file is not None:\n g = io.BytesIO(vid_file.read()) # BytesIO Object\n with open(\"ultralytics.mp4\", \"wb\") as out: # Open temporary file as bytes\n out.write(g.read()) # Read bytes into file\n self.vid_file_name = \"ultralytics.mp4\"\n elif self.source == \"webcam\":\n self.vid_file_name = 0 # Use webcam index 0\n\n def configure(self) -> None:\n \"\"\"Configure the model and load selected classes for inference.\"\"\"\n # Add dropdown menu for model selection\n M_ORD, T_ORD = [\"yolo11n\", \"yolo11s\", \"yolo11m\", \"yolo11l\", \"yolo11x\"], [\"\", \"-seg\", \"-pose\", \"-obb\", \"-cls\"]\n available_models = sorted(\n [\n x.replace(\"yolo\", \"YOLO\")\n for x in GITHUB_ASSETS_STEMS\n if any(x.startswith(b) for b in M_ORD) and \"grayscale\" not in x\n ],\n key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or \"\")),\n )\n if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later\n available_models.insert(0, self.model_path.split(\".pt\", 1)[0])\n selected_model = self.st.sidebar.selectbox(\"Model\", available_models)\n\n with self.st.spinner(\"Model is downloading...\"):\n self.model = YOLO(f\"{selected_model.lower()}.pt\") # Load the YOLO model\n class_names = list(self.model.names.values()) # Convert dictionary to list of class names\n self.st.success(\"Model loaded successfully!\")\n\n # Multiselect box with class names and get indices of selected classes\n selected_classes = self.st.sidebar.multiselect(\"Classes\", class_names, default=class_names[:3])\n self.selected_ind = [class_names.index(option) for option in selected_classes]\n\n if not isinstance(self.selected_ind, list): # Ensure selected_options is a list\n self.selected_ind = list(self.selected_ind)\n\n def inference(self) -> None:\n \"\"\"Perform real-time object detection inference on video or webcam feed.\"\"\"\n self.web_ui() # Initialize the web interface\n self.sidebar() # Create the sidebar\n self.source_upload() # Upload the video source\n self.configure() # Configure the app\n\n if self.st.sidebar.button(\"Start\"):\n stop_button = self.st.button(\"Stop\") # Button to stop the inference\n cap = cv2.VideoCapture(self.vid_file_name) # Capture the video\n if not cap.isOpened():\n self.st.error(\"Could not open webcam or video source.\")\n return\n\n while cap.isOpened():\n success, frame = cap.read()\n if not success:\n self.st.warning(\"Failed to read frame from webcam. Please verify the webcam is connected properly.\")\n break\n\n # Process frame with model\n if self.enable_trk:\n results = self.model.track(\n frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True\n )\n else:\n results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)\n\n annotated_frame = results[0].plot() # Add annotations on frame\n\n if stop_button:\n cap.release() # Release the capture\n self.st.stop() # Stop streamlit app\n\n self.org_frame.image(frame, channels=\"BGR\") # Display original frame\n self.ann_frame.image(annotated_frame, channels=\"BGR\") # Display processed frame\n\n cap.release() # Release the capture\n cv2.destroyAllWindows() # Destroy all OpenCV windows",
"chunk_type": "class",
"name": "Inference",
"file_path": "ultralytics\\ultralytics\\solutions\\streamlit_inference.py",
"start_line": 17,
"end_line": 202,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": "A class to perform object detection, image classification, image segmentation and pose estimation inference.\n\nThis class provides functionalities for loading models, configuring settings, uploading video files, and performing\nreal-time inference using Streamlit and Ultralytics YOLO models.\n\nAttributes:\n st (module): Streamlit module for UI creation.\n temp_dict (dict): Temporary dictionary to store the model path and other configuration.\n model_path (str): Path to the loaded model.\n model (YOLO): The YOLO model instance.\n source (str): Selected video source (webcam or video file).\n enable_trk (bool): Enable tracking option.\n conf (float): Confidence threshold for detection.\n iou (float): IoU threshold for non-maximum suppression.\n org_frame (Any): Container for the original frame to be displayed.\n ann_frame (Any): Container for the annotated frame to be displayed.\n vid_file_name (str | int): Name of the uploaded video file or webcam index.\n selected_ind (List[int]): List of selected class indices for detection.\n\nMethods:\n web_ui: Set up the Streamlit web interface with custom HTML elements.\n sidebar: Configure the Streamlit sidebar for model and inference settings.\n source_upload: Handle video file uploads through the Streamlit interface.\n configure: Configure the model and load selected classes for inference.\n inference: Perform real-time object detection inference.\n\nExamples:\n Create an Inference instance with a custom model\n >>> inf = Inference(model=\"path/to/model.pt\")\n >>> inf.inference()\n\n Create an Inference instance with default settings\n >>> inf = Inference()\n >>> inf.inference()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"io",
"typing.Any",
"typing.List",
"cv2",
"torch",
"ultralytics.YOLO",
"ultralytics.utils.LOGGER",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.downloads.GITHUB_ASSETS_STEMS",
"sys",
"streamlit"
],
"chunk_id": "class_Inference_9c8db937"
},
{
"content": "from typing import Any",
"chunk_type": "import",
"name": "Any",
"file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any_1cb3a850"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_ce09cae2"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_87cb7acd"
},
{
"content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults",
"chunk_type": "import",
"name": "BaseSolution, SolutionAnnotator, SolutionResults",
"file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 92,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_8b131528"
},
{
"content": "from ultralytics.utils.plotting import colors",
"chunk_type": "import",
"name": "colors",
"file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_colors_ecd6be4a"
},
{
"content": "class TrackZone(BaseSolution):\n \"\"\"\n A class to manage region-based object tracking in a video stream.\n\n This class extends the BaseSolution class and provides functionality for tracking objects within a specific region\n defined by a polygonal area. Objects outside the region are excluded from tracking.\n\n Attributes:\n region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points.\n line_width (int): Width of the lines used for drawing bounding boxes and region boundaries.\n names (List[str]): List of class names that the model can detect.\n boxes (List[np.ndarray]): Bounding boxes of tracked objects.\n track_ids (List[int]): Unique identifiers for each tracked object.\n clss (List[int]): Class indices of tracked objects.\n\n Methods:\n process: Process each frame of the video, applying region-based tracking.\n extract_tracks: Extract tracking information from the input frame.\n display_output: Display the processed output.\n\n Examples:\n >>> tracker = TrackZone()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = tracker.process(frame)\n >>> cv2.imshow(\"Tracked Frame\", results.plot_im)\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the TrackZone class for tracking objects within a defined region in video streams.\n\n Args:\n **kwargs (Any): Additional keyword arguments passed to the parent class.\n \"\"\"\n super().__init__(**kwargs)\n default_region = [(75, 75), (565, 75), (565, 285), (75, 285)]\n self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32))\n self.mask = None\n\n def process(self, im0: np.ndarray) -> SolutionResults:\n \"\"\"\n Process the input frame to track objects within a defined region.\n\n This method initializes the annotator, creates a mask for the specified region, extracts tracks\n only from the masked area, and updates tracking information. Objects outside the region are ignored.\n\n Args:\n im0 (np.ndarray): The input image or frame to be processed.\n\n Returns:\n (SolutionResults): Contains processed image `plot_im` and `total_tracks` (int) representing the\n total number of tracked objects within the defined region.\n\n Examples:\n >>> tracker = TrackZone()\n >>> frame = cv2.imread(\"path/to/image.jpg\")\n >>> results = tracker.process(frame)\n \"\"\"\n annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator\n\n if self.mask is None: # Create a mask for the region\n self.mask = np.zeros_like(im0[:, :, 0])\n cv2.fillPoly(self.mask, [self.region], 255)\n masked_frame = cv2.bitwise_and(im0, im0, mask=self.mask)\n self.extract_tracks(masked_frame)\n\n # Draw the region boundary\n cv2.polylines(im0, [self.region], isClosed=True, color=(255, 255, 255), thickness=self.line_width * 2)\n\n # Iterate over boxes, track ids, classes indexes list and draw bounding boxes\n for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):\n annotator.box_label(\n box, label=self.adjust_box_label(cls, conf, track_id=track_id), color=colors(track_id, True)\n )\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display output with base class function\n\n # Return a SolutionResults\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))",
"chunk_type": "class",
"name": "TrackZone",
"file_path": "ultralytics\\ultralytics\\solutions\\trackzone.py",
"start_line": 12,
"end_line": 91,
"start_col": 0,
"end_col": 81,
"parent_name": null,
"docstring": "A class to manage region-based object tracking in a video stream.\n\nThis class extends the BaseSolution class and provides functionality for tracking objects within a specific region\ndefined by a polygonal area. Objects outside the region are excluded from tracking.\n\nAttributes:\n region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points.\n line_width (int): Width of the lines used for drawing bounding boxes and region boundaries.\n names (List[str]): List of class names that the model can detect.\n boxes (List[np.ndarray]): Bounding boxes of tracked objects.\n track_ids (List[int]): Unique identifiers for each tracked object.\n clss (List[int]): Class indices of tracked objects.\n\nMethods:\n process: Process each frame of the video, applying region-based tracking.\n extract_tracks: Extract tracking information from the input frame.\n display_output: Display the processed output.\n\nExamples:\n >>> tracker = TrackZone()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = tracker.process(frame)\n >>> cv2.imshow(\"Tracked Frame\", results.plot_im)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"cv2",
"numpy",
"ultralytics.solutions.solutions.BaseSolution",
"ultralytics.solutions.solutions.SolutionAnnotator",
"ultralytics.solutions.solutions.SolutionResults",
"ultralytics.utils.plotting.colors",
"BaseSolution"
],
"chunk_id": "class_TrackZone_d7ec1fe6"
},
{
"content": "from typing import Any",
"chunk_type": "import",
"name": "Any",
"file_path": "ultralytics\\ultralytics\\solutions\\vision_eye.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any_c71ce5ab"
},
{
"content": "from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults",
"chunk_type": "import",
"name": "BaseSolution, SolutionAnnotator, SolutionResults",
"file_path": "ultralytics\\ultralytics\\solutions\\vision_eye.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 92,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseSolution, SolutionAnnotator, SolutionResults_61009772"
},
{
"content": "from ultralytics.utils.plotting import colors",
"chunk_type": "import",
"name": "colors",
"file_path": "ultralytics\\ultralytics\\solutions\\vision_eye.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_colors_6e3f8bf7"
},
{
"content": "class VisionEye(BaseSolution):\n \"\"\"\n A class to manage object detection and vision mapping in images or video streams.\n\n This class extends the BaseSolution class and provides functionality for detecting objects,\n mapping vision points, and annotating results with bounding boxes and labels.\n\n Attributes:\n vision_point (Tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks.\n\n Methods:\n process: Process the input image to detect objects, annotate them, and apply vision mapping.\n\n Examples:\n >>> vision_eye = VisionEye()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = vision_eye.process(frame)\n >>> print(f\"Total detected instances: {results.total_tracks}\")\n \"\"\"\n\n def __init__(self, **kwargs: Any) -> None:\n \"\"\"\n Initialize the VisionEye class for detecting objects and applying vision mapping.\n\n Args:\n **kwargs (Any): Keyword arguments passed to the parent class and for configuring vision_point.\n \"\"\"\n super().__init__(**kwargs)\n # Set the vision point where the system will view objects and draw tracks\n self.vision_point = self.CFG[\"vision_point\"]\n\n def process(self, im0) -> SolutionResults:\n \"\"\"\n Perform object detection, vision mapping, and annotation on the input image.\n\n Args:\n im0 (np.ndarray): The input image for detection and annotation.\n\n Returns:\n (SolutionResults): Object containing the annotated image and tracking statistics.\n - plot_im: Annotated output image with bounding boxes and vision mapping\n - total_tracks: Number of tracked objects in the frame\n\n Examples:\n >>> vision_eye = VisionEye()\n >>> frame = cv2.imread(\"image.jpg\")\n >>> results = vision_eye.process(frame)\n >>> print(f\"Detected {results.total_tracks} objects\")\n \"\"\"\n self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)\n annotator = SolutionAnnotator(im0, self.line_width)\n\n for cls, t_id, box, conf in zip(self.clss, self.track_ids, self.boxes, self.confs):\n # Annotate the image with bounding boxes, labels, and vision mapping\n annotator.box_label(box, label=self.adjust_box_label(cls, conf, t_id), color=colors(int(t_id), True))\n annotator.visioneye(box, self.vision_point)\n\n plot_im = annotator.result()\n self.display_output(plot_im) # Display the annotated output using the base class function\n\n # Return a SolutionResults object with the annotated image and tracking statistics\n return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))",
"chunk_type": "class",
"name": "VisionEye",
"file_path": "ultralytics\\ultralytics\\solutions\\vision_eye.py",
"start_line": 9,
"end_line": 70,
"start_col": 0,
"end_col": 81,
"parent_name": null,
"docstring": "A class to manage object detection and vision mapping in images or video streams.\n\nThis class extends the BaseSolution class and provides functionality for detecting objects,\nmapping vision points, and annotating results with bounding boxes and labels.\n\nAttributes:\n vision_point (Tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks.\n\nMethods:\n process: Process the input image to detect objects, annotate them, and apply vision mapping.\n\nExamples:\n >>> vision_eye = VisionEye()\n >>> frame = cv2.imread(\"frame.jpg\")\n >>> results = vision_eye.process(frame)\n >>> print(f\"Total detected instances: {results.total_tracks}\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"ultralytics.solutions.solutions.BaseSolution",
"ultralytics.solutions.solutions.SolutionAnnotator",
"ultralytics.solutions.solutions.SolutionResults",
"ultralytics.utils.plotting.colors",
"BaseSolution"
],
"chunk_id": "class_VisionEye_8a71f3f3"
},
{
"content": "from .ai_gym import AIGym",
"chunk_type": "import",
"name": "AIGym",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_AIGym_c56e4460"
},
{
"content": "from .analytics import Analytics",
"chunk_type": "import",
"name": "Analytics",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Analytics_d00bee54"
},
{
"content": "from .distance_calculation import DistanceCalculation",
"chunk_type": "import",
"name": "DistanceCalculation",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DistanceCalculation_91f34e07"
},
{
"content": "from .heatmap import Heatmap",
"chunk_type": "import",
"name": "Heatmap",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Heatmap_08c5d6f2"
},
{
"content": "from .instance_segmentation import InstanceSegmentation",
"chunk_type": "import",
"name": "InstanceSegmentation",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_InstanceSegmentation_4f5ecc69"
},
{
"content": "from .object_blurrer import ObjectBlurrer",
"chunk_type": "import",
"name": "ObjectBlurrer",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ObjectBlurrer_e71acde0"
},
{
"content": "from .object_counter import ObjectCounter",
"chunk_type": "import",
"name": "ObjectCounter",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ObjectCounter_a05e4b23"
},
{
"content": "from .object_cropper import ObjectCropper",
"chunk_type": "import",
"name": "ObjectCropper",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ObjectCropper_0deaa30e"
},
{
"content": "from .parking_management import ParkingManagement, ParkingPtsSelection",
"chunk_type": "import",
"name": "ParkingManagement, ParkingPtsSelection",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 70,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ParkingManagement, ParkingPtsSelection_29d30cd2"
},
{
"content": "from .queue_management import QueueManager",
"chunk_type": "import",
"name": "QueueManager",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_QueueManager_faed8198"
},
{
"content": "from .region_counter import RegionCounter",
"chunk_type": "import",
"name": "RegionCounter",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RegionCounter_0bb865cd"
},
{
"content": "from .security_alarm import SecurityAlarm",
"chunk_type": "import",
"name": "SecurityAlarm",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SecurityAlarm_e3cdaa74"
},
{
"content": "from .similarity_search import SearchApp, VisualAISearch",
"chunk_type": "import",
"name": "SearchApp, VisualAISearch",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SearchApp, VisualAISearch_0adcf032"
},
{
"content": "from .speed_estimation import SpeedEstimator",
"chunk_type": "import",
"name": "SpeedEstimator",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SpeedEstimator_5abd8870"
},
{
"content": "from .streamlit_inference import Inference",
"chunk_type": "import",
"name": "Inference",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Inference_c613e25e"
},
{
"content": "from .trackzone import TrackZone",
"chunk_type": "import",
"name": "TrackZone",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TrackZone_586fe219"
},
{
"content": "from .vision_eye import VisionEye",
"chunk_type": "import",
"name": "VisionEye",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_VisionEye_37d08a7d"
},
{
"content": "__all__ = (\n \"ObjectCounter\",\n \"ObjectCropper\",\n \"ObjectBlurrer\",\n \"AIGym\",\n \"RegionCounter\",\n \"SecurityAlarm\",\n \"Heatmap\",\n \"InstanceSegmentation\",\n \"VisionEye\",\n \"SpeedEstimator\",\n \"DistanceCalculation\",\n \"QueueManager\",\n \"ParkingManagement\",\n \"ParkingPtsSelection\",\n \"Analytics\",\n \"Inference\",\n \"TrackZone\",\n \"SearchApp\",\n \"VisualAISearch\",\n)",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\solutions\\__init__.py",
"start_line": 21,
"end_line": 41,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___5c311cc6"
},
{
"content": "from collections import OrderedDict",
"chunk_type": "import",
"name": "OrderedDict",
"file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OrderedDict_aa0536c1"
},
{
"content": "from typing import Any",
"chunk_type": "import",
"name": "Any",
"file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any_879757ea"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_8c9fb52b"
},
{
"content": "class TrackState:\n \"\"\"\n Enumeration class representing the possible states of an object being tracked.\n\n Attributes:\n New (int): State when the object is newly detected.\n Tracked (int): State when the object is successfully tracked in subsequent frames.\n Lost (int): State when the object is no longer tracked.\n Removed (int): State when the object is removed from tracking.\n\n Examples:\n >>> state = TrackState.New\n >>> if state == TrackState.New:\n >>> print(\"Object is newly detected.\")\n \"\"\"\n\n New = 0\n Tracked = 1\n Lost = 2\n Removed = 3",
"chunk_type": "class",
"name": "TrackState",
"file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py",
"start_line": 10,
"end_line": 29,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Enumeration class representing the possible states of an object being tracked.\n\nAttributes:\n New (int): State when the object is newly detected.\n Tracked (int): State when the object is successfully tracked in subsequent frames.\n Lost (int): State when the object is no longer tracked.\n Removed (int): State when the object is removed from tracking.\n\nExamples:\n >>> state = TrackState.New\n >>> if state == TrackState.New:\n >>> print(\"Object is newly detected.\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.OrderedDict",
"typing.Any",
"numpy"
],
"chunk_id": "class_TrackState_516af351"
},
{
"content": "class BaseTrack:\n \"\"\"\n Base class for object tracking, providing foundational attributes and methods.\n\n Attributes:\n _count (int): Class-level counter for unique track IDs.\n track_id (int): Unique identifier for the track.\n is_activated (bool): Flag indicating whether the track is currently active.\n state (TrackState): Current state of the track.\n history (OrderedDict): Ordered history of the track's states.\n features (list): List of features extracted from the object for tracking.\n curr_feature (Any): The current feature of the object being tracked.\n score (float): The confidence score of the tracking.\n start_frame (int): The frame number where tracking started.\n frame_id (int): The most recent frame ID processed by the track.\n time_since_update (int): Frames passed since the last update.\n location (tuple): The location of the object in the context of multi-camera tracking.\n\n Methods:\n end_frame: Returns the ID of the last frame where the object was tracked.\n next_id: Increments and returns the next global track ID.\n activate: Abstract method to activate the track.\n predict: Abstract method to predict the next state of the track.\n update: Abstract method to update the track with new data.\n mark_lost: Marks the track as lost.\n mark_removed: Marks the track as removed.\n reset_id: Resets the global track ID counter.\n\n Examples:\n Initialize a new track and mark it as lost:\n >>> track = BaseTrack()\n >>> track.mark_lost()\n >>> print(track.state) # Output: 2 (TrackState.Lost)\n \"\"\"\n\n _count = 0\n\n def __init__(self):\n \"\"\"Initialize a new track with a unique ID and foundational tracking attributes.\"\"\"\n self.track_id = 0\n self.is_activated = False\n self.state = TrackState.New\n self.history = OrderedDict()\n self.features = []\n self.curr_feature = None\n self.score = 0\n self.start_frame = 0\n self.frame_id = 0\n self.time_since_update = 0\n self.location = (np.inf, np.inf)\n\n @property\n def end_frame(self) -> int:\n \"\"\"Return the ID of the most recent frame where the object was tracked.\"\"\"\n return self.frame_id\n\n @staticmethod\n def next_id() -> int:\n \"\"\"Increment and return the next unique global track ID for object tracking.\"\"\"\n BaseTrack._count += 1\n return BaseTrack._count\n\n def activate(self, *args: Any) -> None:\n \"\"\"Activate the track with provided arguments, initializing necessary attributes for tracking.\"\"\"\n raise NotImplementedError\n\n def predict(self) -> None:\n \"\"\"Predict the next state of the track based on the current state and tracking model.\"\"\"\n raise NotImplementedError\n\n def update(self, *args: Any, **kwargs: Any) -> None:\n \"\"\"Update the track with new observations and data, modifying its state and attributes accordingly.\"\"\"\n raise NotImplementedError\n\n def mark_lost(self) -> None:\n \"\"\"Mark the track as lost by updating its state to TrackState.Lost.\"\"\"\n self.state = TrackState.Lost\n\n def mark_removed(self) -> None:\n \"\"\"Mark the track as removed by setting its state to TrackState.Removed.\"\"\"\n self.state = TrackState.Removed\n\n @staticmethod\n def reset_id() -> None:\n \"\"\"Reset the global track ID counter to its initial value.\"\"\"\n BaseTrack._count = 0",
"chunk_type": "class",
"name": "BaseTrack",
"file_path": "ultralytics\\ultralytics\\trackers\\basetrack.py",
"start_line": 32,
"end_line": 117,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Base class for object tracking, providing foundational attributes and methods.\n\nAttributes:\n _count (int): Class-level counter for unique track IDs.\n track_id (int): Unique identifier for the track.\n is_activated (bool): Flag indicating whether the track is currently active.\n state (TrackState): Current state of the track.\n history (OrderedDict): Ordered history of the track's states.\n features (list): List of features extracted from the object for tracking.\n curr_feature (Any): The current feature of the object being tracked.\n score (float): The confidence score of the tracking.\n start_frame (int): The frame number where tracking started.\n frame_id (int): The most recent frame ID processed by the track.\n time_since_update (int): Frames passed since the last update.\n location (tuple): The location of the object in the context of multi-camera tracking.\n\nMethods:\n end_frame: Returns the ID of the last frame where the object was tracked.\n next_id: Increments and returns the next global track ID.\n activate: Abstract method to activate the track.\n predict: Abstract method to predict the next state of the track.\n update: Abstract method to update the track with new data.\n mark_lost: Marks the track as lost.\n mark_removed: Marks the track as removed.\n reset_id: Resets the global track ID counter.\n\nExamples:\n Initialize a new track and mark it as lost:\n >>> track = BaseTrack()\n >>> track.mark_lost()\n >>> print(track.state) # Output: 2 (TrackState.Lost)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.OrderedDict",
"typing.Any",
"numpy"
],
"chunk_id": "class_BaseTrack_9a34aaf0"
},
{
"content": "from collections import deque",
"chunk_type": "import",
"name": "deque",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_deque_9cdabf00"
},
{
"content": "from typing import Any, List, Optional",
"chunk_type": "import",
"name": "Any, List, Optional",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, List, Optional_da76344b"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_b77088f8"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_375aabac"
},
{
"content": "from ultralytics.utils.ops import xywh2xyxy",
"chunk_type": "import",
"name": "xywh2xyxy",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_xywh2xyxy_e1e81eaf"
},
{
"content": "from ultralytics.utils.plotting import save_one_box",
"chunk_type": "import",
"name": "save_one_box",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_save_one_box_fa1c4742"
},
{
"content": "from .basetrack import TrackState",
"chunk_type": "import",
"name": "TrackState",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TrackState_fb6e3ae9"
},
{
"content": "from .byte_tracker import BYTETracker, STrack",
"chunk_type": "import",
"name": "BYTETracker, STrack",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BYTETracker, STrack_99380cc4"
},
{
"content": "from .utils import matching",
"chunk_type": "import",
"name": "matching",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_matching_ece3f4dd"
},
{
"content": "from .utils.gmc import GMC",
"chunk_type": "import",
"name": "GMC",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_GMC_35d930cc"
},
{
"content": "from .utils.kalman_filter import KalmanFilterXYWH",
"chunk_type": "import",
"name": "KalmanFilterXYWH",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_KalmanFilterXYWH_23048d5d"
},
{
"content": "class BOTrack(STrack):\n \"\"\"\n An extended version of the STrack class for YOLO, adding object tracking features.\n\n This class extends the STrack class to include additional functionalities for object tracking, such as feature\n smoothing, Kalman filter prediction, and reactivation of tracks.\n\n Attributes:\n shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.\n smooth_feat (np.ndarray): Smoothed feature vector.\n curr_feat (np.ndarray): Current feature vector.\n features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.\n alpha (float): Smoothing factor for the exponential moving average of features.\n mean (np.ndarray): The mean state of the Kalman filter.\n covariance (np.ndarray): The covariance matrix of the Kalman filter.\n\n Methods:\n update_features: Update features vector and smooth it using exponential moving average.\n predict: Predict the mean and covariance using Kalman filter.\n re_activate: Reactivate a track with updated features and optionally new ID.\n update: Update the track with new detection and frame ID.\n tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.\n multi_predict: Predict the mean and covariance of multiple object tracks using shared Kalman filter.\n convert_coords: Convert tlwh bounding box coordinates to xywh format.\n tlwh_to_xywh: Convert bounding box to xywh format `(center x, center y, width, height)`.\n\n Examples:\n Create a BOTrack instance and update its features\n >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))\n >>> bo_track.predict()\n >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))\n >>> bo_track.update(new_track, frame_id=2)\n \"\"\"\n\n shared_kalman = KalmanFilterXYWH()\n\n def __init__(\n self, tlwh: np.ndarray, score: float, cls: int, feat: Optional[np.ndarray] = None, feat_history: int = 50\n ):\n \"\"\"\n Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.\n\n Args:\n tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).\n score (float): Confidence score of the detection.\n cls (int): Class ID of the detected object.\n feat (np.ndarray, optional): Feature vector associated with the detection.\n feat_history (int): Maximum length of the feature history deque.\n\n Examples:\n Initialize a BOTrack object with bounding box, score, class ID, and feature vector\n >>> tlwh = np.array([100, 50, 80, 120])\n >>> score = 0.9\n >>> cls = 1\n >>> feat = np.random.rand(128)\n >>> bo_track = BOTrack(tlwh, score, cls, feat)\n \"\"\"\n super().__init__(tlwh, score, cls)\n\n self.smooth_feat = None\n self.curr_feat = None\n if feat is not None:\n self.update_features(feat)\n self.features = deque([], maxlen=feat_history)\n self.alpha = 0.9\n\n def update_features(self, feat: np.ndarray) -> None:\n \"\"\"Update the feature vector and apply exponential moving average smoothing.\"\"\"\n feat /= np.linalg.norm(feat)\n self.curr_feat = feat\n if self.smooth_feat is None:\n self.smooth_feat = feat\n else:\n self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat\n self.features.append(feat)\n self.smooth_feat /= np.linalg.norm(self.smooth_feat)\n\n def predict(self) -> None:\n \"\"\"Predict the object's future state using the Kalman filter to update its mean and covariance.\"\"\"\n mean_state = self.mean.copy()\n if self.state != TrackState.Tracked:\n mean_state[6] = 0\n mean_state[7] = 0\n\n self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)\n\n def re_activate(self, new_track: \"BOTrack\", frame_id: int, new_id: bool = False) -> None:\n \"\"\"Reactivate a track with updated features and optionally assign a new ID.\"\"\"\n if new_track.curr_feat is not None:\n self.update_features(new_track.curr_feat)\n super().re_activate(new_track, frame_id, new_id)\n\n def update(self, new_track: \"BOTrack\", frame_id: int) -> None:\n \"\"\"Update the track with new detection information and the current frame ID.\"\"\"\n if new_track.curr_feat is not None:\n self.update_features(new_track.curr_feat)\n super().update(new_track, frame_id)\n\n @property\n def tlwh(self) -> np.ndarray:\n \"\"\"Return the current bounding box position in `(top left x, top left y, width, height)` format.\"\"\"\n if self.mean is None:\n return self._tlwh.copy()\n ret = self.mean[:4].copy()\n ret[:2] -= ret[2:] / 2\n return ret\n\n @staticmethod\n def multi_predict(stracks: List[\"BOTrack\"]) -> None:\n \"\"\"Predict the mean and covariance for multiple object tracks using a shared Kalman filter.\"\"\"\n if len(stracks) <= 0:\n return\n multi_mean = np.asarray([st.mean.copy() for st in stracks])\n multi_covariance = np.asarray([st.covariance for st in stracks])\n for i, st in enumerate(stracks):\n if st.state != TrackState.Tracked:\n multi_mean[i][6] = 0\n multi_mean[i][7] = 0\n multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)\n for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):\n stracks[i].mean = mean\n stracks[i].covariance = cov\n\n def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:\n \"\"\"Convert tlwh bounding box coordinates to xywh format.\"\"\"\n return self.tlwh_to_xywh(tlwh)\n\n @staticmethod\n def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:\n \"\"\"Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format.\"\"\"\n ret = np.asarray(tlwh).copy()\n ret[:2] += ret[2:] / 2\n return ret",
"chunk_type": "class",
"name": "BOTrack",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 19,
"end_line": 151,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "An extended version of the STrack class for YOLO, adding object tracking features.\n\nThis class extends the STrack class to include additional functionalities for object tracking, such as feature\nsmoothing, Kalman filter prediction, and reactivation of tracks.\n\nAttributes:\n shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.\n smooth_feat (np.ndarray): Smoothed feature vector.\n curr_feat (np.ndarray): Current feature vector.\n features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.\n alpha (float): Smoothing factor for the exponential moving average of features.\n mean (np.ndarray): The mean state of the Kalman filter.\n covariance (np.ndarray): The covariance matrix of the Kalman filter.\n\nMethods:\n update_features: Update features vector and smooth it using exponential moving average.\n predict: Predict the mean and covariance using Kalman filter.\n re_activate: Reactivate a track with updated features and optionally new ID.\n update: Update the track with new detection and frame ID.\n tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.\n multi_predict: Predict the mean and covariance of multiple object tracks using shared Kalman filter.\n convert_coords: Convert tlwh bounding box coordinates to xywh format.\n tlwh_to_xywh: Convert bounding box to xywh format `(center x, center y, width, height)`.\n\nExamples:\n Create a BOTrack instance and update its features\n >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))\n >>> bo_track.predict()\n >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))\n >>> bo_track.update(new_track, frame_id=2)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.deque",
"typing.Any",
"typing.List",
"typing.Optional",
"numpy",
"torch",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.plotting.save_one_box",
"basetrack.TrackState",
"byte_tracker.BYTETracker",
"byte_tracker.STrack",
"utils.matching",
"utils.gmc.GMC",
"utils.kalman_filter.KalmanFilterXYWH",
"ultralytics.YOLO",
"STrack"
],
"chunk_id": "class_BOTrack_878d3d31"
},
{
"content": "class BOTSORT(BYTETracker):\n \"\"\"\n An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.\n\n Attributes:\n proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.\n appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.\n encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.\n gmc (GMC): An instance of the GMC algorithm for data association.\n args (Any): Parsed command-line arguments containing tracking parameters.\n\n Methods:\n get_kalmanfilter: Return an instance of KalmanFilterXYWH for object tracking.\n init_track: Initialize track with detections, scores, and classes.\n get_dists: Get distances between tracks and detections using IoU and (optionally) ReID.\n multi_predict: Predict and track multiple objects with a YOLO model.\n reset: Reset the BOTSORT tracker to its initial state.\n\n Examples:\n Initialize BOTSORT and process detections\n >>> bot_sort = BOTSORT(args, frame_rate=30)\n >>> bot_sort.init_track(dets, scores, cls, img)\n >>> bot_sort.multi_predict(tracks)\n\n Note:\n The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.\n \"\"\"\n\n def __init__(self, args: Any, frame_rate: int = 30):\n \"\"\"\n Initialize BOTSORT object with ReID module and GMC algorithm.\n\n Args:\n args (Any): Parsed command-line arguments containing tracking parameters.\n frame_rate (int): Frame rate of the video being processed.\n\n Examples:\n Initialize BOTSORT with command-line arguments and a specified frame rate:\n >>> args = parse_args()\n >>> bot_sort = BOTSORT(args, frame_rate=30)\n \"\"\"\n super().__init__(args, frame_rate)\n self.gmc = GMC(method=args.gmc_method)\n\n # ReID module\n self.proximity_thresh = args.proximity_thresh\n self.appearance_thresh = args.appearance_thresh\n self.encoder = (\n (lambda feats, s: [f.cpu().numpy() for f in feats]) # native features do not require any model\n if args.with_reid and self.args.model == \"auto\"\n else ReID(args.model)\n if args.with_reid\n else None\n )\n\n def get_kalmanfilter(self) -> KalmanFilterXYWH:\n \"\"\"Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process.\"\"\"\n return KalmanFilterXYWH()\n\n def init_track(\n self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None\n ) -> List[BOTrack]:\n \"\"\"Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features.\"\"\"\n if len(dets) == 0:\n return []\n if self.args.with_reid and self.encoder is not None:\n features_keep = self.encoder(img, dets)\n return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections\n else:\n return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections\n\n def get_dists(self, tracks: List[BOTrack], detections: List[BOTrack]) -> np.ndarray:\n \"\"\"Calculate distances between tracks and detections using IoU and optionally ReID embeddings.\"\"\"\n dists = matching.iou_distance(tracks, detections)\n dists_mask = dists > (1 - self.proximity_thresh)\n\n if self.args.fuse_score:\n dists = matching.fuse_score(dists, detections)\n\n if self.args.with_reid and self.encoder is not None:\n emb_dists = matching.embedding_distance(tracks, detections) / 2.0\n emb_dists[emb_dists > (1 - self.appearance_thresh)] = 1.0\n emb_dists[dists_mask] = 1.0\n dists = np.minimum(dists, emb_dists)\n return dists\n\n def multi_predict(self, tracks: List[BOTrack]) -> None:\n \"\"\"Predict the mean and covariance of multiple object tracks using a shared Kalman filter.\"\"\"\n BOTrack.multi_predict(tracks)\n\n def reset(self) -> None:\n \"\"\"Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states.\"\"\"\n super().reset()\n self.gmc.reset_params()",
"chunk_type": "class",
"name": "BOTSORT",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 154,
"end_line": 247,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": "An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.\n\nAttributes:\n proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.\n appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.\n encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.\n gmc (GMC): An instance of the GMC algorithm for data association.\n args (Any): Parsed command-line arguments containing tracking parameters.\n\nMethods:\n get_kalmanfilter: Return an instance of KalmanFilterXYWH for object tracking.\n init_track: Initialize track with detections, scores, and classes.\n get_dists: Get distances between tracks and detections using IoU and (optionally) ReID.\n multi_predict: Predict and track multiple objects with a YOLO model.\n reset: Reset the BOTSORT tracker to its initial state.\n\nExamples:\n Initialize BOTSORT and process detections\n >>> bot_sort = BOTSORT(args, frame_rate=30)\n >>> bot_sort.init_track(dets, scores, cls, img)\n >>> bot_sort.multi_predict(tracks)\n\nNote:\n The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.deque",
"typing.Any",
"typing.List",
"typing.Optional",
"numpy",
"torch",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.plotting.save_one_box",
"basetrack.TrackState",
"byte_tracker.BYTETracker",
"byte_tracker.STrack",
"utils.matching",
"utils.gmc.GMC",
"utils.kalman_filter.KalmanFilterXYWH",
"ultralytics.YOLO",
"BYTETracker"
],
"chunk_id": "class_BOTSORT_02df815b"
},
{
"content": "class ReID:\n \"\"\"YOLO model as encoder for re-identification.\"\"\"\n\n def __init__(self, model: str):\n \"\"\"\n Initialize encoder for re-identification.\n\n Args:\n model (str): Path to the YOLO model for re-identification.\n \"\"\"\n from ultralytics import YOLO\n\n self.model = YOLO(model)\n self.model(embed=[len(self.model.model.model) - 2 if \".pt\" in model else -1], verbose=False, save=False) # init\n\n def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]:\n \"\"\"Extract embeddings for detected objects.\"\"\"\n feats = self.model.predictor(\n [save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))]\n )\n if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:\n feats = feats[0] # batched prediction with non-PyTorch backend\n return [f.cpu().numpy() for f in feats]",
"chunk_type": "class",
"name": "ReID",
"file_path": "ultralytics\\ultralytics\\trackers\\bot_sort.py",
"start_line": 250,
"end_line": 272,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "YOLO model as encoder for re-identification.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.deque",
"typing.Any",
"typing.List",
"typing.Optional",
"numpy",
"torch",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.plotting.save_one_box",
"basetrack.TrackState",
"byte_tracker.BYTETracker",
"byte_tracker.STrack",
"utils.matching",
"utils.gmc.GMC",
"utils.kalman_filter.KalmanFilterXYWH",
"ultralytics.YOLO"
],
"chunk_id": "class_ReID_2fea07e5"
},
{
"content": "from typing import Any, List, Optional, Tuple",
"chunk_type": "import",
"name": "Any, List, Optional, Tuple",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, List, Optional, Tuple_eafa25fe"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_97b17d82"
},
{
"content": "from ..utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_10c4c7b7"
},
{
"content": "from ..utils.ops import xywh2ltwh",
"chunk_type": "import",
"name": "xywh2ltwh",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_xywh2ltwh_6391db06"
},
{
"content": "from .basetrack import BaseTrack, TrackState",
"chunk_type": "import",
"name": "BaseTrack, TrackState",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseTrack, TrackState_d0ca7671"
},
{
"content": "from .utils import matching",
"chunk_type": "import",
"name": "matching",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_matching_89619227"
},
{
"content": "from .utils.kalman_filter import KalmanFilterXYAH",
"chunk_type": "import",
"name": "KalmanFilterXYAH",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_KalmanFilterXYAH_887c85f7"
},
{
"content": "class STrack(BaseTrack):\n \"\"\"\n Single object tracking representation that uses Kalman filtering for state estimation.\n\n This class is responsible for storing all the information regarding individual tracklets and performs state updates\n and predictions based on Kalman filter.\n\n Attributes:\n shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction.\n _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.\n kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.\n mean (np.ndarray): Mean state estimate vector.\n covariance (np.ndarray): Covariance of state estimate.\n is_activated (bool): Boolean flag indicating if the track has been activated.\n score (float): Confidence score of the track.\n tracklet_len (int): Length of the tracklet.\n cls (Any): Class label for the object.\n idx (int): Index or identifier for the object.\n frame_id (int): Current frame ID.\n start_frame (int): Frame where the object was first detected.\n angle (float | None): Optional angle information for oriented bounding boxes.\n\n Methods:\n predict: Predict the next state of the object using Kalman filter.\n multi_predict: Predict the next states for multiple tracks.\n multi_gmc: Update multiple track states using a homography matrix.\n activate: Activate a new tracklet.\n re_activate: Reactivate a previously lost tracklet.\n update: Update the state of a matched track.\n convert_coords: Convert bounding box to x-y-aspect-height format.\n tlwh_to_xyah: Convert tlwh bounding box to xyah format.\n\n Examples:\n Initialize and activate a new track\n >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls=\"person\")\n >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)\n \"\"\"\n\n shared_kalman = KalmanFilterXYAH()\n\n def __init__(self, xywh: List[float], score: float, cls: Any):\n \"\"\"\n Initialize a new STrack instance.\n\n Args:\n xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where\n (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.\n score (float): Confidence score of the detection.\n cls (Any): Class label for the detected object.\n\n Examples:\n >>> xywh = [100.0, 150.0, 50.0, 75.0, 1]\n >>> score = 0.9\n >>> cls = \"person\"\n >>> track = STrack(xywh, score, cls)\n \"\"\"\n super().__init__()\n # xywh+idx or xywha+idx\n assert len(xywh) in {5, 6}, f\"expected 5 or 6 values but got {len(xywh)}\"\n self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)\n self.kalman_filter = None\n self.mean, self.covariance = None, None\n self.is_activated = False\n\n self.score = score\n self.tracklet_len = 0\n self.cls = cls\n self.idx = xywh[-1]\n self.angle = xywh[4] if len(xywh) == 6 else None\n\n def predict(self):\n \"\"\"Predict the next state (mean and covariance) of the object using the Kalman filter.\"\"\"\n mean_state = self.mean.copy()\n if self.state != TrackState.Tracked:\n mean_state[7] = 0\n self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)\n\n @staticmethod\n def multi_predict(stracks: List[\"STrack\"]):\n \"\"\"Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances.\"\"\"\n if len(stracks) <= 0:\n return\n multi_mean = np.asarray([st.mean.copy() for st in stracks])\n multi_covariance = np.asarray([st.covariance for st in stracks])\n for i, st in enumerate(stracks):\n if st.state != TrackState.Tracked:\n multi_mean[i][7] = 0\n multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)\n for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):\n stracks[i].mean = mean\n stracks[i].covariance = cov\n\n @staticmethod\n def multi_gmc(stracks: List[\"STrack\"], H: np.ndarray = np.eye(2, 3)):\n \"\"\"Update state tracks positions and covariances using a homography matrix for multiple tracks.\"\"\"\n if len(stracks) > 0:\n multi_mean = np.asarray([st.mean.copy() for st in stracks])\n multi_covariance = np.asarray([st.covariance for st in stracks])\n\n R = H[:2, :2]\n R8x8 = np.kron(np.eye(4, dtype=float), R)\n t = H[:2, 2]\n\n for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):\n mean = R8x8.dot(mean)\n mean[:2] += t\n cov = R8x8.dot(cov).dot(R8x8.transpose())\n\n stracks[i].mean = mean\n stracks[i].covariance = cov\n\n def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):\n \"\"\"Activate a new tracklet using the provided Kalman filter and initialize its state and covariance.\"\"\"\n self.kalman_filter = kalman_filter\n self.track_id = self.next_id()\n self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))\n\n self.tracklet_len = 0\n self.state = TrackState.Tracked\n if frame_id == 1:\n self.is_activated = True\n self.frame_id = frame_id\n self.start_frame = frame_id\n\n def re_activate(self, new_track: \"STrack\", frame_id: int, new_id: bool = False):\n \"\"\"Reactivate a previously lost track using new detection data and update its state and attributes.\"\"\"\n self.mean, self.covariance = self.kalman_filter.update(\n self.mean, self.covariance, self.convert_coords(new_track.tlwh)\n )\n self.tracklet_len = 0\n self.state = TrackState.Tracked\n self.is_activated = True\n self.frame_id = frame_id\n if new_id:\n self.track_id = self.next_id()\n self.score = new_track.score\n self.cls = new_track.cls\n self.angle = new_track.angle\n self.idx = new_track.idx\n\n def update(self, new_track: \"STrack\", frame_id: int):\n \"\"\"\n Update the state of a matched track.\n\n Args:\n new_track (STrack): The new track containing updated information.\n frame_id (int): The ID of the current frame.\n\n Examples:\n Update the state of a track with new detection information\n >>> track = STrack([100, 200, 50, 80, 0.9, 1])\n >>> new_track = STrack([105, 205, 55, 85, 0.95, 1])\n >>> track.update(new_track, 2)\n \"\"\"\n self.frame_id = frame_id\n self.tracklet_len += 1\n\n new_tlwh = new_track.tlwh\n self.mean, self.covariance = self.kalman_filter.update(\n self.mean, self.covariance, self.convert_coords(new_tlwh)\n )\n self.state = TrackState.Tracked\n self.is_activated = True\n\n self.score = new_track.score\n self.cls = new_track.cls\n self.angle = new_track.angle\n self.idx = new_track.idx\n\n def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:\n \"\"\"Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.\"\"\"\n return self.tlwh_to_xyah(tlwh)\n\n @property\n def tlwh(self) -> np.ndarray:\n \"\"\"Get the bounding box in top-left-width-height format from the current state estimate.\"\"\"\n if self.mean is None:\n return self._tlwh.copy()\n ret = self.mean[:4].copy()\n ret[2] *= ret[3]\n ret[:2] -= ret[2:] / 2\n return ret\n\n @property\n def xyxy(self) -> np.ndarray:\n \"\"\"Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format.\"\"\"\n ret = self.tlwh.copy()\n ret[2:] += ret[:2]\n return ret\n\n @staticmethod\n def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:\n \"\"\"Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format.\"\"\"\n ret = np.asarray(tlwh).copy()\n ret[:2] += ret[2:] / 2\n ret[2] /= ret[3]\n return ret\n\n @property\n def xywh(self) -> np.ndarray:\n \"\"\"Get the current position of the bounding box in (center x, center y, width, height) format.\"\"\"\n ret = np.asarray(self.tlwh).copy()\n ret[:2] += ret[2:] / 2\n return ret\n\n @property\n def xywha(self) -> np.ndarray:\n \"\"\"Get position in (center x, center y, width, height, angle) format, warning if angle is missing.\"\"\"\n if self.angle is None:\n LOGGER.warning(\"`angle` attr not found, returning `xywh` instead.\")\n return self.xywh\n return np.concatenate([self.xywh, self.angle[None]])\n\n @property\n def result(self) -> List[float]:\n \"\"\"Get the current tracking results in the appropriate bounding box format.\"\"\"\n coords = self.xyxy if self.angle is None else self.xywha\n return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]\n\n def __repr__(self) -> str:\n \"\"\"Return a string representation of the STrack object including start frame, end frame, and track ID.\"\"\"\n return f\"OT_{self.track_id}_({self.start_frame}-{self.end_frame})\"",
"chunk_type": "class",
"name": "STrack",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 14,
"end_line": 235,
"start_col": 0,
"end_col": 74,
"parent_name": null,
"docstring": "Single object tracking representation that uses Kalman filtering for state estimation.\n\nThis class is responsible for storing all the information regarding individual tracklets and performs state updates\nand predictions based on Kalman filter.\n\nAttributes:\n shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction.\n _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.\n kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.\n mean (np.ndarray): Mean state estimate vector.\n covariance (np.ndarray): Covariance of state estimate.\n is_activated (bool): Boolean flag indicating if the track has been activated.\n score (float): Confidence score of the track.\n tracklet_len (int): Length of the tracklet.\n cls (Any): Class label for the object.\n idx (int): Index or identifier for the object.\n frame_id (int): Current frame ID.\n start_frame (int): Frame where the object was first detected.\n angle (float | None): Optional angle information for oriented bounding boxes.\n\nMethods:\n predict: Predict the next state of the object using Kalman filter.\n multi_predict: Predict the next states for multiple tracks.\n multi_gmc: Update multiple track states using a homography matrix.\n activate: Activate a new tracklet.\n re_activate: Reactivate a previously lost tracklet.\n update: Update the state of a matched track.\n convert_coords: Convert bounding box to x-y-aspect-height format.\n tlwh_to_xyah: Convert tlwh bounding box to xyah format.\n\nExamples:\n Initialize and activate a new track\n >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls=\"person\")\n >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.List",
"typing.Optional",
"typing.Tuple",
"numpy",
"utils.LOGGER",
"utils.ops.xywh2ltwh",
"basetrack.BaseTrack",
"basetrack.TrackState",
"utils.matching",
"utils.kalman_filter.KalmanFilterXYAH",
"BaseTrack"
],
"chunk_id": "class_STrack_55d317ad"
},
{
"content": "class BYTETracker:\n \"\"\"\n BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.\n\n This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a\n video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for\n predicting the new object locations, and performs data association.\n\n Attributes:\n tracked_stracks (List[STrack]): List of successfully activated tracks.\n lost_stracks (List[STrack]): List of lost tracks.\n removed_stracks (List[STrack]): List of removed tracks.\n frame_id (int): The current frame ID.\n args (Namespace): Command-line arguments.\n max_time_lost (int): The maximum frames for a track to be considered as 'lost'.\n kalman_filter (KalmanFilterXYAH): Kalman Filter object.\n\n Methods:\n update: Update object tracker with new detections.\n get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.\n init_track: Initialize object tracking with detections.\n get_dists: Calculate the distance between tracks and detections.\n multi_predict: Predict the location of tracks.\n reset_id: Reset the ID counter of STrack.\n reset: Reset the tracker by clearing all tracks.\n joint_stracks: Combine two lists of stracks.\n sub_stracks: Filter out the stracks present in the second list from the first list.\n remove_duplicate_stracks: Remove duplicate stracks based on IoU.\n\n Examples:\n Initialize BYTETracker and update with detection results\n >>> tracker = BYTETracker(args, frame_rate=30)\n >>> results = yolo_model.detect(image)\n >>> tracked_objects = tracker.update(results)\n \"\"\"\n\n def __init__(self, args, frame_rate: int = 30):\n \"\"\"\n Initialize a BYTETracker instance for object tracking.\n\n Args:\n args (Namespace): Command-line arguments containing tracking parameters.\n frame_rate (int): Frame rate of the video sequence.\n\n Examples:\n Initialize BYTETracker with command-line arguments and a frame rate of 30\n >>> args = Namespace(track_buffer=30)\n >>> tracker = BYTETracker(args, frame_rate=30)\n \"\"\"\n self.tracked_stracks = [] # type: List[STrack]\n self.lost_stracks = [] # type: List[STrack]\n self.removed_stracks = [] # type: List[STrack]\n\n self.frame_id = 0\n self.args = args\n self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer)\n self.kalman_filter = self.get_kalmanfilter()\n self.reset_id()\n\n def update(self, results, img: Optional[np.ndarray] = None, feats: Optional[np.ndarray] = None) -> np.ndarray:\n \"\"\"Update the tracker with new detections and return the current list of tracked objects.\"\"\"\n self.frame_id += 1\n activated_stracks = []\n refind_stracks = []\n lost_stracks = []\n removed_stracks = []\n\n scores = results.conf\n bboxes = results.xywhr if hasattr(results, \"xywhr\") else results.xywh\n # Add index\n bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)\n cls = results.cls\n\n remain_inds = scores >= self.args.track_high_thresh\n inds_low = scores > self.args.track_low_thresh\n inds_high = scores < self.args.track_high_thresh\n\n inds_second = inds_low & inds_high\n dets_second = bboxes[inds_second]\n dets = bboxes[remain_inds]\n scores_keep = scores[remain_inds]\n scores_second = scores[inds_second]\n cls_keep = cls[remain_inds]\n cls_second = cls[inds_second]\n\n detections = self.init_track(dets, scores_keep, cls_keep, img if feats is None else feats)\n # Add newly detected tracklets to tracked_stracks\n unconfirmed = []\n tracked_stracks = [] # type: List[STrack]\n for track in self.tracked_stracks:\n if not track.is_activated:\n unconfirmed.append(track)\n else:\n tracked_stracks.append(track)\n # Step 2: First association, with high score detection boxes\n strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)\n # Predict the current location with KF\n self.multi_predict(strack_pool)\n if hasattr(self, \"gmc\") and img is not None:\n # use try-except here to bypass errors from gmc module\n try:\n warp = self.gmc.apply(img, dets)\n except Exception:\n warp = np.eye(2, 3)\n STrack.multi_gmc(strack_pool, warp)\n STrack.multi_gmc(unconfirmed, warp)\n\n dists = self.get_dists(strack_pool, detections)\n matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)\n\n for itracked, idet in matches:\n track = strack_pool[itracked]\n det = detections[idet]\n if track.state == TrackState.Tracked:\n track.update(det, self.frame_id)\n activated_stracks.append(track)\n else:\n track.re_activate(det, self.frame_id, new_id=False)\n refind_stracks.append(track)\n # Step 3: Second association, with low score detection boxes association the untrack to the low score detections\n detections_second = self.init_track(dets_second, scores_second, cls_second, img if feats is None else feats)\n r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]\n # TODO\n dists = matching.iou_distance(r_tracked_stracks, detections_second)\n matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)\n for itracked, idet in matches:\n track = r_tracked_stracks[itracked]\n det = detections_second[idet]\n if track.state == TrackState.Tracked:\n track.update(det, self.frame_id)\n activated_stracks.append(track)\n else:\n track.re_activate(det, self.frame_id, new_id=False)\n refind_stracks.append(track)\n\n for it in u_track:\n track = r_tracked_stracks[it]\n if track.state != TrackState.Lost:\n track.mark_lost()\n lost_stracks.append(track)\n # Deal with unconfirmed tracks, usually tracks with only one beginning frame\n detections = [detections[i] for i in u_detection]\n dists = self.get_dists(unconfirmed, detections)\n matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)\n for itracked, idet in matches:\n unconfirmed[itracked].update(detections[idet], self.frame_id)\n activated_stracks.append(unconfirmed[itracked])\n for it in u_unconfirmed:\n track = unconfirmed[it]\n track.mark_removed()\n removed_stracks.append(track)\n # Step 4: Init new stracks\n for inew in u_detection:\n track = detections[inew]\n if track.score < self.args.new_track_thresh:\n continue\n track.activate(self.kalman_filter, self.frame_id)\n activated_stracks.append(track)\n # Step 5: Update state\n for track in self.lost_stracks:\n if self.frame_id - track.end_frame > self.max_time_lost:\n track.mark_removed()\n removed_stracks.append(track)\n\n self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]\n self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks)\n self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)\n self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)\n self.lost_stracks.extend(lost_stracks)\n self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)\n self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)\n self.removed_stracks.extend(removed_stracks)\n if len(self.removed_stracks) > 1000:\n self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum\n\n return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)\n\n def get_kalmanfilter(self) -> KalmanFilterXYAH:\n \"\"\"Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH.\"\"\"\n return KalmanFilterXYAH()\n\n def init_track(\n self, dets: np.ndarray, scores: np.ndarray, cls: np.ndarray, img: Optional[np.ndarray] = None\n ) -> List[STrack]:\n \"\"\"Initialize object tracking with given detections, scores, and class labels using the STrack algorithm.\"\"\"\n return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections\n\n def get_dists(self, tracks: List[STrack], detections: List[STrack]) -> np.ndarray:\n \"\"\"Calculate the distance between tracks and detections using IoU and optionally fuse scores.\"\"\"\n dists = matching.iou_distance(tracks, detections)\n if self.args.fuse_score:\n dists = matching.fuse_score(dists, detections)\n return dists\n\n def multi_predict(self, tracks: List[STrack]):\n \"\"\"Predict the next states for multiple tracks using Kalman filter.\"\"\"\n STrack.multi_predict(tracks)\n\n @staticmethod\n def reset_id():\n \"\"\"Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions.\"\"\"\n STrack.reset_id()\n\n def reset(self):\n \"\"\"Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter.\"\"\"\n self.tracked_stracks = [] # type: List[STrack]\n self.lost_stracks = [] # type: List[STrack]\n self.removed_stracks = [] # type: List[STrack]\n self.frame_id = 0\n self.kalman_filter = self.get_kalmanfilter()\n self.reset_id()\n\n @staticmethod\n def joint_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:\n \"\"\"Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs.\"\"\"\n exists = {}\n res = []\n for t in tlista:\n exists[t.track_id] = 1\n res.append(t)\n for t in tlistb:\n tid = t.track_id\n if not exists.get(tid, 0):\n exists[tid] = 1\n res.append(t)\n return res\n\n @staticmethod\n def sub_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]:\n \"\"\"Filter out the stracks present in the second list from the first list.\"\"\"\n track_ids_b = {t.track_id for t in tlistb}\n return [t for t in tlista if t.track_id not in track_ids_b]\n\n @staticmethod\n def remove_duplicate_stracks(stracksa: List[STrack], stracksb: List[STrack]) -> Tuple[List[STrack], List[STrack]]:\n \"\"\"Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance.\"\"\"\n pdist = matching.iou_distance(stracksa, stracksb)\n pairs = np.where(pdist < 0.15)\n dupa, dupb = [], []\n for p, q in zip(*pairs):\n timep = stracksa[p].frame_id - stracksa[p].start_frame\n timeq = stracksb[q].frame_id - stracksb[q].start_frame\n if timep > timeq:\n dupb.append(q)\n else:\n dupa.append(p)\n resa = [t for i, t in enumerate(stracksa) if i not in dupa]\n resb = [t for i, t in enumerate(stracksb) if i not in dupb]\n return resa, resb",
"chunk_type": "class",
"name": "BYTETracker",
"file_path": "ultralytics\\ultralytics\\trackers\\byte_tracker.py",
"start_line": 238,
"end_line": 486,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.\n\nThis class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a\nvideo sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for\npredicting the new object locations, and performs data association.\n\nAttributes:\n tracked_stracks (List[STrack]): List of successfully activated tracks.\n lost_stracks (List[STrack]): List of lost tracks.\n removed_stracks (List[STrack]): List of removed tracks.\n frame_id (int): The current frame ID.\n args (Namespace): Command-line arguments.\n max_time_lost (int): The maximum frames for a track to be considered as 'lost'.\n kalman_filter (KalmanFilterXYAH): Kalman Filter object.\n\nMethods:\n update: Update object tracker with new detections.\n get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.\n init_track: Initialize object tracking with detections.\n get_dists: Calculate the distance between tracks and detections.\n multi_predict: Predict the location of tracks.\n reset_id: Reset the ID counter of STrack.\n reset: Reset the tracker by clearing all tracks.\n joint_stracks: Combine two lists of stracks.\n sub_stracks: Filter out the stracks present in the second list from the first list.\n remove_duplicate_stracks: Remove duplicate stracks based on IoU.\n\nExamples:\n Initialize BYTETracker and update with detection results\n >>> tracker = BYTETracker(args, frame_rate=30)\n >>> results = yolo_model.detect(image)\n >>> tracked_objects = tracker.update(results)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.List",
"typing.Optional",
"typing.Tuple",
"numpy",
"utils.LOGGER",
"utils.ops.xywh2ltwh",
"basetrack.BaseTrack",
"basetrack.TrackState",
"utils.matching",
"utils.kalman_filter.KalmanFilterXYAH"
],
"chunk_id": "class_BYTETracker_2c47ebb5"
},
{
"content": "from functools import partial",
"chunk_type": "import",
"name": "partial",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_partial_e3539b4b"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_1d7eecc6"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_65d492ab"
},
{
"content": "from ultralytics.utils import YAML, IterableSimpleNamespace",
"chunk_type": "import",
"name": "YAML, IterableSimpleNamespace",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YAML, IterableSimpleNamespace_cd5d2008"
},
{
"content": "from ultralytics.utils.checks import check_yaml",
"chunk_type": "import",
"name": "check_yaml",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_yaml_b0cae41a"
},
{
"content": "from .bot_sort import BOTSORT",
"chunk_type": "import",
"name": "BOTSORT",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BOTSORT_62cff40a"
},
{
"content": "from .byte_tracker import BYTETracker",
"chunk_type": "import",
"name": "BYTETracker",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BYTETracker_11b96f56"
},
{
"content": "TRACKER_MAP = {\"bytetrack\": BYTETracker, \"botsort\": BOTSORT}",
"chunk_type": "variable",
"name": "TRACKER_MAP",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 60,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TRACKER_MAP_31ce8c23"
},
{
"content": "def on_predict_start(predictor: object, persist: bool = False) -> None:\n \"\"\"\n Initialize trackers for object tracking during prediction.\n\n Args:\n predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.\n persist (bool, optional): Whether to persist the trackers if they already exist.\n\n Examples:\n Initialize trackers for a predictor object\n >>> predictor = SomePredictorClass()\n >>> on_predict_start(predictor, persist=True)\n \"\"\"\n if predictor.args.task == \"classify\":\n raise ValueError(\"❌ Classification doesn't support 'mode=track'\")\n\n if hasattr(predictor, \"trackers\") and persist:\n return\n\n tracker = check_yaml(predictor.args.tracker)\n cfg = IterableSimpleNamespace(**YAML.load(tracker))\n\n if cfg.tracker_type not in {\"bytetrack\", \"botsort\"}:\n raise AssertionError(f\"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'\")\n\n predictor._feats = None # reset in case used earlier\n if hasattr(predictor, \"_hook\"):\n predictor._hook.remove()\n if cfg.tracker_type == \"botsort\" and cfg.with_reid and cfg.model == \"auto\":\n from ultralytics.nn.modules.head import Detect\n\n if not (\n isinstance(predictor.model.model, torch.nn.Module)\n and isinstance(predictor.model.model.model[-1], Detect)\n and not predictor.model.model.model[-1].end2end\n ):\n cfg.model = \"yolo11n-cls.pt\"\n else:\n # Register hook to extract input of Detect layer\n def pre_hook(module, input):\n predictor._feats = list(input[0]) # unroll to new list to avoid mutation in forward\n\n predictor._hook = predictor.model.model.model[-1].register_forward_pre_hook(pre_hook)\n\n trackers = []\n for _ in range(predictor.dataset.bs):\n tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)\n trackers.append(tracker)\n if predictor.dataset.mode != \"stream\": # only need one tracker for other modes\n break\n predictor.trackers = trackers\n predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video",
"chunk_type": "function",
"name": "on_predict_start",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 18,
"end_line": 69,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": "Initialize trackers for object tracking during prediction.\n\nArgs:\n predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.\n persist (bool, optional): Whether to persist the trackers if they already exist.\n\nExamples:\n Initialize trackers for a predictor object\n >>> predictor = SomePredictorClass()\n >>> on_predict_start(predictor, persist=True)",
"parameters": [
"predictor: object",
"persist: bool"
],
"return_type": "None",
"decorators": [],
"complexity_score": 9,
"dependencies": [
"functools.partial",
"pathlib.Path",
"torch",
"ultralytics.utils.YAML",
"ultralytics.utils.IterableSimpleNamespace",
"ultralytics.utils.checks.check_yaml",
"bot_sort.BOTSORT",
"byte_tracker.BYTETracker",
"ultralytics.nn.modules.head.Detect"
],
"chunk_id": "function_on_predict_start_b9207017"
},
{
"content": "def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:\n \"\"\"\n Postprocess detected boxes and update with object tracking.\n\n Args:\n predictor (object): The predictor object containing the predictions.\n persist (bool, optional): Whether to persist the trackers if they already exist.\n\n Examples:\n Postprocess predictions and update with tracking\n >>> predictor = YourPredictorClass()\n >>> on_predict_postprocess_end(predictor, persist=True)\n \"\"\"\n is_obb = predictor.args.task == \"obb\"\n is_stream = predictor.dataset.mode == \"stream\"\n for i, result in enumerate(predictor.results):\n tracker = predictor.trackers[i if is_stream else 0]\n vid_path = predictor.save_dir / Path(result.path).name\n if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:\n tracker.reset()\n predictor.vid_path[i if is_stream else 0] = vid_path\n\n det = (result.obb if is_obb else result.boxes).cpu().numpy()\n tracks = tracker.update(det, result.orig_img, getattr(result, \"feats\", None))\n if len(tracks) == 0:\n continue\n idx = tracks[:, -1].astype(int)\n predictor.results[i] = result[idx]\n\n update_args = {\"obb\" if is_obb else \"boxes\": torch.as_tensor(tracks[:, :-1])}\n predictor.results[i].update(**update_args)",
"chunk_type": "function",
"name": "on_predict_postprocess_end",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 72,
"end_line": 102,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": "Postprocess detected boxes and update with object tracking.\n\nArgs:\n predictor (object): The predictor object containing the predictions.\n persist (bool, optional): Whether to persist the trackers if they already exist.\n\nExamples:\n Postprocess predictions and update with tracking\n >>> predictor = YourPredictorClass()\n >>> on_predict_postprocess_end(predictor, persist=True)",
"parameters": [
"predictor: object",
"persist: bool"
],
"return_type": "None",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"functools.partial",
"pathlib.Path",
"torch",
"ultralytics.utils.YAML",
"ultralytics.utils.IterableSimpleNamespace",
"ultralytics.utils.checks.check_yaml",
"bot_sort.BOTSORT",
"byte_tracker.BYTETracker",
"ultralytics.nn.modules.head.Detect"
],
"chunk_id": "function_on_predict_postprocess_end_6cfbae9e"
},
{
"content": "def register_tracker(model: object, persist: bool) -> None:\n \"\"\"\n Register tracking callbacks to the model for object tracking during prediction.\n\n Args:\n model (object): The model object to register tracking callbacks for.\n persist (bool): Whether to persist the trackers if they already exist.\n\n Examples:\n Register tracking callbacks to a YOLO model\n >>> model = YOLOModel()\n >>> register_tracker(model, persist=True)\n \"\"\"\n model.add_callback(\"on_predict_start\", partial(on_predict_start, persist=persist))\n model.add_callback(\"on_predict_postprocess_end\", partial(on_predict_postprocess_end, persist=persist))",
"chunk_type": "function",
"name": "register_tracker",
"file_path": "ultralytics\\ultralytics\\trackers\\track.py",
"start_line": 105,
"end_line": 119,
"start_col": 0,
"end_col": 106,
"parent_name": null,
"docstring": "Register tracking callbacks to the model for object tracking during prediction.\n\nArgs:\n model (object): The model object to register tracking callbacks for.\n persist (bool): Whether to persist the trackers if they already exist.\n\nExamples:\n Register tracking callbacks to a YOLO model\n >>> model = YOLOModel()\n >>> register_tracker(model, persist=True)",
"parameters": [
"model: object",
"persist: bool"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"pathlib.Path",
"torch",
"ultralytics.utils.YAML",
"ultralytics.utils.IterableSimpleNamespace",
"ultralytics.utils.checks.check_yaml",
"bot_sort.BOTSORT",
"byte_tracker.BYTETracker",
"ultralytics.nn.modules.head.Detect"
],
"chunk_id": "function_register_tracker_a54d453d"
},
{
"content": "from .bot_sort import BOTSORT",
"chunk_type": "import",
"name": "BOTSORT",
"file_path": "ultralytics\\ultralytics\\trackers\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BOTSORT_cca8eb05"
},
{
"content": "from .byte_tracker import BYTETracker",
"chunk_type": "import",
"name": "BYTETracker",
"file_path": "ultralytics\\ultralytics\\trackers\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BYTETracker_15e97245"
},
{
"content": "from .track import register_tracker",
"chunk_type": "import",
"name": "register_tracker",
"file_path": "ultralytics\\ultralytics\\trackers\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_register_tracker_4ebabca9"
},
{
"content": "__all__ = \"register_tracker\", \"BOTSORT\", \"BYTETracker\" # allow simpler import",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\trackers\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___6e5980f6"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_79cd52cc"
},
{
"content": "from copy import deepcopy",
"chunk_type": "import",
"name": "deepcopy",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_deepcopy_a261682b"
},
{
"content": "from typing import Union",
"chunk_type": "import",
"name": "Union",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Union_a82fbc39"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_17b39cfb"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_cc39b187"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr",
"chunk_type": "import",
"name": "DEFAULT_CFG, LOGGER, colorstr",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, LOGGER, colorstr_765852ba"
},
{
"content": "from ultralytics.utils.torch_utils import autocast, profile_ops",
"chunk_type": "import",
"name": "autocast, profile_ops",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_autocast, profile_ops_8fe6c7a2"
},
{
"content": "def check_train_batch_size(\n model: torch.nn.Module,\n imgsz: int = 640,\n amp: bool = True,\n batch: Union[int, float] = -1,\n max_num_obj: int = 1,\n) -> int:\n \"\"\"\n Compute optimal YOLO training batch size using the autobatch() function.\n\n Args:\n model (torch.nn.Module): YOLO model to check batch size for.\n imgsz (int, optional): Image size used for training.\n amp (bool, optional): Use automatic mixed precision if True.\n batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.\n max_num_obj (int, optional): The maximum number of objects from dataset.\n\n Returns:\n (int): Optimal batch size computed using the autobatch() function.\n\n Notes:\n If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.\n Otherwise, a default fraction of 0.6 is used.\n \"\"\"\n with autocast(enabled=amp):\n return autobatch(\n deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj\n )",
"chunk_type": "function",
"name": "check_train_batch_size",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 15,
"end_line": 42,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Compute optimal YOLO training batch size using the autobatch() function.\n\nArgs:\n model (torch.nn.Module): YOLO model to check batch size for.\n imgsz (int, optional): Image size used for training.\n amp (bool, optional): Use automatic mixed precision if True.\n batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.\n max_num_obj (int, optional): The maximum number of objects from dataset.\n\nReturns:\n (int): Optimal batch size computed using the autobatch() function.\n\nNotes:\n If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.\n Otherwise, a default fraction of 0.6 is used.",
"parameters": [
"model: torch.nn.Module",
"imgsz: int",
"amp: bool",
"batch: Union[int, float]",
"max_num_obj: int"
],
"return_type": "int",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"os",
"copy.deepcopy",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.colorstr",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.profile_ops"
],
"chunk_id": "function_check_train_batch_size_ae6db0ae"
},
{
"content": "def autobatch(\n model: torch.nn.Module,\n imgsz: int = 640,\n fraction: float = 0.60,\n batch_size: int = DEFAULT_CFG.batch,\n max_num_obj: int = 1,\n) -> int:\n \"\"\"\n Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.\n\n Args:\n model (torch.nn.Module): YOLO model to compute batch size for.\n imgsz (int, optional): The image size used as input for the YOLO model.\n fraction (float, optional): The fraction of available CUDA memory to use.\n batch_size (int, optional): The default batch size to use if an error is detected.\n max_num_obj (int, optional): The maximum number of objects from dataset.\n\n Returns:\n (int): The optimal batch size.\n \"\"\"\n # Check device\n prefix = colorstr(\"AutoBatch: \")\n LOGGER.info(f\"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.\")\n device = next(model.parameters()).device # get model device\n if device.type in {\"cpu\", \"mps\"}:\n LOGGER.warning(f\"{prefix}intended for CUDA devices, using default batch-size {batch_size}\")\n return batch_size\n if torch.backends.cudnn.benchmark:\n LOGGER.warning(f\"{prefix}Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}\")\n return batch_size\n\n # Inspect CUDA memory\n gb = 1 << 30 # bytes to GiB (1024 ** 3)\n d = f\"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}\" # 'CUDA:0'\n properties = torch.cuda.get_device_properties(device) # device properties\n t = properties.total_memory / gb # GiB total\n r = torch.cuda.memory_reserved(device) / gb # GiB reserved\n a = torch.cuda.memory_allocated(device) / gb # GiB allocated\n f = t - (r + a) # GiB free\n LOGGER.info(f\"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free\")\n\n # Profile batch sizes\n batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]\n try:\n img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]\n results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)\n\n # Fit a solution\n xy = [\n [x, y[2]]\n for i, (x, y) in enumerate(zip(batch_sizes, results))\n if y # valid result\n and isinstance(y[2], (int, float)) # is numeric\n and 0 < y[2] < t # between 0 and GPU limit\n and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory\n ]\n fit_x, fit_y = zip(*xy) if xy else ([], [])\n p = np.polyfit(fit_x, fit_y, deg=1) # first-degree polynomial fit in log space\n b = int((round(f * fraction) - p[1]) / p[0]) # y intercept (optimal batch size)\n if None in results: # some sizes failed\n i = results.index(None) # first fail index\n if b >= batch_sizes[i]: # y intercept above failure point\n b = batch_sizes[max(i - 1, 0)] # select prior safe point\n if b < 1 or b > 1024: # b outside of safe range\n LOGGER.warning(f\"{prefix}batch={b} outside safe range, using default batch-size {batch_size}.\")\n b = batch_size\n\n fraction = (np.polyval(p, b) + r + a) / t # predicted fraction\n LOGGER.info(f\"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅\")\n return b\n except Exception as e:\n LOGGER.warning(f\"{prefix}error detected: {e}, using default batch-size {batch_size}.\")\n return batch_size\n finally:\n torch.cuda.empty_cache()",
"chunk_type": "function",
"name": "autobatch",
"file_path": "ultralytics\\ultralytics\\utils\\autobatch.py",
"start_line": 45,
"end_line": 119,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": "Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.\n\nArgs:\n model (torch.nn.Module): YOLO model to compute batch size for.\n imgsz (int, optional): The image size used as input for the YOLO model.\n fraction (float, optional): The fraction of available CUDA memory to use.\n batch_size (int, optional): The default batch size to use if an error is detected.\n max_num_obj (int, optional): The maximum number of objects from dataset.\n\nReturns:\n (int): The optimal batch size.",
"parameters": [
"model: torch.nn.Module",
"imgsz: int",
"fraction: float",
"batch_size: int",
"max_num_obj: int"
],
"return_type": "int",
"decorators": [],
"complexity_score": 9,
"dependencies": [
"os",
"copy.deepcopy",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.colorstr",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.profile_ops"
],
"chunk_id": "function_autobatch_582df94b"
},
{
"content": "from typing import Any, Dict, List, Optional",
"chunk_type": "import",
"name": "Any, Dict, List, Optional",
"file_path": "ultralytics\\ultralytics\\utils\\autodevice.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional_805eb627"
},
{
"content": "from ultralytics.utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\utils\\autodevice.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_6ae07218"
},
{
"content": "from ultralytics.utils.checks import check_requirements",
"chunk_type": "import",
"name": "check_requirements",
"file_path": "ultralytics\\ultralytics\\utils\\autodevice.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_requirements_24e0b7a5"
},
{
"content": "class GPUInfo:\n \"\"\"\n Manages NVIDIA GPU information via pynvml with robust error handling.\n\n Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle\n GPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml\n library by logging warnings and disabling related features, preventing application crashes.\n\n Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU\n selection. Manages NVML initialization and shutdown internally.\n\n Attributes:\n pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.\n nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,\n False otherwise.\n gpu_stats (List[Dict[str, Any]]): A list of dictionaries, each holding stats for one GPU. Populated on\n initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%),\n 'memory_used' (MiB), 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),\n 'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.\n\n Methods:\n refresh_stats: Refresh the internal gpu_stats list by querying NVML.\n print_status: Print GPU status in a compact table format using current stats.\n select_idle_gpu: Select the most idle GPUs based on utilization and free memory.\n shutdown: Shut down NVML if it was initialized.\n\n Examples:\n Initialize GPUInfo and print status\n >>> gpu_info = GPUInfo()\n >>> gpu_info.print_status()\n\n Select idle GPUs with minimum memory requirements\n >>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)\n >>> print(f\"Selected GPU indices: {selected}\")\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize GPUInfo, attempting to import and initialize pynvml.\"\"\"\n self.pynvml: Optional[Any] = None\n self.nvml_available: bool = False\n self.gpu_stats: List[Dict[str, Any]] = []\n\n try:\n check_requirements(\"pynvml>=12.0.0\")\n self.pynvml = __import__(\"pynvml\")\n self.pynvml.nvmlInit()\n self.nvml_available = True\n self.refresh_stats()\n except Exception as e:\n LOGGER.warning(f\"Failed to initialize pynvml, GPU stats disabled: {e}\")\n\n def __del__(self):\n \"\"\"Ensure NVML is shut down when the object is garbage collected.\"\"\"\n self.shutdown()\n\n def shutdown(self):\n \"\"\"Shut down NVML if it was initialized.\"\"\"\n if self.nvml_available and self.pynvml:\n try:\n self.pynvml.nvmlShutdown()\n except Exception:\n pass\n self.nvml_available = False\n\n def refresh_stats(self):\n \"\"\"Refresh the internal gpu_stats list by querying NVML.\"\"\"\n self.gpu_stats = []\n if not self.nvml_available or not self.pynvml:\n return\n\n try:\n device_count = self.pynvml.nvmlDeviceGetCount()\n for i in range(device_count):\n self.gpu_stats.append(self._get_device_stats(i))\n except Exception as e:\n LOGGER.warning(f\"Error during device query: {e}\")\n self.gpu_stats = []\n\n def _get_device_stats(self, index: int) -> Dict[str, Any]:\n \"\"\"Get stats for a single GPU device.\"\"\"\n handle = self.pynvml.nvmlDeviceGetHandleByIndex(index)\n memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)\n util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)\n\n def safe_get(func, *args, default=-1, divisor=1):\n try:\n val = func(*args)\n return val // divisor if divisor != 1 and isinstance(val, (int, float)) else val\n except Exception:\n return default\n\n temp_type = getattr(self.pynvml, \"NVML_TEMPERATURE_GPU\", -1)\n\n return {\n \"index\": index,\n \"name\": self.pynvml.nvmlDeviceGetName(handle),\n \"utilization\": util.gpu if util else -1,\n \"memory_used\": memory.used >> 20 if memory else -1, # Convert bytes to MiB\n \"memory_total\": memory.total >> 20 if memory else -1,\n \"memory_free\": memory.free >> 20 if memory else -1,\n \"temperature\": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type),\n \"power_draw\": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000), # Convert mW to W\n \"power_limit\": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000),\n }\n\n def print_status(self):\n \"\"\"Print GPU status in a compact table format using current stats.\"\"\"\n self.refresh_stats()\n if not self.gpu_stats:\n LOGGER.warning(\"No GPU stats available.\")\n return\n\n stats = self.gpu_stats\n name_len = max(len(gpu.get(\"name\", \"N/A\")) for gpu in stats)\n hdr = f\"{'Idx':<3} {'Name':<{name_len}} {'Util':>6} {'Mem (MiB)':>15} {'Temp':>5} {'Pwr (W)':>10}\"\n LOGGER.info(f\"\\n--- GPU Status ---\\n{hdr}\\n{'-' * len(hdr)}\")\n\n for gpu in stats:\n u = f\"{gpu['utilization']:>5}%\" if gpu[\"utilization\"] >= 0 else \" N/A \"\n m = f\"{gpu['memory_used']:>6}/{gpu['memory_total']:<6}\" if gpu[\"memory_used\"] >= 0 else \" N/A / N/A \"\n t = f\"{gpu['temperature']}C\" if gpu[\"temperature\"] >= 0 else \" N/A \"\n p = f\"{gpu['power_draw']:>3}/{gpu['power_limit']:<3}\" if gpu[\"power_draw\"] >= 0 else \" N/A \"\n\n LOGGER.info(f\"{gpu.get('index'):<3d} {gpu.get('name', 'N/A'):<{name_len}} {u:>6} {m:>15} {t:>5} {p:>10}\")\n\n LOGGER.info(f\"{'-' * len(hdr)}\\n\")\n\n def select_idle_gpu(\n self, count: int = 1, min_memory_fraction: float = 0, min_util_fraction: float = 0\n ) -> List[int]:\n \"\"\"\n Select the most idle GPUs based on utilization and free memory.\n\n Args:\n count (int): The number of idle GPUs to select.\n min_memory_fraction (float): Minimum free memory required as a fraction of total memory.\n min_util_fraction (float): Minimum free utilization rate required from 0.0 - 1.0.\n\n Returns:\n (List[int]): Indices of the selected GPUs, sorted by idleness (lowest utilization first).\n\n Notes:\n Returns fewer than 'count' if not enough qualify or exist.\n Returns basic CUDA indices if NVML fails. Empty list if no GPUs found.\n \"\"\"\n assert min_memory_fraction <= 1.0, f\"min_memory_fraction must be <= 1.0, got {min_memory_fraction}\"\n assert min_util_fraction <= 1.0, f\"min_util_fraction must be <= 1.0, got {min_util_fraction}\"\n LOGGER.info(\n f\"Searching for {count} idle GPUs with free memory >= {min_memory_fraction * 100:.1f}% and free utilization >= {min_util_fraction * 100:.1f}%...\"\n )\n\n if count <= 0:\n return []\n\n self.refresh_stats()\n if not self.gpu_stats:\n LOGGER.warning(\"NVML stats unavailable.\")\n return []\n\n # Filter and sort eligible GPUs\n eligible_gpus = [\n gpu\n for gpu in self.gpu_stats\n if gpu.get(\"memory_free\", 0) / gpu.get(\"memory_total\", 1) >= min_memory_fraction\n and (100 - gpu.get(\"utilization\", 100)) >= min_util_fraction * 100\n ]\n eligible_gpus.sort(key=lambda x: (x.get(\"utilization\", 101), -x.get(\"memory_free\", 0)))\n\n # Select top 'count' indices\n selected = [gpu[\"index\"] for gpu in eligible_gpus[:count]]\n\n if selected:\n LOGGER.info(f\"Selected idle CUDA devices {selected}\")\n else:\n LOGGER.warning(\n f\"No GPUs met criteria (Free Mem >= {min_memory_fraction * 100:.1f}% and Free Util >= {min_util_fraction * 100:.1f}%).\"\n )\n\n return selected",
"chunk_type": "class",
"name": "GPUInfo",
"file_path": "ultralytics\\ultralytics\\utils\\autodevice.py",
"start_line": 9,
"end_line": 187,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Manages NVIDIA GPU information via pynvml with robust error handling.\n\nProvides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle\nGPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml\nlibrary by logging warnings and disabling related features, preventing application crashes.\n\nIncludes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU\nselection. Manages NVML initialization and shutdown internally.\n\nAttributes:\n pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.\n nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,\n False otherwise.\n gpu_stats (List[Dict[str, Any]]): A list of dictionaries, each holding stats for one GPU. Populated on\n initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%),\n 'memory_used' (MiB), 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),\n 'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.\n\nMethods:\n refresh_stats: Refresh the internal gpu_stats list by querying NVML.\n print_status: Print GPU status in a compact table format using current stats.\n select_idle_gpu: Select the most idle GPUs based on utilization and free memory.\n shutdown: Shut down NVML if it was initialized.\n\nExamples:\n Initialize GPUInfo and print status\n >>> gpu_info = GPUInfo()\n >>> gpu_info.print_status()\n\n Select idle GPUs with minimum memory requirements\n >>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)\n >>> print(f\"Selected GPU indices: {selected}\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.checks.check_requirements"
],
"chunk_id": "class_GPUInfo_e0b9c1fd"
},
{
"content": "import glob",
"chunk_type": "import",
"name": "glob",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 30,
"end_line": 30,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_glob_daf13c20"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 31,
"end_line": 31,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_f2fede17"
},
{
"content": "import platform",
"chunk_type": "import",
"name": "platform",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 32,
"end_line": 32,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_platform_a51b6978"
},
{
"content": "import re",
"chunk_type": "import",
"name": "re",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 33,
"end_line": 33,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_re_7f62a565"
},
{
"content": "import shutil",
"chunk_type": "import",
"name": "shutil",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 34,
"end_line": 34,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_shutil_0ce18268"
},
{
"content": "import time",
"chunk_type": "import",
"name": "time",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 35,
"end_line": 35,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_time_7ca1fa09"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 36,
"end_line": 36,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_404f7d8b"
},
{
"content": "from typing import List, Optional, Tuple, Union",
"chunk_type": "import",
"name": "List, Optional, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 37,
"end_line": 37,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional, Tuple, Union_379ccf47"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 39,
"end_line": 39,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_0143459b"
},
{
"content": "import torch.cuda",
"chunk_type": "import",
"name": "torch.cuda",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 40,
"end_line": 40,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.cuda_8e6ff85d"
},
{
"content": "from ultralytics import YOLO, YOLOWorld",
"chunk_type": "import",
"name": "YOLO, YOLOWorld",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 42,
"end_line": 42,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLO, YOLOWorld_48830e91"
},
{
"content": "from ultralytics.cfg import TASK2DATA, TASK2METRIC",
"chunk_type": "import",
"name": "TASK2DATA, TASK2METRIC",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 43,
"end_line": 43,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TASK2DATA, TASK2METRIC_5a2739ff"
},
{
"content": "from ultralytics.engine.exporter import export_formats",
"chunk_type": "import",
"name": "export_formats",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 44,
"end_line": 44,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_export_formats_39ba1379"
},
{
"content": "from ultralytics.utils import ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML",
"chunk_type": "import",
"name": "ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 45,
"end_line": 45,
"start_col": 0,
"end_col": 101,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML_36d5d8bd"
},
{
"content": "from ultralytics.utils.checks import IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip",
"chunk_type": "import",
"name": "IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 46,
"end_line": 46,
"start_col": 0,
"end_col": 109,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip_2ed1aae2"
},
{
"content": "from ultralytics.utils.downloads import safe_download",
"chunk_type": "import",
"name": "safe_download",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 47,
"end_line": 47,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_safe_download_15b261a3"
},
{
"content": "from ultralytics.utils.files import file_size",
"chunk_type": "import",
"name": "file_size",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 48,
"end_line": 48,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_file_size_aae2992d"
},
{
"content": "from ultralytics.utils.torch_utils import get_cpu_info, select_device",
"chunk_type": "import",
"name": "get_cpu_info, select_device",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 49,
"end_line": 49,
"start_col": 0,
"end_col": 69,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_get_cpu_info, select_device_e4e5353e"
},
{
"content": "def benchmark(\n model=WEIGHTS_DIR / \"yolo11n.pt\",\n data=None,\n imgsz=160,\n half=False,\n int8=False,\n device=\"cpu\",\n verbose=False,\n eps=1e-3,\n format=\"\",\n **kwargs,\n):\n \"\"\"\n Benchmark a YOLO model across different formats for speed and accuracy.\n\n Args:\n model (str | Path): Path to the model file or directory.\n data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed.\n imgsz (int): Image size for the benchmark.\n half (bool): Use half-precision for the model if True.\n int8 (bool): Use int8-precision for the model if True.\n device (str): Device to run the benchmark on, either 'cpu' or 'cuda'.\n verbose (bool | float): If True or a float, assert benchmarks pass with given metric.\n eps (float): Epsilon value for divide by zero prevention.\n format (str): Export format for benchmarking. If not supplied all formats are benchmarked.\n **kwargs (Any): Additional keyword arguments for exporter.\n\n Returns:\n (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric,\n and inference time.\n\n Examples:\n Benchmark a YOLO model with default settings:\n >>> from ultralytics.utils.benchmarks import benchmark\n >>> benchmark(model=\"yolo11n.pt\", imgsz=640)\n \"\"\"\n imgsz = check_imgsz(imgsz)\n assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, \"benchmark() only supports square imgsz.\"\n\n import pandas as pd # scope for faster 'import ultralytics'\n\n pd.options.display.max_columns = 10\n pd.options.display.width = 120\n device = select_device(device, verbose=False)\n if isinstance(model, (str, Path)):\n model = YOLO(model)\n is_end2end = getattr(model.model.model[-1], \"end2end\", False)\n data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect\n key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect\n\n y = []\n t0 = time.time()\n\n format_arg = format.lower()\n if format_arg:\n formats = frozenset(export_formats()[\"Argument\"])\n assert format in formats, f\"Expected format to be one of {formats}, but got '{format_arg}'.\"\n for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):\n emoji, filename = \"❌\", None # export defaults\n try:\n if format_arg and format_arg != format:\n continue\n\n # Checks\n if format == \"pb\":\n assert model.task != \"obb\", \"TensorFlow GraphDef not supported for OBB task\"\n elif format == \"edgetpu\":\n assert LINUX and not ARM64, \"Edge TPU export only supported on non-aarch64 Linux\"\n elif format in {\"coreml\", \"tfjs\"}:\n assert MACOS or (LINUX and not ARM64), (\n \"CoreML and TF.js export only supported on macOS and non-aarch64 Linux\"\n )\n if format == \"coreml\":\n assert not IS_PYTHON_3_13, \"CoreML not supported on Python 3.13\"\n if format in {\"saved_model\", \"pb\", \"tflite\", \"edgetpu\", \"tfjs\"}:\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet\"\n # assert not IS_PYTHON_MINIMUM_3_12, \"TFLite exports not supported on Python>=3.12 yet\"\n if format == \"paddle\":\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 Paddle exports not supported yet\"\n assert model.task != \"obb\", \"Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024\"\n assert not is_end2end, \"End-to-end models not supported by PaddlePaddle yet\"\n assert (LINUX and not IS_JETSON) or MACOS, \"Windows and Jetson Paddle exports not supported yet\"\n if format == \"mnn\":\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 MNN exports not supported yet\"\n if format == \"ncnn\":\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 NCNN exports not supported yet\"\n if format == \"imx\":\n assert not is_end2end\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 IMX exports not supported\"\n assert model.task == \"detect\", \"IMX only supported for detection task\"\n assert \"C2f\" in model.__str__(), \"IMX only supported for YOLOv8\" # TODO: enable for YOLO11\n if format == \"rknn\":\n assert not isinstance(model, YOLOWorld), \"YOLOWorldv2 RKNN exports not supported yet\"\n assert not is_end2end, \"End-to-end models not supported by RKNN yet\"\n assert LINUX, \"RKNN only supported on Linux\"\n assert not is_rockchip(), \"RKNN Inference only supported on Rockchip devices\"\n if \"cpu\" in device.type:\n assert cpu, \"inference not supported on CPU\"\n if \"cuda\" in device.type:\n assert gpu, \"inference not supported on GPU\"\n\n # Export\n if format == \"-\":\n filename = model.pt_path or model.ckpt_path or model.model_name\n exported_model = model # PyTorch format\n else:\n filename = model.export(\n imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False, **kwargs\n )\n exported_model = YOLO(filename, task=model.task)\n assert suffix in str(filename), \"export failed\"\n emoji = \"❎\" # indicates export succeeded\n\n # Predict\n assert model.task != \"pose\" or format != \"pb\", \"GraphDef Pose inference is not supported\"\n assert format not in {\"edgetpu\", \"tfjs\"}, \"inference not supported\"\n assert format != \"coreml\" or platform.system() == \"Darwin\", \"inference only supported on macOS>=10.13\"\n if format == \"ncnn\":\n assert not is_end2end, \"End-to-end torch.topk operation is not supported for NCNN prediction yet\"\n exported_model.predict(ASSETS / \"bus.jpg\", imgsz=imgsz, device=device, half=half, verbose=False)\n\n # Validate\n results = exported_model.val(\n data=data,\n batch=1,\n imgsz=imgsz,\n plots=False,\n device=device,\n half=half,\n int8=int8,\n verbose=False,\n conf=0.001, # all the pre-set benchmark mAP values are based on conf=0.001\n )\n metric, speed = results.results_dict[key], results.speed[\"inference\"]\n fps = round(1000 / (speed + eps), 2) # frames per second\n y.append([name, \"✅\", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps])\n except Exception as e:\n if verbose:\n assert type(e) is AssertionError, f\"Benchmark failure for {name}: {e}\"\n LOGGER.error(f\"Benchmark failure for {name}: {e}\")\n y.append([name, emoji, round(file_size(filename), 1), None, None, None]) # mAP, t_inference\n\n # Print results\n check_yolo(device=device) # print system info\n df = pd.DataFrame(y, columns=[\"Format\", \"Status❔\", \"Size (MB)\", key, \"Inference time (ms/im)\", \"FPS\"])\n\n name = model.model_name\n dt = time.time() - t0\n legend = \"Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed\"\n s = f\"\\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\\n{legend}\\n{df.fillna('-')}\\n\"\n LOGGER.info(s)\n with open(\"benchmarks.log\", \"a\", errors=\"ignore\", encoding=\"utf-8\") as f:\n f.write(s)\n\n if verbose and isinstance(verbose, float):\n metrics = df[key].array # values to compare to floor\n floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n\n assert all(x > floor for x in metrics if pd.notna(x)), f\"Benchmark failure: metric(s) < floor {floor}\"\n\n return df",
"chunk_type": "function",
"name": "benchmark",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 52,
"end_line": 211,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": "Benchmark a YOLO model across different formats for speed and accuracy.\n\nArgs:\n model (str | Path): Path to the model file or directory.\n data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed.\n imgsz (int): Image size for the benchmark.\n half (bool): Use half-precision for the model if True.\n int8 (bool): Use int8-precision for the model if True.\n device (str): Device to run the benchmark on, either 'cpu' or 'cuda'.\n verbose (bool | float): If True or a float, assert benchmarks pass with given metric.\n eps (float): Epsilon value for divide by zero prevention.\n format (str): Export format for benchmarking. If not supplied all formats are benchmarked.\n **kwargs (Any): Additional keyword arguments for exporter.\n\nReturns:\n (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric,\n and inference time.\n\nExamples:\n Benchmark a YOLO model with default settings:\n >>> from ultralytics.utils.benchmarks import benchmark\n >>> benchmark(model=\"yolo11n.pt\", imgsz=640)",
"parameters": [
"model",
"data",
"imgsz",
"half",
"int8",
"device",
"verbose",
"eps",
"format"
],
"return_type": null,
"decorators": [],
"complexity_score": 23,
"dependencies": [
"glob",
"os",
"platform",
"re",
"shutil",
"time",
"pathlib.Path",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"numpy",
"torch.cuda",
"ultralytics.YOLO",
"ultralytics.YOLOWorld",
"ultralytics.cfg.TASK2DATA",
"ultralytics.cfg.TASK2METRIC",
"ultralytics.engine.exporter.export_formats",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.TQDM",
"ultralytics.utils.WEIGHTS_DIR",
"ultralytics.utils.YAML",
"ultralytics.utils.checks.IS_PYTHON_3_13",
"ultralytics.utils.checks.check_imgsz",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.checks.check_yolo",
"ultralytics.utils.checks.is_rockchip",
"ultralytics.utils.downloads.safe_download",
"ultralytics.utils.files.file_size",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.select_device",
"pandas",
"roboflow.Roboflow",
"onnxruntime"
],
"chunk_id": "function_benchmark_488ba16b"
},
{
"content": "class RF100Benchmark:\n \"\"\"\n Benchmark YOLO model performance across various formats for speed and accuracy.\n\n This class provides functionality to benchmark YOLO models on the RF100 dataset collection.\n\n Attributes:\n ds_names (List[str]): Names of datasets used for benchmarking.\n ds_cfg_list (List[Path]): List of paths to dataset configuration files.\n rf (Roboflow): Roboflow instance for accessing datasets.\n val_metrics (List[str]): Metrics used for validation.\n\n Methods:\n set_key: Set Roboflow API key for accessing datasets.\n parse_dataset: Parse dataset links and download datasets.\n fix_yaml: Fix train and validation paths in YAML files.\n evaluate: Evaluate model performance on validation results.\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats.\"\"\"\n self.ds_names = []\n self.ds_cfg_list = []\n self.rf = None\n self.val_metrics = [\"class\", \"images\", \"targets\", \"precision\", \"recall\", \"map50\", \"map95\"]\n\n def set_key(self, api_key: str):\n \"\"\"\n Set Roboflow API key for processing.\n\n Args:\n api_key (str): The API key.\n\n Examples:\n Set the Roboflow API key for accessing datasets:\n >>> benchmark = RF100Benchmark()\n >>> benchmark.set_key(\"your_roboflow_api_key\")\n \"\"\"\n check_requirements(\"roboflow\")\n from roboflow import Roboflow\n\n self.rf = Roboflow(api_key=api_key)\n\n def parse_dataset(self, ds_link_txt: str = \"datasets_links.txt\"):\n \"\"\"\n Parse dataset links and download datasets.\n\n Args:\n ds_link_txt (str): Path to the file containing dataset links.\n\n Returns:\n ds_names (List[str]): List of dataset names.\n ds_cfg_list (List[Path]): List of paths to dataset configuration files.\n\n Examples:\n >>> benchmark = RF100Benchmark()\n >>> benchmark.set_key(\"api_key\")\n >>> benchmark.parse_dataset(\"datasets_links.txt\")\n \"\"\"\n (shutil.rmtree(\"rf-100\"), os.mkdir(\"rf-100\")) if os.path.exists(\"rf-100\") else os.mkdir(\"rf-100\")\n os.chdir(\"rf-100\")\n os.mkdir(\"ultralytics-benchmarks\")\n safe_download(\"https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt\")\n\n with open(ds_link_txt, encoding=\"utf-8\") as file:\n for line in file:\n try:\n _, url, workspace, project, version = re.split(\"/+\", line.strip())\n self.ds_names.append(project)\n proj_version = f\"{project}-{version}\"\n if not Path(proj_version).exists():\n self.rf.workspace(workspace).project(project).version(version).download(\"yolov8\")\n else:\n LOGGER.info(\"Dataset already downloaded.\")\n self.ds_cfg_list.append(Path.cwd() / proj_version / \"data.yaml\")\n except Exception:\n continue\n\n return self.ds_names, self.ds_cfg_list\n\n @staticmethod\n def fix_yaml(path: Path):\n \"\"\"Fix the train and validation paths in a given YAML file.\"\"\"\n yaml_data = YAML.load(path)\n yaml_data[\"train\"] = \"train/images\"\n yaml_data[\"val\"] = \"valid/images\"\n YAML.dump(yaml_data, path)\n\n def evaluate(self, yaml_path: str, val_log_file: str, eval_log_file: str, list_ind: int):\n \"\"\"\n Evaluate model performance on validation results.\n\n Args:\n yaml_path (str): Path to the YAML configuration file.\n val_log_file (str): Path to the validation log file.\n eval_log_file (str): Path to the evaluation log file.\n list_ind (int): Index of the current dataset in the list.\n\n Returns:\n (float): The mean average precision (mAP) value for the evaluated model.\n\n Examples:\n Evaluate a model on a specific dataset\n >>> benchmark = RF100Benchmark()\n >>> benchmark.evaluate(\"path/to/data.yaml\", \"path/to/val_log.txt\", \"path/to/eval_log.txt\", 0)\n \"\"\"\n skip_symbols = [\"🚀\", \"⚠️\", \"💡\", \"❌\"]\n class_names = YAML.load(yaml_path)[\"names\"]\n with open(val_log_file, encoding=\"utf-8\") as f:\n lines = f.readlines()\n eval_lines = []\n for line in lines:\n if any(symbol in line for symbol in skip_symbols):\n continue\n entries = line.split(\" \")\n entries = list(filter(lambda val: val != \"\", entries))\n entries = [e.strip(\"\\n\") for e in entries]\n eval_lines.extend(\n {\n \"class\": entries[0],\n \"images\": entries[1],\n \"targets\": entries[2],\n \"precision\": entries[3],\n \"recall\": entries[4],\n \"map50\": entries[5],\n \"map95\": entries[6],\n }\n for e in entries\n if e in class_names or (e == \"all\" and \"(AP)\" not in entries and \"(AR)\" not in entries)\n )\n map_val = 0.0\n if len(eval_lines) > 1:\n LOGGER.info(\"Multiple dicts found\")\n for lst in eval_lines:\n if lst[\"class\"] == \"all\":\n map_val = lst[\"map50\"]\n else:\n LOGGER.info(\"Single dict found\")\n map_val = [res[\"map50\"] for res in eval_lines][0]\n\n with open(eval_log_file, \"a\", encoding=\"utf-8\") as f:\n f.write(f\"{self.ds_names[list_ind]}: {map_val}\\n\")\n\n return float(map_val)",
"chunk_type": "class",
"name": "RF100Benchmark",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 214,
"end_line": 357,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": "Benchmark YOLO model performance across various formats for speed and accuracy.\n\nThis class provides functionality to benchmark YOLO models on the RF100 dataset collection.\n\nAttributes:\n ds_names (List[str]): Names of datasets used for benchmarking.\n ds_cfg_list (List[Path]): List of paths to dataset configuration files.\n rf (Roboflow): Roboflow instance for accessing datasets.\n val_metrics (List[str]): Metrics used for validation.\n\nMethods:\n set_key: Set Roboflow API key for accessing datasets.\n parse_dataset: Parse dataset links and download datasets.\n fix_yaml: Fix train and validation paths in YAML files.\n evaluate: Evaluate model performance on validation results.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"glob",
"os",
"platform",
"re",
"shutil",
"time",
"pathlib.Path",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"numpy",
"torch.cuda",
"ultralytics.YOLO",
"ultralytics.YOLOWorld",
"ultralytics.cfg.TASK2DATA",
"ultralytics.cfg.TASK2METRIC",
"ultralytics.engine.exporter.export_formats",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.TQDM",
"ultralytics.utils.WEIGHTS_DIR",
"ultralytics.utils.YAML",
"ultralytics.utils.checks.IS_PYTHON_3_13",
"ultralytics.utils.checks.check_imgsz",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.checks.check_yolo",
"ultralytics.utils.checks.is_rockchip",
"ultralytics.utils.downloads.safe_download",
"ultralytics.utils.files.file_size",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.select_device",
"pandas",
"roboflow.Roboflow",
"onnxruntime"
],
"chunk_id": "class_RF100Benchmark_a473c734"
},
{
"content": "class ProfileModels:\n \"\"\"\n ProfileModels class for profiling different models on ONNX and TensorRT.\n\n This class profiles the performance of different models, returning results such as model speed and FLOPs.\n\n Attributes:\n paths (List[str]): Paths of the models to profile.\n num_timed_runs (int): Number of timed runs for the profiling.\n num_warmup_runs (int): Number of warmup runs before profiling.\n min_time (float): Minimum number of seconds to profile for.\n imgsz (int): Image size used in the models.\n half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.\n trt (bool): Flag to indicate whether to profile using TensorRT.\n device (torch.device): Device used for profiling.\n\n Methods:\n run: Profile YOLO models for speed and accuracy across various formats.\n get_files: Get all relevant model files.\n get_onnx_model_info: Extract metadata from an ONNX model.\n iterative_sigma_clipping: Apply sigma clipping to remove outliers.\n profile_tensorrt_model: Profile a TensorRT model.\n profile_onnx_model: Profile an ONNX model.\n generate_table_row: Generate a table row with model metrics.\n generate_results_dict: Generate a dictionary of profiling results.\n print_table: Print a formatted table of results.\n\n Examples:\n Profile models and print results\n >>> from ultralytics.utils.benchmarks import ProfileModels\n >>> profiler = ProfileModels([\"yolo11n.yaml\", \"yolov8s.yaml\"], imgsz=640)\n >>> profiler.run()\n \"\"\"\n\n def __init__(\n self,\n paths: List[str],\n num_timed_runs: int = 100,\n num_warmup_runs: int = 10,\n min_time: float = 60,\n imgsz: int = 640,\n half: bool = True,\n trt: bool = True,\n device: Optional[Union[torch.device, str]] = None,\n ):\n \"\"\"\n Initialize the ProfileModels class for profiling models.\n\n Args:\n paths (List[str]): List of paths of the models to be profiled.\n num_timed_runs (int): Number of timed runs for the profiling.\n num_warmup_runs (int): Number of warmup runs before the actual profiling starts.\n min_time (float): Minimum time in seconds for profiling a model.\n imgsz (int): Size of the image used during profiling.\n half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.\n trt (bool): Flag to indicate whether to profile using TensorRT.\n device (torch.device | str | None): Device used for profiling. If None, it is determined automatically.\n\n Notes:\n FP16 'half' argument option removed for ONNX as slower on CPU than FP32.\n\n Examples:\n Initialize and profile models\n >>> from ultralytics.utils.benchmarks import ProfileModels\n >>> profiler = ProfileModels([\"yolo11n.yaml\", \"yolov8s.yaml\"], imgsz=640)\n >>> profiler.run()\n \"\"\"\n self.paths = paths\n self.num_timed_runs = num_timed_runs\n self.num_warmup_runs = num_warmup_runs\n self.min_time = min_time\n self.imgsz = imgsz\n self.half = half\n self.trt = trt # run TensorRT profiling\n self.device = device if isinstance(device, torch.device) else select_device(device)\n\n def run(self):\n \"\"\"\n Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT.\n\n Returns:\n (List[dict]): List of dictionaries containing profiling results for each model.\n\n Examples:\n Profile models and print results\n >>> from ultralytics.utils.benchmarks import ProfileModels\n >>> profiler = ProfileModels([\"yolo11n.yaml\", \"yolov8s.yaml\"])\n >>> results = profiler.run()\n \"\"\"\n files = self.get_files()\n\n if not files:\n LOGGER.warning(\"No matching *.pt or *.onnx files found.\")\n return []\n\n table_rows = []\n output = []\n for file in files:\n engine_file = file.with_suffix(\".engine\")\n if file.suffix in {\".pt\", \".yaml\", \".yml\"}:\n model = YOLO(str(file))\n model.fuse() # to report correct params and GFLOPs in model.info()\n model_info = model.info()\n if self.trt and self.device.type != \"cpu\" and not engine_file.is_file():\n engine_file = model.export(\n format=\"engine\",\n half=self.half,\n imgsz=self.imgsz,\n device=self.device,\n verbose=False,\n )\n onnx_file = model.export(\n format=\"onnx\",\n imgsz=self.imgsz,\n device=self.device,\n verbose=False,\n )\n elif file.suffix == \".onnx\":\n model_info = self.get_onnx_model_info(file)\n onnx_file = file\n else:\n continue\n\n t_engine = self.profile_tensorrt_model(str(engine_file))\n t_onnx = self.profile_onnx_model(str(onnx_file))\n table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))\n output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))\n\n self.print_table(table_rows)\n return output\n\n def get_files(self):\n \"\"\"\n Return a list of paths for all relevant model files given by the user.\n\n Returns:\n (List[Path]): List of Path objects for the model files.\n \"\"\"\n files = []\n for path in self.paths:\n path = Path(path)\n if path.is_dir():\n extensions = [\"*.pt\", \"*.onnx\", \"*.yaml\"]\n files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])\n elif path.suffix in {\".pt\", \".yaml\", \".yml\"}: # add non-existing\n files.append(str(path))\n else:\n files.extend(glob.glob(str(path)))\n\n LOGGER.info(f\"Profiling: {sorted(files)}\")\n return [Path(file) for file in sorted(files)]\n\n @staticmethod\n def get_onnx_model_info(onnx_file: str):\n \"\"\"Extract metadata from an ONNX model file including parameters, GFLOPs, and input shape.\"\"\"\n return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)\n\n @staticmethod\n def iterative_sigma_clipping(data: np.ndarray, sigma: float = 2, max_iters: int = 3):\n \"\"\"\n Apply iterative sigma clipping to data to remove outliers.\n\n Args:\n data (np.ndarray): Input data array.\n sigma (float): Number of standard deviations to use for clipping.\n max_iters (int): Maximum number of iterations for the clipping process.\n\n Returns:\n (np.ndarray): Clipped data array with outliers removed.\n \"\"\"\n data = np.array(data)\n for _ in range(max_iters):\n mean, std = np.mean(data), np.std(data)\n clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]\n if len(clipped_data) == len(data):\n break\n data = clipped_data\n return data\n\n def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):\n \"\"\"\n Profile YOLO model performance with TensorRT, measuring average run time and standard deviation.\n\n Args:\n engine_file (str): Path to the TensorRT engine file.\n eps (float): Small epsilon value to prevent division by zero.\n\n Returns:\n mean_time (float): Mean inference time in milliseconds.\n std_time (float): Standard deviation of inference time in milliseconds.\n \"\"\"\n if not self.trt or not Path(engine_file).is_file():\n return 0.0, 0.0\n\n # Model and input\n model = YOLO(engine_file)\n input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8) # use uint8 for Classify\n\n # Warmup runs\n elapsed = 0.0\n for _ in range(3):\n start_time = time.time()\n for _ in range(self.num_warmup_runs):\n model(input_data, imgsz=self.imgsz, verbose=False)\n elapsed = time.time() - start_time\n\n # Compute number of runs as higher of min_time or num_timed_runs\n num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)\n\n # Timed runs\n run_times = []\n for _ in TQDM(range(num_runs), desc=engine_file):\n results = model(input_data, imgsz=self.imgsz, verbose=False)\n run_times.append(results[0].speed[\"inference\"]) # Convert to milliseconds\n\n run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping\n return np.mean(run_times), np.std(run_times)\n\n def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):\n \"\"\"\n Profile an ONNX model, measuring average inference time and standard deviation across multiple runs.\n\n Args:\n onnx_file (str): Path to the ONNX model file.\n eps (float): Small epsilon value to prevent division by zero.\n\n Returns:\n mean_time (float): Mean inference time in milliseconds.\n std_time (float): Standard deviation of inference time in milliseconds.\n \"\"\"\n check_requirements(\"onnxruntime\")\n import onnxruntime as ort\n\n # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'\n sess_options = ort.SessionOptions()\n sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n sess_options.intra_op_num_threads = 8 # Limit the number of threads\n sess = ort.InferenceSession(onnx_file, sess_options, providers=[\"CPUExecutionProvider\"])\n\n input_tensor = sess.get_inputs()[0]\n input_type = input_tensor.type\n dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape\n input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape\n\n # Mapping ONNX datatype to numpy datatype\n if \"float16\" in input_type:\n input_dtype = np.float16\n elif \"float\" in input_type:\n input_dtype = np.float32\n elif \"double\" in input_type:\n input_dtype = np.float64\n elif \"int64\" in input_type:\n input_dtype = np.int64\n elif \"int32\" in input_type:\n input_dtype = np.int32\n else:\n raise ValueError(f\"Unsupported ONNX datatype {input_type}\")\n\n input_data = np.random.rand(*input_shape).astype(input_dtype)\n input_name = input_tensor.name\n output_name = sess.get_outputs()[0].name\n\n # Warmup runs\n elapsed = 0.0\n for _ in range(3):\n start_time = time.time()\n for _ in range(self.num_warmup_runs):\n sess.run([output_name], {input_name: input_data})\n elapsed = time.time() - start_time\n\n # Compute number of runs as higher of min_time or num_timed_runs\n num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)\n\n # Timed runs\n run_times = []\n for _ in TQDM(range(num_runs), desc=onnx_file):\n start_time = time.time()\n sess.run([output_name], {input_name: input_data})\n run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds\n\n run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping\n return np.mean(run_times), np.std(run_times)\n\n def generate_table_row(\n self,\n model_name: str,\n t_onnx: Tuple[float, float],\n t_engine: Tuple[float, float],\n model_info: Tuple[float, float, float, float],\n ):\n \"\"\"\n Generate a table row string with model performance metrics.\n\n Args:\n model_name (str): Name of the model.\n t_onnx (tuple): ONNX model inference time statistics (mean, std).\n t_engine (tuple): TensorRT engine inference time statistics (mean, std).\n model_info (tuple): Model information (layers, params, gradients, flops).\n\n Returns:\n (str): Formatted table row string with model metrics.\n \"\"\"\n layers, params, gradients, flops = model_info\n return (\n f\"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±\"\n f\"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |\"\n )\n\n @staticmethod\n def generate_results_dict(\n model_name: str,\n t_onnx: Tuple[float, float],\n t_engine: Tuple[float, float],\n model_info: Tuple[float, float, float, float],\n ):\n \"\"\"\n Generate a dictionary of profiling results.\n\n Args:\n model_name (str): Name of the model.\n t_onnx (tuple): ONNX model inference time statistics (mean, std).\n t_engine (tuple): TensorRT engine inference time statistics (mean, std).\n model_info (tuple): Model information (layers, params, gradients, flops).\n\n Returns:\n (dict): Dictionary containing profiling results.\n \"\"\"\n layers, params, gradients, flops = model_info\n return {\n \"model/name\": model_name,\n \"model/parameters\": params,\n \"model/GFLOPs\": round(flops, 3),\n \"model/speed_ONNX(ms)\": round(t_onnx[0], 3),\n \"model/speed_TensorRT(ms)\": round(t_engine[0], 3),\n }\n\n @staticmethod\n def print_table(table_rows: List[str]):\n \"\"\"\n Print a formatted table of model profiling results.\n\n Args:\n table_rows (List[str]): List of formatted table row strings.\n \"\"\"\n gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"GPU\"\n headers = [\n \"Model\",\n \"size
(pixels)\",\n \"mAPval
50-95\",\n f\"Speed
CPU ({get_cpu_info()}) ONNX
(ms)\",\n f\"Speed
{gpu} TensorRT
(ms)\",\n \"params
(M)\",\n \"FLOPs
(B)\",\n ]\n header = \"|\" + \"|\".join(f\" {h} \" for h in headers) + \"|\"\n separator = \"|\" + \"|\".join(\"-\" * (len(h) + 2) for h in headers) + \"|\"\n\n LOGGER.info(f\"\\n\\n{header}\")\n LOGGER.info(separator)\n for row in table_rows:\n LOGGER.info(row)",
"chunk_type": "class",
"name": "ProfileModels",
"file_path": "ultralytics\\ultralytics\\utils\\benchmarks.py",
"start_line": 360,
"end_line": 720,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "ProfileModels class for profiling different models on ONNX and TensorRT.\n\nThis class profiles the performance of different models, returning results such as model speed and FLOPs.\n\nAttributes:\n paths (List[str]): Paths of the models to profile.\n num_timed_runs (int): Number of timed runs for the profiling.\n num_warmup_runs (int): Number of warmup runs before profiling.\n min_time (float): Minimum number of seconds to profile for.\n imgsz (int): Image size used in the models.\n half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.\n trt (bool): Flag to indicate whether to profile using TensorRT.\n device (torch.device): Device used for profiling.\n\nMethods:\n run: Profile YOLO models for speed and accuracy across various formats.\n get_files: Get all relevant model files.\n get_onnx_model_info: Extract metadata from an ONNX model.\n iterative_sigma_clipping: Apply sigma clipping to remove outliers.\n profile_tensorrt_model: Profile a TensorRT model.\n profile_onnx_model: Profile an ONNX model.\n generate_table_row: Generate a table row with model metrics.\n generate_results_dict: Generate a dictionary of profiling results.\n print_table: Print a formatted table of results.\n\nExamples:\n Profile models and print results\n >>> from ultralytics.utils.benchmarks import ProfileModels\n >>> profiler = ProfileModels([\"yolo11n.yaml\", \"yolov8s.yaml\"], imgsz=640)\n >>> profiler.run()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"glob",
"os",
"platform",
"re",
"shutil",
"time",
"pathlib.Path",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"numpy",
"torch.cuda",
"ultralytics.YOLO",
"ultralytics.YOLOWorld",
"ultralytics.cfg.TASK2DATA",
"ultralytics.cfg.TASK2METRIC",
"ultralytics.engine.exporter.export_formats",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.TQDM",
"ultralytics.utils.WEIGHTS_DIR",
"ultralytics.utils.YAML",
"ultralytics.utils.checks.IS_PYTHON_3_13",
"ultralytics.utils.checks.check_imgsz",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.checks.check_yolo",
"ultralytics.utils.checks.is_rockchip",
"ultralytics.utils.downloads.safe_download",
"ultralytics.utils.files.file_size",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.select_device",
"pandas",
"roboflow.Roboflow",
"onnxruntime"
],
"chunk_id": "class_ProfileModels_c4eb9a5b"
},
{
"content": "import functools",
"chunk_type": "import",
"name": "functools",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_functools_6378b944"
},
{
"content": "import glob",
"chunk_type": "import",
"name": "glob",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_glob_83cd0773"
},
{
"content": "import inspect",
"chunk_type": "import",
"name": "inspect",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 14,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_inspect_f54848db"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_0f8e4529"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_414f6710"
},
{
"content": "import platform",
"chunk_type": "import",
"name": "platform",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_platform_a1f5096d"
},
{
"content": "import re",
"chunk_type": "import",
"name": "re",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_re_b84d5a11"
},
{
"content": "import shutil",
"chunk_type": "import",
"name": "shutil",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_shutil_d7f542ca"
},
{
"content": "import subprocess",
"chunk_type": "import",
"name": "subprocess",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_subprocess_21ebf160"
},
{
"content": "import time",
"chunk_type": "import",
"name": "time",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_time_a6472836"
},
{
"content": "from importlib import metadata",
"chunk_type": "import",
"name": "metadata",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_metadata_f5ca7ec6"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_0f6f55a6"
},
{
"content": "from types import SimpleNamespace",
"chunk_type": "import",
"name": "SimpleNamespace",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SimpleNamespace_1f99f4c4"
},
{
"content": "from typing import Optional",
"chunk_type": "import",
"name": "Optional",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Optional_024c48a9"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_a2b45917"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_47f34676"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 20,
"end_line": 20,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_133e9c35"
},
{
"content": "from ultralytics.utils import (\n ARM64,\n ASSETS,\n AUTOINSTALL,\n IS_COLAB,\n IS_GIT_DIR,\n IS_JETSON,\n IS_KAGGLE,\n IS_PIP_PACKAGE,\n LINUX,\n LOGGER,\n MACOS,\n ONLINE,\n PYTHON_VERSION,\n RKNN_CHIPS,\n ROOT,\n TORCHVISION_VERSION,\n USER_CONFIG_DIR,\n WINDOWS,\n Retry,\n ThreadingLocked,\n TryExcept,\n clean_url,\n colorstr,\n downloads,\n is_github_action_running,\n url2file,\n)",
"chunk_type": "import",
"name": "ARM64, ASSETS, AUTOINSTALL, IS_COLAB, IS_GIT_DIR, IS_JETSON, IS_KAGGLE, IS_PIP_PACKAGE, LINUX, LOGGER, MACOS, ONLINE, PYTHON_VERSION, RKNN_CHIPS, ROOT, TORCHVISION_VERSION, USER_CONFIG_DIR, WINDOWS, Retry, ThreadingLocked, TryExcept, clean_url, colorstr, downloads, is_github_action_running, url2file",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 22,
"end_line": 49,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ARM64, ASSETS, AUTOINSTALL, IS_COLAB, IS_GIT_DIR, IS_JETSON, IS_KAGGLE, IS_PIP_PACKAGE, LINUX, LOGGER, MACOS, ONLINE, PYTHON_VERSION, RKNN_CHIPS, ROOT, TORCHVISION_VERSION, USER_CONFIG_DIR, WINDOWS, Retry, ThreadingLocked, TryExcept, clean_url, colorstr, downloads, is_github_action_running, url2file_8c9fe9a2"
},
{
"content": "def parse_requirements(file_path=ROOT.parent / \"requirements.txt\", package=\"\"):\n \"\"\"\n Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.\n\n Args:\n file_path (Path): Path to the requirements.txt file.\n package (str, optional): Python package to use instead of requirements.txt file.\n\n Returns:\n requirements (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and\n `specifier` attributes.\n\n Examples:\n >>> from ultralytics.utils.checks import parse_requirements\n >>> parse_requirements(package=\"ultralytics\")\n \"\"\"\n if package:\n requires = [x for x in metadata.distribution(package).requires if \"extra == \" not in x]\n else:\n requires = Path(file_path).read_text().splitlines()\n\n requirements = []\n for line in requires:\n line = line.strip()\n if line and not line.startswith(\"#\"):\n line = line.partition(\"#\")[0].strip() # ignore inline comments\n if match := re.match(r\"([a-zA-Z0-9-_]+)\\s*([<>!=~]+.*)?\", line):\n requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else \"\"))\n\n return requirements",
"chunk_type": "function",
"name": "parse_requirements",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 52,
"end_line": 81,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.\n\nArgs:\n file_path (Path): Path to the requirements.txt file.\n package (str, optional): Python package to use instead of requirements.txt file.\n\nReturns:\n requirements (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and\n `specifier` attributes.\n\nExamples:\n >>> from ultralytics.utils.checks import parse_requirements\n >>> parse_requirements(package=\"ultralytics\")",
"parameters": [
"file_path",
"package"
],
"return_type": null,
"decorators": [],
"complexity_score": 6,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_parse_requirements_dc571047"
},
{
"content": "def parse_version(version=\"0.0.0\") -> tuple:\n \"\"\"\n Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.\n\n Args:\n version (str): Version string, i.e. '2.0.1+cpu'\n\n Returns:\n (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)\n \"\"\"\n try:\n return tuple(map(int, re.findall(r\"\\d+\", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)\n except Exception as e:\n LOGGER.warning(f\"failure for parse_version({version}), returning (0, 0, 0): {e}\")\n return 0, 0, 0",
"chunk_type": "function",
"name": "parse_version",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 85,
"end_line": 99,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.\n\nArgs:\n version (str): Version string, i.e. '2.0.1+cpu'\n\nReturns:\n (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)",
"parameters": [
"version"
],
"return_type": "tuple",
"decorators": [
"functools.lru_cache"
],
"complexity_score": 2,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_parse_version_06c0fcd5"
},
{
"content": "def is_ascii(s) -> bool:\n \"\"\"\n Check if a string is composed of only ASCII characters.\n\n Args:\n s (str | list | tuple | dict): Input to be checked (all are converted to string for checking).\n\n Returns:\n (bool): True if the string is composed only of ASCII characters, False otherwise.\n \"\"\"\n return all(ord(c) < 128 for c in str(s))",
"chunk_type": "function",
"name": "is_ascii",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 102,
"end_line": 112,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Check if a string is composed of only ASCII characters.\n\nArgs:\n s (str | list | tuple | dict): Input to be checked (all are converted to string for checking).\n\nReturns:\n (bool): True if the string is composed only of ASCII characters, False otherwise.",
"parameters": [
"s"
],
"return_type": "bool",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_is_ascii_78a4aa5d"
},
{
"content": "def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):\n \"\"\"\n Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the\n stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.\n\n Args:\n imgsz (int | List[int]): Image size.\n stride (int): Stride value.\n min_dim (int): Minimum number of dimensions.\n max_dim (int): Maximum number of dimensions.\n floor (int): Minimum allowed value for image size.\n\n Returns:\n (List[int] | int): Updated image size.\n \"\"\"\n # Convert stride to integer if it is a tensor\n stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)\n\n # Convert image size to list if it is an integer\n if isinstance(imgsz, int):\n imgsz = [imgsz]\n elif isinstance(imgsz, (list, tuple)):\n imgsz = list(imgsz)\n elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'\n imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)\n else:\n raise TypeError(\n f\"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. \"\n f\"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'\"\n )\n\n # Apply max_dim\n if len(imgsz) > max_dim:\n msg = (\n \"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list \"\n \"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'\"\n )\n if max_dim != 1:\n raise ValueError(f\"imgsz={imgsz} is not a valid image size. {msg}\")\n LOGGER.warning(f\"updating to 'imgsz={max(imgsz)}'. {msg}\")\n imgsz = [max(imgsz)]\n # Make image size a multiple of the stride\n sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]\n\n # Print warning message if image size was updated\n if sz != imgsz:\n LOGGER.warning(f\"imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}\")\n\n # Add missing dimensions if necessary\n sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz\n\n return sz",
"chunk_type": "function",
"name": "check_imgsz",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 115,
"end_line": 166,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": "Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the\nstride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.\n\nArgs:\n imgsz (int | List[int]): Image size.\n stride (int): Stride value.\n min_dim (int): Minimum number of dimensions.\n max_dim (int): Maximum number of dimensions.\n floor (int): Minimum allowed value for image size.\n\nReturns:\n (List[int] | int): Updated image size.",
"parameters": [
"imgsz",
"stride",
"min_dim",
"max_dim",
"floor"
],
"return_type": null,
"decorators": [],
"complexity_score": 8,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_imgsz_d00cc53c"
},
{
"content": "def check_uv():\n \"\"\"Check if uv package manager is installed and can run successfully.\"\"\"\n try:\n return subprocess.run([\"uv\", \"-V\"], capture_output=True).returncode == 0\n except FileNotFoundError:\n return False",
"chunk_type": "function",
"name": "check_uv",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 170,
"end_line": 175,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Check if uv package manager is installed and can run successfully.",
"parameters": [],
"return_type": null,
"decorators": [
"functools.lru_cache"
],
"complexity_score": 2,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_uv_b28ad8d6"
},
{
"content": "def check_version(\n current: str = \"0.0.0\",\n required: str = \"0.0.0\",\n name: str = \"version\",\n hard: bool = False,\n verbose: bool = False,\n msg: str = \"\",\n) -> bool:\n \"\"\"\n Check current version against the required version or range.\n\n Args:\n current (str): Current version or package name to get version from.\n required (str): Required version or range (in pip-style format).\n name (str): Name to be used in warning message.\n hard (bool): If True, raise an AssertionError if the requirement is not met.\n verbose (bool): If True, print warning message if requirement is not met.\n msg (str): Extra message to display if verbose.\n\n Returns:\n (bool): True if requirement is met, False otherwise.\n\n Examples:\n Check if current version is exactly 22.04\n >>> check_version(current=\"22.04\", required=\"==22.04\")\n\n Check if current version is greater than or equal to 22.04\n >>> check_version(current=\"22.10\", required=\"22.04\") # assumes '>=' inequality if none passed\n\n Check if current version is less than or equal to 22.04\n >>> check_version(current=\"22.04\", required=\"<=22.04\")\n\n Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)\n >>> check_version(current=\"21.10\", required=\">20.04,<22.04\")\n \"\"\"\n if not current: # if current is '' or None\n LOGGER.warning(f\"invalid check_version({current}, {required}) requested, please check values.\")\n return True\n elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'\n try:\n name = current # assigned package name to 'name' arg\n current = metadata.version(current) # get version string from package name\n except metadata.PackageNotFoundError as e:\n if hard:\n raise ModuleNotFoundError(f\"{current} package is required but not installed\") from e\n else:\n return False\n\n if not required: # if required is '' or None\n return True\n\n if \"sys_platform\" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == \"win32\"'\n (WINDOWS and \"win32\" not in required)\n or (LINUX and \"linux\" not in required)\n or (MACOS and \"macos\" not in required and \"darwin\" not in required)\n ):\n return True\n\n op = \"\"\n version = \"\"\n result = True\n c = parse_version(current) # '1.2.3' -> (1, 2, 3)\n for r in required.strip(\",\").split(\",\"):\n op, version = re.match(r\"([^0-9]*)([\\d.]+)\", r).groups() # split '>=22.04' -> ('>=', '22.04')\n if not op:\n op = \">=\" # assume >= if no op passed\n v = parse_version(version) # '1.2.3' -> (1, 2, 3)\n if op == \"==\" and c != v:\n result = False\n elif op == \"!=\" and c == v:\n result = False\n elif op == \">=\" and not (c >= v):\n result = False\n elif op == \"<=\" and not (c <= v):\n result = False\n elif op == \">\" and not (c > v):\n result = False\n elif op == \"<\" and not (c < v):\n result = False\n if not result:\n warning = f\"{name}{required} is required, but {name}=={current} is currently installed {msg}\"\n if hard:\n raise ModuleNotFoundError(warning) # assert version requirements met\n if verbose:\n LOGGER.warning(warning)\n return result",
"chunk_type": "function",
"name": "check_version",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 179,
"end_line": 264,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": "Check current version against the required version or range.\n\nArgs:\n current (str): Current version or package name to get version from.\n required (str): Required version or range (in pip-style format).\n name (str): Name to be used in warning message.\n hard (bool): If True, raise an AssertionError if the requirement is not met.\n verbose (bool): If True, print warning message if requirement is not met.\n msg (str): Extra message to display if verbose.\n\nReturns:\n (bool): True if requirement is met, False otherwise.\n\nExamples:\n Check if current version is exactly 22.04\n >>> check_version(current=\"22.04\", required=\"==22.04\")\n\n Check if current version is greater than or equal to 22.04\n >>> check_version(current=\"22.10\", required=\"22.04\") # assumes '>=' inequality if none passed\n\n Check if current version is less than or equal to 22.04\n >>> check_version(current=\"22.04\", required=\"<=22.04\")\n\n Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)\n >>> check_version(current=\"21.10\", required=\">20.04,<22.04\")",
"parameters": [
"current: str",
"required: str",
"name: str",
"hard: bool",
"verbose: bool",
"msg: str"
],
"return_type": "bool",
"decorators": [
"functools.lru_cache"
],
"complexity_score": 18,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_version_ae0b5792"
},
{
"content": "def check_latest_pypi_version(package_name=\"ultralytics\"):\n \"\"\"\n Return the latest version of a PyPI package without downloading or installing it.\n\n Args:\n package_name (str): The name of the package to find the latest version for.\n\n Returns:\n (str): The latest version of the package.\n \"\"\"\n import requests # slow import\n\n try:\n requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning\n response = requests.get(f\"https://pypi.org/pypi/{package_name}/json\", timeout=3)\n if response.status_code == 200:\n return response.json()[\"info\"][\"version\"]\n except Exception:\n return None",
"chunk_type": "function",
"name": "check_latest_pypi_version",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 267,
"end_line": 285,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Return the latest version of a PyPI package without downloading or installing it.\n\nArgs:\n package_name (str): The name of the package to find the latest version for.\n\nReturns:\n (str): The latest version of the package.",
"parameters": [
"package_name"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_latest_pypi_version_58a0c5a5"
},
{
"content": "def check_pip_update_available():\n \"\"\"\n Check if a new version of the ultralytics package is available on PyPI.\n\n Returns:\n (bool): True if an update is available, False otherwise.\n \"\"\"\n if ONLINE and IS_PIP_PACKAGE:\n try:\n from ultralytics import __version__\n\n latest = check_latest_pypi_version()\n if check_version(__version__, f\"<{latest}\"): # check if current version is < latest version\n LOGGER.info(\n f\"New https://pypi.org/project/ultralytics/{latest} available 😃 \"\n f\"Update with 'pip install -U ultralytics'\"\n )\n return True\n except Exception:\n pass\n return False",
"chunk_type": "function",
"name": "check_pip_update_available",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 288,
"end_line": 308,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Check if a new version of the ultralytics package is available on PyPI.\n\nReturns:\n (bool): True if an update is available, False otherwise.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_pip_update_available_406d331e"
},
{
"content": "def check_font(font=\"Arial.ttf\"):\n \"\"\"\n Find font locally or download to user's configuration directory if it does not already exist.\n\n Args:\n font (str): Path or name of font.\n\n Returns:\n (Path): Resolved font file path.\n \"\"\"\n from matplotlib import font_manager # scope for faster 'import ultralytics'\n\n # Check USER_CONFIG_DIR\n name = Path(font).name\n file = USER_CONFIG_DIR / name\n if file.exists():\n return file\n\n # Check system fonts\n matches = [s for s in font_manager.findSystemFonts() if font in s]\n if any(matches):\n return matches[0]\n\n # Download to USER_CONFIG_DIR if missing\n url = f\"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}\"\n if downloads.is_url(url, check=True):\n downloads.safe_download(url=url, file=file)\n return file",
"chunk_type": "function",
"name": "check_font",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 313,
"end_line": 340,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Find font locally or download to user's configuration directory if it does not already exist.\n\nArgs:\n font (str): Path or name of font.\n\nReturns:\n (Path): Resolved font file path.",
"parameters": [
"font"
],
"return_type": null,
"decorators": [
"ThreadingLocked()",
"functools.lru_cache"
],
"complexity_score": 5,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_font_fd626ed1"
},
{
"content": "def check_python(minimum: str = \"3.8.0\", hard: bool = True, verbose: bool = False) -> bool:\n \"\"\"\n Check current python version against the required minimum version.\n\n Args:\n minimum (str): Required minimum version of python.\n hard (bool): If True, raise an AssertionError if the requirement is not met.\n verbose (bool): If True, print warning message if requirement is not met.\n\n Returns:\n (bool): Whether the installed Python version meets the minimum constraints.\n \"\"\"\n return check_version(PYTHON_VERSION, minimum, name=\"Python\", hard=hard, verbose=verbose)",
"chunk_type": "function",
"name": "check_python",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 343,
"end_line": 355,
"start_col": 0,
"end_col": 92,
"parent_name": null,
"docstring": "Check current python version against the required minimum version.\n\nArgs:\n minimum (str): Required minimum version of python.\n hard (bool): If True, raise an AssertionError if the requirement is not met.\n verbose (bool): If True, print warning message if requirement is not met.\n\nReturns:\n (bool): Whether the installed Python version meets the minimum constraints.",
"parameters": [
"minimum: str",
"hard: bool",
"verbose: bool"
],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_python_f5574175"
},
{
"content": "def check_requirements(requirements=ROOT.parent / \"requirements.txt\", exclude=(), install=True, cmds=\"\"):\n \"\"\"\n Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.\n\n Args:\n requirements (Path | str | List[str]): Path to a requirements.txt file, a single package requirement as a\n string, or a list of package requirements as strings.\n exclude (tuple): Tuple of package names to exclude from checking.\n install (bool): If True, attempt to auto-update packages that don't meet requirements.\n cmds (str): Additional commands to pass to the pip install command when auto-updating.\n\n Examples:\n >>> from ultralytics.utils.checks import check_requirements\n\n Check a requirements.txt file\n >>> check_requirements(\"path/to/requirements.txt\")\n\n Check a single package\n >>> check_requirements(\"ultralytics>=8.0.0\")\n\n Check multiple packages\n >>> check_requirements([\"numpy\", \"ultralytics>=8.0.0\"])\n \"\"\"\n prefix = colorstr(\"red\", \"bold\", \"requirements:\")\n if isinstance(requirements, Path): # requirements.txt file\n file = requirements.resolve()\n assert file.exists(), f\"{prefix} {file} not found, check failed.\"\n requirements = [f\"{x.name}{x.specifier}\" for x in parse_requirements(file) if x.name not in exclude]\n elif isinstance(requirements, str):\n requirements = [requirements]\n\n pkgs = []\n for r in requirements:\n r_stripped = r.rpartition(\"/\")[-1].replace(\".git\", \"\") # replace git+https://org/repo.git -> 'repo'\n match = re.match(r\"([a-zA-Z0-9-_]+)([<>!=~]+.*)?\", r_stripped)\n name, required = match[1], match[2].strip() if match[2] else \"\"\n try:\n assert check_version(metadata.version(name), required) # exception if requirements not met\n except (AssertionError, metadata.PackageNotFoundError):\n pkgs.append(r)\n\n @Retry(times=2, delay=1)\n def attempt_install(packages, commands, use_uv):\n \"\"\"Attempt package installation with uv if available, falling back to pip.\"\"\"\n if use_uv:\n base = f\"uv pip install --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow\"\n try:\n return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE).decode()\n except subprocess.CalledProcessError as e:\n if e.stderr and \"No virtual environment found\" in e.stderr.decode():\n return subprocess.check_output(\n base.replace(\"uv pip install\", \"uv pip install --system\"), shell=True\n ).decode()\n raise\n return subprocess.check_output(f\"pip install --no-cache-dir {packages} {commands}\", shell=True).decode()\n\n s = \" \".join(f'\"{x}\"' for x in pkgs) # console string\n if s:\n if install and AUTOINSTALL: # check environment variable\n # Note uv fails on arm64 macOS and Raspberry Pi runners\n n = len(pkgs) # number of packages updates\n LOGGER.info(f\"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...\")\n try:\n t = time.time()\n assert ONLINE, \"AutoUpdate skipped (offline)\"\n LOGGER.info(attempt_install(s, cmds, use_uv=not ARM64 and check_uv()))\n dt = time.time() - t\n LOGGER.info(f\"{prefix} AutoUpdate success ✅ {dt:.1f}s\")\n LOGGER.warning(\n f\"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\\n\"\n )\n except Exception as e:\n LOGGER.warning(f\"{prefix} ❌ {e}\")\n return False\n else:\n return False\n\n return True",
"chunk_type": "function",
"name": "check_requirements",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 359,
"end_line": 436,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.\n\nArgs:\n requirements (Path | str | List[str]): Path to a requirements.txt file, a single package requirement as a\n string, or a list of package requirements as strings.\n exclude (tuple): Tuple of package names to exclude from checking.\n install (bool): If True, attempt to auto-update packages that don't meet requirements.\n cmds (str): Additional commands to pass to the pip install command when auto-updating.\n\nExamples:\n >>> from ultralytics.utils.checks import check_requirements\n\n Check a requirements.txt file\n >>> check_requirements(\"path/to/requirements.txt\")\n\n Check a single package\n >>> check_requirements(\"ultralytics>=8.0.0\")\n\n Check multiple packages\n >>> check_requirements([\"numpy\", \"ultralytics>=8.0.0\"])",
"parameters": [
"requirements",
"exclude",
"install",
"cmds"
],
"return_type": null,
"decorators": [
"TryExcept()"
],
"complexity_score": 13,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_requirements_8db63c68"
},
{
"content": "def check_torchvision():\n \"\"\"\n Check the installed versions of PyTorch and Torchvision to ensure they're compatible.\n\n This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according\n to the compatibility table based on: https://github.com/pytorch/vision#installation.\n \"\"\"\n compatibility_table = {\n \"2.7\": [\"0.22\"],\n \"2.6\": [\"0.21\"],\n \"2.5\": [\"0.20\"],\n \"2.4\": [\"0.19\"],\n \"2.3\": [\"0.18\"],\n \"2.2\": [\"0.17\"],\n \"2.1\": [\"0.16\"],\n \"2.0\": [\"0.15\"],\n \"1.13\": [\"0.14\"],\n \"1.12\": [\"0.13\"],\n }\n\n # Check major and minor versions\n v_torch = \".\".join(torch.__version__.split(\"+\", 1)[0].split(\".\")[:2])\n if v_torch in compatibility_table:\n compatible_versions = compatibility_table[v_torch]\n v_torchvision = \".\".join(TORCHVISION_VERSION.split(\"+\", 1)[0].split(\".\")[:2])\n if all(v_torchvision != v for v in compatible_versions):\n LOGGER.warning(\n f\"torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\\n\"\n f\"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or \"\n \"'pip install -U torch torchvision' to update both.\\n\"\n \"For a full compatibility table see https://github.com/pytorch/vision#installation\"\n )",
"chunk_type": "function",
"name": "check_torchvision",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 439,
"end_line": 470,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": "Check the installed versions of PyTorch and Torchvision to ensure they're compatible.\n\nThis function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according\nto the compatibility table based on: https://github.com/pytorch/vision#installation.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_torchvision_9377ec30"
},
{
"content": "def check_suffix(file=\"yolo11n.pt\", suffix=\".pt\", msg=\"\"):\n \"\"\"\n Check file(s) for acceptable suffix.\n\n Args:\n file (str | List[str]): File or list of files to check.\n suffix (str | tuple): Acceptable suffix or tuple of suffixes.\n msg (str): Additional message to display in case of error.\n \"\"\"\n if file and suffix:\n if isinstance(suffix, str):\n suffix = {suffix}\n for f in file if isinstance(file, (list, tuple)) else [file]:\n if s := str(f).rpartition(\".\")[-1].lower().strip(): # file suffix\n assert f\".{s}\" in suffix, f\"{msg}{f} acceptable suffix is {suffix}, not .{s}\"",
"chunk_type": "function",
"name": "check_suffix",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 473,
"end_line": 487,
"start_col": 0,
"end_col": 93,
"parent_name": null,
"docstring": "Check file(s) for acceptable suffix.\n\nArgs:\n file (str | List[str]): File or list of files to check.\n suffix (str | tuple): Acceptable suffix or tuple of suffixes.\n msg (str): Additional message to display in case of error.",
"parameters": [
"file",
"suffix",
"msg"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_suffix_2c408922"
},
{
"content": "def check_yolov5u_filename(file: str, verbose: bool = True):\n \"\"\"\n Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.\n\n Args:\n file (str): Filename to check and potentially update.\n verbose (bool): Whether to print information about the replacement.\n\n Returns:\n (str): Updated filename.\n \"\"\"\n if \"yolov3\" in file or \"yolov5\" in file:\n if \"u.yaml\" in file:\n file = file.replace(\"u.yaml\", \".yaml\") # i.e. yolov5nu.yaml -> yolov5n.yaml\n elif \".pt\" in file and \"u\" not in file:\n original_file = file\n file = re.sub(r\"(.*yolov5([nsmlx]))\\.pt\", \"\\\\1u.pt\", file) # i.e. yolov5n.pt -> yolov5nu.pt\n file = re.sub(r\"(.*yolov5([nsmlx])6)\\.pt\", \"\\\\1u.pt\", file) # i.e. yolov5n6.pt -> yolov5n6u.pt\n file = re.sub(r\"(.*yolov3(|-tiny|-spp))\\.pt\", \"\\\\1u.pt\", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt\n if file != original_file and verbose:\n LOGGER.info(\n f\"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\\nYOLOv5 'u' models are \"\n f\"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs \"\n f\"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\\n\"\n )\n return file",
"chunk_type": "function",
"name": "check_yolov5u_filename",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 490,
"end_line": 515,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.\n\nArgs:\n file (str): Filename to check and potentially update.\n verbose (bool): Whether to print information about the replacement.\n\nReturns:\n (str): Updated filename.",
"parameters": [
"file: str",
"verbose: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_yolov5u_filename_c6c07a40"
},
{
"content": "def check_model_file_from_stem(model=\"yolo11n\"):\n \"\"\"\n Return a model filename from a valid model stem.\n\n Args:\n model (str): Model stem to check.\n\n Returns:\n (str | Path): Model filename with appropriate suffix.\n \"\"\"\n path = Path(model)\n if not path.suffix and path.stem in downloads.GITHUB_ASSETS_STEMS:\n return path.with_suffix(\".pt\") # add suffix, i.e. yolo11n -> yolo11n.pt\n return model",
"chunk_type": "function",
"name": "check_model_file_from_stem",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 518,
"end_line": 531,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Return a model filename from a valid model stem.\n\nArgs:\n model (str): Model stem to check.\n\nReturns:\n (str | Path): Model filename with appropriate suffix.",
"parameters": [
"model"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_model_file_from_stem_585ca3c7"
},
{
"content": "def check_file(file, suffix=\"\", download=True, download_dir=\".\", hard=True):\n \"\"\"\n Search/download file (if necessary), check suffix (if provided), and return path.\n\n Args:\n file (str): File name or path.\n suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file.\n download (bool): Whether to download the file if it doesn't exist locally.\n download_dir (str): Directory to download the file to.\n hard (bool): Whether to raise an error if the file is not found.\n\n Returns:\n (str): Path to the file.\n \"\"\"\n check_suffix(file, suffix) # optional\n file = str(file).strip() # convert to string and strip spaces\n file = check_yolov5u_filename(file) # yolov5n -> yolov5nu\n if (\n not file\n or (\"://\" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10\n or file.lower().startswith(\"grpc://\")\n ): # file exists or gRPC Triton images\n return file\n elif download and file.lower().startswith((\"https://\", \"http://\", \"rtsp://\", \"rtmp://\", \"tcp://\")): # download\n url = file # warning: Pathlib turns :// -> :/\n file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth\n if file.exists():\n LOGGER.info(f\"Found {clean_url(url)} locally at {file}\") # file already exists\n else:\n downloads.safe_download(url=url, file=file, unzip=False)\n return str(file)\n else: # search\n files = glob.glob(str(ROOT / \"**\" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file\n if not files and hard:\n raise FileNotFoundError(f\"'{file}' does not exist\")\n elif len(files) > 1 and hard:\n raise FileNotFoundError(f\"Multiple files match '{file}', specify exact path: {files}\")\n return files[0] if len(files) else [] # return file",
"chunk_type": "function",
"name": "check_file",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 534,
"end_line": 571,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": "Search/download file (if necessary), check suffix (if provided), and return path.\n\nArgs:\n file (str): File name or path.\n suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file.\n download (bool): Whether to download the file if it doesn't exist locally.\n download_dir (str): Directory to download the file to.\n hard (bool): Whether to raise an error if the file is not found.\n\nReturns:\n (str): Path to the file.",
"parameters": [
"file",
"suffix",
"download",
"download_dir",
"hard"
],
"return_type": null,
"decorators": [],
"complexity_score": 6,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_file_f59712e4"
},
{
"content": "def check_yaml(file, suffix=(\".yaml\", \".yml\"), hard=True):\n \"\"\"\n Search/download YAML file (if necessary) and return path, checking suffix.\n\n Args:\n file (str | Path): File name or path.\n suffix (tuple): Tuple of acceptable YAML file suffixes.\n hard (bool): Whether to raise an error if the file is not found or multiple files are found.\n\n Returns:\n (str): Path to the YAML file.\n \"\"\"\n return check_file(file, suffix, hard=hard)",
"chunk_type": "function",
"name": "check_yaml",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 574,
"end_line": 586,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Search/download YAML file (if necessary) and return path, checking suffix.\n\nArgs:\n file (str | Path): File name or path.\n suffix (tuple): Tuple of acceptable YAML file suffixes.\n hard (bool): Whether to raise an error if the file is not found or multiple files are found.\n\nReturns:\n (str): Path to the YAML file.",
"parameters": [
"file",
"suffix",
"hard"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_yaml_3af24e9b"
},
{
"content": "def check_is_path_safe(basedir, path):\n \"\"\"\n Check if the resolved path is under the intended directory to prevent path traversal.\n\n Args:\n basedir (Path | str): The intended directory.\n path (Path | str): The path to check.\n\n Returns:\n (bool): True if the path is safe, False otherwise.\n \"\"\"\n base_dir_resolved = Path(basedir).resolve()\n path_resolved = Path(path).resolve()\n\n return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts",
"chunk_type": "function",
"name": "check_is_path_safe",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 589,
"end_line": 603,
"start_col": 0,
"end_col": 116,
"parent_name": null,
"docstring": "Check if the resolved path is under the intended directory to prevent path traversal.\n\nArgs:\n basedir (Path | str): The intended directory.\n path (Path | str): The path to check.\n\nReturns:\n (bool): True if the path is safe, False otherwise.",
"parameters": [
"basedir",
"path"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_is_path_safe_1b44c9a8"
},
{
"content": "def check_imshow(warn=False):\n \"\"\"\n Check if environment supports image displays.\n\n Args:\n warn (bool): Whether to warn if environment doesn't support image displays.\n\n Returns:\n (bool): True if environment supports image displays, False otherwise.\n \"\"\"\n try:\n if LINUX:\n assert not IS_COLAB and not IS_KAGGLE\n assert \"DISPLAY\" in os.environ, \"The DISPLAY environment variable isn't set.\"\n cv2.imshow(\"test\", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image\n cv2.waitKey(1)\n cv2.destroyAllWindows()\n cv2.waitKey(1)\n return True\n except Exception as e:\n if warn:\n LOGGER.warning(f\"Environment does not support cv2.imshow() or PIL Image.show()\\n{e}\")\n return False",
"chunk_type": "function",
"name": "check_imshow",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 607,
"end_line": 629,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Check if environment supports image displays.\n\nArgs:\n warn (bool): Whether to warn if environment doesn't support image displays.\n\nReturns:\n (bool): True if environment supports image displays, False otherwise.",
"parameters": [
"warn"
],
"return_type": null,
"decorators": [
"functools.lru_cache"
],
"complexity_score": 4,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_imshow_b8705279"
},
{
"content": "def check_yolo(verbose=True, device=\"\"):\n \"\"\"\n Return a human-readable YOLO software and hardware summary.\n\n Args:\n verbose (bool): Whether to print verbose information.\n device (str | torch.device): Device to use for YOLO.\n \"\"\"\n import psutil\n\n from ultralytics.utils.torch_utils import select_device\n\n if IS_COLAB:\n shutil.rmtree(\"sample_data\", ignore_errors=True) # remove colab /sample_data directory\n\n if verbose:\n # System info\n gib = 1 << 30 # bytes per GiB\n ram = psutil.virtual_memory().total\n total, used, free = shutil.disk_usage(\"/\")\n s = f\"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)\"\n try:\n from IPython import display\n\n display.clear_output() # clear display if notebook\n except ImportError:\n pass\n else:\n s = \"\"\n\n select_device(device=device, newline=False)\n LOGGER.info(f\"Setup complete ✅ {s}\")",
"chunk_type": "function",
"name": "check_yolo",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 632,
"end_line": 663,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": "Return a human-readable YOLO software and hardware summary.\n\nArgs:\n verbose (bool): Whether to print verbose information.\n device (str | torch.device): Device to use for YOLO.",
"parameters": [
"verbose",
"device"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_yolo_7a025cfe"
},
{
"content": "def collect_system_info():\n \"\"\"\n Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.\n\n Returns:\n (dict): Dictionary containing system information.\n \"\"\"\n import psutil\n\n from ultralytics.utils import ENVIRONMENT # scope to avoid circular import\n from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info\n\n gib = 1 << 30 # bytes per GiB\n cuda = torch.cuda.is_available()\n check_yolo()\n total, used, free = shutil.disk_usage(\"/\")\n\n info_dict = {\n \"OS\": platform.platform(),\n \"Environment\": ENVIRONMENT,\n \"Python\": PYTHON_VERSION,\n \"Install\": \"git\" if IS_GIT_DIR else \"pip\" if IS_PIP_PACKAGE else \"other\",\n \"Path\": str(ROOT),\n \"RAM\": f\"{psutil.virtual_memory().total / gib:.2f} GB\",\n \"Disk\": f\"{(total - free) / gib:.1f}/{total / gib:.1f} GB\",\n \"CPU\": get_cpu_info(),\n \"CPU count\": os.cpu_count(),\n \"GPU\": get_gpu_info(index=0) if cuda else None,\n \"GPU count\": torch.cuda.device_count() if cuda else None,\n \"CUDA\": torch.version.cuda if cuda else None,\n }\n LOGGER.info(\"\\n\" + \"\\n\".join(f\"{k:<20}{v}\" for k, v in info_dict.items()) + \"\\n\")\n\n package_info = {}\n for r in parse_requirements(package=\"ultralytics\"):\n try:\n current = metadata.version(r.name)\n is_met = \"✅ \" if check_version(current, str(r.specifier), name=r.name, hard=True) else \"❌ \"\n except metadata.PackageNotFoundError:\n current = \"(not installed)\"\n is_met = \"❌ \"\n package_info[r.name] = f\"{is_met}{current}{r.specifier}\"\n LOGGER.info(f\"{r.name:<20}{package_info[r.name]}\")\n\n info_dict[\"Package Info\"] = package_info\n\n if is_github_action_running():\n github_info = {\n \"RUNNER_OS\": os.getenv(\"RUNNER_OS\"),\n \"GITHUB_EVENT_NAME\": os.getenv(\"GITHUB_EVENT_NAME\"),\n \"GITHUB_WORKFLOW\": os.getenv(\"GITHUB_WORKFLOW\"),\n \"GITHUB_ACTOR\": os.getenv(\"GITHUB_ACTOR\"),\n \"GITHUB_REPOSITORY\": os.getenv(\"GITHUB_REPOSITORY\"),\n \"GITHUB_REPOSITORY_OWNER\": os.getenv(\"GITHUB_REPOSITORY_OWNER\"),\n }\n LOGGER.info(\"\\n\" + \"\\n\".join(f\"{k}: {v}\" for k, v in github_info.items()))\n info_dict[\"GitHub Info\"] = github_info\n\n return info_dict",
"chunk_type": "function",
"name": "collect_system_info",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 666,
"end_line": 724,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.\n\nReturns:\n (dict): Dictionary containing system information.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 6,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_collect_system_info_5319b4d4"
},
{
"content": "def check_amp(model):\n \"\"\"\n Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.\n\n If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP\n results, so AMP will be disabled during training.\n\n Args:\n model (torch.nn.Module): A YOLO model instance.\n\n Returns:\n (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.\n\n Examples:\n >>> from ultralytics import YOLO\n >>> from ultralytics.utils.checks import check_amp\n >>> model = YOLO(\"yolo11n.pt\").model.cuda()\n >>> check_amp(model)\n \"\"\"\n from ultralytics.utils.torch_utils import autocast\n\n device = next(model.parameters()).device # get model device\n prefix = colorstr(\"AMP: \")\n if device.type in {\"cpu\", \"mps\"}:\n return False # AMP only used on CUDA devices\n else:\n # GPUs that have issues with AMP\n pattern = re.compile(\n r\"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)\", re.IGNORECASE\n )\n\n gpu = torch.cuda.get_device_name(device)\n if bool(pattern.search(gpu)):\n LOGGER.warning(\n f\"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause \"\n f\"NaN losses or zero-mAP results, so AMP will be disabled during training.\"\n )\n return False\n\n def amp_allclose(m, im):\n \"\"\"All close FP32 vs AMP results.\"\"\"\n batch = [im] * 8\n imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64\n a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference\n with autocast(enabled=True):\n b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference\n del m\n return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance\n\n im = ASSETS / \"bus.jpg\" # image to check\n LOGGER.info(f\"{prefix}running Automatic Mixed Precision (AMP) checks...\")\n warning_msg = \"Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False.\"\n try:\n from ultralytics import YOLO\n\n assert amp_allclose(YOLO(\"yolo11n.pt\"), im)\n LOGGER.info(f\"{prefix}checks passed ✅\")\n except ConnectionError:\n LOGGER.warning(f\"{prefix}checks skipped. Offline and unable to download YOLO11n for AMP checks. {warning_msg}\")\n except (AttributeError, ModuleNotFoundError):\n LOGGER.warning(\n f\"{prefix}checks skipped. \"\n f\"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}\"\n )\n except AssertionError:\n LOGGER.error(\n f\"{prefix}checks failed. Anomalies were detected with AMP on your system that may lead to \"\n f\"NaN losses or zero-mAP results, so AMP will be disabled during training.\"\n )\n return False\n return True",
"chunk_type": "function",
"name": "check_amp",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 727,
"end_line": 797,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.\n\nIf the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP\nresults, so AMP will be disabled during training.\n\nArgs:\n model (torch.nn.Module): A YOLO model instance.\n\nReturns:\n (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.\n\nExamples:\n >>> from ultralytics import YOLO\n >>> from ultralytics.utils.checks import check_amp\n >>> model = YOLO(\"yolo11n.pt\").model.cuda()\n >>> check_amp(model)",
"parameters": [
"model"
],
"return_type": null,
"decorators": [],
"complexity_score": 6,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_check_amp_5eaf9c64"
},
{
"content": "def git_describe(path=ROOT): # path must be a directory\n \"\"\"\n Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.\n\n Args:\n path (Path): Path to git repository.\n\n Returns:\n (str): Human-readable git description.\n \"\"\"\n try:\n return subprocess.check_output(f\"git -C {path} describe --tags --long --always\", shell=True).decode()[:-1]\n except Exception:\n return \"\"",
"chunk_type": "function",
"name": "git_describe",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 800,
"end_line": 813,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": "Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.\n\nArgs:\n path (Path): Path to git repository.\n\nReturns:\n (str): Human-readable git description.",
"parameters": [
"path"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_git_describe_0ce7e654"
},
{
"content": "def print_args(args: Optional[dict] = None, show_file=True, show_func=False):\n \"\"\"\n Print function arguments (optional args dict).\n\n Args:\n args (dict, optional): Arguments to print.\n show_file (bool): Whether to show the file name.\n show_func (bool): Whether to show the function name.\n \"\"\"\n\n def strip_auth(v):\n \"\"\"Clean longer Ultralytics HUB URLs by stripping potential authentication information.\"\"\"\n return clean_url(v) if (isinstance(v, str) and v.startswith(\"http\") and len(v) > 100) else v\n\n x = inspect.currentframe().f_back # previous frame\n file, _, func, _, _ = inspect.getframeinfo(x)\n if args is None: # get args automatically\n args, _, _, frm = inspect.getargvalues(x)\n args = {k: v for k, v in frm.items() if k in args}\n try:\n file = Path(file).resolve().relative_to(ROOT).with_suffix(\"\")\n except ValueError:\n file = Path(file).stem\n s = (f\"{file}: \" if show_file else \"\") + (f\"{func}: \" if show_func else \"\")\n LOGGER.info(colorstr(s) + \", \".join(f\"{k}={strip_auth(v)}\" for k, v in sorted(args.items())))",
"chunk_type": "function",
"name": "print_args",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 816,
"end_line": 840,
"start_col": 0,
"end_col": 97,
"parent_name": null,
"docstring": "Print function arguments (optional args dict).\n\nArgs:\n args (dict, optional): Arguments to print.\n show_file (bool): Whether to show the file name.\n show_func (bool): Whether to show the function name.",
"parameters": [
"args: Optional[dict]",
"show_file",
"show_func"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_print_args_927eccde"
},
{
"content": "def cuda_device_count() -> int:\n \"\"\"\n Get the number of NVIDIA GPUs available in the environment.\n\n Returns:\n (int): The number of NVIDIA GPUs available.\n \"\"\"\n if IS_JETSON:\n # NVIDIA Jetson does not fully support nvidia-smi and therefore use PyTorch instead\n return torch.cuda.device_count()\n else:\n try:\n # Run the nvidia-smi command and capture its output\n output = subprocess.check_output(\n [\"nvidia-smi\", \"--query-gpu=count\", \"--format=csv,noheader,nounits\"], encoding=\"utf-8\"\n )\n\n # Take the first line and strip any leading/trailing white space\n first_line = output.strip().split(\"\\n\", 1)[0]\n\n return int(first_line)\n except (subprocess.CalledProcessError, FileNotFoundError, ValueError):\n # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available\n return 0",
"chunk_type": "function",
"name": "cuda_device_count",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 843,
"end_line": 866,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Get the number of NVIDIA GPUs available in the environment.\n\nReturns:\n (int): The number of NVIDIA GPUs available.",
"parameters": [],
"return_type": "int",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_cuda_device_count_464b91d4"
},
{
"content": "def cuda_is_available() -> bool:\n \"\"\"\n Check if CUDA is available in the environment.\n\n Returns:\n (bool): True if one or more NVIDIA GPUs are available, False otherwise.\n \"\"\"\n return cuda_device_count() > 0",
"chunk_type": "function",
"name": "cuda_is_available",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 869,
"end_line": 876,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": "Check if CUDA is available in the environment.\n\nReturns:\n (bool): True if one or more NVIDIA GPUs are available, False otherwise.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_cuda_is_available_61f8d92d"
},
{
"content": "def is_rockchip():\n \"\"\"\n Check if the current environment is running on a Rockchip SoC.\n\n Returns:\n (bool): True if running on a Rockchip SoC, False otherwise.\n \"\"\"\n if LINUX and ARM64:\n try:\n with open(\"/proc/device-tree/compatible\") as f:\n dev_str = f.read()\n *_, soc = dev_str.split(\",\")\n if soc.replace(\"\\x00\", \"\") in RKNN_CHIPS:\n return True\n except OSError:\n return False\n else:\n return False",
"chunk_type": "function",
"name": "is_rockchip",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 879,
"end_line": 896,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Check if the current environment is running on a Rockchip SoC.\n\nReturns:\n (bool): True if running on a Rockchip SoC, False otherwise.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_is_rockchip_2fe49297"
},
{
"content": "def is_intel():\n \"\"\"\n Check if the system has Intel hardware (CPU or GPU).\n\n Returns:\n (bool): True if Intel hardware is detected, False otherwise.\n \"\"\"\n from ultralytics.utils.torch_utils import get_cpu_info\n\n # Check CPU\n if \"intel\" in get_cpu_info().lower():\n return True\n\n # Check GPU via xpu-smi\n try:\n result = subprocess.run([\"xpu-smi\", \"discovery\"], capture_output=True, text=True, timeout=5)\n return \"intel\" in result.stdout.lower()\n except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError):\n return False",
"chunk_type": "function",
"name": "is_intel",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 899,
"end_line": 917,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Check if the system has Intel hardware (CPU or GPU).\n\nReturns:\n (bool): True if Intel hardware is detected, False otherwise.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_is_intel_375f87f7"
},
{
"content": "def is_sudo_available() -> bool:\n \"\"\"\n Check if the sudo command is available in the environment.\n\n Returns:\n (bool): True if the sudo command is available, False otherwise.\n \"\"\"\n if WINDOWS:\n return False\n cmd = \"sudo --version\"\n return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0",
"chunk_type": "function",
"name": "is_sudo_available",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 920,
"end_line": 930,
"start_col": 0,
"end_col": 112,
"parent_name": null,
"docstring": "Check if the sudo command is available in the environment.\n\nReturns:\n (bool): True if the sudo command is available, False otherwise.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"glob",
"inspect",
"math",
"os",
"platform",
"re",
"shutil",
"subprocess",
"time",
"importlib.metadata",
"pathlib.Path",
"types.SimpleNamespace",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.ARM64",
"ultralytics.utils.ASSETS",
"ultralytics.utils.AUTOINSTALL",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_GIT_DIR",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.IS_PIP_PACKAGE",
"ultralytics.utils.LINUX",
"ultralytics.utils.LOGGER",
"ultralytics.utils.MACOS",
"ultralytics.utils.ONLINE",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.RKNN_CHIPS",
"ultralytics.utils.ROOT",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.USER_CONFIG_DIR",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.Retry",
"ultralytics.utils.ThreadingLocked",
"ultralytics.utils.TryExcept",
"ultralytics.utils.clean_url",
"ultralytics.utils.colorstr",
"ultralytics.utils.downloads",
"ultralytics.utils.is_github_action_running",
"ultralytics.utils.url2file",
"requests",
"matplotlib.font_manager",
"psutil",
"ultralytics.utils.torch_utils.select_device",
"psutil",
"ultralytics.utils.ENVIRONMENT",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.utils.torch_utils.get_gpu_info",
"ultralytics.utils.torch_utils.autocast",
"ultralytics.utils.torch_utils.get_cpu_info",
"ultralytics.YOLO",
"ultralytics.__version__",
"IPython.display"
],
"chunk_id": "function_is_sudo_available_7bc7ce4d"
},
{
"content": "IS_PYTHON_3_8 = PYTHON_VERSION.startswith(\"3.8\")",
"chunk_type": "variable",
"name": "IS_PYTHON_3_8",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 938,
"end_line": 938,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_PYTHON_3_8_17224886"
},
{
"content": "IS_PYTHON_3_12 = PYTHON_VERSION.startswith(\"3.12\")",
"chunk_type": "variable",
"name": "IS_PYTHON_3_12",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 939,
"end_line": 939,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_PYTHON_3_12_be9b4c7b"
},
{
"content": "IS_PYTHON_3_13 = PYTHON_VERSION.startswith(\"3.13\")",
"chunk_type": "variable",
"name": "IS_PYTHON_3_13",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 940,
"end_line": 940,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_PYTHON_3_13_915d7a72"
},
{
"content": "IS_PYTHON_MINIMUM_3_10 = check_python(\"3.10\", hard=False)",
"chunk_type": "variable",
"name": "IS_PYTHON_MINIMUM_3_10",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 942,
"end_line": 942,
"start_col": 0,
"end_col": 57,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_PYTHON_MINIMUM_3_10_0cceec0b"
},
{
"content": "IS_PYTHON_MINIMUM_3_12 = check_python(\"3.12\", hard=False)",
"chunk_type": "variable",
"name": "IS_PYTHON_MINIMUM_3_12",
"file_path": "ultralytics\\ultralytics\\utils\\checks.py",
"start_line": 943,
"end_line": 943,
"start_col": 0,
"end_col": 57,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_PYTHON_MINIMUM_3_12_26a33f45"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_1fbc79e0"
},
{
"content": "import shutil",
"chunk_type": "import",
"name": "shutil",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_shutil_8a4c838b"
},
{
"content": "import sys",
"chunk_type": "import",
"name": "sys",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_sys_28c13f74"
},
{
"content": "import tempfile",
"chunk_type": "import",
"name": "tempfile",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_tempfile_20828fc4"
},
{
"content": "from . import USER_CONFIG_DIR",
"chunk_type": "import",
"name": "USER_CONFIG_DIR",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_USER_CONFIG_DIR_d2b6230b"
},
{
"content": "from .torch_utils import TORCH_1_9",
"chunk_type": "import",
"name": "TORCH_1_9",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TORCH_1_9_b327005c"
},
{
"content": "def find_free_network_port() -> int:\n \"\"\"\n Find a free port on localhost.\n\n It is useful in single-node training when we don't want to connect to a real main node but have to set the\n `MASTER_PORT` environment variable.\n\n Returns:\n (int): The available network port number.\n \"\"\"\n import socket\n\n with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n s.bind((\"127.0.0.1\", 0))\n return s.getsockname()[1] # port",
"chunk_type": "function",
"name": "find_free_network_port",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 12,
"end_line": 26,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "Find a free port on localhost.\n\nIt is useful in single-node training when we don't want to connect to a real main node but have to set the\n`MASTER_PORT` environment variable.\n\nReturns:\n (int): The available network port number.",
"parameters": [],
"return_type": "int",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"os",
"shutil",
"sys",
"tempfile",
"USER_CONFIG_DIR",
"torch_utils.TORCH_1_9",
"socket",
"__main__"
],
"chunk_id": "function_find_free_network_port_f7ebaaa5"
},
{
"content": "def generate_ddp_file(trainer):\n \"\"\"\n Generate a DDP (Distributed Data Parallel) file for multi-GPU training.\n\n This function creates a temporary Python file that enables distributed training across multiple GPUs.\n The file contains the necessary configuration to initialize the trainer in a distributed environment.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.\n Must have args attribute and be a class instance.\n\n Returns:\n (str): Path to the generated temporary DDP file.\n\n Notes:\n The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:\n - Trainer class import\n - Configuration overrides from the trainer arguments\n - Model path configuration\n - Training initialization code\n \"\"\"\n module, name = f\"{trainer.__class__.__module__}.{trainer.__class__.__name__}\".rsplit(\".\", 1)\n\n content = f\"\"\"\n# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)\noverrides = {vars(trainer.args)}\n\nif __name__ == \"__main__\":\n from {module} import {name}\n from ultralytics.utils import DEFAULT_CFG_DICT\n\n cfg = DEFAULT_CFG_DICT.copy()\n cfg.update(save_dir='') # handle the extra key 'save_dir'\n trainer = {name}(cfg=cfg, overrides=overrides)\n trainer.args.model = \"{getattr(trainer.hub_session, \"model_url\", trainer.args.model)}\"\n results = trainer.train()\n\"\"\"\n (USER_CONFIG_DIR / \"DDP\").mkdir(exist_ok=True)\n with tempfile.NamedTemporaryFile(\n prefix=\"_temp_\",\n suffix=f\"{id(trainer)}.py\",\n mode=\"w+\",\n encoding=\"utf-8\",\n dir=USER_CONFIG_DIR / \"DDP\",\n delete=False,\n ) as file:\n file.write(content)\n return file.name",
"chunk_type": "function",
"name": "generate_ddp_file",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 29,
"end_line": 76,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Generate a DDP (Distributed Data Parallel) file for multi-GPU training.\n\nThis function creates a temporary Python file that enables distributed training across multiple GPUs.\nThe file contains the necessary configuration to initialize the trainer in a distributed environment.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.\n Must have args attribute and be a class instance.\n\nReturns:\n (str): Path to the generated temporary DDP file.\n\nNotes:\n The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:\n - Trainer class import\n - Configuration overrides from the trainer arguments\n - Model path configuration\n - Training initialization code",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"os",
"shutil",
"sys",
"tempfile",
"USER_CONFIG_DIR",
"torch_utils.TORCH_1_9",
"socket",
"__main__"
],
"chunk_id": "function_generate_ddp_file_2558acd0"
},
{
"content": "def generate_ddp_command(world_size: int, trainer):\n \"\"\"\n Generate command for distributed training.\n\n Args:\n world_size (int): Number of processes to spawn for distributed training.\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.\n\n Returns:\n cmd (List[str]): The command to execute for distributed training.\n file (str): Path to the temporary file created for DDP training.\n \"\"\"\n import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218\n\n if not trainer.resume:\n shutil.rmtree(trainer.save_dir) # remove the save_dir\n file = generate_ddp_file(trainer)\n dist_cmd = \"torch.distributed.run\" if TORCH_1_9 else \"torch.distributed.launch\"\n port = find_free_network_port()\n cmd = [sys.executable, \"-m\", dist_cmd, \"--nproc_per_node\", f\"{world_size}\", \"--master_port\", f\"{port}\", file]\n return cmd, file",
"chunk_type": "function",
"name": "generate_ddp_command",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 79,
"end_line": 99,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Generate command for distributed training.\n\nArgs:\n world_size (int): Number of processes to spawn for distributed training.\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.\n\nReturns:\n cmd (List[str]): The command to execute for distributed training.\n file (str): Path to the temporary file created for DDP training.",
"parameters": [
"world_size: int",
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"os",
"shutil",
"sys",
"tempfile",
"USER_CONFIG_DIR",
"torch_utils.TORCH_1_9",
"socket",
"__main__"
],
"chunk_id": "function_generate_ddp_command_377d6fdd"
},
{
"content": "def ddp_cleanup(trainer, file):\n \"\"\"\n Delete temporary file if created during distributed data parallel (DDP) training.\n\n This function checks if the provided file contains the trainer's ID in its name, indicating it was created\n as a temporary file for DDP training, and deletes it if so.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.\n file (str): Path to the file that might need to be deleted.\n\n Examples:\n >>> trainer = YOLOTrainer()\n >>> file = \"/tmp/ddp_temp_123456789.py\"\n >>> ddp_cleanup(trainer, file)\n \"\"\"\n if f\"{id(trainer)}.py\" in file: # if temp_file suffix in file\n os.remove(file)",
"chunk_type": "function",
"name": "ddp_cleanup",
"file_path": "ultralytics\\ultralytics\\utils\\dist.py",
"start_line": 102,
"end_line": 119,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Delete temporary file if created during distributed data parallel (DDP) training.\n\nThis function checks if the provided file contains the trainer's ID in its name, indicating it was created\nas a temporary file for DDP training, and deletes it if so.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.\n file (str): Path to the file that might need to be deleted.\n\nExamples:\n >>> trainer = YOLOTrainer()\n >>> file = \"/tmp/ddp_temp_123456789.py\"\n >>> ddp_cleanup(trainer, file)",
"parameters": [
"trainer",
"file"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"os",
"shutil",
"sys",
"tempfile",
"USER_CONFIG_DIR",
"torch_utils.TORCH_1_9",
"socket",
"__main__"
],
"chunk_id": "function_ddp_cleanup_d6aa6b54"
},
{
"content": "import re",
"chunk_type": "import",
"name": "re",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_re_69b018a7"
},
{
"content": "import shutil",
"chunk_type": "import",
"name": "shutil",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_shutil_fd555cdc"
},
{
"content": "import subprocess",
"chunk_type": "import",
"name": "subprocess",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_subprocess_de40055c"
},
{
"content": "from itertools import repeat",
"chunk_type": "import",
"name": "repeat",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_repeat_380fc23f"
},
{
"content": "from multiprocessing.pool import ThreadPool",
"chunk_type": "import",
"name": "ThreadPool",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ThreadPool_b6d1e36d"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_7ae7dad5"
},
{
"content": "from typing import List, Tuple",
"chunk_type": "import",
"name": "List, Tuple",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Tuple_328d12b9"
},
{
"content": "from urllib import parse, request",
"chunk_type": "import",
"name": "parse, request",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_parse, request_36875657"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_364b3d6b"
},
{
"content": "from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file",
"chunk_type": "import",
"name": "LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 90,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file_4cd4534a"
},
{
"content": "GITHUB_ASSETS_REPO = \"ultralytics/assets\"",
"chunk_type": "variable",
"name": "GITHUB_ASSETS_REPO",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_GITHUB_ASSETS_REPO_74b554b4"
},
{
"content": "GITHUB_ASSETS_NAMES = frozenset(\n [f\"yolov8{k}{suffix}.pt\" for k in \"nsmlx\" for suffix in (\"\", \"-cls\", \"-seg\", \"-pose\", \"-obb\", \"-oiv7\")]\n + [f\"yolo11{k}{suffix}.pt\" for k in \"nsmlx\" for suffix in (\"\", \"-cls\", \"-seg\", \"-pose\", \"-obb\")]\n + [f\"yolo12{k}{suffix}.pt\" for k in \"nsmlx\" for suffix in (\"\",)] # detect models only currently\n + [f\"yolov5{k}{resolution}u.pt\" for k in \"nsmlx\" for resolution in (\"\", \"6\")]\n + [f\"yolov3{k}u.pt\" for k in (\"\", \"-spp\", \"-tiny\")]\n + [f\"yolov8{k}-world.pt\" for k in \"smlx\"]\n + [f\"yolov8{k}-worldv2.pt\" for k in \"smlx\"]\n + [f\"yoloe-v8{k}{suffix}.pt\" for k in \"sml\" for suffix in (\"-seg\", \"-seg-pf\")]\n + [f\"yoloe-11{k}{suffix}.pt\" for k in \"sml\" for suffix in (\"-seg\", \"-seg-pf\")]\n + [f\"yolov9{k}.pt\" for k in \"tsmce\"]\n + [f\"yolov10{k}.pt\" for k in \"nsmblx\"]\n + [f\"yolo_nas_{k}.pt\" for k in \"sml\"]\n + [f\"sam_{k}.pt\" for k in \"bl\"]\n + [f\"sam2_{k}.pt\" for k in \"blst\"]\n + [f\"sam2.1_{k}.pt\" for k in \"blst\"]\n + [f\"FastSAM-{k}.pt\" for k in \"sx\"]\n + [f\"rtdetr-{k}.pt\" for k in \"lx\"]\n + [\n \"mobile_sam.pt\",\n \"mobileclip_blt.ts\",\n \"yolo11n-grayscale.pt\",\n \"calibration_image_sample_data_20x128x128x3_float32.npy.zip\",\n ]\n)",
"chunk_type": "variable",
"name": "GITHUB_ASSETS_NAMES",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 18,
"end_line": 42,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_GITHUB_ASSETS_NAMES_ba359a8b"
},
{
"content": "GITHUB_ASSETS_STEMS = frozenset(k.rpartition(\".\")[0] for k in GITHUB_ASSETS_NAMES)",
"chunk_type": "variable",
"name": "GITHUB_ASSETS_STEMS",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 43,
"end_line": 43,
"start_col": 0,
"end_col": 82,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_GITHUB_ASSETS_STEMS_97b36524"
},
{
"content": "def is_url(url, check: bool = False) -> bool:\n \"\"\"\n Validate if the given string is a URL and optionally check if the URL exists online.\n\n Args:\n url (str): The string to be validated as a URL.\n check (bool, optional): If True, performs an additional check to see if the URL exists online.\n\n Returns:\n (bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online.\n\n Examples:\n >>> valid = is_url(\"https://www.example.com\")\n >>> valid_and_exists = is_url(\"https://www.example.com\", check=True)\n \"\"\"\n try:\n url = str(url)\n result = parse.urlparse(url)\n assert all([result.scheme, result.netloc]) # check if is url\n if check:\n with request.urlopen(url) as response:\n return response.getcode() == 200 # check if exists online\n return True\n except Exception:\n return False",
"chunk_type": "function",
"name": "is_url",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 46,
"end_line": 70,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Validate if the given string is a URL and optionally check if the URL exists online.\n\nArgs:\n url (str): The string to be validated as a URL.\n check (bool, optional): If True, performs an additional check to see if the URL exists online.\n\nReturns:\n (bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online.\n\nExamples:\n >>> valid = is_url(\"https://www.example.com\")\n >>> valid_and_exists = is_url(\"https://www.example.com\", check=True)",
"parameters": [
"url",
"check: bool"
],
"return_type": "bool",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_is_url_791f4eb3"
},
{
"content": "def delete_dsstore(path, files_to_delete=(\".DS_Store\", \"__MACOSX\")):\n \"\"\"\n Delete all specified system files in a directory.\n\n Args:\n path (str | Path): The directory path where the files should be deleted.\n files_to_delete (tuple): The files to be deleted.\n\n Examples:\n >>> from ultralytics.utils.downloads import delete_dsstore\n >>> delete_dsstore(\"path/to/dir\")\n\n Notes:\n \".DS_store\" files are created by the Apple operating system and contain metadata about folders and files. They\n are hidden system files and can cause issues when transferring files between different operating systems.\n \"\"\"\n for file in files_to_delete:\n matches = list(Path(path).rglob(file))\n LOGGER.info(f\"Deleting {file} files: {matches}\")\n for f in matches:\n f.unlink()",
"chunk_type": "function",
"name": "delete_dsstore",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 73,
"end_line": 93,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Delete all specified system files in a directory.\n\nArgs:\n path (str | Path): The directory path where the files should be deleted.\n files_to_delete (tuple): The files to be deleted.\n\nExamples:\n >>> from ultralytics.utils.downloads import delete_dsstore\n >>> delete_dsstore(\"path/to/dir\")\n\nNotes:\n \".DS_store\" files are created by the Apple operating system and contain metadata about folders and files. They\n are hidden system files and can cause issues when transferring files between different operating systems.",
"parameters": [
"path",
"files_to_delete"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_delete_dsstore_cfe15d33"
},
{
"content": "def zip_directory(directory, compress: bool = True, exclude=(\".DS_Store\", \"__MACOSX\"), progress: bool = True) -> Path:\n \"\"\"\n Zip the contents of a directory, excluding specified files.\n\n The resulting zip file is named after the directory and placed alongside it.\n\n Args:\n directory (str | Path): The path to the directory to be zipped.\n compress (bool): Whether to compress the files while zipping.\n exclude (tuple, optional): A tuple of filename strings to be excluded.\n progress (bool, optional): Whether to display a progress bar.\n\n Returns:\n (Path): The path to the resulting zip file.\n\n Examples:\n >>> from ultralytics.utils.downloads import zip_directory\n >>> file = zip_directory(\"path/to/dir\")\n \"\"\"\n from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile\n\n delete_dsstore(directory)\n directory = Path(directory)\n if not directory.is_dir():\n raise FileNotFoundError(f\"Directory '{directory}' does not exist.\")\n\n # Zip with progress bar\n files_to_zip = [f for f in directory.rglob(\"*\") if f.is_file() and all(x not in f.name for x in exclude)]\n zip_file = directory.with_suffix(\".zip\")\n compression = ZIP_DEFLATED if compress else ZIP_STORED\n with ZipFile(zip_file, \"w\", compression) as f:\n for file in TQDM(files_to_zip, desc=f\"Zipping {directory} to {zip_file}...\", unit=\"file\", disable=not progress):\n f.write(file, file.relative_to(directory))\n\n return zip_file # return path to zip file",
"chunk_type": "function",
"name": "zip_directory",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 96,
"end_line": 130,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Zip the contents of a directory, excluding specified files.\n\nThe resulting zip file is named after the directory and placed alongside it.\n\nArgs:\n directory (str | Path): The path to the directory to be zipped.\n compress (bool): Whether to compress the files while zipping.\n exclude (tuple, optional): A tuple of filename strings to be excluded.\n progress (bool, optional): Whether to display a progress bar.\n\nReturns:\n (Path): The path to the resulting zip file.\n\nExamples:\n >>> from ultralytics.utils.downloads import zip_directory\n >>> file = zip_directory(\"path/to/dir\")",
"parameters": [
"directory",
"compress: bool",
"exclude",
"progress: bool"
],
"return_type": "Path",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_zip_directory_d38bafc2"
},
{
"content": "def unzip_file(\n file,\n path=None,\n exclude=(\".DS_Store\", \"__MACOSX\"),\n exist_ok: bool = False,\n progress: bool = True,\n) -> Path:\n \"\"\"\n Unzip a *.zip file to the specified path, excluding specified files.\n\n If the zipfile does not contain a single top-level directory, the function will create a new\n directory with the same name as the zipfile (without the extension) to extract its contents.\n If a path is not provided, the function will use the parent directory of the zipfile as the default path.\n\n Args:\n file (str | Path): The path to the zipfile to be extracted.\n path (str | Path, optional): The path to extract the zipfile to.\n exclude (tuple, optional): A tuple of filename strings to be excluded.\n exist_ok (bool, optional): Whether to overwrite existing contents if they exist.\n progress (bool, optional): Whether to display a progress bar.\n\n Returns:\n (Path): The path to the directory where the zipfile was extracted.\n\n Raises:\n BadZipFile: If the provided file does not exist or is not a valid zipfile.\n\n Examples:\n >>> from ultralytics.utils.downloads import unzip_file\n >>> directory = unzip_file(\"path/to/file.zip\")\n \"\"\"\n from zipfile import BadZipFile, ZipFile, is_zipfile\n\n if not (Path(file).exists() and is_zipfile(file)):\n raise BadZipFile(f\"File '{file}' does not exist or is a bad zip file.\")\n if path is None:\n path = Path(file).parent # default path\n\n # Unzip the file contents\n with ZipFile(file) as zipObj:\n files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]\n top_level_dirs = {Path(f).parts[0] for f in files}\n\n # Decide to unzip directly or unzip into a directory\n unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith(\"/\"))\n if unzip_as_dir:\n # Zip has 1 top-level directory\n extract_path = path # i.e. ../datasets\n path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/\n else:\n # Zip has multiple files at top level\n path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/\n\n # Check if destination directory already exists and contains files\n if path.exists() and any(path.iterdir()) and not exist_ok:\n # If it exists and is not empty, return the path without unzipping\n LOGGER.warning(f\"Skipping {file} unzip as destination directory {path} is not empty.\")\n return path\n\n for f in TQDM(files, desc=f\"Unzipping {file} to {Path(path).resolve()}...\", unit=\"file\", disable=not progress):\n # Ensure the file is within the extract_path to avoid path traversal security vulnerability\n if \"..\" in Path(f).parts:\n LOGGER.warning(f\"Potentially insecure file path: {f}, skipping extraction.\")\n continue\n zipObj.extract(f, extract_path)\n\n return path # return unzip dir",
"chunk_type": "function",
"name": "unzip_file",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 133,
"end_line": 199,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Unzip a *.zip file to the specified path, excluding specified files.\n\nIf the zipfile does not contain a single top-level directory, the function will create a new\ndirectory with the same name as the zipfile (without the extension) to extract its contents.\nIf a path is not provided, the function will use the parent directory of the zipfile as the default path.\n\nArgs:\n file (str | Path): The path to the zipfile to be extracted.\n path (str | Path, optional): The path to extract the zipfile to.\n exclude (tuple, optional): A tuple of filename strings to be excluded.\n exist_ok (bool, optional): Whether to overwrite existing contents if they exist.\n progress (bool, optional): Whether to display a progress bar.\n\nReturns:\n (Path): The path to the directory where the zipfile was extracted.\n\nRaises:\n BadZipFile: If the provided file does not exist or is not a valid zipfile.\n\nExamples:\n >>> from ultralytics.utils.downloads import unzip_file\n >>> directory = unzip_file(\"path/to/file.zip\")",
"parameters": [
"file",
"path",
"exclude",
"exist_ok: bool",
"progress: bool"
],
"return_type": "Path",
"decorators": [],
"complexity_score": 10,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_unzip_file_941bdc4b"
},
{
"content": "def check_disk_space(\n url: str = \"https://ultralytics.com/assets/coco8.zip\",\n path=Path.cwd(),\n sf: float = 1.5,\n hard: bool = True,\n) -> bool:\n \"\"\"\n Check if there is sufficient disk space to download and store a file.\n\n Args:\n url (str, optional): The URL to the file.\n path (str | Path, optional): The path or drive to check the available free space on.\n sf (float, optional): Safety factor, the multiplier for the required free space.\n hard (bool, optional): Whether to throw an error or not on insufficient disk space.\n\n Returns:\n (bool): True if there is sufficient disk space, False otherwise.\n \"\"\"\n import requests # slow import\n\n try:\n r = requests.head(url) # response\n assert r.status_code < 400, f\"URL error for {url}: {r.status_code} {r.reason}\" # check response\n except Exception:\n return True # requests issue, default to True\n\n # Check file size\n gib = 1 << 30 # bytes per GiB\n data = int(r.headers.get(\"Content-Length\", 0)) / gib # file size (GB)\n total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes\n\n if data * sf < free:\n return True # sufficient space\n\n # Insufficient space\n text = (\n f\"Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, \"\n f\"Please free {data * sf - free:.1f} GB additional disk space and try again.\"\n )\n if hard:\n raise MemoryError(text)\n LOGGER.warning(text)\n return False",
"chunk_type": "function",
"name": "check_disk_space",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 202,
"end_line": 244,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Check if there is sufficient disk space to download and store a file.\n\nArgs:\n url (str, optional): The URL to the file.\n path (str | Path, optional): The path or drive to check the available free space on.\n sf (float, optional): Safety factor, the multiplier for the required free space.\n hard (bool, optional): Whether to throw an error or not on insufficient disk space.\n\nReturns:\n (bool): True if there is sufficient disk space, False otherwise.",
"parameters": [
"url: str",
"path",
"sf: float",
"hard: bool"
],
"return_type": "bool",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_check_disk_space_f4244111"
},
{
"content": "def get_google_drive_file_info(link: str) -> Tuple[str, str]:\n \"\"\"\n Retrieve the direct download link and filename for a shareable Google Drive file link.\n\n Args:\n link (str): The shareable link of the Google Drive file.\n\n Returns:\n url (str): Direct download URL for the Google Drive file.\n filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None.\n\n Examples:\n >>> from ultralytics.utils.downloads import get_google_drive_file_info\n >>> link = \"https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link\"\n >>> url, filename = get_google_drive_file_info(link)\n \"\"\"\n import requests # slow import\n\n file_id = link.split(\"/d/\")[1].split(\"/view\", 1)[0]\n drive_url = f\"https://drive.google.com/uc?export=download&id={file_id}\"\n filename = None\n\n # Start session\n with requests.Session() as session:\n response = session.get(drive_url, stream=True)\n if \"quota exceeded\" in str(response.content.lower()):\n raise ConnectionError(\n emojis(\n f\"❌ Google Drive file download quota exceeded. \"\n f\"Please try again later or download this file manually at {link}.\"\n )\n )\n for k, v in response.cookies.items():\n if k.startswith(\"download_warning\"):\n drive_url += f\"&confirm={v}\" # v is token\n if cd := response.headers.get(\"content-disposition\"):\n filename = re.findall('filename=\"(.+)\"', cd)[0]\n return drive_url, filename",
"chunk_type": "function",
"name": "get_google_drive_file_info",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 247,
"end_line": 284,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": "Retrieve the direct download link and filename for a shareable Google Drive file link.\n\nArgs:\n link (str): The shareable link of the Google Drive file.\n\nReturns:\n url (str): Direct download URL for the Google Drive file.\n filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None.\n\nExamples:\n >>> from ultralytics.utils.downloads import get_google_drive_file_info\n >>> link = \"https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link\"\n >>> url, filename = get_google_drive_file_info(link)",
"parameters": [
"link: str"
],
"return_type": "Tuple[str, str]",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_get_google_drive_file_info_3e2a9616"
},
{
"content": "def safe_download(\n url,\n file=None,\n dir=None,\n unzip: bool = True,\n delete: bool = False,\n curl: bool = False,\n retry: int = 3,\n min_bytes: float = 1e0,\n exist_ok: bool = False,\n progress: bool = True,\n):\n \"\"\"\n Download files from a URL with options for retrying, unzipping, and deleting the downloaded file.\n\n Args:\n url (str): The URL of the file to be downloaded.\n file (str, optional): The filename of the downloaded file.\n If not provided, the file will be saved with the same name as the URL.\n dir (str | Path, optional): The directory to save the downloaded file.\n If not provided, the file will be saved in the current working directory.\n unzip (bool, optional): Whether to unzip the downloaded file.\n delete (bool, optional): Whether to delete the downloaded file after unzipping.\n curl (bool, optional): Whether to use curl command line tool for downloading.\n retry (int, optional): The number of times to retry the download in case of failure.\n min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered\n a successful download.\n exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.\n progress (bool, optional): Whether to display a progress bar during the download.\n\n Returns:\n (Path | str): The path to the downloaded file or extracted directory.\n\n Examples:\n >>> from ultralytics.utils.downloads import safe_download\n >>> link = \"https://ultralytics.com/assets/bus.jpg\"\n >>> path = safe_download(link)\n \"\"\"\n gdrive = url.startswith(\"https://drive.google.com/\") # check if the URL is a Google Drive link\n if gdrive:\n url, file = get_google_drive_file_info(url)\n\n f = Path(dir or \".\") / (file or url2file(url)) # URL converted to filename\n if \"://\" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)\n f = Path(url) # filename\n elif not f.is_file(): # URL and file do not exist\n uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url\n \"https://github.com/ultralytics/assets/releases/download/v0.0.0/\",\n \"https://ultralytics.com/assets/\", # assets alias\n )\n desc = f\"Downloading {uri} to '{f}'\"\n LOGGER.info(f\"{desc}...\")\n f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing\n check_disk_space(url, path=f.parent)\n curl_installed = shutil.which(\"curl\")\n for i in range(retry + 1):\n try:\n if (curl or i > 0) and curl_installed: # curl download with retry, continue\n s = \"sS\" * (not progress) # silent\n r = subprocess.run([\"curl\", \"-#\", f\"-{s}L\", url, \"-o\", f, \"--retry\", \"3\", \"-C\", \"-\"]).returncode\n assert r == 0, f\"Curl return value {r}\"\n else: # urllib download\n method = \"torch\"\n if method == \"torch\":\n torch.hub.download_url_to_file(url, f, progress=progress)\n else:\n with request.urlopen(url) as response, TQDM(\n total=int(response.getheader(\"Content-Length\", 0)),\n desc=desc,\n disable=not progress,\n unit=\"B\",\n unit_scale=True,\n unit_divisor=1024,\n ) as pbar:\n with open(f, \"wb\") as f_opened:\n for data in response:\n f_opened.write(data)\n pbar.update(len(data))\n\n if f.exists():\n if f.stat().st_size > min_bytes:\n break # success\n f.unlink() # remove partial downloads\n except Exception as e:\n if i == 0 and not is_online():\n raise ConnectionError(emojis(f\"❌ Download failure for {uri}. Environment is not online.\")) from e\n elif i >= retry:\n raise ConnectionError(emojis(f\"❌ Download failure for {uri}. Retry limit reached.\")) from e\n LOGGER.warning(f\"Download failure, retrying {i + 1}/{retry} {uri}...\")\n\n if unzip and f.exists() and f.suffix in {\"\", \".zip\", \".tar\", \".gz\"}:\n from zipfile import is_zipfile\n\n unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place\n if is_zipfile(f):\n unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip\n elif f.suffix in {\".tar\", \".gz\"}:\n LOGGER.info(f\"Unzipping {f} to {unzip_dir}...\")\n subprocess.run([\"tar\", \"xf\" if f.suffix == \".tar\" else \"xfz\", f, \"--directory\", unzip_dir], check=True)\n if delete:\n f.unlink() # remove zip\n return unzip_dir\n return f",
"chunk_type": "function",
"name": "safe_download",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 287,
"end_line": 389,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Download files from a URL with options for retrying, unzipping, and deleting the downloaded file.\n\nArgs:\n url (str): The URL of the file to be downloaded.\n file (str, optional): The filename of the downloaded file.\n If not provided, the file will be saved with the same name as the URL.\n dir (str | Path, optional): The directory to save the downloaded file.\n If not provided, the file will be saved in the current working directory.\n unzip (bool, optional): Whether to unzip the downloaded file.\n delete (bool, optional): Whether to delete the downloaded file after unzipping.\n curl (bool, optional): Whether to use curl command line tool for downloading.\n retry (int, optional): The number of times to retry the download in case of failure.\n min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered\n a successful download.\n exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.\n progress (bool, optional): Whether to display a progress bar during the download.\n\nReturns:\n (Path | str): The path to the downloaded file or extracted directory.\n\nExamples:\n >>> from ultralytics.utils.downloads import safe_download\n >>> link = \"https://ultralytics.com/assets/bus.jpg\"\n >>> path = safe_download(link)",
"parameters": [
"url",
"file",
"dir",
"unzip: bool",
"delete: bool",
"curl: bool",
"retry: int",
"min_bytes: float",
"exist_ok: bool",
"progress: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 17,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_safe_download_915d599d"
},
{
"content": "def get_github_assets(\n repo: str = \"ultralytics/assets\",\n version: str = \"latest\",\n retry: bool = False,\n) -> Tuple[str, List[str]]:\n \"\"\"\n Retrieve the specified version's tag and assets from a GitHub repository.\n\n If the version is not specified, the function fetches the latest release assets.\n\n Args:\n repo (str, optional): The GitHub repository in the format 'owner/repo'.\n version (str, optional): The release version to fetch assets from.\n retry (bool, optional): Flag to retry the request in case of a failure.\n\n Returns:\n tag (str): The release tag.\n assets (List[str]): A list of asset names.\n\n Examples:\n >>> tag, assets = get_github_assets(repo=\"ultralytics/assets\", version=\"latest\")\n \"\"\"\n import requests # slow import\n\n if version != \"latest\":\n version = f\"tags/{version}\" # i.e. tags/v6.2\n url = f\"https://api.github.com/repos/{repo}/releases/{version}\"\n r = requests.get(url) # github api\n if r.status_code != 200 and r.reason != \"rate limit exceeded\" and retry: # failed and not 403 rate limit exceeded\n r = requests.get(url) # try again\n if r.status_code != 200:\n LOGGER.warning(f\"GitHub assets check failure for {url}: {r.status_code} {r.reason}\")\n return \"\", []\n data = r.json()\n return data[\"tag_name\"], [x[\"name\"] for x in data[\"assets\"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...]",
"chunk_type": "function",
"name": "get_github_assets",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 392,
"end_line": 426,
"start_col": 0,
"end_col": 64,
"parent_name": null,
"docstring": "Retrieve the specified version's tag and assets from a GitHub repository.\n\nIf the version is not specified, the function fetches the latest release assets.\n\nArgs:\n repo (str, optional): The GitHub repository in the format 'owner/repo'.\n version (str, optional): The release version to fetch assets from.\n retry (bool, optional): Flag to retry the request in case of a failure.\n\nReturns:\n tag (str): The release tag.\n assets (List[str]): A list of asset names.\n\nExamples:\n >>> tag, assets = get_github_assets(repo=\"ultralytics/assets\", version=\"latest\")",
"parameters": [
"repo: str",
"version: str",
"retry: bool"
],
"return_type": "Tuple[str, List[str]]",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_get_github_assets_f59536ca"
},
{
"content": "def attempt_download_asset(file, repo: str = \"ultralytics/assets\", release: str = \"v8.3.0\", **kwargs) -> str:\n \"\"\"\n Attempt to download a file from GitHub release assets if it is not found locally.\n\n Args:\n file (str | Path): The filename or file path to be downloaded.\n repo (str, optional): The GitHub repository in the format 'owner/repo'.\n release (str, optional): The specific release version to be downloaded.\n **kwargs (Any): Additional keyword arguments for the download process.\n\n Returns:\n (str): The path to the downloaded file.\n\n Examples:\n >>> file_path = attempt_download_asset(\"yolo11n.pt\", repo=\"ultralytics/assets\", release=\"latest\")\n \"\"\"\n from ultralytics.utils import SETTINGS # scoped for circular import\n\n # YOLOv3/5u updates\n file = str(file)\n file = checks.check_yolov5u_filename(file)\n file = Path(file.strip().replace(\"'\", \"\"))\n if file.exists():\n return str(file)\n elif (SETTINGS[\"weights_dir\"] / file).exists():\n return str(SETTINGS[\"weights_dir\"] / file)\n else:\n # URL specified\n name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.\n download_url = f\"https://github.com/{repo}/releases/download\"\n if str(file).startswith((\"http:/\", \"https:/\")): # download\n url = str(file).replace(\":/\", \"://\") # Pathlib turns :// -> :/\n file = url2file(name) # parse authentication https://url.com/file.txt?auth...\n if Path(file).is_file():\n LOGGER.info(f\"Found {clean_url(url)} locally at {file}\") # file already exists\n else:\n safe_download(url=url, file=file, min_bytes=1e5, **kwargs)\n\n elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:\n safe_download(url=f\"{download_url}/{release}/{name}\", file=file, min_bytes=1e5, **kwargs)\n\n else:\n tag, assets = get_github_assets(repo, release)\n if not assets:\n tag, assets = get_github_assets(repo) # latest release\n if name in assets:\n safe_download(url=f\"{download_url}/{tag}/{name}\", file=file, min_bytes=1e5, **kwargs)\n\n return str(file)",
"chunk_type": "function",
"name": "attempt_download_asset",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 429,
"end_line": 477,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": "Attempt to download a file from GitHub release assets if it is not found locally.\n\nArgs:\n file (str | Path): The filename or file path to be downloaded.\n repo (str, optional): The GitHub repository in the format 'owner/repo'.\n release (str, optional): The specific release version to be downloaded.\n **kwargs (Any): Additional keyword arguments for the download process.\n\nReturns:\n (str): The path to the downloaded file.\n\nExamples:\n >>> file_path = attempt_download_asset(\"yolo11n.pt\", repo=\"ultralytics/assets\", release=\"latest\")",
"parameters": [
"file",
"repo: str",
"release: str"
],
"return_type": "str",
"decorators": [],
"complexity_score": 8,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_attempt_download_asset_d74cab11"
},
{
"content": "def download(\n url,\n dir=Path.cwd(),\n unzip: bool = True,\n delete: bool = False,\n curl: bool = False,\n threads: int = 1,\n retry: int = 3,\n exist_ok: bool = False,\n):\n \"\"\"\n Download files from specified URLs to a given directory.\n\n Supports concurrent downloads if multiple threads are specified.\n\n Args:\n url (str | List[str]): The URL or list of URLs of the files to be downloaded.\n dir (Path, optional): The directory where the files will be saved.\n unzip (bool, optional): Flag to unzip the files after downloading.\n delete (bool, optional): Flag to delete the zip files after extraction.\n curl (bool, optional): Flag to use curl for downloading.\n threads (int, optional): Number of threads to use for concurrent downloads.\n retry (int, optional): Number of retries in case of download failure.\n exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.\n\n Examples:\n >>> download(\"https://ultralytics.com/assets/example.zip\", dir=\"path/to/dir\", unzip=True)\n \"\"\"\n dir = Path(dir)\n dir.mkdir(parents=True, exist_ok=True) # make directory\n if threads > 1:\n with ThreadPool(threads) as pool:\n pool.map(\n lambda x: safe_download(\n url=x[0],\n dir=x[1],\n unzip=unzip,\n delete=delete,\n curl=curl,\n retry=retry,\n exist_ok=exist_ok,\n progress=threads <= 1,\n ),\n zip(url, repeat(dir)),\n )\n pool.close()\n pool.join()\n else:\n for u in [url] if isinstance(url, (str, Path)) else url:\n safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)",
"chunk_type": "function",
"name": "download",
"file_path": "ultralytics\\ultralytics\\utils\\downloads.py",
"start_line": 480,
"end_line": 529,
"start_col": 0,
"end_col": 112,
"parent_name": null,
"docstring": "Download files from specified URLs to a given directory.\n\nSupports concurrent downloads if multiple threads are specified.\n\nArgs:\n url (str | List[str]): The URL or list of URLs of the files to be downloaded.\n dir (Path, optional): The directory where the files will be saved.\n unzip (bool, optional): Flag to unzip the files after downloading.\n delete (bool, optional): Flag to delete the zip files after extraction.\n curl (bool, optional): Flag to use curl for downloading.\n threads (int, optional): Number of threads to use for concurrent downloads.\n retry (int, optional): Number of retries in case of download failure.\n exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.\n\nExamples:\n >>> download(\"https://ultralytics.com/assets/example.zip\", dir=\"path/to/dir\", unzip=True)",
"parameters": [
"url",
"dir",
"unzip: bool",
"delete: bool",
"curl: bool",
"threads: int",
"retry: int",
"exist_ok: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"re",
"shutil",
"subprocess",
"itertools.repeat",
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.List",
"typing.Tuple",
"urllib.parse",
"urllib.request",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.checks",
"ultralytics.utils.clean_url",
"ultralytics.utils.emojis",
"ultralytics.utils.is_online",
"ultralytics.utils.url2file",
"zipfile.ZIP_DEFLATED",
"zipfile.ZIP_STORED",
"zipfile.ZipFile",
"zipfile.BadZipFile",
"zipfile.ZipFile",
"zipfile.is_zipfile",
"requests",
"requests",
"requests",
"ultralytics.utils.SETTINGS",
"zipfile.is_zipfile"
],
"chunk_id": "function_download_a573ab84"
},
{
"content": "from ultralytics.utils import emojis",
"chunk_type": "import",
"name": "emojis",
"file_path": "ultralytics\\ultralytics\\utils\\errors.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_emojis_003e07a6"
},
{
"content": "class HUBModelError(Exception):\n \"\"\"\n Exception raised when a model cannot be found or retrieved from Ultralytics HUB.\n\n This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.\n The error message is processed to include emojis for better user experience.\n\n Attributes:\n message (str): The error message displayed when the exception is raised.\n\n Methods:\n __init__: Initialize the HUBModelError with a custom message.\n\n Examples:\n >>> try:\n ... # Code that might fail to find a model\n ... raise HUBModelError(\"Custom model not found message\")\n ... except HUBModelError as e:\n ... print(e) # Displays the emoji-enhanced error message\n \"\"\"\n\n def __init__(self, message: str = \"Model not found. Please check model URL and try again.\"):\n \"\"\"\n Initialize a HUBModelError exception.\n\n This exception is raised when a requested model is not found or cannot be retrieved from Ultralytics HUB.\n The message is processed to include emojis for better user experience.\n\n Args:\n message (str, optional): The error message to display when the exception is raised.\n\n Examples:\n >>> try:\n ... raise HUBModelError(\"Custom model error message\")\n ... except HUBModelError as e:\n ... print(e)\n \"\"\"\n super().__init__(emojis(message))",
"chunk_type": "class",
"name": "HUBModelError",
"file_path": "ultralytics\\ultralytics\\utils\\errors.py",
"start_line": 6,
"end_line": 43,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": "Exception raised when a model cannot be found or retrieved from Ultralytics HUB.\n\nThis custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.\nThe error message is processed to include emojis for better user experience.\n\nAttributes:\n message (str): The error message displayed when the exception is raised.\n\nMethods:\n __init__: Initialize the HUBModelError with a custom message.\n\nExamples:\n >>> try:\n ... # Code that might fail to find a model\n ... raise HUBModelError(\"Custom model not found message\")\n ... except HUBModelError as e:\n ... print(e) # Displays the emoji-enhanced error message",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"ultralytics.utils.emojis",
"Exception"
],
"chunk_id": "class_HUBModelError_4875e1d8"
},
{
"content": "import json",
"chunk_type": "import",
"name": "json",
"file_path": "ultralytics\\ultralytics\\utils\\export.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_json_83c7175c"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\export.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_c24a6087"
},
{
"content": "from typing import Dict, List, Optional, Tuple, Union",
"chunk_type": "import",
"name": "Dict, List, Optional, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\utils\\export.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Dict, List, Optional, Tuple, Union_31e6da6e"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\export.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_4e281efa"
},
{
"content": "from ultralytics.utils import IS_JETSON, LOGGER",
"chunk_type": "import",
"name": "IS_JETSON, LOGGER",
"file_path": "ultralytics\\ultralytics\\utils\\export.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_IS_JETSON, LOGGER_ae878683"
},
{
"content": "def export_onnx(\n torch_model: torch.nn.Module,\n im: torch.Tensor,\n onnx_file: str,\n opset: int = 14,\n input_names: List[str] = [\"images\"],\n output_names: List[str] = [\"output0\"],\n dynamic: Union[bool, Dict] = False,\n) -> None:\n \"\"\"\n Export a PyTorch model to ONNX format.\n\n Args:\n torch_model (torch.nn.Module): The PyTorch model to export.\n im (torch.Tensor): Example input tensor for the model.\n onnx_file (str): Path to save the exported ONNX file.\n opset (int): ONNX opset version to use for export.\n input_names (List[str]): List of input tensor names.\n output_names (List[str]): List of output tensor names.\n dynamic (bool | Dict, optional): Whether to enable dynamic axes.\n\n Notes:\n Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.\n \"\"\"\n torch.onnx.export(\n torch_model,\n im,\n onnx_file,\n verbose=False,\n opset_version=opset,\n do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False\n input_names=input_names,\n output_names=output_names,\n dynamic_axes=dynamic or None,\n )",
"chunk_type": "function",
"name": "export_onnx",
"file_path": "ultralytics\\ultralytics\\utils\\export.py",
"start_line": 12,
"end_line": 46,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Export a PyTorch model to ONNX format.\n\nArgs:\n torch_model (torch.nn.Module): The PyTorch model to export.\n im (torch.Tensor): Example input tensor for the model.\n onnx_file (str): Path to save the exported ONNX file.\n opset (int): ONNX opset version to use for export.\n input_names (List[str]): List of input tensor names.\n output_names (List[str]): List of output tensor names.\n dynamic (bool | Dict, optional): Whether to enable dynamic axes.\n\nNotes:\n Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.",
"parameters": [
"torch_model: torch.nn.Module",
"im: torch.Tensor",
"onnx_file: str",
"opset: int",
"input_names: List[str]",
"output_names: List[str]",
"dynamic: Union[bool, Dict]"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"json",
"pathlib.Path",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.LOGGER",
"tensorrt"
],
"chunk_id": "function_export_onnx_918c7e90"
},
{
"content": "def export_engine(\n onnx_file: str,\n engine_file: Optional[str] = None,\n workspace: Optional[int] = None,\n half: bool = False,\n int8: bool = False,\n dynamic: bool = False,\n shape: Tuple[int, int, int, int] = (1, 3, 640, 640),\n dla: Optional[int] = None,\n dataset=None,\n metadata: Optional[Dict] = None,\n verbose: bool = False,\n prefix: str = \"\",\n) -> None:\n \"\"\"\n Export a YOLO model to TensorRT engine format.\n\n Args:\n onnx_file (str): Path to the ONNX file to be converted.\n engine_file (str, optional): Path to save the generated TensorRT engine file.\n workspace (int, optional): Workspace size in GB for TensorRT.\n half (bool, optional): Enable FP16 precision.\n int8 (bool, optional): Enable INT8 precision.\n dynamic (bool, optional): Enable dynamic input shapes.\n shape (Tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).\n dla (int, optional): DLA core to use (Jetson devices only).\n dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.\n metadata (Dict, optional): Metadata to include in the engine file.\n verbose (bool, optional): Enable verbose logging.\n prefix (str, optional): Prefix for log messages.\n\n Raises:\n ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.\n RuntimeError: If the ONNX file cannot be parsed.\n\n Notes:\n TensorRT version compatibility is handled for workspace size and engine building.\n INT8 calibration requires a dataset and generates a calibration cache.\n Metadata is serialized and written to the engine file if provided.\n \"\"\"\n import tensorrt as trt # noqa\n\n engine_file = engine_file or Path(onnx_file).with_suffix(\".engine\")\n\n logger = trt.Logger(trt.Logger.INFO)\n if verbose:\n logger.min_severity = trt.Logger.Severity.VERBOSE\n\n # Engine builder\n builder = trt.Builder(logger)\n config = builder.create_builder_config()\n workspace = int((workspace or 0) * (1 << 30))\n is_trt10 = int(trt.__version__.split(\".\", 1)[0]) >= 10 # is TensorRT >= 10\n if is_trt10 and workspace > 0:\n config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)\n elif workspace > 0: # TensorRT versions 7, 8\n config.max_workspace_size = workspace\n flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)\n network = builder.create_network(flag)\n half = builder.platform_has_fast_fp16 and half\n int8 = builder.platform_has_fast_int8 and int8\n\n # Optionally switch to DLA if enabled\n if dla is not None:\n if not IS_JETSON:\n raise ValueError(\"DLA is only available on NVIDIA Jetson devices\")\n LOGGER.info(f\"{prefix} enabling DLA on core {dla}...\")\n if not half and not int8:\n raise ValueError(\n \"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again.\"\n )\n config.default_device_type = trt.DeviceType.DLA\n config.DLA_core = int(dla)\n config.set_flag(trt.BuilderFlag.GPU_FALLBACK)\n\n # Read ONNX file\n parser = trt.OnnxParser(network, logger)\n if not parser.parse_from_file(onnx_file):\n raise RuntimeError(f\"failed to load ONNX file: {onnx_file}\")\n\n # Network inputs\n inputs = [network.get_input(i) for i in range(network.num_inputs)]\n outputs = [network.get_output(i) for i in range(network.num_outputs)]\n for inp in inputs:\n LOGGER.info(f'{prefix} input \"{inp.name}\" with shape{inp.shape} {inp.dtype}')\n for out in outputs:\n LOGGER.info(f'{prefix} output \"{out.name}\" with shape{out.shape} {out.dtype}')\n\n if dynamic:\n profile = builder.create_optimization_profile()\n min_shape = (1, shape[1], 32, 32) # minimum input shape\n max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape\n for inp in inputs:\n profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)\n config.add_optimization_profile(profile)\n if int8:\n config.set_calibration_profile(profile)\n\n LOGGER.info(f\"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}\")\n if int8:\n config.set_flag(trt.BuilderFlag.INT8)\n config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED\n\n class EngineCalibrator(trt.IInt8Calibrator):\n \"\"\"\n Custom INT8 calibrator for TensorRT engine optimization.\n\n This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration\n using a dataset. It handles batch generation, caching, and calibration algorithm selection.\n\n Attributes:\n dataset: Dataset for calibration.\n data_iter: Iterator over the calibration dataset.\n algo (trt.CalibrationAlgoType): Calibration algorithm type.\n batch (int): Batch size for calibration.\n cache (Path): Path to save the calibration cache.\n\n Methods:\n get_algorithm: Get the calibration algorithm to use.\n get_batch_size: Get the batch size to use for calibration.\n get_batch: Get the next batch to use for calibration.\n read_calibration_cache: Use existing cache instead of calibrating again.\n write_calibration_cache: Write calibration cache to disk.\n \"\"\"\n\n def __init__(\n self,\n dataset, # ultralytics.data.build.InfiniteDataLoader\n cache: str = \"\",\n ) -> None:\n \"\"\"Initialize the INT8 calibrator with dataset and cache path.\"\"\"\n trt.IInt8Calibrator.__init__(self)\n self.dataset = dataset\n self.data_iter = iter(dataset)\n self.algo = (\n trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2\n if dla is not None\n else trt.CalibrationAlgoType.MINMAX_CALIBRATION\n )\n self.batch = dataset.batch_size\n self.cache = Path(cache)\n\n def get_algorithm(self) -> trt.CalibrationAlgoType:\n \"\"\"Get the calibration algorithm to use.\"\"\"\n return self.algo\n\n def get_batch_size(self) -> int:\n \"\"\"Get the batch size to use for calibration.\"\"\"\n return self.batch or 1\n\n def get_batch(self, names) -> Optional[List[int]]:\n \"\"\"Get the next batch to use for calibration, as a list of device memory pointers.\"\"\"\n try:\n im0s = next(self.data_iter)[\"img\"] / 255.0\n im0s = im0s.to(\"cuda\") if im0s.device.type == \"cpu\" else im0s\n return [int(im0s.data_ptr())]\n except StopIteration:\n # Return None to signal to TensorRT there is no calibration data remaining\n return None\n\n def read_calibration_cache(self) -> Optional[bytes]:\n \"\"\"Use existing cache instead of calibrating again, otherwise, implicitly return None.\"\"\"\n if self.cache.exists() and self.cache.suffix == \".cache\":\n return self.cache.read_bytes()\n\n def write_calibration_cache(self, cache: bytes) -> None:\n \"\"\"Write calibration cache to disk.\"\"\"\n _ = self.cache.write_bytes(cache)\n\n # Load dataset w/ builder (for batching) and calibrate\n config.int8_calibrator = EngineCalibrator(\n dataset=dataset,\n cache=str(Path(onnx_file).with_suffix(\".cache\")),\n )\n\n elif half:\n config.set_flag(trt.BuilderFlag.FP16)\n\n # Write file\n build = builder.build_serialized_network if is_trt10 else builder.build_engine\n with build(network, config) as engine, open(engine_file, \"wb\") as t:\n # Metadata\n if metadata is not None:\n meta = json.dumps(metadata)\n t.write(len(meta).to_bytes(4, byteorder=\"little\", signed=True))\n t.write(meta.encode())\n # Model\n t.write(engine if is_trt10 else engine.serialize())",
"chunk_type": "function",
"name": "export_engine",
"file_path": "ultralytics\\ultralytics\\utils\\export.py",
"start_line": 49,
"end_line": 236,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": "Export a YOLO model to TensorRT engine format.\n\nArgs:\n onnx_file (str): Path to the ONNX file to be converted.\n engine_file (str, optional): Path to save the generated TensorRT engine file.\n workspace (int, optional): Workspace size in GB for TensorRT.\n half (bool, optional): Enable FP16 precision.\n int8 (bool, optional): Enable INT8 precision.\n dynamic (bool, optional): Enable dynamic input shapes.\n shape (Tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).\n dla (int, optional): DLA core to use (Jetson devices only).\n dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.\n metadata (Dict, optional): Metadata to include in the engine file.\n verbose (bool, optional): Enable verbose logging.\n prefix (str, optional): Prefix for log messages.\n\nRaises:\n ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.\n RuntimeError: If the ONNX file cannot be parsed.\n\nNotes:\n TensorRT version compatibility is handled for workspace size and engine building.\n INT8 calibration requires a dataset and generates a calibration cache.\n Metadata is serialized and written to the engine file if provided.",
"parameters": [
"onnx_file: str",
"engine_file: Optional[str]",
"workspace: Optional[int]",
"half: bool",
"int8: bool",
"dynamic: bool",
"shape: Tuple[int, int, int, int]",
"dla: Optional[int]",
"dataset",
"metadata: Optional[Dict]",
"verbose: bool",
"prefix: str"
],
"return_type": "None",
"decorators": [],
"complexity_score": 21,
"dependencies": [
"json",
"pathlib.Path",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"ultralytics.utils.IS_JETSON",
"ultralytics.utils.LOGGER",
"tensorrt"
],
"chunk_id": "function_export_engine_5e8b025f"
},
{
"content": "import contextlib",
"chunk_type": "import",
"name": "contextlib",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_contextlib_bacf7ad7"
},
{
"content": "import glob",
"chunk_type": "import",
"name": "glob",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_glob_b6b5d75e"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_778704ca"
},
{
"content": "import shutil",
"chunk_type": "import",
"name": "shutil",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_shutil_3896adf2"
},
{
"content": "import tempfile",
"chunk_type": "import",
"name": "tempfile",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_tempfile_5965034a"
},
{
"content": "from contextlib import contextmanager",
"chunk_type": "import",
"name": "contextmanager",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_contextmanager_91e65499"
},
{
"content": "from datetime import datetime",
"chunk_type": "import",
"name": "datetime",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_datetime_ea6c4397"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_613eb98b"
},
{
"content": "from typing import Union",
"chunk_type": "import",
"name": "Union",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Union_19336d67"
},
{
"content": "class WorkingDirectory(contextlib.ContextDecorator):\n \"\"\"\n A context manager and decorator for temporarily changing the working directory.\n\n This class allows for the temporary change of the working directory using a context manager or decorator.\n It ensures that the original working directory is restored after the context or decorated function completes.\n\n Attributes:\n dir (Path | str): The new directory to switch to.\n cwd (Path): The original current working directory before the switch.\n\n Methods:\n __enter__: Changes the current directory to the specified directory.\n __exit__: Restores the original working directory on context exit.\n\n Examples:\n Using as a context manager:\n >>> with WorkingDirectory('/path/to/new/dir'):\n >>> # Perform operations in the new directory\n >>> pass\n\n Using as a decorator:\n >>> @WorkingDirectory('/path/to/new/dir')\n >>> def some_function():\n >>> # Perform operations in the new directory\n >>> pass\n \"\"\"\n\n def __init__(self, new_dir: Union[str, Path]):\n \"\"\"Initialize the WorkingDirectory context manager with the target directory.\"\"\"\n self.dir = new_dir # new dir\n self.cwd = Path.cwd().resolve() # current dir\n\n def __enter__(self):\n \"\"\"Change the current working directory to the specified directory upon entering the context.\"\"\"\n os.chdir(self.dir)\n\n def __exit__(self, exc_type, exc_val, exc_tb): # noqa\n \"\"\"Restore the original working directory when exiting the context.\"\"\"\n os.chdir(self.cwd)",
"chunk_type": "class",
"name": "WorkingDirectory",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 14,
"end_line": 53,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": "A context manager and decorator for temporarily changing the working directory.\n\nThis class allows for the temporary change of the working directory using a context manager or decorator.\nIt ensures that the original working directory is restored after the context or decorated function completes.\n\nAttributes:\n dir (Path | str): The new directory to switch to.\n cwd (Path): The original current working directory before the switch.\n\nMethods:\n __enter__: Changes the current directory to the specified directory.\n __exit__: Restores the original working directory on context exit.\n\nExamples:\n Using as a context manager:\n >>> with WorkingDirectory('/path/to/new/dir'):\n >>> # Perform operations in the new directory\n >>> pass\n\n Using as a decorator:\n >>> @WorkingDirectory('/path/to/new/dir')\n >>> def some_function():\n >>> # Perform operations in the new directory\n >>> pass",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"glob",
"os",
"shutil",
"tempfile",
"contextlib.contextmanager",
"datetime.datetime",
"pathlib.Path",
"typing.Union",
"ultralytics.YOLO",
"ultralytics.nn.autobackend.default_class_names",
"contextlib.ContextDecorator"
],
"chunk_id": "class_WorkingDirectory_98c3f7ce"
},
{
"content": "def spaces_in_path(path: Union[str, Path]):\n \"\"\"\n Context manager to handle paths with spaces in their names.\n\n If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes\n the context code block, then copies the file/directory back to its original location.\n\n Args:\n path (str | Path): The original path that may contain spaces.\n\n Yields:\n (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the\n original path.\n\n Examples:\n >>> with spaces_in_path('/path/with spaces') as new_path:\n >>> # Your code here\n >>> pass\n \"\"\"\n # If path has spaces, replace them with underscores\n if \" \" in str(path):\n string = isinstance(path, str) # input type\n path = Path(path)\n\n # Create a temporary directory and construct the new path\n with tempfile.TemporaryDirectory() as tmp_dir:\n tmp_path = Path(tmp_dir) / path.name.replace(\" \", \"_\")\n\n # Copy file/directory\n if path.is_dir():\n shutil.copytree(path, tmp_path)\n elif path.is_file():\n tmp_path.parent.mkdir(parents=True, exist_ok=True)\n shutil.copy2(path, tmp_path)\n\n try:\n # Yield the temporary path\n yield str(tmp_path) if string else tmp_path\n\n finally:\n # Copy file/directory back\n if tmp_path.is_dir():\n shutil.copytree(tmp_path, path, dirs_exist_ok=True)\n elif tmp_path.is_file():\n shutil.copy2(tmp_path, path) # Copy back the file\n\n else:\n # If there are no spaces, just yield the original path\n yield path",
"chunk_type": "function",
"name": "spaces_in_path",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 57,
"end_line": 105,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Context manager to handle paths with spaces in their names.\n\nIf a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes\nthe context code block, then copies the file/directory back to its original location.\n\nArgs:\n path (str | Path): The original path that may contain spaces.\n\nYields:\n (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the\n original path.\n\nExamples:\n >>> with spaces_in_path('/path/with spaces') as new_path:\n >>> # Your code here\n >>> pass",
"parameters": [
"path: Union[str, Path]"
],
"return_type": null,
"decorators": [
"contextmanager"
],
"complexity_score": 6,
"dependencies": [
"contextlib",
"glob",
"os",
"shutil",
"tempfile",
"contextlib.contextmanager",
"datetime.datetime",
"pathlib.Path",
"typing.Union",
"ultralytics.YOLO",
"ultralytics.nn.autobackend.default_class_names"
],
"chunk_id": "function_spaces_in_path_fac3d854"
},
{
"content": "def increment_path(path: Union[str, Path], exist_ok: bool = False, sep: str = \"\", mkdir: bool = False) -> Path:\n \"\"\"\n Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.\n\n If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to\n the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the\n number will be appended directly to the end of the path.\n\n Args:\n path (str | Path): Path to increment.\n exist_ok (bool, optional): If True, the path will not be incremented and returned as-is.\n sep (str, optional): Separator to use between the path and the incrementation number.\n mkdir (bool, optional): Create a directory if it does not exist.\n\n Returns:\n (Path): Incremented path.\n\n Examples:\n Increment a directory path:\n >>> from pathlib import Path\n >>> path = Path(\"runs/exp\")\n >>> new_path = increment_path(path)\n >>> print(new_path)\n runs/exp2\n\n Increment a file path:\n >>> path = Path(\"runs/exp/results.txt\")\n >>> new_path = increment_path(path)\n >>> print(new_path)\n runs/exp/results2.txt\n \"\"\"\n path = Path(path) # os-agnostic\n if path.exists() and not exist_ok:\n path, suffix = (path.with_suffix(\"\"), path.suffix) if path.is_file() else (path, \"\")\n\n # Method 1\n for n in range(2, 9999):\n p = f\"{path}{sep}{n}{suffix}\" # increment path\n if not os.path.exists(p):\n break\n path = Path(p)\n\n if mkdir:\n path.mkdir(parents=True, exist_ok=True) # make directory\n\n return path",
"chunk_type": "function",
"name": "increment_path",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 108,
"end_line": 153,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.\n\nIf the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to\nthe end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the\nnumber will be appended directly to the end of the path.\n\nArgs:\n path (str | Path): Path to increment.\n exist_ok (bool, optional): If True, the path will not be incremented and returned as-is.\n sep (str, optional): Separator to use between the path and the incrementation number.\n mkdir (bool, optional): Create a directory if it does not exist.\n\nReturns:\n (Path): Incremented path.\n\nExamples:\n Increment a directory path:\n >>> from pathlib import Path\n >>> path = Path(\"runs/exp\")\n >>> new_path = increment_path(path)\n >>> print(new_path)\n runs/exp2\n\n Increment a file path:\n >>> path = Path(\"runs/exp/results.txt\")\n >>> new_path = increment_path(path)\n >>> print(new_path)\n runs/exp/results2.txt",
"parameters": [
"path: Union[str, Path]",
"exist_ok: bool",
"sep: str",
"mkdir: bool"
],
"return_type": "Path",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"contextlib",
"glob",
"os",
"shutil",
"tempfile",
"contextlib.contextmanager",
"datetime.datetime",
"pathlib.Path",
"typing.Union",
"ultralytics.YOLO",
"ultralytics.nn.autobackend.default_class_names"
],
"chunk_id": "function_increment_path_f13e1833"
},
{
"content": "def file_age(path: Union[str, Path] = __file__) -> int:\n \"\"\"Return days since the last modification of the specified file.\"\"\"\n dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta\n return dt.days # + dt.seconds / 86400 # fractional days",
"chunk_type": "function",
"name": "file_age",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 156,
"end_line": 159,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Return days since the last modification of the specified file.",
"parameters": [
"path: Union[str, Path]"
],
"return_type": "int",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"glob",
"os",
"shutil",
"tempfile",
"contextlib.contextmanager",
"datetime.datetime",
"pathlib.Path",
"typing.Union",
"ultralytics.YOLO",
"ultralytics.nn.autobackend.default_class_names"
],
"chunk_id": "function_file_age_ba03809c"
},
{
"content": "def file_date(path: Union[str, Path] = __file__) -> str:\n \"\"\"Return the file modification date in 'YYYY-M-D' format.\"\"\"\n t = datetime.fromtimestamp(Path(path).stat().st_mtime)\n return f\"{t.year}-{t.month}-{t.day}\"",
"chunk_type": "function",
"name": "file_date",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 162,
"end_line": 165,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Return the file modification date in 'YYYY-M-D' format.",
"parameters": [
"path: Union[str, Path]"
],
"return_type": "str",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"glob",
"os",
"shutil",
"tempfile",
"contextlib.contextmanager",
"datetime.datetime",
"pathlib.Path",
"typing.Union",
"ultralytics.YOLO",
"ultralytics.nn.autobackend.default_class_names"
],
"chunk_id": "function_file_date_0e5fa0e2"
},
{
"content": "def file_size(path: Union[str, Path]) -> float:\n \"\"\"Return the size of a file or directory in megabytes (MB).\"\"\"\n if isinstance(path, (str, Path)):\n mb = 1 << 20 # bytes to MiB (1024 ** 2)\n path = Path(path)\n if path.is_file():\n return path.stat().st_size / mb\n elif path.is_dir():\n return sum(f.stat().st_size for f in path.glob(\"**/*\") if f.is_file()) / mb\n return 0.0",
"chunk_type": "function",
"name": "file_size",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 168,
"end_line": 177,
"start_col": 0,
"end_col": 14,
"parent_name": null,
"docstring": "Return the size of a file or directory in megabytes (MB).",
"parameters": [
"path: Union[str, Path]"
],
"return_type": "float",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"contextlib",
"glob",
"os",
"shutil",
"tempfile",
"contextlib.contextmanager",
"datetime.datetime",
"pathlib.Path",
"typing.Union",
"ultralytics.YOLO",
"ultralytics.nn.autobackend.default_class_names"
],
"chunk_id": "function_file_size_04b9f1b2"
},
{
"content": "def get_latest_run(search_dir: str = \".\") -> str:\n \"\"\"Return the path to the most recent 'last.pt' file in the specified directory for resuming training.\"\"\"\n last_list = glob.glob(f\"{search_dir}/**/last*.pt\", recursive=True)\n return max(last_list, key=os.path.getctime) if last_list else \"\"",
"chunk_type": "function",
"name": "get_latest_run",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 180,
"end_line": 183,
"start_col": 0,
"end_col": 68,
"parent_name": null,
"docstring": "Return the path to the most recent 'last.pt' file in the specified directory for resuming training.",
"parameters": [
"search_dir: str"
],
"return_type": "str",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"glob",
"os",
"shutil",
"tempfile",
"contextlib.contextmanager",
"datetime.datetime",
"pathlib.Path",
"typing.Union",
"ultralytics.YOLO",
"ultralytics.nn.autobackend.default_class_names"
],
"chunk_id": "function_get_latest_run_dcf5fd05"
},
{
"content": "def update_models(model_names: tuple = (\"yolo11n.pt\",), source_dir: Path = Path(\".\"), update_names: bool = False):\n \"\"\"\n Update and re-save specified YOLO models in an 'updated_models' subdirectory.\n\n Args:\n model_names (tuple, optional): Model filenames to update.\n source_dir (Path, optional): Directory containing models and target subdirectory.\n update_names (bool, optional): Update model names from a data YAML.\n\n Examples:\n Update specified YOLO models and save them in 'updated_models' subdirectory:\n >>> from ultralytics.utils.files import update_models\n >>> model_names = (\"yolo11n.pt\", \"yolov8s.pt\")\n >>> update_models(model_names, source_dir=Path(\"/models\"), update_names=True)\n \"\"\"\n from ultralytics import YOLO\n from ultralytics.nn.autobackend import default_class_names\n\n target_dir = source_dir / \"updated_models\"\n target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists\n\n for model_name in model_names:\n model_path = source_dir / model_name\n print(f\"Loading model from {model_path}\")\n\n # Load model\n model = YOLO(model_path)\n model.half()\n if update_names: # update model names from a dataset YAML\n model.model.names = default_class_names(\"coco8.yaml\")\n\n # Define new save path\n save_path = target_dir / model_name\n\n # Save model using model.save()\n print(f\"Re-saving {model_name} model to {save_path}\")\n model.save(save_path)",
"chunk_type": "function",
"name": "update_models",
"file_path": "ultralytics\\ultralytics\\utils\\files.py",
"start_line": 186,
"end_line": 222,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": "Update and re-save specified YOLO models in an 'updated_models' subdirectory.\n\nArgs:\n model_names (tuple, optional): Model filenames to update.\n source_dir (Path, optional): Directory containing models and target subdirectory.\n update_names (bool, optional): Update model names from a data YAML.\n\nExamples:\n Update specified YOLO models and save them in 'updated_models' subdirectory:\n >>> from ultralytics.utils.files import update_models\n >>> model_names = (\"yolo11n.pt\", \"yolov8s.pt\")\n >>> update_models(model_names, source_dir=Path(\"/models\"), update_names=True)",
"parameters": [
"model_names: tuple",
"source_dir: Path",
"update_names: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"contextlib",
"glob",
"os",
"shutil",
"tempfile",
"contextlib.contextmanager",
"datetime.datetime",
"pathlib.Path",
"typing.Union",
"ultralytics.YOLO",
"ultralytics.nn.autobackend.default_class_names"
],
"chunk_id": "function_update_models_b08d3edc"
},
{
"content": "from collections import abc",
"chunk_type": "import",
"name": "abc",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_abc_1a55c307"
},
{
"content": "from itertools import repeat",
"chunk_type": "import",
"name": "repeat",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_repeat_818933e2"
},
{
"content": "from numbers import Number",
"chunk_type": "import",
"name": "Number",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Number_b9d16c87"
},
{
"content": "from typing import List, Union",
"chunk_type": "import",
"name": "List, Union",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Union_2df48fed"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_a2efdae5"
},
{
"content": "from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh",
"chunk_type": "import",
"name": "ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 100,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh_9b94f05f"
},
{
"content": "def _ntuple(n):\n \"\"\"Create a function that converts input to n-tuple by repeating singleton values.\"\"\"\n\n def parse(x):\n \"\"\"Parse input to return n-tuple by repeating singleton values n times.\"\"\"\n return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))\n\n return parse",
"chunk_type": "function",
"name": "_ntuple",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 13,
"end_line": 20,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Create a function that converts input to n-tuple by repeating singleton values.",
"parameters": [
"n"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc",
"itertools.repeat",
"numbers.Number",
"typing.List",
"typing.Union",
"numpy",
"ops.ltwh2xywh",
"ops.ltwh2xyxy",
"ops.resample_segments",
"ops.xywh2ltwh",
"ops.xywh2xyxy",
"ops.xyxy2ltwh",
"ops.xyxy2xywh"
],
"chunk_id": "function__ntuple_833ec773"
},
{
"content": "to_2tuple = _ntuple(2)",
"chunk_type": "variable",
"name": "to_2tuple",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 23,
"end_line": 23,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_to_2tuple_0c3837fe"
},
{
"content": "to_4tuple = _ntuple(4)",
"chunk_type": "variable",
"name": "to_4tuple",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 24,
"end_line": 24,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_to_4tuple_a9f230e9"
},
{
"content": "_formats = [\"xyxy\", \"xywh\", \"ltwh\"]",
"chunk_type": "variable",
"name": "_formats",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 29,
"end_line": 29,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable__formats_736251a6"
},
{
"content": "__all__ = (\"Bboxes\", \"Instances\") # tuple or list",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 31,
"end_line": 31,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___bf54126f"
},
{
"content": "class Bboxes:\n \"\"\"\n A class for handling bounding boxes in multiple formats.\n\n The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format\n conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.\n\n Attributes:\n bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).\n format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').\n\n Methods:\n convert: Convert bounding box format from one type to another.\n areas: Calculate the area of bounding boxes.\n mul: Multiply bounding box coordinates by scale factor(s).\n add: Add offset to bounding box coordinates.\n concatenate: Concatenate multiple Bboxes objects.\n\n Examples:\n Create bounding boxes in YOLO format\n >>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format=\"xywh\")\n >>> bboxes.convert(\"xyxy\")\n >>> print(bboxes.areas())\n\n Notes:\n This class does not handle normalization or denormalization of bounding boxes.\n \"\"\"\n\n def __init__(self, bboxes: np.ndarray, format: str = \"xyxy\") -> None:\n \"\"\"\n Initialize the Bboxes class with bounding box data in a specified format.\n\n Args:\n bboxes (np.ndarray): Array of bounding boxes with shape (N, 4) or (4,).\n format (str): Format of the bounding boxes, one of 'xyxy', 'xywh', or 'ltwh'.\n \"\"\"\n assert format in _formats, f\"Invalid bounding box format: {format}, format must be one of {_formats}\"\n bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes\n assert bboxes.ndim == 2\n assert bboxes.shape[1] == 4\n self.bboxes = bboxes\n self.format = format\n\n def convert(self, format: str) -> None:\n \"\"\"\n Convert bounding box format from one type to another.\n\n Args:\n format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.\n \"\"\"\n assert format in _formats, f\"Invalid bounding box format: {format}, format must be one of {_formats}\"\n if self.format == format:\n return\n elif self.format == \"xyxy\":\n func = xyxy2xywh if format == \"xywh\" else xyxy2ltwh\n elif self.format == \"xywh\":\n func = xywh2xyxy if format == \"xyxy\" else xywh2ltwh\n else:\n func = ltwh2xyxy if format == \"xyxy\" else ltwh2xywh\n self.bboxes = func(self.bboxes)\n self.format = format\n\n def areas(self) -> np.ndarray:\n \"\"\"Calculate the area of bounding boxes.\"\"\"\n return (\n (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy\n if self.format == \"xyxy\"\n else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh\n )\n\n def mul(self, scale: Union[int, tuple, list]) -> None:\n \"\"\"\n Multiply bounding box coordinates by scale factor(s).\n\n Args:\n scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to\n all coordinates.\n \"\"\"\n if isinstance(scale, Number):\n scale = to_4tuple(scale)\n assert isinstance(scale, (tuple, list))\n assert len(scale) == 4\n self.bboxes[:, 0] *= scale[0]\n self.bboxes[:, 1] *= scale[1]\n self.bboxes[:, 2] *= scale[2]\n self.bboxes[:, 3] *= scale[3]\n\n def add(self, offset: Union[int, tuple, list]) -> None:\n \"\"\"\n Add offset to bounding box coordinates.\n\n Args:\n offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to\n all coordinates.\n \"\"\"\n if isinstance(offset, Number):\n offset = to_4tuple(offset)\n assert isinstance(offset, (tuple, list))\n assert len(offset) == 4\n self.bboxes[:, 0] += offset[0]\n self.bboxes[:, 1] += offset[1]\n self.bboxes[:, 2] += offset[2]\n self.bboxes[:, 3] += offset[3]\n\n def __len__(self) -> int:\n \"\"\"Return the number of bounding boxes.\"\"\"\n return len(self.bboxes)\n\n @classmethod\n def concatenate(cls, boxes_list: List[\"Bboxes\"], axis: int = 0) -> \"Bboxes\":\n \"\"\"\n Concatenate a list of Bboxes objects into a single Bboxes object.\n\n Args:\n boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.\n axis (int, optional): The axis along which to concatenate the bounding boxes.\n\n Returns:\n (Bboxes): A new Bboxes object containing the concatenated bounding boxes.\n\n Notes:\n The input should be a list or tuple of Bboxes objects.\n \"\"\"\n assert isinstance(boxes_list, (list, tuple))\n if not boxes_list:\n return cls(np.empty(0))\n assert all(isinstance(box, Bboxes) for box in boxes_list)\n\n if len(boxes_list) == 1:\n return boxes_list[0]\n return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))\n\n def __getitem__(self, index: Union[int, np.ndarray, slice]) -> \"Bboxes\":\n \"\"\"\n Retrieve a specific bounding box or a set of bounding boxes using indexing.\n\n Args:\n index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes.\n\n Returns:\n (Bboxes): A new Bboxes object containing the selected bounding boxes.\n\n Notes:\n When using boolean indexing, make sure to provide a boolean array with the same length as the number of\n bounding boxes.\n \"\"\"\n if isinstance(index, int):\n return Bboxes(self.bboxes[index].reshape(1, -1))\n b = self.bboxes[index]\n assert b.ndim == 2, f\"Indexing on Bboxes with {index} failed to return a matrix!\"\n return Bboxes(b)",
"chunk_type": "class",
"name": "Bboxes",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 34,
"end_line": 184,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": "A class for handling bounding boxes in multiple formats.\n\nThe class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format\nconversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.\n\nAttributes:\n bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).\n format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').\n\nMethods:\n convert: Convert bounding box format from one type to another.\n areas: Calculate the area of bounding boxes.\n mul: Multiply bounding box coordinates by scale factor(s).\n add: Add offset to bounding box coordinates.\n concatenate: Concatenate multiple Bboxes objects.\n\nExamples:\n Create bounding boxes in YOLO format\n >>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format=\"xywh\")\n >>> bboxes.convert(\"xyxy\")\n >>> print(bboxes.areas())\n\nNotes:\n This class does not handle normalization or denormalization of bounding boxes.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.abc",
"itertools.repeat",
"numbers.Number",
"typing.List",
"typing.Union",
"numpy",
"ops.ltwh2xywh",
"ops.ltwh2xyxy",
"ops.resample_segments",
"ops.xywh2ltwh",
"ops.xywh2xyxy",
"ops.xyxy2ltwh",
"ops.xyxy2xywh"
],
"chunk_id": "class_Bboxes_8691b2b0"
},
{
"content": "class Instances:\n \"\"\"\n Container for bounding boxes, segments, and keypoints of detected objects in an image.\n\n This class provides a unified interface for handling different types of object annotations including bounding\n boxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,\n and format conversion.\n\n Attributes:\n _bboxes (Bboxes): Internal object for handling bounding box operations.\n keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).\n normalized (bool): Flag indicating whether the bounding box coordinates are normalized.\n segments (np.ndarray): Segments array with shape (N, M, 2) after resampling.\n\n Methods:\n convert_bbox: Convert bounding box format.\n scale: Scale coordinates by given factors.\n denormalize: Convert normalized coordinates to absolute coordinates.\n normalize: Convert absolute coordinates to normalized coordinates.\n add_padding: Add padding to coordinates.\n flipud: Flip coordinates vertically.\n fliplr: Flip coordinates horizontally.\n clip: Clip coordinates to stay within image boundaries.\n remove_zero_area_boxes: Remove boxes with zero area.\n update: Update instance variables.\n concatenate: Concatenate multiple Instances objects.\n\n Examples:\n Create instances with bounding boxes and segments\n >>> instances = Instances(\n ... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),\n ... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],\n ... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),\n ... )\n \"\"\"\n\n def __init__(\n self,\n bboxes: np.ndarray,\n segments: np.ndarray = None,\n keypoints: np.ndarray = None,\n bbox_format: str = \"xywh\",\n normalized: bool = True,\n ) -> None:\n \"\"\"\n Initialize the Instances object with bounding boxes, segments, and keypoints.\n\n Args:\n bboxes (np.ndarray): Bounding boxes with shape (N, 4).\n segments (np.ndarray, optional): Segmentation masks.\n keypoints (np.ndarray, optional): Keypoints with shape (N, 17, 3) in format (x, y, visible).\n bbox_format (str): Format of bboxes.\n normalized (bool): Whether the coordinates are normalized.\n \"\"\"\n self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)\n self.keypoints = keypoints\n self.normalized = normalized\n self.segments = segments\n\n def convert_bbox(self, format: str) -> None:\n \"\"\"\n Convert bounding box format.\n\n Args:\n format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.\n \"\"\"\n self._bboxes.convert(format=format)\n\n @property\n def bbox_areas(self) -> np.ndarray:\n \"\"\"Calculate the area of bounding boxes.\"\"\"\n return self._bboxes.areas()\n\n def scale(self, scale_w: float, scale_h: float, bbox_only: bool = False):\n \"\"\"\n Scale coordinates by given factors.\n\n Args:\n scale_w (float): Scale factor for width.\n scale_h (float): Scale factor for height.\n bbox_only (bool, optional): Whether to scale only bounding boxes.\n \"\"\"\n self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))\n if bbox_only:\n return\n self.segments[..., 0] *= scale_w\n self.segments[..., 1] *= scale_h\n if self.keypoints is not None:\n self.keypoints[..., 0] *= scale_w\n self.keypoints[..., 1] *= scale_h\n\n def denormalize(self, w: int, h: int) -> None:\n \"\"\"\n Convert normalized coordinates to absolute coordinates.\n\n Args:\n w (int): Image width.\n h (int): Image height.\n \"\"\"\n if not self.normalized:\n return\n self._bboxes.mul(scale=(w, h, w, h))\n self.segments[..., 0] *= w\n self.segments[..., 1] *= h\n if self.keypoints is not None:\n self.keypoints[..., 0] *= w\n self.keypoints[..., 1] *= h\n self.normalized = False\n\n def normalize(self, w: int, h: int) -> None:\n \"\"\"\n Convert absolute coordinates to normalized coordinates.\n\n Args:\n w (int): Image width.\n h (int): Image height.\n \"\"\"\n if self.normalized:\n return\n self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))\n self.segments[..., 0] /= w\n self.segments[..., 1] /= h\n if self.keypoints is not None:\n self.keypoints[..., 0] /= w\n self.keypoints[..., 1] /= h\n self.normalized = True\n\n def add_padding(self, padw: int, padh: int) -> None:\n \"\"\"\n Add padding to coordinates.\n\n Args:\n padw (int): Padding width.\n padh (int): Padding height.\n \"\"\"\n assert not self.normalized, \"you should add padding with absolute coordinates.\"\n self._bboxes.add(offset=(padw, padh, padw, padh))\n self.segments[..., 0] += padw\n self.segments[..., 1] += padh\n if self.keypoints is not None:\n self.keypoints[..., 0] += padw\n self.keypoints[..., 1] += padh\n\n def __getitem__(self, index: Union[int, np.ndarray, slice]) -> \"Instances\":\n \"\"\"\n Retrieve a specific instance or a set of instances using indexing.\n\n Args:\n index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired instances.\n\n Returns:\n (Instances): A new Instances object containing the selected boxes, segments, and keypoints if present.\n\n Notes:\n When using boolean indexing, make sure to provide a boolean array with the same length as the number of\n instances.\n \"\"\"\n segments = self.segments[index] if len(self.segments) else self.segments\n keypoints = self.keypoints[index] if self.keypoints is not None else None\n bboxes = self.bboxes[index]\n bbox_format = self._bboxes.format\n return Instances(\n bboxes=bboxes,\n segments=segments,\n keypoints=keypoints,\n bbox_format=bbox_format,\n normalized=self.normalized,\n )\n\n def flipud(self, h: int) -> None:\n \"\"\"\n Flip coordinates vertically.\n\n Args:\n h (int): Image height.\n \"\"\"\n if self._bboxes.format == \"xyxy\":\n y1 = self.bboxes[:, 1].copy()\n y2 = self.bboxes[:, 3].copy()\n self.bboxes[:, 1] = h - y2\n self.bboxes[:, 3] = h - y1\n else:\n self.bboxes[:, 1] = h - self.bboxes[:, 1]\n self.segments[..., 1] = h - self.segments[..., 1]\n if self.keypoints is not None:\n self.keypoints[..., 1] = h - self.keypoints[..., 1]\n\n def fliplr(self, w: int) -> None:\n \"\"\"\n Flip coordinates horizontally.\n\n Args:\n w (int): Image width.\n \"\"\"\n if self._bboxes.format == \"xyxy\":\n x1 = self.bboxes[:, 0].copy()\n x2 = self.bboxes[:, 2].copy()\n self.bboxes[:, 0] = w - x2\n self.bboxes[:, 2] = w - x1\n else:\n self.bboxes[:, 0] = w - self.bboxes[:, 0]\n self.segments[..., 0] = w - self.segments[..., 0]\n if self.keypoints is not None:\n self.keypoints[..., 0] = w - self.keypoints[..., 0]\n\n def clip(self, w: int, h: int) -> None:\n \"\"\"\n Clip coordinates to stay within image boundaries.\n\n Args:\n w (int): Image width.\n h (int): Image height.\n \"\"\"\n ori_format = self._bboxes.format\n self.convert_bbox(format=\"xyxy\")\n self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)\n self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)\n if ori_format != \"xyxy\":\n self.convert_bbox(format=ori_format)\n self.segments[..., 0] = self.segments[..., 0].clip(0, w)\n self.segments[..., 1] = self.segments[..., 1].clip(0, h)\n if self.keypoints is not None:\n # Set out of bounds visibility to zero\n self.keypoints[..., 2][\n (self.keypoints[..., 0] < 0)\n | (self.keypoints[..., 0] > w)\n | (self.keypoints[..., 1] < 0)\n | (self.keypoints[..., 1] > h)\n ] = 0.0\n self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)\n self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)\n\n def remove_zero_area_boxes(self) -> np.ndarray:\n \"\"\"\n Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.\n\n Returns:\n (np.ndarray): Boolean array indicating which boxes were kept.\n \"\"\"\n good = self.bbox_areas > 0\n if not all(good):\n self._bboxes = self._bboxes[good]\n if len(self.segments):\n self.segments = self.segments[good]\n if self.keypoints is not None:\n self.keypoints = self.keypoints[good]\n return good\n\n def update(self, bboxes: np.ndarray, segments: np.ndarray = None, keypoints: np.ndarray = None):\n \"\"\"\n Update instance variables.\n\n Args:\n bboxes (np.ndarray): New bounding boxes.\n segments (np.ndarray, optional): New segments.\n keypoints (np.ndarray, optional): New keypoints.\n \"\"\"\n self._bboxes = Bboxes(bboxes, format=self._bboxes.format)\n if segments is not None:\n self.segments = segments\n if keypoints is not None:\n self.keypoints = keypoints\n\n def __len__(self) -> int:\n \"\"\"Return the number of instances.\"\"\"\n return len(self.bboxes)\n\n @classmethod\n def concatenate(cls, instances_list: List[\"Instances\"], axis=0) -> \"Instances\":\n \"\"\"\n Concatenate a list of Instances objects into a single Instances object.\n\n Args:\n instances_list (List[Instances]): A list of Instances objects to concatenate.\n axis (int, optional): The axis along which the arrays will be concatenated.\n\n Returns:\n (Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints\n if present.\n\n Notes:\n The `Instances` objects in the list should have the same properties, such as the format of the bounding\n boxes, whether keypoints are present, and if the coordinates are normalized.\n \"\"\"\n assert isinstance(instances_list, (list, tuple))\n if not instances_list:\n return cls(np.empty(0))\n assert all(isinstance(instance, Instances) for instance in instances_list)\n\n if len(instances_list) == 1:\n return instances_list[0]\n\n use_keypoint = instances_list[0].keypoints is not None\n bbox_format = instances_list[0]._bboxes.format\n normalized = instances_list[0].normalized\n\n cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)\n seg_len = [b.segments.shape[1] for b in instances_list]\n if len(frozenset(seg_len)) > 1: # resample segments if there's different length\n max_len = max(seg_len)\n cat_segments = np.concatenate(\n [\n resample_segments(list(b.segments), max_len)\n if len(b.segments)\n else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments\n for b in instances_list\n ],\n axis=axis,\n )\n else:\n cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)\n cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None\n return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)\n\n @property\n def bboxes(self) -> np.ndarray:\n \"\"\"Return bounding boxes.\"\"\"\n return self._bboxes.bboxes",
"chunk_type": "class",
"name": "Instances",
"file_path": "ultralytics\\ultralytics\\utils\\instance.py",
"start_line": 187,
"end_line": 504,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": "Container for bounding boxes, segments, and keypoints of detected objects in an image.\n\nThis class provides a unified interface for handling different types of object annotations including bounding\nboxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,\nand format conversion.\n\nAttributes:\n _bboxes (Bboxes): Internal object for handling bounding box operations.\n keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).\n normalized (bool): Flag indicating whether the bounding box coordinates are normalized.\n segments (np.ndarray): Segments array with shape (N, M, 2) after resampling.\n\nMethods:\n convert_bbox: Convert bounding box format.\n scale: Scale coordinates by given factors.\n denormalize: Convert normalized coordinates to absolute coordinates.\n normalize: Convert absolute coordinates to normalized coordinates.\n add_padding: Add padding to coordinates.\n flipud: Flip coordinates vertically.\n fliplr: Flip coordinates horizontally.\n clip: Clip coordinates to stay within image boundaries.\n remove_zero_area_boxes: Remove boxes with zero area.\n update: Update instance variables.\n concatenate: Concatenate multiple Instances objects.\n\nExamples:\n Create instances with bounding boxes and segments\n >>> instances = Instances(\n ... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),\n ... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],\n ... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),\n ... )",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.abc",
"itertools.repeat",
"numbers.Number",
"typing.List",
"typing.Union",
"numpy",
"ops.ltwh2xywh",
"ops.ltwh2xyxy",
"ops.resample_segments",
"ops.xywh2ltwh",
"ops.xywh2xyxy",
"ops.xyxy2ltwh",
"ops.xyxy2xywh"
],
"chunk_id": "class_Instances_76a2c85b"
},
{
"content": "from typing import Any, Dict, List, Tuple",
"chunk_type": "import",
"name": "Any, Dict, List, Tuple",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Tuple_0099e5e8"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_496c93cf"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_b28e90cb"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_9922e2b8"
},
{
"content": "from ultralytics.utils.metrics import OKS_SIGMA",
"chunk_type": "import",
"name": "OKS_SIGMA",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OKS_SIGMA_beb7ea8c"
},
{
"content": "from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh",
"chunk_type": "import",
"name": "crop_mask, xywh2xyxy, xyxy2xywh",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_crop_mask, xywh2xyxy, xyxy2xywh_383a0b84"
},
{
"content": "from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors",
"chunk_type": "import",
"name": "RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 117,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors_1722eb43"
},
{
"content": "from ultralytics.utils.torch_utils import autocast",
"chunk_type": "import",
"name": "autocast",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_autocast_4a645288"
},
{
"content": "from .metrics import bbox_iou, probiou",
"chunk_type": "import",
"name": "bbox_iou, probiou",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_bbox_iou, probiou_3b940619"
},
{
"content": "from .tal import bbox2dist",
"chunk_type": "import",
"name": "bbox2dist",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_bbox2dist_e462b29f"
},
{
"content": "class VarifocalLoss(nn.Module):\n \"\"\"\n Varifocal loss by Zhang et al.\n\n Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on\n hard-to-classify examples and balancing positive/negative samples.\n\n Attributes:\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (float): The balancing factor used to address class imbalance.\n\n References:\n https://arxiv.org/abs/2008.13367\n \"\"\"\n\n def __init__(self, gamma: float = 2.0, alpha: float = 0.75):\n \"\"\"Initialize the VarifocalLoss class with focusing and balancing parameters.\"\"\"\n super().__init__()\n self.gamma = gamma\n self.alpha = alpha\n\n def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:\n \"\"\"Compute varifocal loss between predictions and ground truth.\"\"\"\n weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label\n with autocast(enabled=False):\n loss = (\n (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction=\"none\") * weight)\n .mean(1)\n .sum()\n )\n return loss",
"chunk_type": "class",
"name": "VarifocalLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 18,
"end_line": 48,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Varifocal loss by Zhang et al.\n\nImplements the Varifocal Loss function for addressing class imbalance in object detection by focusing on\nhard-to-classify examples and balancing positive/negative samples.\n\nAttributes:\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (float): The balancing factor used to address class imbalance.\n\nReferences:\n https://arxiv.org/abs/2008.13367",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"nn.Module"
],
"chunk_id": "class_VarifocalLoss_0b82ffe4"
},
{
"content": "class FocalLoss(nn.Module):\n \"\"\"\n Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).\n\n Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing\n on hard negatives during training.\n\n Attributes:\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (torch.Tensor): The balancing factor used to address class imbalance.\n \"\"\"\n\n def __init__(self, gamma: float = 1.5, alpha: float = 0.25):\n \"\"\"Initialize FocalLoss class with focusing and balancing parameters.\"\"\"\n super().__init__()\n self.gamma = gamma\n self.alpha = torch.tensor(alpha)\n\n def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:\n \"\"\"Calculate focal loss with modulating factors for class imbalance.\"\"\"\n loss = F.binary_cross_entropy_with_logits(pred, label, reduction=\"none\")\n # p_t = torch.exp(-loss)\n # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability\n\n # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py\n pred_prob = pred.sigmoid() # prob from logits\n p_t = label * pred_prob + (1 - label) * (1 - pred_prob)\n modulating_factor = (1.0 - p_t) ** self.gamma\n loss *= modulating_factor\n if (self.alpha > 0).any():\n self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)\n alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)\n loss *= alpha_factor\n return loss.mean(1).sum()",
"chunk_type": "class",
"name": "FocalLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 51,
"end_line": 84,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).\n\nImplements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing\non hard negatives during training.\n\nAttributes:\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (torch.Tensor): The balancing factor used to address class imbalance.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"nn.Module"
],
"chunk_id": "class_FocalLoss_070521f6"
},
{
"content": "class DFLoss(nn.Module):\n \"\"\"Criterion class for computing Distribution Focal Loss (DFL).\"\"\"\n\n def __init__(self, reg_max: int = 16) -> None:\n \"\"\"Initialize the DFL module with regularization maximum.\"\"\"\n super().__init__()\n self.reg_max = reg_max\n\n def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n \"\"\"Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391.\"\"\"\n target = target.clamp_(0, self.reg_max - 1 - 0.01)\n tl = target.long() # target left\n tr = tl + 1 # target right\n wl = tr - target # weight left\n wr = 1 - wl # weight right\n return (\n F.cross_entropy(pred_dist, tl.view(-1), reduction=\"none\").view(tl.shape) * wl\n + F.cross_entropy(pred_dist, tr.view(-1), reduction=\"none\").view(tl.shape) * wr\n ).mean(-1, keepdim=True)",
"chunk_type": "class",
"name": "DFLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 87,
"end_line": 105,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": "Criterion class for computing Distribution Focal Loss (DFL).",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"nn.Module"
],
"chunk_id": "class_DFLoss_77fc59e8"
},
{
"content": "class BboxLoss(nn.Module):\n \"\"\"Criterion class for computing training losses for bounding boxes.\"\"\"\n\n def __init__(self, reg_max: int = 16):\n \"\"\"Initialize the BboxLoss module with regularization maximum and DFL settings.\"\"\"\n super().__init__()\n self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None\n\n def forward(\n self,\n pred_dist: torch.Tensor,\n pred_bboxes: torch.Tensor,\n anchor_points: torch.Tensor,\n target_bboxes: torch.Tensor,\n target_scores: torch.Tensor,\n target_scores_sum: torch.Tensor,\n fg_mask: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Compute IoU and DFL losses for bounding boxes.\"\"\"\n weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)\n iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)\n loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum\n\n # DFL loss\n if self.dfl_loss:\n target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)\n loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight\n loss_dfl = loss_dfl.sum() / target_scores_sum\n else:\n loss_dfl = torch.tensor(0.0).to(pred_dist.device)\n\n return loss_iou, loss_dfl",
"chunk_type": "class",
"name": "BboxLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 108,
"end_line": 139,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "Criterion class for computing training losses for bounding boxes.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"nn.Module"
],
"chunk_id": "class_BboxLoss_1de7f202"
},
{
"content": "class RotatedBboxLoss(BboxLoss):\n \"\"\"Criterion class for computing training losses for rotated bounding boxes.\"\"\"\n\n def __init__(self, reg_max: int):\n \"\"\"Initialize the RotatedBboxLoss module with regularization maximum and DFL settings.\"\"\"\n super().__init__(reg_max)\n\n def forward(\n self,\n pred_dist: torch.Tensor,\n pred_bboxes: torch.Tensor,\n anchor_points: torch.Tensor,\n target_bboxes: torch.Tensor,\n target_scores: torch.Tensor,\n target_scores_sum: torch.Tensor,\n fg_mask: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Compute IoU and DFL losses for rotated bounding boxes.\"\"\"\n weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)\n iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])\n loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum\n\n # DFL loss\n if self.dfl_loss:\n target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)\n loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight\n loss_dfl = loss_dfl.sum() / target_scores_sum\n else:\n loss_dfl = torch.tensor(0.0).to(pred_dist.device)\n\n return loss_iou, loss_dfl",
"chunk_type": "class",
"name": "RotatedBboxLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 142,
"end_line": 172,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "Criterion class for computing training losses for rotated bounding boxes.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"BboxLoss"
],
"chunk_id": "class_RotatedBboxLoss_0b2451d8"
},
{
"content": "class KeypointLoss(nn.Module):\n \"\"\"Criterion class for computing keypoint losses.\"\"\"\n\n def __init__(self, sigmas: torch.Tensor) -> None:\n \"\"\"Initialize the KeypointLoss class with keypoint sigmas.\"\"\"\n super().__init__()\n self.sigmas = sigmas\n\n def forward(\n self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor\n ) -> torch.Tensor:\n \"\"\"Calculate keypoint loss factor and Euclidean distance loss for keypoints.\"\"\"\n d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)\n kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)\n # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula\n e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval\n return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()",
"chunk_type": "class",
"name": "KeypointLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 175,
"end_line": 191,
"start_col": 0,
"end_col": 86,
"parent_name": null,
"docstring": "Criterion class for computing keypoint losses.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"nn.Module"
],
"chunk_id": "class_KeypointLoss_a5a78e11"
},
{
"content": "class v8DetectionLoss:\n \"\"\"Criterion class for computing training losses for YOLOv8 object detection.\"\"\"\n\n def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled\n \"\"\"Initialize v8DetectionLoss with model parameters and task-aligned assignment settings.\"\"\"\n device = next(model.parameters()).device # get model device\n h = model.args # hyperparameters\n\n m = model.model[-1] # Detect() module\n self.bce = nn.BCEWithLogitsLoss(reduction=\"none\")\n self.hyp = h\n self.stride = m.stride # model strides\n self.nc = m.nc # number of classes\n self.no = m.nc + m.reg_max * 4\n self.reg_max = m.reg_max\n self.device = device\n\n self.use_dfl = m.reg_max > 1\n\n self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)\n self.bbox_loss = BboxLoss(m.reg_max).to(device)\n self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)\n\n def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:\n \"\"\"Preprocess targets by converting to tensor format and scaling coordinates.\"\"\"\n nl, ne = targets.shape\n if nl == 0:\n out = torch.zeros(batch_size, 0, ne - 1, device=self.device)\n else:\n i = targets[:, 0] # image index\n _, counts = i.unique(return_counts=True)\n counts = counts.to(dtype=torch.int32)\n out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)\n for j in range(batch_size):\n matches = i == j\n if n := matches.sum():\n out[j, :n] = targets[matches, 1:]\n out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))\n return out\n\n def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:\n \"\"\"Decode predicted object bounding box coordinates from anchor points and distribution.\"\"\"\n if self.use_dfl:\n b, a, c = pred_dist.shape # batch, anchors, channels\n pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))\n # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))\n # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)\n return dist2bbox(pred_dist, anchor_points, xywh=False)\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the sum of the loss for box, cls and dfl multiplied by batch size.\"\"\"\n loss = torch.zeros(3, device=self.device) # box, cls, dfl\n feats = preds[1] if isinstance(preds, tuple) else preds\n pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(\n (self.reg_max * 4, self.nc), 1\n )\n\n pred_scores = pred_scores.permute(0, 2, 1).contiguous()\n pred_distri = pred_distri.permute(0, 2, 1).contiguous()\n\n dtype = pred_scores.dtype\n batch_size = pred_scores.shape[0]\n imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)\n anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)\n\n # Targets\n targets = torch.cat((batch[\"batch_idx\"].view(-1, 1), batch[\"cls\"].view(-1, 1), batch[\"bboxes\"]), 1)\n targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])\n gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy\n mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)\n\n # Pboxes\n pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)\n # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)\n # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2\n\n _, target_bboxes, target_scores, fg_mask, _ = self.assigner(\n # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,\n pred_scores.detach().sigmoid(),\n (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),\n anchor_points * stride_tensor,\n gt_labels,\n gt_bboxes,\n mask_gt,\n )\n\n target_scores_sum = max(target_scores.sum(), 1)\n\n # Cls loss\n # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way\n loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE\n\n # Bbox loss\n if fg_mask.sum():\n target_bboxes /= stride_tensor\n loss[0], loss[2] = self.bbox_loss(\n pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask\n )\n\n loss[0] *= self.hyp.box # box gain\n loss[1] *= self.hyp.cls # cls gain\n loss[2] *= self.hyp.dfl # dfl gain\n\n return loss * batch_size, loss.detach() # loss(box, cls, dfl)",
"chunk_type": "class",
"name": "v8DetectionLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 194,
"end_line": 297,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Criterion class for computing training losses for YOLOv8 object detection.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist"
],
"chunk_id": "class_v8DetectionLoss_0f5de2a7"
},
{
"content": "class v8SegmentationLoss(v8DetectionLoss):\n \"\"\"Criterion class for computing training losses for YOLOv8 segmentation.\"\"\"\n\n def __init__(self, model): # model must be de-paralleled\n \"\"\"Initialize the v8SegmentationLoss class with model parameters and mask overlap setting.\"\"\"\n super().__init__(model)\n self.overlap = model.args.overlap_mask\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate and return the combined loss for detection and segmentation.\"\"\"\n loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl\n feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]\n batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width\n pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(\n (self.reg_max * 4, self.nc), 1\n )\n\n # B, grids, ..\n pred_scores = pred_scores.permute(0, 2, 1).contiguous()\n pred_distri = pred_distri.permute(0, 2, 1).contiguous()\n pred_masks = pred_masks.permute(0, 2, 1).contiguous()\n\n dtype = pred_scores.dtype\n imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)\n anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)\n\n # Targets\n try:\n batch_idx = batch[\"batch_idx\"].view(-1, 1)\n targets = torch.cat((batch_idx, batch[\"cls\"].view(-1, 1), batch[\"bboxes\"]), 1)\n targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])\n gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy\n mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)\n except RuntimeError as e:\n raise TypeError(\n \"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\\n\"\n \"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, \"\n \"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\\nVerify your dataset is a \"\n \"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' \"\n \"as an example.\\nSee https://docs.ultralytics.com/datasets/segment/ for help.\"\n ) from e\n\n # Pboxes\n pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)\n\n _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(\n pred_scores.detach().sigmoid(),\n (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),\n anchor_points * stride_tensor,\n gt_labels,\n gt_bboxes,\n mask_gt,\n )\n\n target_scores_sum = max(target_scores.sum(), 1)\n\n # Cls loss\n # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way\n loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE\n\n if fg_mask.sum():\n # Bbox loss\n loss[0], loss[3] = self.bbox_loss(\n pred_distri,\n pred_bboxes,\n anchor_points,\n target_bboxes / stride_tensor,\n target_scores,\n target_scores_sum,\n fg_mask,\n )\n # Masks loss\n masks = batch[\"masks\"].to(self.device).float()\n if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample\n masks = F.interpolate(masks[None], (mask_h, mask_w), mode=\"nearest\")[0]\n\n loss[1] = self.calculate_segmentation_loss(\n fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap\n )\n\n # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove\n else:\n loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss\n\n loss[0] *= self.hyp.box # box gain\n loss[1] *= self.hyp.box # seg gain\n loss[2] *= self.hyp.cls # cls gain\n loss[3] *= self.hyp.dfl # dfl gain\n\n return loss * batch_size, loss.detach() # loss(box, cls, dfl)\n\n @staticmethod\n def single_mask_loss(\n gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor\n ) -> torch.Tensor:\n \"\"\"\n Compute the instance segmentation loss for a single image.\n\n Args:\n gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.\n pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).\n proto (torch.Tensor): Prototype masks of shape (32, H, W).\n xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).\n area (torch.Tensor): Area of each ground truth bounding box of shape (N,).\n\n Returns:\n (torch.Tensor): The calculated mask loss for a single image.\n\n Notes:\n The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the\n predicted masks from the prototype masks and predicted mask coefficients.\n \"\"\"\n pred_mask = torch.einsum(\"in,nhw->ihw\", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)\n loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction=\"none\")\n return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()\n\n def calculate_segmentation_loss(\n self,\n fg_mask: torch.Tensor,\n masks: torch.Tensor,\n target_gt_idx: torch.Tensor,\n target_bboxes: torch.Tensor,\n batch_idx: torch.Tensor,\n proto: torch.Tensor,\n pred_masks: torch.Tensor,\n imgsz: torch.Tensor,\n overlap: bool,\n ) -> torch.Tensor:\n \"\"\"\n Calculate the loss for instance segmentation.\n\n Args:\n fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.\n masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).\n target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).\n target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).\n batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).\n proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).\n pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).\n imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).\n overlap (bool): Whether the masks in `masks` tensor overlap.\n\n Returns:\n (torch.Tensor): The calculated loss for instance segmentation.\n\n Notes:\n The batch loss can be computed for improved speed at higher memory usage.\n For example, pred_mask can be computed as follows:\n pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)\n \"\"\"\n _, _, mask_h, mask_w = proto.shape\n loss = 0\n\n # Normalize to 0-1\n target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]\n\n # Areas of target bboxes\n marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)\n\n # Normalize to mask size\n mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)\n\n for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):\n fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i\n if fg_mask_i.any():\n mask_idx = target_gt_idx_i[fg_mask_i]\n if overlap:\n gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)\n gt_mask = gt_mask.float()\n else:\n gt_mask = masks[batch_idx.view(-1) == i][mask_idx]\n\n loss += self.single_mask_loss(\n gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]\n )\n\n # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove\n else:\n loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss\n\n return loss / fg_mask.sum()",
"chunk_type": "class",
"name": "v8SegmentationLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 300,
"end_line": 480,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": "Criterion class for computing training losses for YOLOv8 segmentation.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"v8DetectionLoss"
],
"chunk_id": "class_v8SegmentationLoss_3732945c"
},
{
"content": "class v8PoseLoss(v8DetectionLoss):\n \"\"\"Criterion class for computing training losses for YOLOv8 pose estimation.\"\"\"\n\n def __init__(self, model): # model must be de-paralleled\n \"\"\"Initialize v8PoseLoss with model parameters and keypoint-specific loss functions.\"\"\"\n super().__init__(model)\n self.kpt_shape = model.model[-1].kpt_shape\n self.bce_pose = nn.BCEWithLogitsLoss()\n is_pose = self.kpt_shape == [17, 3]\n nkpt = self.kpt_shape[0] # number of keypoints\n sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt\n self.keypoint_loss = KeypointLoss(sigmas=sigmas)\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the total loss and detach it for pose estimation.\"\"\"\n loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility\n feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]\n pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(\n (self.reg_max * 4, self.nc), 1\n )\n\n # B, grids, ..\n pred_scores = pred_scores.permute(0, 2, 1).contiguous()\n pred_distri = pred_distri.permute(0, 2, 1).contiguous()\n pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()\n\n dtype = pred_scores.dtype\n imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)\n anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)\n\n # Targets\n batch_size = pred_scores.shape[0]\n batch_idx = batch[\"batch_idx\"].view(-1, 1)\n targets = torch.cat((batch_idx, batch[\"cls\"].view(-1, 1), batch[\"bboxes\"]), 1)\n targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])\n gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy\n mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)\n\n # Pboxes\n pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)\n pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)\n\n _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(\n pred_scores.detach().sigmoid(),\n (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),\n anchor_points * stride_tensor,\n gt_labels,\n gt_bboxes,\n mask_gt,\n )\n\n target_scores_sum = max(target_scores.sum(), 1)\n\n # Cls loss\n # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way\n loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE\n\n # Bbox loss\n if fg_mask.sum():\n target_bboxes /= stride_tensor\n loss[0], loss[4] = self.bbox_loss(\n pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask\n )\n keypoints = batch[\"keypoints\"].to(self.device).float().clone()\n keypoints[..., 0] *= imgsz[1]\n keypoints[..., 1] *= imgsz[0]\n\n loss[1], loss[2] = self.calculate_keypoints_loss(\n fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts\n )\n\n loss[0] *= self.hyp.box # box gain\n loss[1] *= self.hyp.pose # pose gain\n loss[2] *= self.hyp.kobj # kobj gain\n loss[3] *= self.hyp.cls # cls gain\n loss[4] *= self.hyp.dfl # dfl gain\n\n return loss * batch_size, loss.detach() # loss(box, cls, dfl)\n\n @staticmethod\n def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:\n \"\"\"Decode predicted keypoints to image coordinates.\"\"\"\n y = pred_kpts.clone()\n y[..., :2] *= 2.0\n y[..., 0] += anchor_points[:, [0]] - 0.5\n y[..., 1] += anchor_points[:, [1]] - 0.5\n return y\n\n def calculate_keypoints_loss(\n self,\n masks: torch.Tensor,\n target_gt_idx: torch.Tensor,\n keypoints: torch.Tensor,\n batch_idx: torch.Tensor,\n stride_tensor: torch.Tensor,\n target_bboxes: torch.Tensor,\n pred_kpts: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Calculate the keypoints loss for the model.\n\n This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is\n based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is\n a binary classification loss that classifies whether a keypoint is present or not.\n\n Args:\n masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).\n target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).\n keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).\n batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).\n stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).\n target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).\n pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).\n\n Returns:\n kpts_loss (torch.Tensor): The keypoints loss.\n kpts_obj_loss (torch.Tensor): The keypoints object loss.\n \"\"\"\n batch_idx = batch_idx.flatten()\n batch_size = len(masks)\n\n # Find the maximum number of keypoints in a single image\n max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()\n\n # Create a tensor to hold batched keypoints\n batched_keypoints = torch.zeros(\n (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device\n )\n\n # TODO: any idea how to vectorize this?\n # Fill batched_keypoints with keypoints based on batch_idx\n for i in range(batch_size):\n keypoints_i = keypoints[batch_idx == i]\n batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i\n\n # Expand dimensions of target_gt_idx to match the shape of batched_keypoints\n target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)\n\n # Use target_gt_idx_expanded to select keypoints from batched_keypoints\n selected_keypoints = batched_keypoints.gather(\n 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])\n )\n\n # Divide coordinates by stride\n selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)\n\n kpts_loss = 0\n kpts_obj_loss = 0\n\n if masks.any():\n gt_kpt = selected_keypoints[masks]\n area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)\n pred_kpt = pred_kpts[masks]\n kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)\n kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss\n\n if pred_kpt.shape[-1] == 3:\n kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss\n\n return kpts_loss, kpts_obj_loss",
"chunk_type": "class",
"name": "v8PoseLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 483,
"end_line": 642,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": "Criterion class for computing training losses for YOLOv8 pose estimation.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"v8DetectionLoss"
],
"chunk_id": "class_v8PoseLoss_298674cd"
},
{
"content": "class v8ClassificationLoss:\n \"\"\"Criterion class for computing training losses for classification.\"\"\"\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Compute the classification loss between predictions and true labels.\"\"\"\n preds = preds[1] if isinstance(preds, (list, tuple)) else preds\n loss = F.cross_entropy(preds, batch[\"cls\"], reduction=\"mean\")\n return loss, loss.detach()",
"chunk_type": "class",
"name": "v8ClassificationLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 645,
"end_line": 652,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": "Criterion class for computing training losses for classification.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist"
],
"chunk_id": "class_v8ClassificationLoss_649f5dc3"
},
{
"content": "class v8OBBLoss(v8DetectionLoss):\n \"\"\"Calculates losses for object detection, classification, and box distribution in rotated YOLO models.\"\"\"\n\n def __init__(self, model):\n \"\"\"Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled.\"\"\"\n super().__init__(model)\n self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)\n self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)\n\n def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:\n \"\"\"Preprocess targets for oriented bounding box detection.\"\"\"\n if targets.shape[0] == 0:\n out = torch.zeros(batch_size, 0, 6, device=self.device)\n else:\n i = targets[:, 0] # image index\n _, counts = i.unique(return_counts=True)\n counts = counts.to(dtype=torch.int32)\n out = torch.zeros(batch_size, counts.max(), 6, device=self.device)\n for j in range(batch_size):\n matches = i == j\n if n := matches.sum():\n bboxes = targets[matches, 2:]\n bboxes[..., :4].mul_(scale_tensor)\n out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)\n return out\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate and return the loss for oriented bounding box detection.\"\"\"\n loss = torch.zeros(3, device=self.device) # box, cls, dfl\n feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]\n batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width\n pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(\n (self.reg_max * 4, self.nc), 1\n )\n\n # b, grids, ..\n pred_scores = pred_scores.permute(0, 2, 1).contiguous()\n pred_distri = pred_distri.permute(0, 2, 1).contiguous()\n pred_angle = pred_angle.permute(0, 2, 1).contiguous()\n\n dtype = pred_scores.dtype\n imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)\n anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)\n\n # targets\n try:\n batch_idx = batch[\"batch_idx\"].view(-1, 1)\n targets = torch.cat((batch_idx, batch[\"cls\"].view(-1, 1), batch[\"bboxes\"].view(-1, 5)), 1)\n rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()\n targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training\n targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])\n gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr\n mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)\n except RuntimeError as e:\n raise TypeError(\n \"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\\n\"\n \"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, \"\n \"i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\\nVerify your dataset is a \"\n \"correctly formatted 'OBB' dataset using 'data=dota8.yaml' \"\n \"as an example.\\nSee https://docs.ultralytics.com/datasets/obb/ for help.\"\n ) from e\n\n # Pboxes\n pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)\n\n bboxes_for_assigner = pred_bboxes.clone().detach()\n # Only the first four elements need to be scaled\n bboxes_for_assigner[..., :4] *= stride_tensor\n _, target_bboxes, target_scores, fg_mask, _ = self.assigner(\n pred_scores.detach().sigmoid(),\n bboxes_for_assigner.type(gt_bboxes.dtype),\n anchor_points * stride_tensor,\n gt_labels,\n gt_bboxes,\n mask_gt,\n )\n\n target_scores_sum = max(target_scores.sum(), 1)\n\n # Cls loss\n # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way\n loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE\n\n # Bbox loss\n if fg_mask.sum():\n target_bboxes[..., :4] /= stride_tensor\n loss[0], loss[2] = self.bbox_loss(\n pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask\n )\n else:\n loss[0] += (pred_angle * 0).sum()\n\n loss[0] *= self.hyp.box # box gain\n loss[1] *= self.hyp.cls # cls gain\n loss[2] *= self.hyp.dfl # dfl gain\n\n return loss * batch_size, loss.detach() # loss(box, cls, dfl)\n\n def bbox_decode(\n self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor\n ) -> torch.Tensor:\n \"\"\"\n Decode predicted object bounding box coordinates from anchor points and distribution.\n\n Args:\n anchor_points (torch.Tensor): Anchor points, (h*w, 2).\n pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).\n pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).\n\n Returns:\n (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).\n \"\"\"\n if self.use_dfl:\n b, a, c = pred_dist.shape # batch, anchors, channels\n pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))\n return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)",
"chunk_type": "class",
"name": "v8OBBLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 655,
"end_line": 770,
"start_col": 0,
"end_col": 95,
"parent_name": null,
"docstring": "Calculates losses for object detection, classification, and box distribution in rotated YOLO models.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"v8DetectionLoss"
],
"chunk_id": "class_v8OBBLoss_9a1d2a7d"
},
{
"content": "class E2EDetectLoss:\n \"\"\"Criterion class for computing training losses for end-to-end detection.\"\"\"\n\n def __init__(self, model):\n \"\"\"Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model.\"\"\"\n self.one2many = v8DetectionLoss(model, tal_topk=10)\n self.one2one = v8DetectionLoss(model, tal_topk=1)\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the sum of the loss for box, cls and dfl multiplied by batch size.\"\"\"\n preds = preds[1] if isinstance(preds, tuple) else preds\n one2many = preds[\"one2many\"]\n loss_one2many = self.one2many(one2many, batch)\n one2one = preds[\"one2one\"]\n loss_one2one = self.one2one(one2one, batch)\n return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]",
"chunk_type": "class",
"name": "E2EDetectLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 773,
"end_line": 788,
"start_col": 0,
"end_col": 85,
"parent_name": null,
"docstring": "Criterion class for computing training losses for end-to-end detection.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist"
],
"chunk_id": "class_E2EDetectLoss_584a97b6"
},
{
"content": "class TVPDetectLoss:\n \"\"\"Criterion class for computing training losses for text-visual prompt detection.\"\"\"\n\n def __init__(self, model):\n \"\"\"Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model.\"\"\"\n self.vp_criterion = v8DetectionLoss(model)\n # NOTE: store following info as it's changeable in __call__\n self.ori_nc = self.vp_criterion.nc\n self.ori_no = self.vp_criterion.no\n self.ori_reg_max = self.vp_criterion.reg_max\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the loss for text-visual prompt detection.\"\"\"\n feats = preds[1] if isinstance(preds, tuple) else preds\n assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it\n\n if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:\n loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)\n return loss, loss.detach()\n\n vp_feats = self._get_vp_features(feats)\n vp_loss = self.vp_criterion(vp_feats, batch)\n box_loss = vp_loss[0][1]\n return box_loss, vp_loss[1]\n\n def _get_vp_features(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:\n \"\"\"Extract visual-prompt features from the model output.\"\"\"\n vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc\n\n self.vp_criterion.nc = vnc\n self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4\n self.vp_criterion.assigner.num_classes = vnc\n\n return [\n torch.cat((box, cls_vp), dim=1)\n for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]\n ]",
"chunk_type": "class",
"name": "TVPDetectLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 791,
"end_line": 827,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Criterion class for computing training losses for text-visual prompt detection.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist"
],
"chunk_id": "class_TVPDetectLoss_3756ee0d"
},
{
"content": "class TVPSegmentLoss(TVPDetectLoss):\n \"\"\"Criterion class for computing training losses for text-visual prompt segmentation.\"\"\"\n\n def __init__(self, model):\n \"\"\"Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model.\"\"\"\n super().__init__(model)\n self.vp_criterion = v8SegmentationLoss(model)\n\n def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Calculate the loss for text-visual prompt segmentation.\"\"\"\n feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]\n assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it\n\n if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:\n loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)\n return loss, loss.detach()\n\n vp_feats = self._get_vp_features(feats)\n vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)\n cls_loss = vp_loss[0][2]\n return cls_loss, vp_loss[1]",
"chunk_type": "class",
"name": "TVPSegmentLoss",
"file_path": "ultralytics\\ultralytics\\utils\\loss.py",
"start_line": 830,
"end_line": 850,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": "Criterion class for computing training losses for text-visual prompt segmentation.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.ops.crop_mask",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"ultralytics.utils.tal.RotatedTaskAlignedAssigner",
"ultralytics.utils.tal.TaskAlignedAssigner",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.autocast",
"metrics.bbox_iou",
"metrics.probiou",
"tal.bbox2dist",
"TVPDetectLoss"
],
"chunk_id": "class_TVPSegmentLoss_36d5dc11"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_e72c7ed4"
},
{
"content": "import warnings",
"chunk_type": "import",
"name": "warnings",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_warnings_6a87eafa"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_5c08fd5b"
},
{
"content": "from typing import Any, Dict, List, Tuple, Union",
"chunk_type": "import",
"name": "Any, Dict, List, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Tuple, Union_6ad3df08"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_bfe3320a"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_855bdd3f"
},
{
"content": "from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings",
"chunk_type": "import",
"name": "LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 99,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings_95373bc3"
},
{
"content": "OKS_SIGMA = (\n np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])\n / 10.0\n)",
"chunk_type": "variable",
"name": "OKS_SIGMA",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 14,
"end_line": 17,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_OKS_SIGMA_162900b5"
},
{
"content": "def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:\n \"\"\"\n Calculate the intersection over box2 area given box1 and box2.\n\n Args:\n box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.\n box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format.\n iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area.\n \"\"\"\n # Get the coordinates of bounding boxes\n b1_x1, b1_y1, b1_x2, b1_y2 = box1.T\n b2_x1, b2_y1, b2_x2, b2_y2 = box2.T\n\n # Intersection area\n inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (\n np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)\n ).clip(0)\n\n # Box2 area\n area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)\n if iou:\n box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)\n area = area + box1_area[:, None] - inter_area\n\n # Intersection over box2 area\n return inter_area / (area + eps)",
"chunk_type": "function",
"name": "bbox_ioa",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 20,
"end_line": 49,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": "Calculate the intersection over box2 area given box1 and box2.\n\nArgs:\n box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.\n box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format.\n iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area.",
"parameters": [
"box1: np.ndarray",
"box2: np.ndarray",
"iou: bool",
"eps: float"
],
"return_type": "np.ndarray",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_bbox_ioa_55f2214d"
},
{
"content": "def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:\n \"\"\"\n Calculate intersection-over-union (IoU) of boxes.\n\n Args:\n box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.\n box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.\n\n References:\n https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py\n \"\"\"\n # NOTE: Need .float() to get accurate iou values\n # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)\n (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)\n inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)\n\n # IoU = inter / (area1 + area2 - inter)\n return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)",
"chunk_type": "function",
"name": "box_iou",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 52,
"end_line": 73,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": "Calculate intersection-over-union (IoU) of boxes.\n\nArgs:\n box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.\n box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.\n\nReferences:\n https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py",
"parameters": [
"box1: torch.Tensor",
"box2: torch.Tensor",
"eps: float"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_box_iou_6f011eb8"
},
{
"content": "def bbox_iou(\n box1: torch.Tensor,\n box2: torch.Tensor,\n xywh: bool = True,\n GIoU: bool = False,\n DIoU: bool = False,\n CIoU: bool = False,\n eps: float = 1e-7,\n) -> torch.Tensor:\n \"\"\"\n Calculate the Intersection over Union (IoU) between bounding boxes.\n\n This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.\n For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).\n Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,\n or (x1, y1, x2, y2) if `xywh=False`.\n\n Args:\n box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.\n box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.\n xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in\n (x1, y1, x2, y2) format.\n GIoU (bool, optional): If True, calculate Generalized IoU.\n DIoU (bool, optional): If True, calculate Distance IoU.\n CIoU (bool, optional): If True, calculate Complete IoU.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.\n \"\"\"\n # Get the coordinates of bounding boxes\n if xywh: # transform from xywh to xyxy\n (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)\n w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2\n b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_\n b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_\n else: # x1, y1, x2, y2 = box1\n b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)\n b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)\n w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps\n w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps\n\n # Intersection area\n inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (\n b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)\n ).clamp_(0)\n\n # Union Area\n union = w1 * h1 + w2 * h2 - inter + eps\n\n # IoU\n iou = inter / union\n if CIoU or DIoU or GIoU:\n cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width\n ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height\n if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1\n c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squared\n rho2 = (\n (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)\n ) / 4 # center dist**2\n if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47\n v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)\n with torch.no_grad():\n alpha = v / (v - iou + (1 + eps))\n return iou - (rho2 / c2 + v * alpha) # CIoU\n return iou - rho2 / c2 # DIoU\n c_area = cw * ch + eps # convex area\n return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf\n return iou # IoU",
"chunk_type": "function",
"name": "bbox_iou",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 76,
"end_line": 144,
"start_col": 0,
"end_col": 14,
"parent_name": null,
"docstring": "Calculate the Intersection over Union (IoU) between bounding boxes.\n\nThis function supports various shapes for `box1` and `box2` as long as the last dimension is 4.\nFor instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).\nInternally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,\nor (x1, y1, x2, y2) if `xywh=False`.\n\nArgs:\n box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.\n box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.\n xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in\n (x1, y1, x2, y2) format.\n GIoU (bool, optional): If True, calculate Generalized IoU.\n DIoU (bool, optional): If True, calculate Distance IoU.\n CIoU (bool, optional): If True, calculate Complete IoU.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.",
"parameters": [
"box1: torch.Tensor",
"box2: torch.Tensor",
"xywh: bool",
"GIoU: bool",
"DIoU: bool",
"CIoU: bool",
"eps: float"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_bbox_iou_459ddd01"
},
{
"content": "def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:\n \"\"\"\n Calculate masks IoU.\n\n Args:\n mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the\n product of image width and height.\n mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the\n product of image width and height.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): A tensor of shape (N, M) representing masks IoU.\n \"\"\"\n intersection = torch.matmul(mask1, mask2.T).clamp_(0)\n union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection\n return intersection / (union + eps)",
"chunk_type": "function",
"name": "mask_iou",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 147,
"end_line": 163,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": "Calculate masks IoU.\n\nArgs:\n mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the\n product of image width and height.\n mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the\n product of image width and height.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): A tensor of shape (N, M) representing masks IoU.",
"parameters": [
"mask1: torch.Tensor",
"mask2: torch.Tensor",
"eps: float"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_mask_iou_acb5a26e"
},
{
"content": "def kpt_iou(\n kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: List[float], eps: float = 1e-7\n) -> torch.Tensor:\n \"\"\"\n Calculate Object Keypoint Similarity (OKS).\n\n Args:\n kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.\n kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.\n area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.\n sigma (list): A list containing 17 values representing keypoint scales.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.\n \"\"\"\n d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)\n sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )\n kpt_mask = kpt1[..., 2] != 0 # (N, 17)\n e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval\n # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula\n return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)",
"chunk_type": "function",
"name": "kpt_iou",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 166,
"end_line": 187,
"start_col": 0,
"end_col": 87,
"parent_name": null,
"docstring": "Calculate Object Keypoint Similarity (OKS).\n\nArgs:\n kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.\n kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.\n area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.\n sigma (list): A list containing 17 values representing keypoint scales.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.",
"parameters": [
"kpt1: torch.Tensor",
"kpt2: torch.Tensor",
"area: torch.Tensor",
"sigma: List[float]",
"eps: float"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_kpt_iou_95674ad5"
},
{
"content": "def _get_covariance_matrix(boxes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"\n Generate covariance matrix from oriented bounding boxes.\n\n Args:\n boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.\n\n Returns:\n (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.\n \"\"\"\n # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.\n gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)\n a, b, c = gbbs.split(1, dim=-1)\n cos = c.cos()\n sin = c.sin()\n cos2 = cos.pow(2)\n sin2 = sin.pow(2)\n return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin",
"chunk_type": "function",
"name": "_get_covariance_matrix",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 190,
"end_line": 207,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": "Generate covariance matrix from oriented bounding boxes.\n\nArgs:\n boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.\n\nReturns:\n (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.",
"parameters": [
"boxes: torch.Tensor"
],
"return_type": "Tuple[torch.Tensor, torch.Tensor, torch.Tensor]",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function__get_covariance_matrix_9af498db"
},
{
"content": "def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:\n \"\"\"\n Calculate probabilistic IoU between oriented bounding boxes.\n\n Args:\n obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.\n obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.\n CIoU (bool, optional): If True, calculate CIoU.\n eps (float, optional): Small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): OBB similarities, shape (N,).\n\n Notes:\n OBB format: [center_x, center_y, width, height, rotation_angle].\n\n References:\n https://arxiv.org/pdf/2106.06072v1.pdf\n \"\"\"\n x1, y1 = obb1[..., :2].split(1, dim=-1)\n x2, y2 = obb2[..., :2].split(1, dim=-1)\n a1, b1, c1 = _get_covariance_matrix(obb1)\n a2, b2, c2 = _get_covariance_matrix(obb2)\n\n t1 = (\n ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)\n ) * 0.25\n t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5\n t3 = (\n ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))\n / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)\n + eps\n ).log() * 0.5\n bd = (t1 + t2 + t3).clamp(eps, 100.0)\n hd = (1.0 - (-bd).exp() + eps).sqrt()\n iou = 1 - hd\n if CIoU: # only include the wh aspect ratio part\n w1, h1 = obb1[..., 2:4].split(1, dim=-1)\n w2, h2 = obb2[..., 2:4].split(1, dim=-1)\n v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)\n with torch.no_grad():\n alpha = v / (v - iou + (1 + eps))\n return iou - v * alpha # CIoU\n return iou",
"chunk_type": "function",
"name": "probiou",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 210,
"end_line": 253,
"start_col": 0,
"end_col": 14,
"parent_name": null,
"docstring": "Calculate probabilistic IoU between oriented bounding boxes.\n\nArgs:\n obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.\n obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.\n CIoU (bool, optional): If True, calculate CIoU.\n eps (float, optional): Small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): OBB similarities, shape (N,).\n\nNotes:\n OBB format: [center_x, center_y, width, height, rotation_angle].\n\nReferences:\n https://arxiv.org/pdf/2106.06072v1.pdf",
"parameters": [
"obb1: torch.Tensor",
"obb2: torch.Tensor",
"CIoU: bool",
"eps: float"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_probiou_78b462b7"
},
{
"content": "def batch_probiou(\n obb1: Union[torch.Tensor, np.ndarray], obb2: Union[torch.Tensor, np.ndarray], eps: float = 1e-7\n) -> torch.Tensor:\n \"\"\"\n Calculate the probabilistic IoU between oriented bounding boxes.\n\n Args:\n obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.\n obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.\n eps (float, optional): A small value to avoid division by zero.\n\n Returns:\n (torch.Tensor): A tensor of shape (N, M) representing obb similarities.\n\n References:\n https://arxiv.org/pdf/2106.06072v1.pdf\n \"\"\"\n obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1\n obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2\n\n x1, y1 = obb1[..., :2].split(1, dim=-1)\n x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))\n a1, b1, c1 = _get_covariance_matrix(obb1)\n a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))\n\n t1 = (\n ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)\n ) * 0.25\n t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5\n t3 = (\n ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))\n / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)\n + eps\n ).log() * 0.5\n bd = (t1 + t2 + t3).clamp(eps, 100.0)\n hd = (1.0 - (-bd).exp() + eps).sqrt()\n return 1 - hd",
"chunk_type": "function",
"name": "batch_probiou",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 256,
"end_line": 292,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": "Calculate the probabilistic IoU between oriented bounding boxes.\n\nArgs:\n obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.\n obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.\n eps (float, optional): A small value to avoid division by zero.\n\nReturns:\n (torch.Tensor): A tensor of shape (N, M) representing obb similarities.\n\nReferences:\n https://arxiv.org/pdf/2106.06072v1.pdf",
"parameters": [
"obb1: Union[torch.Tensor, np.ndarray]",
"obb2: Union[torch.Tensor, np.ndarray]",
"eps: float"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_batch_probiou_1b0ac7e4"
},
{
"content": "def smooth_bce(eps: float = 0.1) -> Tuple[float, float]:\n \"\"\"\n Compute smoothed positive and negative Binary Cross-Entropy targets.\n\n Args:\n eps (float, optional): The epsilon value for label smoothing.\n\n Returns:\n pos (float): Positive label smoothing BCE target.\n neg (float): Negative label smoothing BCE target.\n\n References:\n https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441\n \"\"\"\n return 1.0 - 0.5 * eps, 0.5 * eps",
"chunk_type": "function",
"name": "smooth_bce",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 295,
"end_line": 309,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": "Compute smoothed positive and negative Binary Cross-Entropy targets.\n\nArgs:\n eps (float, optional): The epsilon value for label smoothing.\n\nReturns:\n pos (float): Positive label smoothing BCE target.\n neg (float): Negative label smoothing BCE target.\n\nReferences:\n https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441",
"parameters": [
"eps: float"
],
"return_type": "Tuple[float, float]",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_smooth_bce_94ea830c"
},
{
"content": "class ConfusionMatrix(DataExportMixin):\n \"\"\"\n A class for calculating and updating a confusion matrix for object detection and classification tasks.\n\n Attributes:\n task (str): The type of task, either 'detect' or 'classify'.\n matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.\n nc (int): The number of category.\n names (List[str]): The names of the classes, used as labels on the plot.\n \"\"\"\n\n def __init__(self, names: List[str] = [], task: str = \"detect\"):\n \"\"\"\n Initialize a ConfusionMatrix instance.\n\n Args:\n names (List[str], optional): Names of classes, used as labels on the plot.\n task (str, optional): Type of task, either 'detect' or 'classify'.\n \"\"\"\n self.task = task\n self.nc = len(names) # number of classes\n self.matrix = np.zeros((self.nc, self.nc)) if self.task == \"classify\" else np.zeros((self.nc + 1, self.nc + 1))\n self.names = names # name of classes\n\n def process_cls_preds(self, preds, targets):\n \"\"\"\n Update confusion matrix for classification task.\n\n Args:\n preds (Array[N, min(nc,5)]): Predicted class labels.\n targets (Array[N, 1]): Ground truth class labels.\n \"\"\"\n preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)\n for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):\n self.matrix[p][t] += 1\n\n def process_batch(\n self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45\n ) -> None:\n \"\"\"\n Update confusion matrix for object detection task.\n\n Args:\n detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.\n Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be\n Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.\n batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and\n 'cls' (Array[M]) keys, where M is the number of ground truth objects.\n conf (float, optional): Confidence threshold for detections.\n iou_thres (float, optional): IoU threshold for matching detections to ground truth.\n \"\"\"\n gt_cls, gt_bboxes = batch[\"cls\"], batch[\"bboxes\"]\n is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB\n conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed\n no_pred = len(detections[\"cls\"]) == 0\n if gt_cls.shape[0] == 0: # Check if labels is empty\n if not no_pred:\n detections = {k: detections[k][detections[\"conf\"] > conf] for k in {\"cls\", \"bboxes\"}}\n detection_classes = detections[\"cls\"].int().tolist()\n for dc in detection_classes:\n self.matrix[dc, self.nc] += 1 # false positives\n return\n if no_pred:\n gt_classes = gt_cls.int().tolist()\n for gc in gt_classes:\n self.matrix[self.nc, gc] += 1 # background FN\n return\n\n detections = {k: detections[k][detections[\"conf\"] > conf] for k in {\"cls\", \"bboxes\"}}\n gt_classes = gt_cls.int().tolist()\n detection_classes = detections[\"cls\"].int().tolist()\n bboxes = detections[\"bboxes\"]\n iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)\n\n x = torch.where(iou > iou_thres)\n if x[0].shape[0]:\n matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()\n if x[0].shape[0] > 1:\n matches = matches[matches[:, 2].argsort()[::-1]]\n matches = matches[np.unique(matches[:, 1], return_index=True)[1]]\n matches = matches[matches[:, 2].argsort()[::-1]]\n matches = matches[np.unique(matches[:, 0], return_index=True)[1]]\n else:\n matches = np.zeros((0, 3))\n\n n = matches.shape[0] > 0\n m0, m1, _ = matches.transpose().astype(int)\n for i, gc in enumerate(gt_classes):\n j = m0 == i\n if n and sum(j) == 1:\n self.matrix[detection_classes[m1[j].item()], gc] += 1 # correct\n else:\n self.matrix[self.nc, gc] += 1 # true background\n\n for i, dc in enumerate(detection_classes):\n if not any(m1 == i):\n self.matrix[dc, self.nc] += 1 # predicted background\n\n def matrix(self):\n \"\"\"Return the confusion matrix.\"\"\"\n return self.matrix\n\n def tp_fp(self) -> Tuple[np.ndarray, np.ndarray]:\n \"\"\"\n Return true positives and false positives.\n\n Returns:\n tp (np.ndarray): True positives.\n fp (np.ndarray): False positives.\n \"\"\"\n tp = self.matrix.diagonal() # true positives\n fp = self.matrix.sum(1) - tp # false positives\n # fn = self.matrix.sum(0) - tp # false negatives (missed detections)\n return (tp, fp) if self.task == \"classify\" else (tp[:-1], fp[:-1]) # remove background class if task=detect\n\n @TryExcept(msg=\"ConfusionMatrix plot failure\")\n @plt_settings()\n def plot(self, normalize: bool = True, save_dir: str = \"\", on_plot=None):\n \"\"\"\n Plot the confusion matrix using matplotlib and save it to a file.\n\n Args:\n normalize (bool, optional): Whether to normalize the confusion matrix.\n save_dir (str, optional): Directory where the plot will be saved.\n on_plot (callable, optional): An optional callback to pass plots path and data when they are rendered.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns\n array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)\n\n fig, ax = plt.subplots(1, 1, figsize=(12, 9))\n names, n = self.names, self.nc\n if self.nc >= 100: # downsample for large class count\n k = max(2, self.nc // 60) # step size for downsampling, always > 1\n keep_idx = slice(None, None, k) # create slice instead of array\n names = names[keep_idx] # slice class names\n array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols\n n = (self.nc + k - 1) // k # number of retained classes\n nc = nn = n if self.task == \"classify\" else n + 1 # adjust for background if needed\n ticklabels = (names + [\"background\"]) if (0 < nn < 99) and (nn == nc) else \"auto\"\n xy_ticks = np.arange(len(ticklabels))\n tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6\n label_fontsize = max(6, 12 - 0.1 * nc)\n title_fontsize = max(6, 12 - 0.1 * nc)\n btm = max(0.1, 0.25 - 0.001 * nc) # Minimum value is 0.1\n with warnings.catch_warnings():\n warnings.simplefilter(\"ignore\") # suppress empty matrix RuntimeWarning: All-NaN slice encountered\n im = ax.imshow(array, cmap=\"Blues\", vmin=0.0, interpolation=\"none\")\n ax.xaxis.set_label_position(\"bottom\")\n if nc < 30: # Add score for each cell of confusion matrix\n color_threshold = 0.45 * (1 if normalize else np.nanmax(array)) # text color threshold\n for i, row in enumerate(array[:nc]):\n for j, val in enumerate(row[:nc]):\n val = array[i, j]\n if np.isnan(val):\n continue\n ax.text(\n j,\n i,\n f\"{val:.2f}\" if normalize else f\"{int(val)}\",\n ha=\"center\",\n va=\"center\",\n fontsize=10,\n color=\"white\" if val > color_threshold else \"black\",\n )\n cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.05)\n title = \"Confusion Matrix\" + \" Normalized\" * normalize\n ax.set_xlabel(\"True\", fontsize=label_fontsize, labelpad=10)\n ax.set_ylabel(\"Predicted\", fontsize=label_fontsize, labelpad=10)\n ax.set_title(title, fontsize=title_fontsize, pad=20)\n ax.set_xticks(xy_ticks)\n ax.set_yticks(xy_ticks)\n ax.tick_params(axis=\"x\", bottom=True, top=False, labelbottom=True, labeltop=False)\n ax.tick_params(axis=\"y\", left=True, right=False, labelleft=True, labelright=False)\n if ticklabels != \"auto\":\n ax.set_xticklabels(ticklabels, fontsize=tick_fontsize, rotation=90, ha=\"center\")\n ax.set_yticklabels(ticklabels, fontsize=tick_fontsize)\n for s in {\"left\", \"right\", \"bottom\", \"top\", \"outline\"}:\n if s != \"outline\":\n ax.spines[s].set_visible(False) # Confusion matrix plot don't have outline\n cbar.ax.spines[s].set_visible(False)\n fig.subplots_adjust(left=0, right=0.84, top=0.94, bottom=btm) # Adjust layout to ensure equal margins\n plot_fname = Path(save_dir) / f\"{title.lower().replace(' ', '_')}.png\"\n fig.savefig(plot_fname, dpi=250)\n plt.close(fig)\n if on_plot:\n on_plot(plot_fname)\n\n def print(self):\n \"\"\"Print the confusion matrix to the console.\"\"\"\n for i in range(self.matrix.shape[0]):\n LOGGER.info(\" \".join(map(str, self.matrix[i])))\n\n def summary(self, normalize: bool = False, decimals: int = 5) -> List[Dict[str, float]]:\n \"\"\"\n Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional\n normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON, or SQL.\n\n Args:\n normalize (bool): Whether to normalize the confusion matrix values.\n decimals (int): Number of decimal places to round the output values to.\n\n Returns:\n (List[Dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding values for all actual classes.\n\n Examples:\n >>> results = model.val(data=\"coco8.yaml\", plots=True)\n >>> cm_dict = results.confusion_matrix.summary(normalize=True, decimals=5)\n >>> print(cm_dict)\n \"\"\"\n import re\n\n names = self.names if self.task == \"classify\" else self.names + [\"background\"]\n clean_names, seen = [], set()\n for name in names:\n clean_name = re.sub(r\"[^a-zA-Z0-9_]\", \"_\", name)\n original_clean = clean_name\n counter = 1\n while clean_name.lower() in seen:\n clean_name = f\"{original_clean}_{counter}\"\n counter += 1\n seen.add(clean_name.lower())\n clean_names.append(clean_name)\n array = (self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1)).round(decimals)\n return [\n dict({\"Predicted\": clean_names[i]}, **{clean_names[j]: array[i, j] for j in range(len(clean_names))})\n for i in range(len(clean_names))\n ]",
"chunk_type": "class",
"name": "ConfusionMatrix",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 312,
"end_line": 540,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "A class for calculating and updating a confusion matrix for object detection and classification tasks.\n\nAttributes:\n task (str): The type of task, either 'detect' or 'classify'.\n matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.\n nc (int): The number of category.\n names (List[str]): The names of the classes, used as labels on the plot.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re",
"DataExportMixin"
],
"chunk_id": "class_ConfusionMatrix_f68dd0e5"
},
{
"content": "def smooth(y: np.ndarray, f: float = 0.05) -> np.ndarray:\n \"\"\"Box filter of fraction f.\"\"\"\n nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)\n p = np.ones(nf // 2) # ones padding\n yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded\n return np.convolve(yp, np.ones(nf) / nf, mode=\"valid\") # y-smoothed",
"chunk_type": "function",
"name": "smooth",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 543,
"end_line": 548,
"start_col": 0,
"end_col": 58,
"parent_name": null,
"docstring": "Box filter of fraction f.",
"parameters": [
"y: np.ndarray",
"f: float"
],
"return_type": "np.ndarray",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_smooth_9012a130"
},
{
"content": "def plot_pr_curve(\n px: np.ndarray,\n py: np.ndarray,\n ap: np.ndarray,\n save_dir: Path = Path(\"pr_curve.png\"),\n names: Dict[int, str] = {},\n on_plot=None,\n):\n \"\"\"\n Plot precision-recall curve.\n\n Args:\n px (np.ndarray): X values for the PR curve.\n py (np.ndarray): Y values for the PR curve.\n ap (np.ndarray): Average precision values.\n save_dir (Path, optional): Path to save the plot.\n names (Dict[int, str], optional): Dictionary mapping class indices to class names.\n on_plot (callable, optional): Function to call after plot is saved.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)\n py = np.stack(py, axis=1)\n\n if 0 < len(names) < 21: # display per-class legend if < 21 classes\n for i, y in enumerate(py.T):\n ax.plot(px, y, linewidth=1, label=f\"{names[i]} {ap[i, 0]:.3f}\") # plot(recall, precision)\n else:\n ax.plot(px, py, linewidth=1, color=\"grey\") # plot(recall, precision)\n\n ax.plot(px, py.mean(1), linewidth=3, color=\"blue\", label=f\"all classes {ap[:, 0].mean():.3f} mAP@0.5\")\n ax.set_xlabel(\"Recall\")\n ax.set_ylabel(\"Precision\")\n ax.set_xlim(0, 1)\n ax.set_ylim(0, 1)\n ax.legend(bbox_to_anchor=(1.04, 1), loc=\"upper left\")\n ax.set_title(\"Precision-Recall Curve\")\n fig.savefig(save_dir, dpi=250)\n plt.close(fig)\n if on_plot:\n on_plot(save_dir)",
"chunk_type": "function",
"name": "plot_pr_curve",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 552,
"end_line": 592,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "Plot precision-recall curve.\n\nArgs:\n px (np.ndarray): X values for the PR curve.\n py (np.ndarray): Y values for the PR curve.\n ap (np.ndarray): Average precision values.\n save_dir (Path, optional): Path to save the plot.\n names (Dict[int, str], optional): Dictionary mapping class indices to class names.\n on_plot (callable, optional): Function to call after plot is saved.",
"parameters": [
"px: np.ndarray",
"py: np.ndarray",
"ap: np.ndarray",
"save_dir: Path",
"names: Dict[int, str]",
"on_plot"
],
"return_type": null,
"decorators": [
"plt_settings()"
],
"complexity_score": 4,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_plot_pr_curve_590548b1"
},
{
"content": "def plot_mc_curve(\n px: np.ndarray,\n py: np.ndarray,\n save_dir: Path = Path(\"mc_curve.png\"),\n names: Dict[int, str] = {},\n xlabel: str = \"Confidence\",\n ylabel: str = \"Metric\",\n on_plot=None,\n):\n \"\"\"\n Plot metric-confidence curve.\n\n Args:\n px (np.ndarray): X values for the metric-confidence curve.\n py (np.ndarray): Y values for the metric-confidence curve.\n save_dir (Path, optional): Path to save the plot.\n names (Dict[int, str], optional): Dictionary mapping class indices to class names.\n xlabel (str, optional): X-axis label.\n ylabel (str, optional): Y-axis label.\n on_plot (callable, optional): Function to call after plot is saved.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)\n\n if 0 < len(names) < 21: # display per-class legend if < 21 classes\n for i, y in enumerate(py):\n ax.plot(px, y, linewidth=1, label=f\"{names[i]}\") # plot(confidence, metric)\n else:\n ax.plot(px, py.T, linewidth=1, color=\"grey\") # plot(confidence, metric)\n\n y = smooth(py.mean(0), 0.1)\n ax.plot(px, y, linewidth=3, color=\"blue\", label=f\"all classes {y.max():.2f} at {px[y.argmax()]:.3f}\")\n ax.set_xlabel(xlabel)\n ax.set_ylabel(ylabel)\n ax.set_xlim(0, 1)\n ax.set_ylim(0, 1)\n ax.legend(bbox_to_anchor=(1.04, 1), loc=\"upper left\")\n ax.set_title(f\"{ylabel}-Confidence Curve\")\n fig.savefig(save_dir, dpi=250)\n plt.close(fig)\n if on_plot:\n on_plot(save_dir)",
"chunk_type": "function",
"name": "plot_mc_curve",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 596,
"end_line": 638,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "Plot metric-confidence curve.\n\nArgs:\n px (np.ndarray): X values for the metric-confidence curve.\n py (np.ndarray): Y values for the metric-confidence curve.\n save_dir (Path, optional): Path to save the plot.\n names (Dict[int, str], optional): Dictionary mapping class indices to class names.\n xlabel (str, optional): X-axis label.\n ylabel (str, optional): Y-axis label.\n on_plot (callable, optional): Function to call after plot is saved.",
"parameters": [
"px: np.ndarray",
"py: np.ndarray",
"save_dir: Path",
"names: Dict[int, str]",
"xlabel: str",
"ylabel: str",
"on_plot"
],
"return_type": null,
"decorators": [
"plt_settings()"
],
"complexity_score": 4,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_plot_mc_curve_70749d7c"
},
{
"content": "def compute_ap(recall: List[float], precision: List[float]) -> Tuple[float, np.ndarray, np.ndarray]:\n \"\"\"\n Compute the average precision (AP) given the recall and precision curves.\n\n Args:\n recall (list): The recall curve.\n precision (list): The precision curve.\n\n Returns:\n ap (float): Average precision.\n mpre (np.ndarray): Precision envelope curve.\n mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.\n \"\"\"\n # Append sentinel values to beginning and end\n mrec = np.concatenate(([0.0], recall, [1.0]))\n mpre = np.concatenate(([1.0], precision, [0.0]))\n\n # Compute the precision envelope\n mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))\n\n # Integrate area under curve\n method = \"interp\" # methods: 'continuous', 'interp'\n if method == \"interp\":\n x = np.linspace(0, 1, 101) # 101-point interp (COCO)\n func = np.trapezoid if checks.check_version(np.__version__, \">=2.0\") else np.trapz # np.trapz deprecated\n ap = func(np.interp(x, mrec, mpre), x) # integrate\n else: # 'continuous'\n i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes\n ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve\n\n return ap, mpre, mrec",
"chunk_type": "function",
"name": "compute_ap",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 641,
"end_line": 671,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "Compute the average precision (AP) given the recall and precision curves.\n\nArgs:\n recall (list): The recall curve.\n precision (list): The precision curve.\n\nReturns:\n ap (float): Average precision.\n mpre (np.ndarray): Precision envelope curve.\n mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.",
"parameters": [
"recall: List[float]",
"precision: List[float]"
],
"return_type": "Tuple[float, np.ndarray, np.ndarray]",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_compute_ap_0f423827"
},
{
"content": "def ap_per_class(\n tp: np.ndarray,\n conf: np.ndarray,\n pred_cls: np.ndarray,\n target_cls: np.ndarray,\n plot: bool = False,\n on_plot=None,\n save_dir: Path = Path(),\n names: Dict[int, str] = {},\n eps: float = 1e-16,\n prefix: str = \"\",\n) -> Tuple:\n \"\"\"\n Compute the average precision per class for object detection evaluation.\n\n Args:\n tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).\n conf (np.ndarray): Array of confidence scores of the detections.\n pred_cls (np.ndarray): Array of predicted classes of the detections.\n target_cls (np.ndarray): Array of true classes of the detections.\n plot (bool, optional): Whether to plot PR curves or not.\n on_plot (callable, optional): A callback to pass plots path and data when they are rendered.\n save_dir (Path, optional): Directory to save the PR curves.\n names (Dict[int, str], optional): Dictionary of class names to plot PR curves.\n eps (float, optional): A small value to avoid division by zero.\n prefix (str, optional): A prefix string for saving the plot files.\n\n Returns:\n tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.\n fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.\n p (np.ndarray): Precision values at threshold given by max F1 metric for each class.\n r (np.ndarray): Recall values at threshold given by max F1 metric for each class.\n f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.\n ap (np.ndarray): Average precision for each class at different IoU thresholds.\n unique_classes (np.ndarray): An array of unique classes that have data.\n p_curve (np.ndarray): Precision curves for each class.\n r_curve (np.ndarray): Recall curves for each class.\n f1_curve (np.ndarray): F1-score curves for each class.\n x (np.ndarray): X-axis values for the curves.\n prec_values (np.ndarray): Precision values at mAP@0.5 for each class.\n \"\"\"\n # Sort by objectness\n i = np.argsort(-conf)\n tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]\n\n # Find unique classes\n unique_classes, nt = np.unique(target_cls, return_counts=True)\n nc = unique_classes.shape[0] # number of classes, number of detections\n\n # Create Precision-Recall curve and compute AP for each class\n x, prec_values = np.linspace(0, 1, 1000), []\n\n # Average precision, precision and recall curves\n ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))\n for ci, c in enumerate(unique_classes):\n i = pred_cls == c\n n_l = nt[ci] # number of labels\n n_p = i.sum() # number of predictions\n if n_p == 0 or n_l == 0:\n continue\n\n # Accumulate FPs and TPs\n fpc = (1 - tp[i]).cumsum(0)\n tpc = tp[i].cumsum(0)\n\n # Recall\n recall = tpc / (n_l + eps) # recall curve\n r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases\n\n # Precision\n precision = tpc / (tpc + fpc) # precision curve\n p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score\n\n # AP from recall-precision curve\n for j in range(tp.shape[1]):\n ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])\n if j == 0:\n prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5\n\n prec_values = np.array(prec_values) if prec_values else np.zeros((1, 1000)) # (nc, 1000)\n\n # Compute F1 (harmonic mean of precision and recall)\n f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)\n names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data\n if plot:\n plot_pr_curve(x, prec_values, ap, save_dir / f\"{prefix}PR_curve.png\", names, on_plot=on_plot)\n plot_mc_curve(x, f1_curve, save_dir / f\"{prefix}F1_curve.png\", names, ylabel=\"F1\", on_plot=on_plot)\n plot_mc_curve(x, p_curve, save_dir / f\"{prefix}P_curve.png\", names, ylabel=\"Precision\", on_plot=on_plot)\n plot_mc_curve(x, r_curve, save_dir / f\"{prefix}R_curve.png\", names, ylabel=\"Recall\", on_plot=on_plot)\n\n i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index\n p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values\n tp = (r * nt).round() # true positives\n fp = (tp / (p + eps) - tp).round() # false positives\n return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values",
"chunk_type": "function",
"name": "ap_per_class",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 674,
"end_line": 768,
"start_col": 0,
"end_col": 103,
"parent_name": null,
"docstring": "Compute the average precision per class for object detection evaluation.\n\nArgs:\n tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).\n conf (np.ndarray): Array of confidence scores of the detections.\n pred_cls (np.ndarray): Array of predicted classes of the detections.\n target_cls (np.ndarray): Array of true classes of the detections.\n plot (bool, optional): Whether to plot PR curves or not.\n on_plot (callable, optional): A callback to pass plots path and data when they are rendered.\n save_dir (Path, optional): Directory to save the PR curves.\n names (Dict[int, str], optional): Dictionary of class names to plot PR curves.\n eps (float, optional): A small value to avoid division by zero.\n prefix (str, optional): A prefix string for saving the plot files.\n\nReturns:\n tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.\n fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.\n p (np.ndarray): Precision values at threshold given by max F1 metric for each class.\n r (np.ndarray): Recall values at threshold given by max F1 metric for each class.\n f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.\n ap (np.ndarray): Average precision for each class at different IoU thresholds.\n unique_classes (np.ndarray): An array of unique classes that have data.\n p_curve (np.ndarray): Precision curves for each class.\n r_curve (np.ndarray): Recall curves for each class.\n f1_curve (np.ndarray): F1-score curves for each class.\n x (np.ndarray): X-axis values for the curves.\n prec_values (np.ndarray): Precision values at mAP@0.5 for each class.",
"parameters": [
"tp: np.ndarray",
"conf: np.ndarray",
"pred_cls: np.ndarray",
"target_cls: np.ndarray",
"plot: bool",
"on_plot",
"save_dir: Path",
"names: Dict[int, str]",
"eps: float",
"prefix: str"
],
"return_type": "Tuple",
"decorators": [],
"complexity_score": 7,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re"
],
"chunk_id": "function_ap_per_class_3008df19"
},
{
"content": "class Metric(SimpleClass):\n \"\"\"\n Class for computing evaluation metrics for Ultralytics YOLO models.\n\n Attributes:\n p (list): Precision for each class. Shape: (nc,).\n r (list): Recall for each class. Shape: (nc,).\n f1 (list): F1 score for each class. Shape: (nc,).\n all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).\n ap_class_index (list): Index of class for each AP score. Shape: (nc,).\n nc (int): Number of classes.\n\n Methods:\n ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].\n ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].\n mp(): Mean precision of all classes. Returns: Float.\n mr(): Mean recall of all classes. Returns: Float.\n map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.\n map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.\n map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.\n mean_results(): Mean of results, returns mp, mr, map50, map.\n class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].\n maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).\n fitness(): Model fitness as a weighted combination of metrics. Returns: Float.\n update(results): Update metric attributes with new evaluation results.\n \"\"\"\n\n def __init__(self) -> None:\n \"\"\"Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model.\"\"\"\n self.p = [] # (nc, )\n self.r = [] # (nc, )\n self.f1 = [] # (nc, )\n self.all_ap = [] # (nc, 10)\n self.ap_class_index = [] # (nc, )\n self.nc = 0\n\n @property\n def ap50(self) -> Union[np.ndarray, List]:\n \"\"\"\n Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.\n\n Returns:\n (np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.\n \"\"\"\n return self.all_ap[:, 0] if len(self.all_ap) else []\n\n @property\n def ap(self) -> Union[np.ndarray, List]:\n \"\"\"\n Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.\n\n Returns:\n (np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.\n \"\"\"\n return self.all_ap.mean(1) if len(self.all_ap) else []\n\n @property\n def mp(self) -> float:\n \"\"\"\n Return the Mean Precision of all classes.\n\n Returns:\n (float): The mean precision of all classes.\n \"\"\"\n return self.p.mean() if len(self.p) else 0.0\n\n @property\n def mr(self) -> float:\n \"\"\"\n Return the Mean Recall of all classes.\n\n Returns:\n (float): The mean recall of all classes.\n \"\"\"\n return self.r.mean() if len(self.r) else 0.0\n\n @property\n def map50(self) -> float:\n \"\"\"\n Return the mean Average Precision (mAP) at an IoU threshold of 0.5.\n\n Returns:\n (float): The mAP at an IoU threshold of 0.5.\n \"\"\"\n return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0\n\n @property\n def map75(self) -> float:\n \"\"\"\n Return the mean Average Precision (mAP) at an IoU threshold of 0.75.\n\n Returns:\n (float): The mAP at an IoU threshold of 0.75.\n \"\"\"\n return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0\n\n @property\n def map(self) -> float:\n \"\"\"\n Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.\n\n Returns:\n (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.\n \"\"\"\n return self.all_ap.mean() if len(self.all_ap) else 0.0\n\n def mean_results(self) -> List[float]:\n \"\"\"Return mean of results, mp, mr, map50, map.\"\"\"\n return [self.mp, self.mr, self.map50, self.map]\n\n def class_result(self, i: int) -> Tuple[float, float, float, float]:\n \"\"\"Return class-aware result, p[i], r[i], ap50[i], ap[i].\"\"\"\n return self.p[i], self.r[i], self.ap50[i], self.ap[i]\n\n @property\n def maps(self) -> np.ndarray:\n \"\"\"Return mAP of each class.\"\"\"\n maps = np.zeros(self.nc) + self.map\n for i, c in enumerate(self.ap_class_index):\n maps[c] = self.ap[i]\n return maps\n\n def fitness(self) -> float:\n \"\"\"Return model fitness as a weighted combination of metrics.\"\"\"\n w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]\n return (np.nan_to_num(np.array(self.mean_results())) * w).sum()\n\n def update(self, results: tuple):\n \"\"\"\n Update the evaluation metrics with a new set of results.\n\n Args:\n results (tuple): A tuple containing evaluation metrics:\n - p (list): Precision for each class.\n - r (list): Recall for each class.\n - f1 (list): F1 score for each class.\n - all_ap (list): AP scores for all classes and all IoU thresholds.\n - ap_class_index (list): Index of class for each AP score.\n - p_curve (list): Precision curve for each class.\n - r_curve (list): Recall curve for each class.\n - f1_curve (list): F1 curve for each class.\n - px (list): X values for the curves.\n - prec_values (list): Precision values for each class.\n \"\"\"\n (\n self.p,\n self.r,\n self.f1,\n self.all_ap,\n self.ap_class_index,\n self.p_curve,\n self.r_curve,\n self.f1_curve,\n self.px,\n self.prec_values,\n ) = results\n\n @property\n def curves(self) -> List:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return []\n\n @property\n def curves_results(self) -> List[List]:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return [\n [self.px, self.prec_values, \"Recall\", \"Precision\"],\n [self.px, self.f1_curve, \"Confidence\", \"F1\"],\n [self.px, self.p_curve, \"Confidence\", \"Precision\"],\n [self.px, self.r_curve, \"Confidence\", \"Recall\"],\n ]",
"chunk_type": "class",
"name": "Metric",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 771,
"end_line": 941,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Class for computing evaluation metrics for Ultralytics YOLO models.\n\nAttributes:\n p (list): Precision for each class. Shape: (nc,).\n r (list): Recall for each class. Shape: (nc,).\n f1 (list): F1 score for each class. Shape: (nc,).\n all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).\n ap_class_index (list): Index of class for each AP score. Shape: (nc,).\n nc (int): Number of classes.\n\nMethods:\n ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].\n ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].\n mp(): Mean precision of all classes. Returns: Float.\n mr(): Mean recall of all classes. Returns: Float.\n map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.\n map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.\n map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.\n mean_results(): Mean of results, returns mp, mr, map50, map.\n class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].\n maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).\n fitness(): Model fitness as a weighted combination of metrics. Returns: Float.\n update(results): Update metric attributes with new evaluation results.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re",
"SimpleClass"
],
"chunk_id": "class_Metric_62f28f7e"
},
{
"content": "class DetMetrics(SimpleClass, DataExportMixin):\n \"\"\"\n Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).\n\n Attributes:\n names (Dict[int, str]): A dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'detect'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n \"\"\"\n\n def __init__(self, names: Dict[int, str] = {}) -> None:\n \"\"\"\n Initialize a DetMetrics instance with a save directory, plot flag, and class names.\n\n Args:\n names (Dict[int, str], optional): Dictionary of class names.\n \"\"\"\n self.names = names\n self.box = Metric()\n self.speed = {\"preprocess\": 0.0, \"inference\": 0.0, \"loss\": 0.0, \"postprocess\": 0.0}\n self.task = \"detect\"\n self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])\n self.nt_per_class = None\n self.nt_per_image = None\n\n def update_stats(self, stat: Dict[str, Any]) -> None:\n \"\"\"\n Update statistics by appending new values to existing stat collections.\n\n Args:\n stat (Dict[str, any]): Dictionary containing new statistical values to append.\n Keys should match existing keys in self.stats.\n \"\"\"\n for k in self.stats.keys():\n self.stats[k].append(stat[k])\n\n def process(self, save_dir: Path = Path(\".\"), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:\n \"\"\"\n Process predicted results for object detection and update metrics.\n\n Args:\n save_dir (Path): Directory to save plots. Defaults to Path(\".\").\n plot (bool): Whether to plot precision-recall curves. Defaults to False.\n on_plot (callable, optional): Function to call after plots are generated. Defaults to None.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.\n \"\"\"\n stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy\n if len(stats) == 0:\n return stats\n results = ap_per_class(\n stats[\"tp\"],\n stats[\"conf\"],\n stats[\"pred_cls\"],\n stats[\"target_cls\"],\n plot=plot,\n save_dir=save_dir,\n names=self.names,\n on_plot=on_plot,\n prefix=\"Box\",\n )[2:]\n self.box.nc = len(self.names)\n self.box.update(results)\n self.nt_per_class = np.bincount(stats[\"target_cls\"].astype(int), minlength=len(self.names))\n self.nt_per_image = np.bincount(stats[\"target_img\"].astype(int), minlength=len(self.names))\n return stats\n\n def clear_stats(self):\n \"\"\"Clear the stored statistics.\"\"\"\n for v in self.stats.values():\n v.clear()\n\n @property\n def keys(self) -> List[str]:\n \"\"\"Return a list of keys for accessing specific metrics.\"\"\"\n return [\"metrics/precision(B)\", \"metrics/recall(B)\", \"metrics/mAP50(B)\", \"metrics/mAP50-95(B)\"]\n\n def mean_results(self) -> List[float]:\n \"\"\"Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.\"\"\"\n return self.box.mean_results()\n\n def class_result(self, i: int) -> Tuple[float, float, float, float]:\n \"\"\"Return the result of evaluating the performance of an object detection model on a specific class.\"\"\"\n return self.box.class_result(i)\n\n @property\n def maps(self) -> np.ndarray:\n \"\"\"Return mean Average Precision (mAP) scores per class.\"\"\"\n return self.box.maps\n\n @property\n def fitness(self) -> float:\n \"\"\"Return the fitness of box object.\"\"\"\n return self.box.fitness()\n\n @property\n def ap_class_index(self) -> List:\n \"\"\"Return the average precision index per class.\"\"\"\n return self.box.ap_class_index\n\n @property\n def results_dict(self) -> Dict[str, float]:\n \"\"\"Return dictionary of computed performance metrics and statistics.\"\"\"\n return dict(zip(self.keys + [\"fitness\"], self.mean_results() + [self.fitness]))\n\n @property\n def curves(self) -> List[str]:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return [\"Precision-Recall(B)\", \"F1-Confidence(B)\", \"Precision-Confidence(B)\", \"Recall-Confidence(B)\"]\n\n @property\n def curves_results(self) -> List[List]:\n \"\"\"Return dictionary of computed performance metrics and statistics.\"\"\"\n return self.box.curves_results\n\n def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]:\n \"\"\"\n Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared\n scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.\n\n Args:\n normalize (bool): For Detect metrics, everything is normalized by default [0-1].\n decimals (int): Number of decimal places to round the metrics values to.\n\n Returns:\n (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.\n\n Examples:\n >>> results = model.val(data=\"coco8.yaml\")\n >>> detection_summary = results.summary()\n >>> print(detection_summary)\n \"\"\"\n per_class = {\n \"Box-P\": self.box.p,\n \"Box-R\": self.box.r,\n \"Box-F1\": self.box.f1,\n }\n return [\n {\n \"Class\": self.names[self.ap_class_index[i]],\n \"Images\": self.nt_per_image[self.ap_class_index[i]],\n \"Instances\": self.nt_per_class[self.ap_class_index[i]],\n **{k: round(v[i], decimals) for k, v in per_class.items()},\n \"mAP50\": round(self.class_result(i)[2], decimals),\n \"mAP50-95\": round(self.class_result(i)[3], decimals),\n }\n for i in range(len(per_class[\"Box-P\"]))\n ]",
"chunk_type": "class",
"name": "DetMetrics",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 944,
"end_line": 1096,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).\n\nAttributes:\n names (Dict[int, str]): A dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'detect'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re",
"SimpleClass",
"DataExportMixin"
],
"chunk_id": "class_DetMetrics_d7bfdbda"
},
{
"content": "class SegmentMetrics(DetMetrics):\n \"\"\"\n Calculate and aggregate detection and segmentation metrics over a given set of classes.\n\n Attributes:\n names (Dict[int, str]): Dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'segment'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n \"\"\"\n\n def __init__(self, names: Dict[int, str] = {}) -> None:\n \"\"\"\n Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.\n\n Args:\n names (Dict[int, str], optional): Dictionary of class names.\n \"\"\"\n DetMetrics.__init__(self, names)\n self.seg = Metric()\n self.task = \"segment\"\n self.stats[\"tp_m\"] = [] # add additional stats for masks\n\n def process(self, save_dir: Path = Path(\".\"), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:\n \"\"\"\n Process the detection and segmentation metrics over the given set of predictions.\n\n Args:\n save_dir (Path): Directory to save plots. Defaults to Path(\".\").\n plot (bool): Whether to plot precision-recall curves. Defaults to False.\n on_plot (callable, optional): Function to call after plots are generated. Defaults to None.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.\n \"\"\"\n stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats\n results_mask = ap_per_class(\n stats[\"tp_m\"],\n stats[\"conf\"],\n stats[\"pred_cls\"],\n stats[\"target_cls\"],\n plot=plot,\n on_plot=on_plot,\n save_dir=save_dir,\n names=self.names,\n prefix=\"Mask\",\n )[2:]\n self.seg.nc = len(self.names)\n self.seg.update(results_mask)\n return stats\n\n @property\n def keys(self) -> List[str]:\n \"\"\"Return a list of keys for accessing metrics.\"\"\"\n return DetMetrics.keys.fget(self) + [\n \"metrics/precision(M)\",\n \"metrics/recall(M)\",\n \"metrics/mAP50(M)\",\n \"metrics/mAP50-95(M)\",\n ]\n\n def mean_results(self) -> List[float]:\n \"\"\"Return the mean metrics for bounding box and segmentation results.\"\"\"\n return DetMetrics.mean_results(self) + self.seg.mean_results()\n\n def class_result(self, i: int) -> List[float]:\n \"\"\"Return classification results for a specified class index.\"\"\"\n return DetMetrics.class_result(self, i) + self.seg.class_result(i)\n\n @property\n def maps(self) -> np.ndarray:\n \"\"\"Return mAP scores for object detection and semantic segmentation models.\"\"\"\n return DetMetrics.maps.fget(self) + self.seg.maps\n\n @property\n def fitness(self) -> float:\n \"\"\"Return the fitness score for both segmentation and bounding box models.\"\"\"\n return self.seg.fitness() + DetMetrics.fitness.fget(self)\n\n @property\n def curves(self) -> List[str]:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return DetMetrics.curves.fget(self) + [\n \"Precision-Recall(M)\",\n \"F1-Confidence(M)\",\n \"Precision-Confidence(M)\",\n \"Recall-Confidence(M)\",\n ]\n\n @property\n def curves_results(self) -> List[List]:\n \"\"\"Return dictionary of computed performance metrics and statistics.\"\"\"\n return DetMetrics.curves_results.fget(self) + self.seg.curves_results\n\n def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]:\n \"\"\"\n Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes both\n box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.\n\n Args:\n normalize (bool): For Segment metrics, everything is normalized by default [0-1].\n decimals (int): Number of decimal places to round the metrics values to.\n\n Returns:\n (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.\n\n Examples:\n >>> results = model.val(data=\"coco8-seg.yaml\")\n >>> seg_summary = results.summary(decimals=4)\n >>> print(seg_summary)\n \"\"\"\n per_class = {\n \"Mask-P\": self.seg.p,\n \"Mask-R\": self.seg.r,\n \"Mask-F1\": self.seg.f1,\n }\n summary = DetMetrics.summary(self, normalize, decimals) # get box summary\n for i, s in enumerate(summary):\n s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})\n return summary",
"chunk_type": "class",
"name": "SegmentMetrics",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 1099,
"end_line": 1222,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Calculate and aggregate detection and segmentation metrics over a given set of classes.\n\nAttributes:\n names (Dict[int, str]): Dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'segment'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re",
"DetMetrics"
],
"chunk_id": "class_SegmentMetrics_97d3182e"
},
{
"content": "class PoseMetrics(DetMetrics):\n \"\"\"\n Calculate and aggregate detection and pose metrics over a given set of classes.\n\n Attributes:\n names (Dict[int, str]): Dictionary of class names.\n pose (Metric): An instance of the Metric class to calculate pose metrics.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'pose'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n\n Methods:\n process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.\n mean_results(): Return the mean of the detection and segmentation metrics over all the classes.\n class_result(i): Return the detection and segmentation metrics of class `i`.\n maps: Return the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.\n fitness: Return the fitness scores, which are a single weighted combination of metrics.\n ap_class_index: Return the list of indices of classes used to compute Average Precision (AP).\n results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.\n \"\"\"\n\n def __init__(self, names: Dict[int, str] = {}) -> None:\n \"\"\"\n Initialize the PoseMetrics class with directory path, class names, and plotting options.\n\n Args:\n names (Dict[int, str], optional): Dictionary of class names.\n \"\"\"\n super().__init__(names)\n self.pose = Metric()\n self.task = \"pose\"\n self.stats[\"tp_p\"] = [] # add additional stats for pose\n\n def process(self, save_dir: Path = Path(\".\"), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:\n \"\"\"\n Process the detection and pose metrics over the given set of predictions.\n\n Args:\n save_dir (Path): Directory to save plots. Defaults to Path(\".\").\n plot (bool): Whether to plot precision-recall curves. Defaults to False.\n on_plot (callable, optional): Function to call after plots are generated.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.\n \"\"\"\n stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats\n results_pose = ap_per_class(\n stats[\"tp_p\"],\n stats[\"conf\"],\n stats[\"pred_cls\"],\n stats[\"target_cls\"],\n plot=plot,\n on_plot=on_plot,\n save_dir=save_dir,\n names=self.names,\n prefix=\"Pose\",\n )[2:]\n self.pose.nc = len(self.names)\n self.pose.update(results_pose)\n return stats\n\n @property\n def keys(self) -> List[str]:\n \"\"\"Return list of evaluation metric keys.\"\"\"\n return DetMetrics.keys.fget(self) + [\n \"metrics/precision(P)\",\n \"metrics/recall(P)\",\n \"metrics/mAP50(P)\",\n \"metrics/mAP50-95(P)\",\n ]\n\n def mean_results(self) -> List[float]:\n \"\"\"Return the mean results of box and pose.\"\"\"\n return DetMetrics.mean_results(self) + self.pose.mean_results()\n\n def class_result(self, i: int) -> List[float]:\n \"\"\"Return the class-wise detection results for a specific class i.\"\"\"\n return DetMetrics.class_result(self, i) + self.pose.class_result(i)\n\n @property\n def maps(self) -> np.ndarray:\n \"\"\"Return the mean average precision (mAP) per class for both box and pose detections.\"\"\"\n return DetMetrics.maps.fget(self) + self.pose.maps\n\n @property\n def fitness(self) -> float:\n \"\"\"Return combined fitness score for pose and box detection.\"\"\"\n return self.pose.fitness() + DetMetrics.fitness.fget(self)\n\n @property\n def curves(self) -> List[str]:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return DetMetrics.curves.fget(self) + [\n \"Precision-Recall(B)\",\n \"F1-Confidence(B)\",\n \"Precision-Confidence(B)\",\n \"Recall-Confidence(B)\",\n \"Precision-Recall(P)\",\n \"F1-Confidence(P)\",\n \"Precision-Confidence(P)\",\n \"Recall-Confidence(P)\",\n ]\n\n @property\n def curves_results(self) -> List[List]:\n \"\"\"Return dictionary of computed performance metrics and statistics.\"\"\"\n return DetMetrics.curves_results.fget(self) + self.pose.curves_results\n\n def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]:\n \"\"\"\n Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box and\n pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.\n\n Args:\n normalize (bool): For Pose metrics, everything is normalized by default [0-1].\n decimals (int): Number of decimal places to round the metrics values to.\n\n Returns:\n (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.\n\n Examples:\n >>> results = model.val(data=\"coco8-pose.yaml\")\n >>> pose_summary = results.summary(decimals=4)\n >>> print(pose_summary)\n \"\"\"\n per_class = {\n \"Pose-P\": self.pose.p,\n \"Pose-R\": self.pose.r,\n \"Pose-F1\": self.pose.f1,\n }\n summary = DetMetrics.summary(self, normalize, decimals) # get box summary\n for i, s in enumerate(summary):\n s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})\n return summary",
"chunk_type": "class",
"name": "PoseMetrics",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 1225,
"end_line": 1361,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Calculate and aggregate detection and pose metrics over a given set of classes.\n\nAttributes:\n names (Dict[int, str]): Dictionary of class names.\n pose (Metric): An instance of the Metric class to calculate pose metrics.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'pose'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n\nMethods:\n process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.\n mean_results(): Return the mean of the detection and segmentation metrics over all the classes.\n class_result(i): Return the detection and segmentation metrics of class `i`.\n maps: Return the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.\n fitness: Return the fitness scores, which are a single weighted combination of metrics.\n ap_class_index: Return the list of indices of classes used to compute Average Precision (AP).\n results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re",
"DetMetrics"
],
"chunk_id": "class_PoseMetrics_25476aa9"
},
{
"content": "class ClassifyMetrics(SimpleClass, DataExportMixin):\n \"\"\"\n Class for computing classification metrics including top-1 and top-5 accuracy.\n\n Attributes:\n top1 (float): The top-1 accuracy.\n top5 (float): The top-5 accuracy.\n speed (dict): A dictionary containing the time taken for each step in the pipeline.\n task (str): The task type, set to 'classify'.\n \"\"\"\n\n def __init__(self) -> None:\n \"\"\"Initialize a ClassifyMetrics instance.\"\"\"\n self.top1 = 0\n self.top5 = 0\n self.speed = {\"preprocess\": 0.0, \"inference\": 0.0, \"loss\": 0.0, \"postprocess\": 0.0}\n self.task = \"classify\"\n\n def process(self, targets: torch.Tensor, pred: torch.Tensor):\n \"\"\"\n Process target classes and predicted classes to compute metrics.\n\n Args:\n targets (torch.Tensor): Target classes.\n pred (torch.Tensor): Predicted classes.\n \"\"\"\n pred, targets = torch.cat(pred), torch.cat(targets)\n correct = (targets[:, None] == pred).float()\n acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy\n self.top1, self.top5 = acc.mean(0).tolist()\n\n @property\n def fitness(self) -> float:\n \"\"\"Return mean of top-1 and top-5 accuracies as fitness score.\"\"\"\n return (self.top1 + self.top5) / 2\n\n @property\n def results_dict(self) -> Dict[str, float]:\n \"\"\"Return a dictionary with model's performance metrics and fitness score.\"\"\"\n return dict(zip(self.keys + [\"fitness\"], [self.top1, self.top5, self.fitness]))\n\n @property\n def keys(self) -> List[str]:\n \"\"\"Return a list of keys for the results_dict property.\"\"\"\n return [\"metrics/accuracy_top1\", \"metrics/accuracy_top5\"]\n\n @property\n def curves(self) -> List:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return []\n\n @property\n def curves_results(self) -> List:\n \"\"\"Return a list of curves for accessing specific metrics curves.\"\"\"\n return []\n\n def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, float]]:\n \"\"\"\n Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).\n\n Args:\n normalize (bool): For Classify metrics, everything is normalized by default [0-1].\n decimals (int): Number of decimal places to round the metrics values to.\n\n Returns:\n (List[Dict[str, float]]): A list with one dictionary containing Top-1 and Top-5 classification accuracy.\n\n Examples:\n >>> results = model.val(data=\"imagenet10\")\n >>> classify_summary = results.summary(decimals=4)\n >>> print(classify_summary)\n \"\"\"\n return [{\"top1_acc\": round(self.top1, decimals), \"top5_acc\": round(self.top5, decimals)}]",
"chunk_type": "class",
"name": "ClassifyMetrics",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 1364,
"end_line": 1436,
"start_col": 0,
"end_col": 97,
"parent_name": null,
"docstring": "Class for computing classification metrics including top-1 and top-5 accuracy.\n\nAttributes:\n top1 (float): The top-1 accuracy.\n top5 (float): The top-5 accuracy.\n speed (dict): A dictionary containing the time taken for each step in the pipeline.\n task (str): The task type, set to 'classify'.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re",
"SimpleClass",
"DataExportMixin"
],
"chunk_id": "class_ClassifyMetrics_2b5ec4ae"
},
{
"content": "class OBBMetrics(DetMetrics):\n \"\"\"\n Metrics for evaluating oriented bounding box (OBB) detection.\n\n Attributes:\n names (Dict[int, str]): Dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'obb'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n\n References:\n https://arxiv.org/pdf/2106.06072.pdf\n \"\"\"\n\n def __init__(self, names: Dict[int, str] = {}) -> None:\n \"\"\"\n Initialize an OBBMetrics instance with directory, plotting, and class names.\n\n Args:\n names (Dict[int, str], optional): Dictionary of class names.\n \"\"\"\n DetMetrics.__init__(self, names)\n # TODO: probably remove task as well\n self.task = \"obb\"",
"chunk_type": "class",
"name": "OBBMetrics",
"file_path": "ultralytics\\ultralytics\\utils\\metrics.py",
"start_line": 1439,
"end_line": 1465,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "Metrics for evaluating oriented bounding box (OBB) detection.\n\nAttributes:\n names (Dict[int, str]): Dictionary of class names.\n box (Metric): An instance of the Metric class for storing detection results.\n speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.\n task (str): The task type, set to 'obb'.\n stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.\n nt_per_class: Number of targets per class.\n nt_per_image: Number of targets per image.\n\nReferences:\n https://arxiv.org/pdf/2106.06072.pdf",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.utils.LOGGER",
"ultralytics.utils.DataExportMixin",
"ultralytics.utils.SimpleClass",
"ultralytics.utils.TryExcept",
"ultralytics.utils.checks",
"ultralytics.utils.plt_settings",
"matplotlib.pyplot",
"matplotlib.pyplot",
"matplotlib.pyplot",
"re",
"DetMetrics"
],
"chunk_id": "class_OBBMetrics_cccc183b"
},
{
"content": "import contextlib",
"chunk_type": "import",
"name": "contextlib",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_contextlib_7046437c"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_6b55b476"
},
{
"content": "import re",
"chunk_type": "import",
"name": "re",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_re_30a394c5"
},
{
"content": "import time",
"chunk_type": "import",
"name": "time",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_time_55bf7bb5"
},
{
"content": "from typing import Optional",
"chunk_type": "import",
"name": "Optional",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Optional_fac61e14"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_0fea62e2"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_3ae73d6e"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_790c7dfa"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_4e111fa3"
},
{
"content": "from ultralytics.utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_131edebf"
},
{
"content": "from ultralytics.utils.metrics import batch_probiou",
"chunk_type": "import",
"name": "batch_probiou",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_batch_probiou_5553b42e"
},
{
"content": "class Profile(contextlib.ContextDecorator):\n \"\"\"\n Ultralytics Profile class for timing code execution.\n\n Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing\n measurements with CUDA synchronization support for GPU operations.\n\n Attributes:\n t (float): Accumulated time in seconds.\n device (torch.device): Device used for model inference.\n cuda (bool): Whether CUDA is being used for timing synchronization.\n\n Examples:\n Use as a context manager to time code execution\n >>> with Profile(device=device) as dt:\n ... pass # slow operation here\n >>> print(dt) # prints \"Elapsed time is 9.5367431640625e-07 s\"\n\n Use as a decorator to time function execution\n >>> @Profile()\n ... def slow_function():\n ... time.sleep(0.1)\n \"\"\"\n\n def __init__(self, t: float = 0.0, device: Optional[torch.device] = None):\n \"\"\"\n Initialize the Profile class.\n\n Args:\n t (float): Initial accumulated time in seconds.\n device (torch.device, optional): Device used for model inference to enable CUDA synchronization.\n \"\"\"\n self.t = t\n self.device = device\n self.cuda = bool(device and str(device).startswith(\"cuda\"))\n\n def __enter__(self):\n \"\"\"Start timing.\"\"\"\n self.start = self.time()\n return self\n\n def __exit__(self, type, value, traceback): # noqa\n \"\"\"Stop timing.\"\"\"\n self.dt = self.time() - self.start # delta-time\n self.t += self.dt # accumulate dt\n\n def __str__(self):\n \"\"\"Return a human-readable string representing the accumulated elapsed time.\"\"\"\n return f\"Elapsed time is {self.t} s\"\n\n def time(self):\n \"\"\"Get current time with CUDA synchronization if applicable.\"\"\"\n if self.cuda:\n torch.cuda.synchronize(self.device)\n return time.perf_counter()",
"chunk_type": "class",
"name": "Profile",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 18,
"end_line": 72,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": "Ultralytics Profile class for timing code execution.\n\nUse as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing\nmeasurements with CUDA synchronization support for GPU operations.\n\nAttributes:\n t (float): Accumulated time in seconds.\n device (torch.device): Device used for model inference.\n cuda (bool): Whether CUDA is being used for timing synchronization.\n\nExamples:\n Use as a context manager to time code execution\n >>> with Profile(device=device) as dt:\n ... pass # slow operation here\n >>> print(dt) # prints \"Elapsed time is 9.5367431640625e-07 s\"\n\n Use as a decorator to time function execution\n >>> @Profile()\n ... def slow_function():\n ... time.sleep(0.1)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment",
"contextlib.ContextDecorator"
],
"chunk_id": "class_Profile_6b5e8a32"
},
{
"content": "def segment2box(segment, width: int = 640, height: int = 640):\n \"\"\"\n Convert segment coordinates to bounding box coordinates.\n\n Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.\n Applies inside-image constraint and clips coordinates when necessary.\n\n Args:\n segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.\n width (int): Width of the image in pixels.\n height (int): Height of the image in pixels.\n\n Returns:\n (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].\n \"\"\"\n x, y = segment.T # segment xy\n # Clip coordinates if 3 out of 4 sides are outside the image\n if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:\n x = x.clip(0, width)\n y = y.clip(0, height)\n inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)\n x = x[inside]\n y = y[inside]\n return (\n np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)\n if any(x)\n else np.zeros(4, dtype=segment.dtype)\n ) # xyxy",
"chunk_type": "function",
"name": "segment2box",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 75,
"end_line": 102,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Convert segment coordinates to bounding box coordinates.\n\nConverts a single segment label to a box label by finding the minimum and maximum x and y coordinates.\nApplies inside-image constraint and clips coordinates when necessary.\n\nArgs:\n segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.\n width (int): Width of the image in pixels.\n height (int): Height of the image in pixels.\n\nReturns:\n (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].",
"parameters": [
"segment",
"width: int",
"height: int"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_segment2box_3f4ab63a"
},
{
"content": "def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):\n \"\"\"\n Rescale bounding boxes from one image shape to another.\n\n Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.\n Supports both xyxy and xywh box formats.\n\n Args:\n img1_shape (tuple): Shape of the source image (height, width).\n boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).\n img0_shape (tuple): Shape of the target image (height, width).\n ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.\n padding (bool): Whether boxes are based on YOLO-style augmented images with padding.\n xywh (bool): Whether box format is xywh (True) or xyxy (False).\n\n Returns:\n (torch.Tensor): Rescaled bounding boxes in the same format as input.\n \"\"\"\n if ratio_pad is None: # calculate from img0_shape\n gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new\n pad = (\n round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),\n round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),\n ) # wh padding\n else:\n gain = ratio_pad[0][0]\n pad = ratio_pad[1]\n\n if padding:\n boxes[..., 0] -= pad[0] # x padding\n boxes[..., 1] -= pad[1] # y padding\n if not xywh:\n boxes[..., 2] -= pad[0] # x padding\n boxes[..., 3] -= pad[1] # y padding\n boxes[..., :4] /= gain\n return clip_boxes(boxes, img0_shape)",
"chunk_type": "function",
"name": "scale_boxes",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 105,
"end_line": 140,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Rescale bounding boxes from one image shape to another.\n\nRescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.\nSupports both xyxy and xywh box formats.\n\nArgs:\n img1_shape (tuple): Shape of the source image (height, width).\n boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).\n img0_shape (tuple): Shape of the target image (height, width).\n ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.\n padding (bool): Whether boxes are based on YOLO-style augmented images with padding.\n xywh (bool): Whether box format is xywh (True) or xyxy (False).\n\nReturns:\n (torch.Tensor): Rescaled bounding boxes in the same format as input.",
"parameters": [
"img1_shape",
"boxes",
"img0_shape",
"ratio_pad",
"padding: bool",
"xywh: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_scale_boxes_e90f89fd"
},
{
"content": "def make_divisible(x: int, divisor):\n \"\"\"\n Return the nearest number that is divisible by the given divisor.\n\n Args:\n x (int): The number to make divisible.\n divisor (int | torch.Tensor): The divisor.\n\n Returns:\n (int): The nearest number divisible by the divisor.\n \"\"\"\n if isinstance(divisor, torch.Tensor):\n divisor = int(divisor.max()) # to int\n return math.ceil(x / divisor) * divisor",
"chunk_type": "function",
"name": "make_divisible",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 143,
"end_line": 156,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": "Return the nearest number that is divisible by the given divisor.\n\nArgs:\n x (int): The number to make divisible.\n divisor (int | torch.Tensor): The divisor.\n\nReturns:\n (int): The nearest number divisible by the divisor.",
"parameters": [
"x: int",
"divisor"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_make_divisible_3cc87115"
},
{
"content": "def nms_rotated(boxes, scores, threshold: float = 0.45, use_triu: bool = True):\n \"\"\"\n Perform NMS on oriented bounding boxes using probiou and fast-nms.\n\n Args:\n boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format.\n scores (torch.Tensor): Confidence scores with shape (N,).\n threshold (float): IoU threshold for NMS.\n use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.\n\n Returns:\n (torch.Tensor): Indices of boxes to keep after NMS.\n \"\"\"\n sorted_idx = torch.argsort(scores, descending=True)\n boxes = boxes[sorted_idx]\n ious = batch_probiou(boxes, boxes)\n if use_triu:\n ious = ious.triu_(diagonal=1)\n # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition\n pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)\n else:\n n = boxes.shape[0]\n row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)\n col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)\n upper_mask = row_idx < col_idx\n ious = ious * upper_mask\n # Zeroing these scores ensures the additional indices would not affect the final results\n scores[~((ious >= threshold).sum(0) <= 0)] = 0\n # NOTE: return indices with fixed length to avoid TFLite reshape error\n pick = torch.topk(scores, scores.shape[0]).indices\n return sorted_idx[pick]",
"chunk_type": "function",
"name": "nms_rotated",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 159,
"end_line": 189,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": "Perform NMS on oriented bounding boxes using probiou and fast-nms.\n\nArgs:\n boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format.\n scores (torch.Tensor): Confidence scores with shape (N,).\n threshold (float): IoU threshold for NMS.\n use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.\n\nReturns:\n (torch.Tensor): Indices of boxes to keep after NMS.",
"parameters": [
"boxes",
"scores",
"threshold: float",
"use_triu: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_nms_rotated_4d120576"
},
{
"content": "def non_max_suppression(\n prediction,\n conf_thres: float = 0.25,\n iou_thres: float = 0.45,\n classes=None,\n agnostic: bool = False,\n multi_label: bool = False,\n labels=(),\n max_det: int = 300,\n nc: int = 0, # number of classes (optional)\n max_time_img: float = 0.05,\n max_nms: int = 30000,\n max_wh: int = 7680,\n in_place: bool = True,\n rotated: bool = False,\n end2end: bool = False,\n return_idxs: bool = False,\n):\n \"\"\"\n Perform non-maximum suppression (NMS) on prediction results.\n\n Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple\n detection formats including standard boxes, rotated boxes, and masks.\n\n Args:\n prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)\n containing boxes, classes, and optional masks.\n conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.\n iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.\n classes (List[int], optional): List of class indices to consider. If None, all classes are considered.\n agnostic (bool): Whether to perform class-agnostic NMS.\n multi_label (bool): Whether each box can have multiple labels.\n labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.\n max_det (int): Maximum number of detections to keep per image.\n nc (int): Number of classes. Indices after this are considered masks.\n max_time_img (float): Maximum time in seconds for processing one image.\n max_nms (int): Maximum number of boxes for torchvision.ops.nms().\n max_wh (int): Maximum box width and height in pixels.\n in_place (bool): Whether to modify the input prediction tensor in place.\n rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).\n end2end (bool): Whether the model is end-to-end and doesn't require NMS.\n return_idxs (bool): Whether to return the indices of kept detections.\n\n Returns:\n output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)\n containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).\n keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n # Checks\n assert 0 <= conf_thres <= 1, f\"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0\"\n assert 0 <= iou_thres <= 1, f\"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0\"\n if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)\n prediction = prediction[0] # select only inference output\n if classes is not None:\n classes = torch.tensor(classes, device=prediction.device)\n\n if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)\n output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]\n if classes is not None:\n output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]\n return output\n\n bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)\n nc = nc or (prediction.shape[1] - 4) # number of classes\n extra = prediction.shape[1] - nc - 4 # number of extra info\n mi = 4 + nc # mask start index\n xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates\n xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None] # to track idxs\n\n # Settings\n # min_wh = 2 # (pixels) minimum box width and height\n time_limit = 2.0 + max_time_img * bs # seconds to quit after\n multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)\n\n prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)\n if not rotated:\n if in_place:\n prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy\n else:\n prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy\n\n t = time.time()\n output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs\n keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs\n for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)\n # Apply constraints\n # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height\n filt = xc[xi] # confidence\n x, xk = x[filt], xk[filt]\n\n # Cat apriori labels if autolabelling\n if labels and len(labels[xi]) and not rotated:\n lb = labels[xi]\n v = torch.zeros((len(lb), nc + extra + 4), device=x.device)\n v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box\n v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls\n x = torch.cat((x, v), 0)\n\n # If none remain process next image\n if not x.shape[0]:\n continue\n\n # Detections matrix nx6 (xyxy, conf, cls)\n box, cls, mask = x.split((4, nc, extra), 1)\n\n if multi_label:\n i, j = torch.where(cls > conf_thres)\n x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)\n xk = xk[i]\n else: # best class only\n conf, j = cls.max(1, keepdim=True)\n filt = conf.view(-1) > conf_thres\n x = torch.cat((box, conf, j.float(), mask), 1)[filt]\n xk = xk[filt]\n\n # Filter by class\n if classes is not None:\n filt = (x[:, 5:6] == classes).any(1)\n x, xk = x[filt], xk[filt]\n\n # Check shape\n n = x.shape[0] # number of boxes\n if not n: # no boxes\n continue\n if n > max_nms: # excess boxes\n filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes\n x, xk = x[filt], xk[filt]\n\n # Batched NMS\n c = x[:, 5:6] * (0 if agnostic else max_wh) # classes\n scores = x[:, 4] # scores\n if rotated:\n boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr\n i = nms_rotated(boxes, scores, iou_thres)\n else:\n boxes = x[:, :4] + c # boxes (offset by class)\n i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS\n i = i[:max_det] # limit detections\n\n output[xi], keepi[xi] = x[i], xk[i].reshape(-1)\n if (time.time() - t) > time_limit:\n LOGGER.warning(f\"NMS time limit {time_limit:.3f}s exceeded\")\n break # time limit exceeded\n\n return (output, keepi) if return_idxs else output",
"chunk_type": "function",
"name": "non_max_suppression",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 192,
"end_line": 338,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": "Perform non-maximum suppression (NMS) on prediction results.\n\nApplies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple\ndetection formats including standard boxes, rotated boxes, and masks.\n\nArgs:\n prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)\n containing boxes, classes, and optional masks.\n conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.\n iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.\n classes (List[int], optional): List of class indices to consider. If None, all classes are considered.\n agnostic (bool): Whether to perform class-agnostic NMS.\n multi_label (bool): Whether each box can have multiple labels.\n labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.\n max_det (int): Maximum number of detections to keep per image.\n nc (int): Number of classes. Indices after this are considered masks.\n max_time_img (float): Maximum time in seconds for processing one image.\n max_nms (int): Maximum number of boxes for torchvision.ops.nms().\n max_wh (int): Maximum box width and height in pixels.\n in_place (bool): Whether to modify the input prediction tensor in place.\n rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).\n end2end (bool): Whether the model is end-to-end and doesn't require NMS.\n return_idxs (bool): Whether to return the indices of kept detections.\n\nReturns:\n output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)\n containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).\n keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.",
"parameters": [
"prediction",
"conf_thres: float",
"iou_thres: float",
"classes",
"agnostic: bool",
"multi_label: bool",
"labels",
"max_det: int",
"nc: int",
"max_time_img: float",
"max_nms: int",
"max_wh: int",
"in_place: bool",
"rotated: bool",
"end2end: bool",
"return_idxs: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 19,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_non_max_suppression_2bb4cfbb"
},
{
"content": "def clip_boxes(boxes, shape):\n \"\"\"\n Clip bounding boxes to image boundaries.\n\n Args:\n boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.\n shape (tuple): Image shape as (height, width).\n\n Returns:\n (torch.Tensor | np.ndarray): Clipped bounding boxes.\n \"\"\"\n if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)\n boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1\n boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1\n boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2\n boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2\n else: # np.array (faster grouped)\n boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2\n boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2\n return boxes",
"chunk_type": "function",
"name": "clip_boxes",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 341,
"end_line": 360,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Clip bounding boxes to image boundaries.\n\nArgs:\n boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.\n shape (tuple): Image shape as (height, width).\n\nReturns:\n (torch.Tensor | np.ndarray): Clipped bounding boxes.",
"parameters": [
"boxes",
"shape"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_clip_boxes_7ca129ed"
},
{
"content": "def clip_coords(coords, shape):\n \"\"\"\n Clip line coordinates to image boundaries.\n\n Args:\n coords (torch.Tensor | np.ndarray): Line coordinates to clip.\n shape (tuple): Image shape as (height, width).\n\n Returns:\n (torch.Tensor | np.ndarray): Clipped coordinates.\n \"\"\"\n if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)\n coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x\n coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y\n else: # np.array (faster grouped)\n coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x\n coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y\n return coords",
"chunk_type": "function",
"name": "clip_coords",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 363,
"end_line": 380,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": "Clip line coordinates to image boundaries.\n\nArgs:\n coords (torch.Tensor | np.ndarray): Line coordinates to clip.\n shape (tuple): Image shape as (height, width).\n\nReturns:\n (torch.Tensor | np.ndarray): Clipped coordinates.",
"parameters": [
"coords",
"shape"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_clip_coords_a9c40ad7"
},
{
"content": "def scale_image(masks, im0_shape, ratio_pad=None):\n \"\"\"\n Rescale masks to original image size.\n\n Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding\n that was applied during preprocessing.\n\n Args:\n masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].\n im0_shape (tuple): Original image shape as (height, width).\n ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).\n\n Returns:\n (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.\n \"\"\"\n # Rescale coordinates (xyxy) from im1_shape to im0_shape\n im1_shape = masks.shape\n if im1_shape[:2] == im0_shape[:2]:\n return masks\n if ratio_pad is None: # calculate from im0_shape\n gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new\n pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding\n else:\n pad = ratio_pad[1]\n\n top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1)))\n bottom, right = (\n im1_shape[0] - int(round(pad[1] + 0.1)),\n im1_shape[1] - int(round(pad[0] + 0.1)),\n )\n\n if len(masks.shape) < 2:\n raise ValueError(f'\"len of masks shape\" should be 2 or 3, but got {len(masks.shape)}')\n masks = masks[top:bottom, left:right]\n masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))\n if len(masks.shape) == 2:\n masks = masks[:, :, None]\n\n return masks",
"chunk_type": "function",
"name": "scale_image",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 383,
"end_line": 421,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Rescale masks to original image size.\n\nTakes resized and padded masks and rescales them back to the original image dimensions, removing any padding\nthat was applied during preprocessing.\n\nArgs:\n masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].\n im0_shape (tuple): Original image shape as (height, width).\n ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).\n\nReturns:\n (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.",
"parameters": [
"masks",
"im0_shape",
"ratio_pad"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_scale_image_d3e50193"
},
{
"content": "def xyxy2xywh(x):\n \"\"\"\n Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the\n top-left corner and (x2, y2) is the bottom-right corner.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.\n \"\"\"\n assert x.shape[-1] == 4, f\"input shape last dimension expected 4 but input shape is {x.shape}\"\n y = empty_like(x) # faster than clone/copy\n y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center\n y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center\n y[..., 2] = x[..., 2] - x[..., 0] # width\n y[..., 3] = x[..., 3] - x[..., 1] # height\n return y",
"chunk_type": "function",
"name": "xyxy2xywh",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 424,
"end_line": 441,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the\ntop-left corner and (x2, y2) is the bottom-right corner.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_xyxy2xywh_011f3ef6"
},
{
"content": "def xywh2xyxy(x):\n \"\"\"\n Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the\n top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.\n \"\"\"\n assert x.shape[-1] == 4, f\"input shape last dimension expected 4 but input shape is {x.shape}\"\n y = empty_like(x) # faster than clone/copy\n xy = x[..., :2] # centers\n wh = x[..., 2:] / 2 # half width-height\n y[..., :2] = xy - wh # top left xy\n y[..., 2:] = xy + wh # bottom right xy\n return y",
"chunk_type": "function",
"name": "xywh2xyxy",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 444,
"end_line": 461,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the\ntop-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_xywh2xyxy_e2f5b2e0"
},
{
"content": "def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):\n \"\"\"\n Convert normalized bounding box coordinates to pixel coordinates.\n\n Args:\n x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.\n w (int): Image width in pixels.\n h (int): Image height in pixels.\n padw (int): Padding width in pixels.\n padh (int): Padding height in pixels.\n\n Returns:\n y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where\n x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.\n \"\"\"\n assert x.shape[-1] == 4, f\"input shape last dimension expected 4 but input shape is {x.shape}\"\n y = empty_like(x) # faster than clone/copy\n y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x\n y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y\n y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x\n y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y\n return y",
"chunk_type": "function",
"name": "xywhn2xyxy",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 464,
"end_line": 485,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Convert normalized bounding box coordinates to pixel coordinates.\n\nArgs:\n x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.\n w (int): Image width in pixels.\n h (int): Image height in pixels.\n padw (int): Padding width in pixels.\n padh (int): Padding height in pixels.\n\nReturns:\n y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where\n x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.",
"parameters": [
"x",
"w: int",
"h: int",
"padw: int",
"padh: int"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_xywhn2xyxy_00ae5793"
},
{
"content": "def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):\n \"\"\"\n Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,\n width and height are normalized to image dimensions.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.\n w (int): Image width in pixels.\n h (int): Image height in pixels.\n clip (bool): Whether to clip boxes to image boundaries.\n eps (float): Minimum value for box width and height.\n\n Returns:\n (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.\n \"\"\"\n if clip:\n x = clip_boxes(x, (h - eps, w - eps))\n assert x.shape[-1] == 4, f\"input shape last dimension expected 4 but input shape is {x.shape}\"\n y = empty_like(x) # faster than clone/copy\n y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center\n y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center\n y[..., 2] = (x[..., 2] - x[..., 0]) / w # width\n y[..., 3] = (x[..., 3] - x[..., 1]) / h # height\n return y",
"chunk_type": "function",
"name": "xyxy2xywhn",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 488,
"end_line": 511,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,\nwidth and height are normalized to image dimensions.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.\n w (int): Image width in pixels.\n h (int): Image height in pixels.\n clip (bool): Whether to clip boxes to image boundaries.\n eps (float): Minimum value for box width and height.\n\nReturns:\n (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.",
"parameters": [
"x",
"w: int",
"h: int",
"clip: bool",
"eps: float"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_xyxy2xywhn_bb3f38d9"
},
{
"content": "def xywh2ltwh(x):\n \"\"\"\n Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.\n \"\"\"\n y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)\n y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x\n y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y\n return y",
"chunk_type": "function",
"name": "xywh2ltwh",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 514,
"end_line": 527,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_xywh2ltwh_dd88b01f"
},
{
"content": "def xyxy2ltwh(x):\n \"\"\"\n Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.\n \"\"\"\n y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)\n y[..., 2] = x[..., 2] - x[..., 0] # width\n y[..., 3] = x[..., 3] - x[..., 1] # height\n return y",
"chunk_type": "function",
"name": "xyxy2ltwh",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 530,
"end_line": 543,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_xyxy2ltwh_d71804d8"
},
{
"content": "def ltwh2xywh(x):\n \"\"\"\n Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.\n\n Args:\n x (torch.Tensor): Input bounding box coordinates.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.\n \"\"\"\n y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)\n y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x\n y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y\n return y",
"chunk_type": "function",
"name": "ltwh2xywh",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 546,
"end_line": 559,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.\n\nArgs:\n x (torch.Tensor): Input bounding box coordinates.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_ltwh2xywh_bd19410c"
},
{
"content": "def xyxyxyxy2xywhr(x):\n \"\"\"\n Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.\n\n Args:\n x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.\n\n Returns:\n (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).\n Rotation values are in radians from 0 to pi/2.\n \"\"\"\n is_torch = isinstance(x, torch.Tensor)\n points = x.cpu().numpy() if is_torch else x\n points = points.reshape(len(x), -1, 2)\n rboxes = []\n for pts in points:\n # NOTE: Use cv2.minAreaRect to get accurate xywhr,\n # especially some objects are cut off by augmentations in dataloader.\n (cx, cy), (w, h), angle = cv2.minAreaRect(pts)\n rboxes.append([cx, cy, w, h, angle / 180 * np.pi])\n return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)",
"chunk_type": "function",
"name": "xyxyxyxy2xywhr",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 562,
"end_line": 582,
"start_col": 0,
"end_col": 99,
"parent_name": null,
"docstring": "Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.\n\nReturns:\n (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).\n Rotation values are in radians from 0 to pi/2.",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_xyxyxyxy2xywhr_a60266e7"
},
{
"content": "def xywhr2xyxyxyxy(x):\n \"\"\"\n Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.\n\n Args:\n x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).\n Rotation values should be in radians from 0 to pi/2.\n\n Returns:\n (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).\n \"\"\"\n cos, sin, cat, stack = (\n (torch.cos, torch.sin, torch.cat, torch.stack)\n if isinstance(x, torch.Tensor)\n else (np.cos, np.sin, np.concatenate, np.stack)\n )\n\n ctr = x[..., :2]\n w, h, angle = (x[..., i : i + 1] for i in range(2, 5))\n cos_value, sin_value = cos(angle), sin(angle)\n vec1 = [w / 2 * cos_value, w / 2 * sin_value]\n vec2 = [-h / 2 * sin_value, h / 2 * cos_value]\n vec1 = cat(vec1, -1)\n vec2 = cat(vec2, -1)\n pt1 = ctr + vec1 + vec2\n pt2 = ctr + vec1 - vec2\n pt3 = ctr - vec1 - vec2\n pt4 = ctr - vec1 + vec2\n return stack([pt1, pt2, pt3, pt4], -2)",
"chunk_type": "function",
"name": "xywhr2xyxyxyxy",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 585,
"end_line": 613,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": "Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.\n\nArgs:\n x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).\n Rotation values should be in radians from 0 to pi/2.\n\nReturns:\n (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_xywhr2xyxyxyxy_eb8258ba"
},
{
"content": "def ltwh2xyxy(x):\n \"\"\"\n Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.\n\n Args:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates.\n\n Returns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.\n \"\"\"\n y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)\n y[..., 2] = x[..., 2] + x[..., 0] # width\n y[..., 3] = x[..., 3] + x[..., 1] # height\n return y",
"chunk_type": "function",
"name": "ltwh2xyxy",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 616,
"end_line": 629,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.\n\nArgs:\n x (np.ndarray | torch.Tensor): Input bounding box coordinates.\n\nReturns:\n (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_ltwh2xyxy_31ad8c0b"
},
{
"content": "def segments2boxes(segments):\n \"\"\"\n Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).\n\n Args:\n segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.\n\n Returns:\n (np.ndarray): Bounding box coordinates in xywh format.\n \"\"\"\n boxes = []\n for s in segments:\n x, y = s.T # segment xy\n boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy\n return xyxy2xywh(np.array(boxes)) # cls, xywh",
"chunk_type": "function",
"name": "segments2boxes",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 632,
"end_line": 646,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": "Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).\n\nArgs:\n segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.\n\nReturns:\n (np.ndarray): Bounding box coordinates in xywh format.",
"parameters": [
"segments"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_segments2boxes_3e6c5756"
},
{
"content": "def resample_segments(segments, n: int = 1000):\n \"\"\"\n Resample segments to n points each using linear interpolation.\n\n Args:\n segments (list): List of (N, 2) arrays where N is the number of points in each segment.\n n (int): Number of points to resample each segment to.\n\n Returns:\n (list): Resampled segments with n points each.\n \"\"\"\n for i, s in enumerate(segments):\n if len(s) == n:\n continue\n s = np.concatenate((s, s[0:1, :]), axis=0)\n x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)\n xp = np.arange(len(s))\n x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x\n segments[i] = (\n np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T\n ) # segment xy\n return segments",
"chunk_type": "function",
"name": "resample_segments",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 649,
"end_line": 670,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Resample segments to n points each using linear interpolation.\n\nArgs:\n segments (list): List of (N, 2) arrays where N is the number of points in each segment.\n n (int): Number of points to resample each segment to.\n\nReturns:\n (list): Resampled segments with n points each.",
"parameters": [
"segments",
"n: int"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_resample_segments_bf05b57c"
},
{
"content": "def crop_mask(masks, boxes):\n \"\"\"\n Crop masks to bounding box regions.\n\n Args:\n masks (torch.Tensor): Masks with shape (N, H, W).\n boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.\n\n Returns:\n (torch.Tensor): Cropped masks.\n \"\"\"\n _, h, w = masks.shape\n x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)\n r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)\n c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)\n\n return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))",
"chunk_type": "function",
"name": "crop_mask",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 673,
"end_line": 689,
"start_col": 0,
"end_col": 64,
"parent_name": null,
"docstring": "Crop masks to bounding box regions.\n\nArgs:\n masks (torch.Tensor): Masks with shape (N, H, W).\n boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.\n\nReturns:\n (torch.Tensor): Cropped masks.",
"parameters": [
"masks",
"boxes"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_crop_mask_25034c2c"
},
{
"content": "def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):\n \"\"\"\n Apply masks to bounding boxes using mask head output.\n\n Args:\n protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.\n shape (tuple): Input image size as (height, width).\n upsample (bool): Whether to upsample masks to original image size.\n\n Returns:\n (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w\n are the height and width of the input image. The mask is applied to the bounding boxes.\n \"\"\"\n c, mh, mw = protos.shape # CHW\n ih, iw = shape\n masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW\n width_ratio = mw / iw\n height_ratio = mh / ih\n\n downsampled_bboxes = bboxes.clone()\n downsampled_bboxes[:, 0] *= width_ratio\n downsampled_bboxes[:, 2] *= width_ratio\n downsampled_bboxes[:, 3] *= height_ratio\n downsampled_bboxes[:, 1] *= height_ratio\n\n masks = crop_mask(masks, downsampled_bboxes) # CHW\n if upsample:\n masks = F.interpolate(masks[None], shape, mode=\"bilinear\", align_corners=False)[0] # CHW\n return masks.gt_(0.0)",
"chunk_type": "function",
"name": "process_mask",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 692,
"end_line": 722,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "Apply masks to bounding boxes using mask head output.\n\nArgs:\n protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.\n shape (tuple): Input image size as (height, width).\n upsample (bool): Whether to upsample masks to original image size.\n\nReturns:\n (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w\n are the height and width of the input image. The mask is applied to the bounding boxes.",
"parameters": [
"protos",
"masks_in",
"bboxes",
"shape",
"upsample: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_process_mask_7c305c15"
},
{
"content": "def process_mask_native(protos, masks_in, bboxes, shape):\n \"\"\"\n Apply masks to bounding boxes using mask head output with native upsampling.\n\n Args:\n protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.\n shape (tuple): Input image size as (height, width).\n\n Returns:\n (torch.Tensor): Binary mask tensor with shape (H, W, N).\n \"\"\"\n c, mh, mw = protos.shape # CHW\n masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)\n masks = scale_masks(masks[None], shape)[0] # CHW\n masks = crop_mask(masks, bboxes) # CHW\n return masks.gt_(0.0)",
"chunk_type": "function",
"name": "process_mask_native",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 725,
"end_line": 742,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "Apply masks to bounding boxes using mask head output with native upsampling.\n\nArgs:\n protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).\n masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.\n bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.\n shape (tuple): Input image size as (height, width).\n\nReturns:\n (torch.Tensor): Binary mask tensor with shape (H, W, N).",
"parameters": [
"protos",
"masks_in",
"bboxes",
"shape"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_process_mask_native_556480d4"
},
{
"content": "def scale_masks(masks, shape, padding: bool = True):\n \"\"\"\n Rescale segment masks to target shape.\n\n Args:\n masks (torch.Tensor): Masks with shape (N, C, H, W).\n shape (tuple): Target height and width as (height, width).\n padding (bool): Whether masks are based on YOLO-style augmented images with padding.\n\n Returns:\n (torch.Tensor): Rescaled masks.\n \"\"\"\n mh, mw = masks.shape[2:]\n gain = min(mh / shape[0], mw / shape[1]) # gain = old / new\n pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding\n if padding:\n pad[0] /= 2\n pad[1] /= 2\n top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) if padding else (0, 0) # y, x\n bottom, right = (\n mh - int(round(pad[1] + 0.1)),\n mw - int(round(pad[0] + 0.1)),\n )\n masks = masks[..., top:bottom, left:right]\n\n masks = F.interpolate(masks, shape, mode=\"bilinear\", align_corners=False) # NCHW\n return masks",
"chunk_type": "function",
"name": "scale_masks",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 745,
"end_line": 771,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Rescale segment masks to target shape.\n\nArgs:\n masks (torch.Tensor): Masks with shape (N, C, H, W).\n shape (tuple): Target height and width as (height, width).\n padding (bool): Whether masks are based on YOLO-style augmented images with padding.\n\nReturns:\n (torch.Tensor): Rescaled masks.",
"parameters": [
"masks",
"shape",
"padding: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_scale_masks_29488e17"
},
{
"content": "def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):\n \"\"\"\n Rescale segment coordinates from img1_shape to img0_shape.\n\n Args:\n img1_shape (tuple): Shape of the source image.\n coords (torch.Tensor): Coordinates to scale with shape (N, 2).\n img0_shape (tuple): Shape of the target image.\n ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).\n normalize (bool): Whether to normalize coordinates to range [0, 1].\n padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.\n\n Returns:\n (torch.Tensor): Scaled coordinates.\n \"\"\"\n if ratio_pad is None: # calculate from img0_shape\n gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new\n pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding\n else:\n gain = ratio_pad[0][0]\n pad = ratio_pad[1]\n\n if padding:\n coords[..., 0] -= pad[0] # x padding\n coords[..., 1] -= pad[1] # y padding\n coords[..., 0] /= gain\n coords[..., 1] /= gain\n coords = clip_coords(coords, img0_shape)\n if normalize:\n coords[..., 0] /= img0_shape[1] # width\n coords[..., 1] /= img0_shape[0] # height\n return coords",
"chunk_type": "function",
"name": "scale_coords",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 774,
"end_line": 805,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": "Rescale segment coordinates from img1_shape to img0_shape.\n\nArgs:\n img1_shape (tuple): Shape of the source image.\n coords (torch.Tensor): Coordinates to scale with shape (N, 2).\n img0_shape (tuple): Shape of the target image.\n ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).\n normalize (bool): Whether to normalize coordinates to range [0, 1].\n padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.\n\nReturns:\n (torch.Tensor): Scaled coordinates.",
"parameters": [
"img1_shape",
"coords",
"img0_shape",
"ratio_pad",
"normalize: bool",
"padding: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_scale_coords_af97eefd"
},
{
"content": "def regularize_rboxes(rboxes):\n \"\"\"\n Regularize rotated bounding boxes to range [0, pi/2].\n\n Args:\n rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.\n\n Returns:\n (torch.Tensor): Regularized rotated boxes.\n \"\"\"\n x, y, w, h, t = rboxes.unbind(dim=-1)\n # Swap edge if t >= pi/2 while not being symmetrically opposite\n swap = t % math.pi >= math.pi / 2\n w_ = torch.where(swap, h, w)\n h_ = torch.where(swap, w, h)\n t = t % (math.pi / 2)\n return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes",
"chunk_type": "function",
"name": "regularize_rboxes",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 808,
"end_line": 824,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": "Regularize rotated bounding boxes to range [0, pi/2].\n\nArgs:\n rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.\n\nReturns:\n (torch.Tensor): Regularized rotated boxes.",
"parameters": [
"rboxes"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_regularize_rboxes_c10e57ca"
},
{
"content": "def masks2segments(masks, strategy: str = \"all\"):\n \"\"\"\n Convert masks to segments using contour detection.\n\n Args:\n masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).\n strategy (str): Segmentation strategy, either 'all' or 'largest'.\n\n Returns:\n (list): List of segment masks as float32 arrays.\n \"\"\"\n from ultralytics.data.converter import merge_multi_segment\n\n segments = []\n for x in masks.int().cpu().numpy().astype(\"uint8\"):\n c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]\n if c:\n if strategy == \"all\": # merge and concatenate all segments\n c = (\n np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))\n if len(c) > 1\n else c[0].reshape(-1, 2)\n )\n elif strategy == \"largest\": # select largest segment\n c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)\n else:\n c = np.zeros((0, 2)) # no segments found\n segments.append(c.astype(\"float32\"))\n return segments",
"chunk_type": "function",
"name": "masks2segments",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 827,
"end_line": 855,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Convert masks to segments using contour detection.\n\nArgs:\n masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).\n strategy (str): Segmentation strategy, either 'all' or 'largest'.\n\nReturns:\n (list): List of segment masks as float32 arrays.",
"parameters": [
"masks",
"strategy: str"
],
"return_type": null,
"decorators": [],
"complexity_score": 7,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_masks2segments_1e64c283"
},
{
"content": "def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:\n \"\"\"\n Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.\n\n Args:\n batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.\n\n Returns:\n (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.\n \"\"\"\n return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()",
"chunk_type": "function",
"name": "convert_torch2numpy_batch",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 858,
"end_line": 868,
"start_col": 0,
"end_col": 101,
"parent_name": null,
"docstring": "Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.\n\nArgs:\n batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.\n\nReturns:\n (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.",
"parameters": [
"batch: torch.Tensor"
],
"return_type": "np.ndarray",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_convert_torch2numpy_batch_47ba17e5"
},
{
"content": "def clean_str(s):\n \"\"\"\n Clean a string by replacing special characters with '_' character.\n\n Args:\n s (str): A string needing special characters replaced.\n\n Returns:\n (str): A string with special characters replaced by an underscore _.\n \"\"\"\n return re.sub(pattern=\"[|@#!¡·$€%&()=?¿^*;:,¨´><+]\", repl=\"_\", string=s)",
"chunk_type": "function",
"name": "clean_str",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 871,
"end_line": 881,
"start_col": 0,
"end_col": 83,
"parent_name": null,
"docstring": "Clean a string by replacing special characters with '_' character.\n\nArgs:\n s (str): A string needing special characters replaced.\n\nReturns:\n (str): A string with special characters replaced by an underscore _.",
"parameters": [
"s"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_clean_str_119c2be3"
},
{
"content": "def empty_like(x):\n \"\"\"Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.\"\"\"\n return (\n torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)\n )",
"chunk_type": "function",
"name": "empty_like",
"file_path": "ultralytics\\ultralytics\\utils\\ops.py",
"start_line": 884,
"end_line": 888,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.",
"parameters": [
"x"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"math",
"re",
"time",
"typing.Optional",
"cv2",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.batch_probiou",
"torchvision",
"ultralytics.data.converter.merge_multi_segment"
],
"chunk_id": "function_empty_like_35cd77b9"
},
{
"content": "import time",
"chunk_type": "import",
"name": "time",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_time_9e1c91d6"
},
{
"content": "from contextlib import contextmanager",
"chunk_type": "import",
"name": "contextmanager",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_contextmanager_860a4511"
},
{
"content": "from copy import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_5b142c7f"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_e82f0c22"
},
{
"content": "from typing import Any, Dict, List, Optional",
"chunk_type": "import",
"name": "Any, Dict, List, Optional",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional_3383557e"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_c903a9b5"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_ff8c444e"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_57df7b1e"
},
{
"content": "_imshow = cv2.imshow # copy to avoid recursion errors",
"chunk_type": "variable",
"name": "_imshow",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable__imshow_015f4d04"
},
{
"content": "def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> Optional[np.ndarray]:\n \"\"\"\n Read an image from a file with multilanguage filename support.\n\n Args:\n filename (str): Path to the file to read.\n flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.\n\n Returns:\n (np.ndarray | None): The read image array, or None if reading fails.\n\n Examples:\n >>> img = imread(\"path/to/image.jpg\")\n >>> img = imread(\"path/to/image.jpg\", cv2.IMREAD_GRAYSCALE)\n \"\"\"\n file_bytes = np.fromfile(filename, np.uint8)\n if filename.endswith((\".tiff\", \".tif\")):\n success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)\n if success:\n # Handle RGB images in tif/tiff format\n return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)\n return None\n else:\n im = cv2.imdecode(file_bytes, flags)\n return im[..., None] if im.ndim == 2 else im # Always ensure 3 dimensions",
"chunk_type": "function",
"name": "imread",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 18,
"end_line": 42,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": "Read an image from a file with multilanguage filename support.\n\nArgs:\n filename (str): Path to the file to read.\n flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.\n\nReturns:\n (np.ndarray | None): The read image array, or None if reading fails.\n\nExamples:\n >>> img = imread(\"path/to/image.jpg\")\n >>> img = imread(\"path/to/image.jpg\", cv2.IMREAD_GRAYSCALE)",
"parameters": [
"filename: str",
"flags: int"
],
"return_type": "Optional[np.ndarray]",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"time",
"contextlib.contextmanager",
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.torch_utils.TORCH_1_13"
],
"chunk_id": "function_imread_53ca1577"
},
{
"content": "def imwrite(filename: str, img: np.ndarray, params: Optional[List[int]] = None) -> bool:\n \"\"\"\n Write an image to a file with multilanguage filename support.\n\n Args:\n filename (str): Path to the file to write.\n img (np.ndarray): Image to write.\n params (List[int], optional): Additional parameters for image encoding.\n\n Returns:\n (bool): True if the file was written successfully, False otherwise.\n\n Examples:\n >>> import numpy as np\n >>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image\n >>> success = imwrite(\"output.jpg\", img) # Write image to file\n >>> print(success)\n True\n \"\"\"\n try:\n cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)\n return True\n except Exception:\n return False",
"chunk_type": "function",
"name": "imwrite",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 45,
"end_line": 68,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Write an image to a file with multilanguage filename support.\n\nArgs:\n filename (str): Path to the file to write.\n img (np.ndarray): Image to write.\n params (List[int], optional): Additional parameters for image encoding.\n\nReturns:\n (bool): True if the file was written successfully, False otherwise.\n\nExamples:\n >>> import numpy as np\n >>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image\n >>> success = imwrite(\"output.jpg\", img) # Write image to file\n >>> print(success)\n True",
"parameters": [
"filename: str",
"img: np.ndarray",
"params: Optional[List[int]]"
],
"return_type": "bool",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"time",
"contextlib.contextmanager",
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.torch_utils.TORCH_1_13"
],
"chunk_id": "function_imwrite_bfb76116"
},
{
"content": "def imshow(winname: str, mat: np.ndarray) -> None:\n \"\"\"\n Display an image in the specified window with multilanguage window name support.\n\n This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles\n multilanguage window names by encoding them properly for OpenCV compatibility.\n\n Args:\n winname (str): Name of the window where the image will be displayed. If a window with this name already\n exists, the image will be displayed in that window.\n mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.\n\n Examples:\n >>> import numpy as np\n >>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image\n >>> img[:100, :100] = [255, 0, 0] # Add a blue square\n >>> imshow(\"Example Window\", img) # Display the image\n \"\"\"\n _imshow(winname.encode(\"unicode_escape\").decode(), mat)",
"chunk_type": "function",
"name": "imshow",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 71,
"end_line": 89,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": "Display an image in the specified window with multilanguage window name support.\n\nThis function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles\nmultilanguage window names by encoding them properly for OpenCV compatibility.\n\nArgs:\n winname (str): Name of the window where the image will be displayed. If a window with this name already\n exists, the image will be displayed in that window.\n mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.\n\nExamples:\n >>> import numpy as np\n >>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image\n >>> img[:100, :100] = [255, 0, 0] # Add a blue square\n >>> imshow(\"Example Window\", img) # Display the image",
"parameters": [
"winname: str",
"mat: np.ndarray"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"time",
"contextlib.contextmanager",
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.torch_utils.TORCH_1_13"
],
"chunk_id": "function_imshow_a3da89bb"
},
{
"content": "_torch_save = torch.save",
"chunk_type": "variable",
"name": "_torch_save",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 93,
"end_line": 93,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable__torch_save_18ff7425"
},
{
"content": "def torch_load(*args, **kwargs):\n \"\"\"\n Load a PyTorch model with updated arguments to avoid warnings.\n\n This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.\n\n Args:\n *args (Any): Variable length argument list to pass to torch.load.\n **kwargs (Any): Arbitrary keyword arguments to pass to torch.load.\n\n Returns:\n (Any): The loaded PyTorch object.\n\n Notes:\n For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'\n if the argument is not provided, to avoid deprecation warnings.\n \"\"\"\n from ultralytics.utils.torch_utils import TORCH_1_13\n\n if TORCH_1_13 and \"weights_only\" not in kwargs:\n kwargs[\"weights_only\"] = False\n\n return torch.load(*args, **kwargs)",
"chunk_type": "function",
"name": "torch_load",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 96,
"end_line": 118,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": "Load a PyTorch model with updated arguments to avoid warnings.\n\nThis function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.\n\nArgs:\n *args (Any): Variable length argument list to pass to torch.load.\n **kwargs (Any): Arbitrary keyword arguments to pass to torch.load.\n\nReturns:\n (Any): The loaded PyTorch object.\n\nNotes:\n For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'\n if the argument is not provided, to avoid deprecation warnings.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"time",
"contextlib.contextmanager",
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.torch_utils.TORCH_1_13"
],
"chunk_id": "function_torch_load_8dfa9a14"
},
{
"content": "def torch_save(*args, **kwargs):\n \"\"\"\n Save PyTorch objects with retry mechanism for robustness.\n\n This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur\n due to device flushing delays or antivirus scanning.\n\n Args:\n *args (Any): Positional arguments to pass to torch.save.\n **kwargs (Any): Keyword arguments to pass to torch.save.\n\n Examples:\n >>> model = torch.nn.Linear(10, 1)\n >>> torch_save(model.state_dict(), \"model.pt\")\n \"\"\"\n for i in range(4): # 3 retries\n try:\n return _torch_save(*args, **kwargs)\n except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan\n if i == 3:\n raise e\n time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s",
"chunk_type": "function",
"name": "torch_save",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 121,
"end_line": 142,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": "Save PyTorch objects with retry mechanism for robustness.\n\nThis function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur\ndue to device flushing delays or antivirus scanning.\n\nArgs:\n *args (Any): Positional arguments to pass to torch.save.\n **kwargs (Any): Keyword arguments to pass to torch.save.\n\nExamples:\n >>> model = torch.nn.Linear(10, 1)\n >>> torch_save(model.state_dict(), \"model.pt\")",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"time",
"contextlib.contextmanager",
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.torch_utils.TORCH_1_13"
],
"chunk_id": "function_torch_save_48c194ec"
},
{
"content": "def arange_patch(args):\n \"\"\"\n Workaround for ONNX torch.arange incompatibility with FP16.\n\n https://github.com/pytorch/pytorch/issues/148041.\n \"\"\"\n if args.dynamic and args.half and args.format == \"onnx\":\n func = torch.arange\n\n def arange(*args, dtype=None, **kwargs):\n \"\"\"Return a 1-D tensor of size with values from the interval and common difference.\"\"\"\n return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype\n\n torch.arange = arange # patch\n yield\n torch.arange = func # unpatch\n else:\n yield",
"chunk_type": "function",
"name": "arange_patch",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 146,
"end_line": 163,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": "Workaround for ONNX torch.arange incompatibility with FP16.\n\nhttps://github.com/pytorch/pytorch/issues/148041.",
"parameters": [
"args"
],
"return_type": null,
"decorators": [
"contextmanager"
],
"complexity_score": 2,
"dependencies": [
"time",
"contextlib.contextmanager",
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.torch_utils.TORCH_1_13"
],
"chunk_id": "function_arange_patch_143727d9"
},
{
"content": "def override_configs(args, overrides: Optional[Dict[str, Any]] = None):\n \"\"\"\n Context manager to temporarily override configurations in args.\n\n Args:\n args (IterableSimpleNamespace): Original configuration arguments.\n overrides (Dict[str, Any]): Dictionary of overrides to apply.\n\n Yields:\n (IterableSimpleNamespace): Configuration arguments with overrides applied.\n \"\"\"\n if overrides:\n original_args = copy(args)\n for key, value in overrides.items():\n setattr(args, key, value)\n try:\n yield args\n finally:\n args.__dict__.update(original_args.__dict__)\n else:\n yield args",
"chunk_type": "function",
"name": "override_configs",
"file_path": "ultralytics\\ultralytics\\utils\\patches.py",
"start_line": 167,
"end_line": 187,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Context manager to temporarily override configurations in args.\n\nArgs:\n args (IterableSimpleNamespace): Original configuration arguments.\n overrides (Dict[str, Any]): Dictionary of overrides to apply.\n\nYields:\n (IterableSimpleNamespace): Configuration arguments with overrides applied.",
"parameters": [
"args",
"overrides: Optional[Dict[str, Any]]"
],
"return_type": null,
"decorators": [
"contextmanager"
],
"complexity_score": 3,
"dependencies": [
"time",
"contextlib.contextmanager",
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"torch",
"ultralytics.utils.torch_utils.TORCH_1_13"
],
"chunk_id": "function_override_configs_29af933f"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_4cce6133"
},
{
"content": "import warnings",
"chunk_type": "import",
"name": "warnings",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_warnings_8fac1370"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_fd97ea76"
},
{
"content": "from typing import Any, Callable, Dict, List, Optional, Union",
"chunk_type": "import",
"name": "Any, Callable, Dict, List, Optional, Union",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Callable, Dict, List, Optional, Union_297079a9"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_e3ab8784"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_3015f2ee"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_0baed0fc"
},
{
"content": "from PIL import Image, ImageDraw, ImageFont",
"chunk_type": "import",
"name": "Image, ImageDraw, ImageFont",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Image, ImageDraw, ImageFont_d1d0340f"
},
{
"content": "from PIL import __version__ as pil_version",
"chunk_type": "import",
"name": "__version__",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import___version___08b5082a"
},
{
"content": "from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded",
"chunk_type": "import",
"name": "IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 97,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded_bccb5e0f"
},
{
"content": "from ultralytics.utils.checks import check_font, check_version, is_ascii",
"chunk_type": "import",
"name": "check_font, check_version, is_ascii",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_font, check_version, is_ascii_78665ee3"
},
{
"content": "from ultralytics.utils.files import increment_path",
"chunk_type": "import",
"name": "increment_path",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_increment_path_6587a485"
},
{
"content": "class Colors:\n \"\"\"\n Ultralytics color palette for visualization and plotting.\n\n This class provides methods to work with the Ultralytics color palette, including converting hex color codes to\n RGB values and accessing predefined color schemes for object detection and pose estimation.\n\n Attributes:\n palette (List[tuple]): List of RGB color tuples for general use.\n n (int): The number of colors in the palette.\n pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.\n\n Examples:\n >>> from ultralytics.utils.plotting import Colors\n >>> colors = Colors()\n >>> colors(5, True) # Returns BGR format: (221, 111, 255)\n >>> colors(5, False) # Returns RGB format: (255, 111, 221)\n\n ## Ultralytics Color Palette\n\n | Index | Color | HEX | RGB |\n |-------|-------------------------------------------------------------------|-----------|-------------------|\n | 0 | | `#042aff` | (4, 42, 255) |\n | 1 | | `#0bdbeb` | (11, 219, 235) |\n | 2 | | `#f3f3f3` | (243, 243, 243) |\n | 3 | | `#00dfb7` | (0, 223, 183) |\n | 4 | | `#111f68` | (17, 31, 104) |\n | 5 | | `#ff6fdd` | (255, 111, 221) |\n | 6 | | `#ff444f` | (255, 68, 79) |\n | 7 | | `#cced00` | (204, 237, 0) |\n | 8 | | `#00f344` | (0, 243, 68) |\n | 9 | | `#bd00ff` | (189, 0, 255) |\n | 10 | | `#00b4ff` | (0, 180, 255) |\n | 11 | | `#dd00ba` | (221, 0, 186) |\n | 12 | | `#00ffff` | (0, 255, 255) |\n | 13 | | `#26c000` | (38, 192, 0) |\n | 14 | | `#01ffb3` | (1, 255, 179) |\n | 15 | | `#7d24ff` | (125, 36, 255) |\n | 16 | | `#7b0068` | (123, 0, 104) |\n | 17 | | `#ff1b6c` | (255, 27, 108) |\n | 18 | | `#fc6d2f` | (252, 109, 47) |\n | 19 | | `#a2ff0b` | (162, 255, 11) |\n\n ## Pose Color Palette\n\n | Index | Color | HEX | RGB |\n |-------|-------------------------------------------------------------------|-----------|-------------------|\n | 0 | | `#ff8000` | (255, 128, 0) |\n | 1 | | `#ff9933` | (255, 153, 51) |\n | 2 | | `#ffb266` | (255, 178, 102) |\n | 3 | | `#e6e600` | (230, 230, 0) |\n | 4 | | `#ff99ff` | (255, 153, 255) |\n | 5 | | `#99ccff` | (153, 204, 255) |\n | 6 | | `#ff66ff` | (255, 102, 255) |\n | 7 | | `#ff33ff` | (255, 51, 255) |\n | 8 | | `#66b2ff` | (102, 178, 255) |\n | 9 | | `#3399ff` | (51, 153, 255) |\n | 10 | | `#ff9999` | (255, 153, 153) |\n | 11 | | `#ff6666` | (255, 102, 102) |\n | 12 | | `#ff3333` | (255, 51, 51) |\n | 13 | | `#99ff99` | (153, 255, 153) |\n | 14 | | `#66ff66` | (102, 255, 102) |\n | 15 | | `#33ff33` | (51, 255, 51) |\n | 16 | | `#00ff00` | (0, 255, 0) |\n | 17 | | `#0000ff` | (0, 0, 255) |\n | 18 | | `#ff0000` | (255, 0, 0) |\n | 19 | | `#ffffff` | (255, 255, 255) |\n\n !!! note \"Ultralytics Brand Colors\"\n\n For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).\n Please use the official Ultralytics colors for all marketing materials.\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().\"\"\"\n hexs = (\n \"042AFF\",\n \"0BDBEB\",\n \"F3F3F3\",\n \"00DFB7\",\n \"111F68\",\n \"FF6FDD\",\n \"FF444F\",\n \"CCED00\",\n \"00F344\",\n \"BD00FF\",\n \"00B4FF\",\n \"DD00BA\",\n \"00FFFF\",\n \"26C000\",\n \"01FFB3\",\n \"7D24FF\",\n \"7B0068\",\n \"FF1B6C\",\n \"FC6D2F\",\n \"A2FF0B\",\n )\n self.palette = [self.hex2rgb(f\"#{c}\") for c in hexs]\n self.n = len(self.palette)\n self.pose_palette = np.array(\n [\n [255, 128, 0],\n [255, 153, 51],\n [255, 178, 102],\n [230, 230, 0],\n [255, 153, 255],\n [153, 204, 255],\n [255, 102, 255],\n [255, 51, 255],\n [102, 178, 255],\n [51, 153, 255],\n [255, 153, 153],\n [255, 102, 102],\n [255, 51, 51],\n [153, 255, 153],\n [102, 255, 102],\n [51, 255, 51],\n [0, 255, 0],\n [0, 0, 255],\n [255, 0, 0],\n [255, 255, 255],\n ],\n dtype=np.uint8,\n )\n\n def __call__(self, i: int, bgr: bool = False) -> tuple:\n \"\"\"\n Convert hex color codes to RGB values.\n\n Args:\n i (int): Color index.\n bgr (bool, optional): Whether to return BGR format instead of RGB.\n\n Returns:\n (tuple): RGB or BGR color tuple.\n \"\"\"\n c = self.palette[int(i) % self.n]\n return (c[2], c[1], c[0]) if bgr else c\n\n @staticmethod\n def hex2rgb(h: str) -> tuple:\n \"\"\"Convert hex color codes to RGB values (i.e. default PIL order).\"\"\"\n return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))",
"chunk_type": "class",
"name": "Colors",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 19,
"end_line": 162,
"start_col": 0,
"end_col": 70,
"parent_name": null,
"docstring": "Ultralytics color palette for visualization and plotting.\n\nThis class provides methods to work with the Ultralytics color palette, including converting hex color codes to\nRGB values and accessing predefined color schemes for object detection and pose estimation.\n\nAttributes:\n palette (List[tuple]): List of RGB color tuples for general use.\n n (int): The number of colors in the palette.\n pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.\n\nExamples:\n >>> from ultralytics.utils.plotting import Colors\n >>> colors = Colors()\n >>> colors(5, True) # Returns BGR format: (221, 111, 255)\n >>> colors(5, False) # Returns RGB format: (255, 111, 221)\n\n## Ultralytics Color Palette\n\n| Index | Color | HEX | RGB |\n|-------|-------------------------------------------------------------------|-----------|-------------------|\n| 0 | | `#042aff` | (4, 42, 255) |\n| 1 | | `#0bdbeb` | (11, 219, 235) |\n| 2 | | `#f3f3f3` | (243, 243, 243) |\n| 3 | | `#00dfb7` | (0, 223, 183) |\n| 4 | | `#111f68` | (17, 31, 104) |\n| 5 | | `#ff6fdd` | (255, 111, 221) |\n| 6 | | `#ff444f` | (255, 68, 79) |\n| 7 | | `#cced00` | (204, 237, 0) |\n| 8 | | `#00f344` | (0, 243, 68) |\n| 9 | | `#bd00ff` | (189, 0, 255) |\n| 10 | | `#00b4ff` | (0, 180, 255) |\n| 11 | | `#dd00ba` | (221, 0, 186) |\n| 12 | | `#00ffff` | (0, 255, 255) |\n| 13 | | `#26c000` | (38, 192, 0) |\n| 14 | | `#01ffb3` | (1, 255, 179) |\n| 15 | | `#7d24ff` | (125, 36, 255) |\n| 16 | | `#7b0068` | (123, 0, 104) |\n| 17 | | `#ff1b6c` | (255, 27, 108) |\n| 18 | | `#fc6d2f` | (252, 109, 47) |\n| 19 | | `#a2ff0b` | (162, 255, 11) |\n\n## Pose Color Palette\n\n| Index | Color | HEX | RGB |\n|-------|-------------------------------------------------------------------|-----------|-------------------|\n| 0 | | `#ff8000` | (255, 128, 0) |\n| 1 | | `#ff9933` | (255, 153, 51) |\n| 2 | | `#ffb266` | (255, 178, 102) |\n| 3 | | `#e6e600` | (230, 230, 0) |\n| 4 | | `#ff99ff` | (255, 153, 255) |\n| 5 | | `#99ccff` | (153, 204, 255) |\n| 6 | | `#ff66ff` | (255, 102, 255) |\n| 7 | | `#ff33ff` | (255, 51, 255) |\n| 8 | | `#66b2ff` | (102, 178, 255) |\n| 9 | | `#3399ff` | (51, 153, 255) |\n| 10 | | `#ff9999` | (255, 153, 153) |\n| 11 | | `#ff6666` | (255, 102, 102) |\n| 12 | | `#ff3333` | (255, 51, 51) |\n| 13 | | `#99ff99` | (153, 255, 153) |\n| 14 | | `#66ff66` | (102, 255, 102) |\n| 15 | | `#33ff33` | (51, 255, 51) |\n| 16 | | `#00ff00` | (0, 255, 0) |\n| 17 | | `#0000ff` | (0, 0, 255) |\n| 18 | | `#ff0000` | (255, 0, 0) |\n| 19 | | `#ffffff` | (255, 255, 255) |\n\n!!! note \"Ultralytics Brand Colors\"\n\n For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).\n Please use the official Ultralytics colors for all marketing materials.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "class_Colors_52a994d8"
},
{
"content": "colors = Colors() # create instance for 'from utils.plots import colors'",
"chunk_type": "variable",
"name": "colors",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 165,
"end_line": 165,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_colors_0f7f05d2"
},
{
"content": "class Annotator:\n \"\"\"\n Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.\n\n Attributes:\n im (Image.Image | np.ndarray): The image to annotate.\n pil (bool): Whether to use PIL or cv2 for drawing annotations.\n font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.\n lw (float): Line width for drawing.\n skeleton (List[List[int]]): Skeleton structure for keypoints.\n limb_color (List[int]): Color palette for limbs.\n kpt_color (List[int]): Color palette for keypoints.\n dark_colors (set): Set of colors considered dark for text contrast.\n light_colors (set): Set of colors considered light for text contrast.\n\n Examples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.box_label([10, 10, 100, 100], \"person\", (255, 0, 0))\n \"\"\"\n\n def __init__(\n self,\n im,\n line_width: Optional[int] = None,\n font_size: Optional[int] = None,\n font: str = \"Arial.ttf\",\n pil: bool = False,\n example: str = \"abc\",\n ):\n \"\"\"Initialize the Annotator class with image and line width along with color palette for keypoints and limbs.\"\"\"\n non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic\n input_is_pil = isinstance(im, Image.Image)\n self.pil = pil or non_ascii or input_is_pil\n self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)\n if not input_is_pil:\n if im.shape[2] == 1: # handle grayscale\n im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)\n elif im.shape[2] > 3: # multispectral\n im = np.ascontiguousarray(im[..., :3])\n if self.pil: # use PIL\n self.im = im if input_is_pil else Image.fromarray(im)\n if self.im.mode not in {\"RGB\", \"RGBA\"}: # multispectral\n self.im = self.im.convert(\"RGB\")\n self.draw = ImageDraw.Draw(self.im, \"RGBA\")\n try:\n font = check_font(\"Arial.Unicode.ttf\" if non_ascii else font)\n size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)\n self.font = ImageFont.truetype(str(font), size)\n except Exception:\n self.font = ImageFont.load_default()\n # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)\n if check_version(pil_version, \"9.2.0\"):\n self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height\n else: # use cv2\n assert im.data.contiguous, \"Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images.\"\n self.im = im if im.flags.writeable else im.copy()\n self.tf = max(self.lw - 1, 1) # font thickness\n self.sf = self.lw / 3 # font scale\n # Pose\n self.skeleton = [\n [16, 14],\n [14, 12],\n [17, 15],\n [15, 13],\n [12, 13],\n [6, 12],\n [7, 13],\n [6, 7],\n [6, 8],\n [7, 9],\n [8, 10],\n [9, 11],\n [2, 3],\n [1, 2],\n [1, 3],\n [2, 4],\n [3, 5],\n [4, 6],\n [5, 7],\n ]\n\n self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]\n self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]\n self.dark_colors = {\n (235, 219, 11),\n (243, 243, 243),\n (183, 223, 0),\n (221, 111, 255),\n (0, 237, 204),\n (68, 243, 0),\n (255, 255, 0),\n (179, 255, 1),\n (11, 255, 162),\n }\n self.light_colors = {\n (255, 42, 4),\n (79, 68, 255),\n (255, 0, 189),\n (255, 180, 0),\n (186, 0, 221),\n (0, 192, 38),\n (255, 36, 125),\n (104, 0, 123),\n (108, 27, 255),\n (47, 109, 252),\n (104, 31, 17),\n }\n\n def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:\n \"\"\"\n Assign text color based on background color.\n\n Args:\n color (tuple, optional): The background color of the rectangle for text (B, G, R).\n txt_color (tuple, optional): The color of the text (R, G, B).\n\n Returns:\n (tuple): Text color for label.\n\n Examples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.get_txt_color(color=(104, 31, 17)) # return (255, 255, 255)\n \"\"\"\n if color in self.dark_colors:\n return 104, 31, 17\n elif color in self.light_colors:\n return 255, 255, 255\n else:\n return txt_color\n\n def box_label(self, box, label: str = \"\", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):\n \"\"\"\n Draw a bounding box on an image with a given label.\n\n Args:\n box (tuple): The bounding box coordinates (x1, y1, x2, y2).\n label (str, optional): The text label to be displayed.\n color (tuple, optional): The background color of the rectangle (B, G, R).\n txt_color (tuple, optional): The color of the text (R, G, B).\n\n Examples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.box_label(box=[10, 20, 30, 40], label=\"person\")\n \"\"\"\n txt_color = self.get_txt_color(color, txt_color)\n if isinstance(box, torch.Tensor):\n box = box.tolist()\n\n multi_points = isinstance(box[0], list) # multiple points with shape (n, 2)\n p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1]))\n if self.pil:\n self.draw.polygon(\n [tuple(b) for b in box], width=self.lw, outline=color\n ) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color)\n if label:\n w, h = self.font.getsize(label) # text width, height\n outside = p1[1] >= h # label fits outside box\n if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image\n p1 = self.im.size[0] - w, p1[1]\n self.draw.rectangle(\n (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),\n fill=color,\n )\n # self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0\n self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)\n else: # cv2\n cv2.polylines(\n self.im, [np.asarray(box, dtype=int)], True, color, self.lw\n ) if multi_points else cv2.rectangle(\n self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA\n )\n if label:\n w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height\n h += 3 # add pixels to pad text\n outside = p1[1] >= h # label fits outside box\n if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image\n p1 = self.im.shape[1] - w, p1[1]\n p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h\n cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled\n cv2.putText(\n self.im,\n label,\n (p1[0], p1[1] - 2 if outside else p1[1] + h - 1),\n 0,\n self.sf,\n txt_color,\n thickness=self.tf,\n lineType=cv2.LINE_AA,\n )\n\n def masks(self, masks, colors, im_gpu, alpha: float = 0.5, retina_masks: bool = False):\n \"\"\"\n Plot masks on image.\n\n Args:\n masks (torch.Tensor): Predicted masks on cuda, shape: [n, h, w]\n colors (List[List[int]]): Colors for predicted masks, [[r, g, b] * n]\n im_gpu (torch.Tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]\n alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque.\n retina_masks (bool, optional): Whether to use high resolution masks or not.\n \"\"\"\n if self.pil:\n # Convert to numpy first\n self.im = np.asarray(self.im).copy()\n if len(masks) == 0:\n self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255\n if im_gpu.device != masks.device:\n im_gpu = im_gpu.to(masks.device)\n colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)\n colors = colors[:, None, None] # shape(n,1,1,3)\n masks = masks.unsqueeze(3) # shape(n,h,w,1)\n masks_color = masks * (colors * alpha) # shape(n,h,w,3)\n\n inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)\n mcs = masks_color.max(dim=0).values # shape(n,h,w,3)\n\n im_gpu = im_gpu.flip(dims=[0]) # flip channel\n im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)\n im_gpu = im_gpu * inv_alpha_masks[-1] + mcs\n im_mask = im_gpu * 255\n im_mask_np = im_mask.byte().cpu().numpy()\n self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)\n if self.pil:\n # Convert im back to PIL and update draw\n self.fromarray(self.im)\n\n def kpts(\n self,\n kpts,\n shape: tuple = (640, 640),\n radius: Optional[int] = None,\n kpt_line: bool = True,\n conf_thres: float = 0.25,\n kpt_color: Optional[tuple] = None,\n ):\n \"\"\"\n Plot keypoints on the image.\n\n Args:\n kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).\n shape (tuple, optional): Image shape (h, w).\n radius (int, optional): Keypoint radius.\n kpt_line (bool, optional): Draw lines between keypoints.\n conf_thres (float, optional): Confidence threshold.\n kpt_color (tuple, optional): Keypoint color (B, G, R).\n\n Note:\n - `kpt_line=True` currently only supports human pose plotting.\n - Modifies self.im in-place.\n - If self.pil is True, converts image to numpy array and back to PIL.\n \"\"\"\n radius = radius if radius is not None else self.lw\n if self.pil:\n # Convert to numpy first\n self.im = np.asarray(self.im).copy()\n nkpt, ndim = kpts.shape\n is_pose = nkpt == 17 and ndim in {2, 3}\n kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting\n for i, k in enumerate(kpts):\n color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))\n x_coord, y_coord = k[0], k[1]\n if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:\n if len(k) == 3:\n conf = k[2]\n if conf < conf_thres:\n continue\n cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)\n\n if kpt_line:\n ndim = kpts.shape[-1]\n for i, sk in enumerate(self.skeleton):\n pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))\n pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))\n if ndim == 3:\n conf1 = kpts[(sk[0] - 1), 2]\n conf2 = kpts[(sk[1] - 1), 2]\n if conf1 < conf_thres or conf2 < conf_thres:\n continue\n if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:\n continue\n if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:\n continue\n cv2.line(\n self.im,\n pos1,\n pos2,\n kpt_color or self.limb_color[i].tolist(),\n thickness=int(np.ceil(self.lw / 2)),\n lineType=cv2.LINE_AA,\n )\n if self.pil:\n # Convert im back to PIL and update draw\n self.fromarray(self.im)\n\n def rectangle(self, xy, fill=None, outline=None, width: int = 1):\n \"\"\"Add rectangle to image (PIL-only).\"\"\"\n self.draw.rectangle(xy, fill, outline, width)\n\n def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = \"top\", box_color: tuple = ()):\n \"\"\"\n Add text to an image using PIL or cv2.\n\n Args:\n xy (List[int]): Top-left coordinates for text placement.\n text (str): Text to be drawn.\n txt_color (tuple, optional): Text color (R, G, B).\n anchor (str, optional): Text anchor position ('top' or 'bottom').\n box_color (tuple, optional): Box color (R, G, B, A) with optional alpha.\n \"\"\"\n if self.pil:\n w, h = self.font.getsize(text)\n if anchor == \"bottom\": # start y from font bottom\n xy[1] += 1 - h\n for line in text.split(\"\\n\"):\n if box_color:\n # Draw rectangle for each line\n w, h = self.font.getsize(line)\n self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=box_color)\n self.draw.text(xy, line, fill=txt_color, font=self.font)\n xy[1] += h\n else:\n if box_color:\n w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]\n h += 3 # add pixels to pad text\n outside = xy[1] >= h # label fits outside box\n p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h\n cv2.rectangle(self.im, xy, p2, box_color, -1, cv2.LINE_AA) # filled\n cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)\n\n def fromarray(self, im):\n \"\"\"Update self.im from a numpy array.\"\"\"\n self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)\n self.draw = ImageDraw.Draw(self.im)\n\n def result(self):\n \"\"\"Return annotated image as array.\"\"\"\n return np.asarray(self.im)\n\n def show(self, title: Optional[str] = None):\n \"\"\"Show the annotated image.\"\"\"\n im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR\n if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments\n try:\n display(im) # noqa - display() function only available in ipython environments\n except ImportError as e:\n LOGGER.warning(f\"Unable to display image in Jupyter notebooks: {e}\")\n else:\n im.show(title=title)\n\n def save(self, filename: str = \"image.jpg\"):\n \"\"\"Save the annotated image to 'filename'.\"\"\"\n cv2.imwrite(filename, np.asarray(self.im))\n\n @staticmethod\n def get_bbox_dimension(bbox: Optional[tuple] = None):\n \"\"\"\n Calculate the dimensions and area of a bounding box.\n\n Args:\n bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).\n\n Returns:\n width (float): Width of the bounding box.\n height (float): Height of the bounding box.\n area (float): Area enclosed by the bounding box.\n\n Examples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.get_bbox_dimension(bbox=[10, 20, 30, 40])\n \"\"\"\n x_min, y_min, x_max, y_max = bbox\n width = x_max - x_min\n height = y_max - y_min\n return width, height, width * height",
"chunk_type": "class",
"name": "Annotator",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 168,
"end_line": 549,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.\n\nAttributes:\n im (Image.Image | np.ndarray): The image to annotate.\n pil (bool): Whether to use PIL or cv2 for drawing annotations.\n font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.\n lw (float): Line width for drawing.\n skeleton (List[List[int]]): Skeleton structure for keypoints.\n limb_color (List[int]): Color palette for limbs.\n kpt_color (List[int]): Color palette for keypoints.\n dark_colors (set): Set of colors considered dark for text contrast.\n light_colors (set): Set of colors considered light for text contrast.\n\nExamples:\n >>> from ultralytics.utils.plotting import Annotator\n >>> im0 = cv2.imread(\"test.png\")\n >>> annotator = Annotator(im0, line_width=10)\n >>> annotator.box_label([10, 10, 100, 100], \"person\", (255, 0, 0))",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "class_Annotator_82473fd1"
},
{
"content": "def plot_labels(boxes, cls, names=(), save_dir=Path(\"\"), on_plot=None):\n \"\"\"\n Plot training labels including class histograms and box statistics.\n\n Args:\n boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].\n cls (np.ndarray): Class indices.\n names (dict, optional): Dictionary mapping class indices to class names.\n save_dir (Path, optional): Directory to save the plot.\n on_plot (Callable, optional): Function to call after plot is saved.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n import pandas\n from matplotlib.colors import LinearSegmentedColormap\n\n # Filter matplotlib>=3.7.2 warning\n warnings.filterwarnings(\"ignore\", category=UserWarning, message=\"The figure layout has changed to tight\")\n warnings.filterwarnings(\"ignore\", category=FutureWarning)\n\n # Plot dataset labels\n LOGGER.info(f\"Plotting labels to {save_dir / 'labels.jpg'}... \")\n nc = int(cls.max() + 1) # number of classes\n boxes = boxes[:1000000] # limit to 1M boxes\n x = pandas.DataFrame(boxes, columns=[\"x\", \"y\", \"width\", \"height\"])\n\n try: # Seaborn correlogram\n import seaborn\n\n seaborn.pairplot(x, corner=True, diag_kind=\"auto\", kind=\"hist\", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))\n plt.savefig(save_dir / \"labels_correlogram.jpg\", dpi=200)\n plt.close()\n except ImportError:\n pass # Skip if seaborn is not installed\n\n # Matplotlib labels\n subplot_3_4_color = LinearSegmentedColormap.from_list(\"white_blue\", [\"white\", \"blue\"])\n ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()\n y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)\n for i in range(nc):\n y[2].patches[i].set_color([x / 255 for x in colors(i)])\n ax[0].set_ylabel(\"instances\")\n if 0 < len(names) < 30:\n ax[0].set_xticks(range(len(names)))\n ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)\n else:\n ax[0].set_xlabel(\"classes\")\n boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000\n img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)\n for cls, box in zip(cls[:500], boxes[:500]):\n ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot\n ax[1].imshow(img)\n ax[1].axis(\"off\")\n\n ax[2].hist2d(x[\"x\"], x[\"y\"], bins=50, cmap=subplot_3_4_color)\n ax[2].set_xlabel(\"x\")\n ax[2].set_ylabel(\"y\")\n ax[3].hist2d(x[\"width\"], x[\"height\"], bins=50, cmap=subplot_3_4_color)\n ax[3].set_xlabel(\"width\")\n ax[3].set_ylabel(\"height\")\n for a in {0, 1, 2, 3}:\n for s in {\"top\", \"right\", \"left\", \"bottom\"}:\n ax[a].spines[s].set_visible(False)\n\n fname = save_dir / \"labels.jpg\"\n plt.savefig(fname, dpi=200)\n plt.close()\n if on_plot:\n on_plot(fname)",
"chunk_type": "function",
"name": "plot_labels",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 554,
"end_line": 621,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Plot training labels including class histograms and box statistics.\n\nArgs:\n boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].\n cls (np.ndarray): Class indices.\n names (dict, optional): Dictionary mapping class indices to class names.\n save_dir (Path, optional): Directory to save the plot.\n on_plot (Callable, optional): Function to call after plot is saved.",
"parameters": [
"boxes",
"cls",
"names",
"save_dir",
"on_plot"
],
"return_type": null,
"decorators": [
"TryExcept()",
"plt_settings()"
],
"complexity_score": 9,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "function_plot_labels_0423c306"
},
{
"content": "def save_one_box(\n xyxy,\n im,\n file: Path = Path(\"im.jpg\"),\n gain: float = 1.02,\n pad: int = 10,\n square: bool = False,\n BGR: bool = False,\n save: bool = True,\n):\n \"\"\"\n Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.\n\n This function takes a bounding box and an image, and then saves a cropped portion of the image according\n to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding\n adjustments to the bounding box.\n\n Args:\n xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.\n im (np.ndarray): The input image.\n file (Path, optional): The path where the cropped image will be saved.\n gain (float, optional): A multiplicative factor to increase the size of the bounding box.\n pad (int, optional): The number of pixels to add to the width and height of the bounding box.\n square (bool, optional): If True, the bounding box will be transformed into a square.\n BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.\n save (bool, optional): If True, the cropped image will be saved to disk.\n\n Returns:\n (np.ndarray): The cropped image.\n\n Examples:\n >>> from ultralytics.utils.plotting import save_one_box\n >>> xyxy = [50, 50, 150, 150]\n >>> im = cv2.imread(\"image.jpg\")\n >>> cropped_im = save_one_box(xyxy, im, file=\"cropped.jpg\", square=True)\n \"\"\"\n if not isinstance(xyxy, torch.Tensor): # may be list\n xyxy = torch.stack(xyxy)\n b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes\n if square:\n b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square\n b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad\n xyxy = ops.xywh2xyxy(b).long()\n xyxy = ops.clip_boxes(xyxy, im.shape)\n grayscale = im.shape[2] == 1 # grayscale image\n crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)]\n if save:\n file.parent.mkdir(parents=True, exist_ok=True) # make directory\n f = str(increment_path(file).with_suffix(\".jpg\"))\n # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue\n crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop\n Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB\n return crop",
"chunk_type": "function",
"name": "save_one_box",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 624,
"end_line": 676,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.\n\nThis function takes a bounding box and an image, and then saves a cropped portion of the image according\nto the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding\nadjustments to the bounding box.\n\nArgs:\n xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.\n im (np.ndarray): The input image.\n file (Path, optional): The path where the cropped image will be saved.\n gain (float, optional): A multiplicative factor to increase the size of the bounding box.\n pad (int, optional): The number of pixels to add to the width and height of the bounding box.\n square (bool, optional): If True, the bounding box will be transformed into a square.\n BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.\n save (bool, optional): If True, the cropped image will be saved to disk.\n\nReturns:\n (np.ndarray): The cropped image.\n\nExamples:\n >>> from ultralytics.utils.plotting import save_one_box\n >>> xyxy = [50, 50, 150, 150]\n >>> im = cv2.imread(\"image.jpg\")\n >>> cropped_im = save_one_box(xyxy, im, file=\"cropped.jpg\", square=True)",
"parameters": [
"xyxy",
"im",
"file: Path",
"gain: float",
"pad: int",
"square: bool",
"BGR: bool",
"save: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "function_save_one_box_081a27c2"
},
{
"content": "def plot_images(\n labels: Dict[str, Any],\n images: Union[torch.Tensor, np.ndarray] = np.zeros((0, 3, 640, 640), dtype=np.float32),\n paths: Optional[List[str]] = None,\n fname: str = \"images.jpg\",\n names: Optional[Dict[int, str]] = None,\n on_plot: Optional[Callable] = None,\n max_size: int = 1920,\n max_subplots: int = 16,\n save: bool = True,\n conf_thres: float = 0.25,\n) -> Optional[np.ndarray]:\n \"\"\"\n Plot image grid with labels, bounding boxes, masks, and keypoints.\n\n Args:\n labels (Dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.\n images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).\n paths (Optional[List[str]]): List of file paths for each image in the batch.\n fname (str): Output filename for the plotted image grid.\n names (Optional[Dict[int, str]]): Dictionary mapping class indices to class names.\n on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.\n max_size (int): Maximum size of the output image grid.\n max_subplots (int): Maximum number of subplots in the image grid.\n save (bool): Whether to save the plotted image grid to a file.\n conf_thres (float): Confidence threshold for displaying detections.\n\n Returns:\n (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.\n\n Note:\n This function supports both tensor and numpy array inputs. It will automatically\n convert tensor inputs to numpy arrays for processing.\n \"\"\"\n for k in {\"cls\", \"bboxes\", \"conf\", \"masks\", \"keypoints\", \"batch_idx\", \"images\"}:\n if k not in labels:\n continue\n if k == \"cls\" and labels[k].ndim == 2:\n labels[k] = labels[k].squeeze(1) # squeeze if shape is (n, 1)\n if isinstance(labels[k], torch.Tensor):\n labels[k] = labels[k].cpu().numpy()\n\n cls = labels.get(\"cls\", np.zeros(0, dtype=np.int64))\n batch_idx = labels.get(\"batch_idx\", np.zeros(cls.shape, dtype=np.int64))\n bboxes = labels.get(\"bboxes\", np.zeros(0, dtype=np.float32))\n confs = labels.get(\"conf\", None)\n masks = labels.get(\"masks\", np.zeros(0, dtype=np.uint8))\n kpts = labels.get(\"keypoints\", np.zeros(0, dtype=np.float32))\n images = labels.get(\"img\", images) # default to input images\n\n if len(images) and isinstance(images, torch.Tensor):\n images = images.cpu().float().numpy()\n if images.shape[1] > 3:\n images = images[:, :3] # crop multispectral images to first 3 channels\n\n bs, _, h, w = images.shape # batch size, _, height, width\n bs = min(bs, max_subplots) # limit plot images\n ns = np.ceil(bs**0.5) # number of subplots (square)\n if np.max(images[0]) <= 1:\n images *= 255 # de-normalise (optional)\n\n # Build Image\n mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init\n for i in range(bs):\n x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin\n mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)\n\n # Resize (optional)\n scale = max_size / ns / max(h, w)\n if scale < 1:\n h = math.ceil(scale * h)\n w = math.ceil(scale * w)\n mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))\n\n # Annotate\n fs = int((h + w) * ns * 0.01) # font size\n fs = max(fs, 18) # ensure that the font size is large enough to be easily readable.\n annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=str(names))\n for i in range(bs):\n x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin\n annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders\n if paths:\n annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames\n if len(cls) > 0:\n idx = batch_idx == i\n classes = cls[idx].astype(\"int\")\n labels = confs is None\n\n if len(bboxes):\n boxes = bboxes[idx]\n conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)\n if len(boxes):\n if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1\n boxes[..., [0, 2]] *= w # scale to pixels\n boxes[..., [1, 3]] *= h\n elif scale < 1: # absolute coords need scale if image scales\n boxes[..., :4] *= scale\n boxes[..., 0] += x\n boxes[..., 1] += y\n is_obb = boxes.shape[-1] == 5 # xywhr\n # TODO: this transformation might be unnecessary\n boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)\n for j, box in enumerate(boxes.astype(np.int64).tolist()):\n c = classes[j]\n color = colors(c)\n c = names.get(c, c) if names else c\n if labels or conf[j] > conf_thres:\n label = f\"{c}\" if labels else f\"{c} {conf[j]:.1f}\"\n annotator.box_label(box, label, color=color)\n\n elif len(classes):\n for c in classes:\n color = colors(c)\n c = names.get(c, c) if names else c\n annotator.text([x, y], f\"{c}\", txt_color=color, box_color=(64, 64, 64, 128))\n\n # Plot keypoints\n if len(kpts):\n kpts_ = kpts[idx].copy()\n if len(kpts_):\n if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01\n kpts_[..., 0] *= w # scale to pixels\n kpts_[..., 1] *= h\n elif scale < 1: # absolute coords need scale if image scales\n kpts_ *= scale\n kpts_[..., 0] += x\n kpts_[..., 1] += y\n for j in range(len(kpts_)):\n if labels or conf[j] > conf_thres:\n annotator.kpts(kpts_[j], conf_thres=conf_thres)\n\n # Plot masks\n if len(masks):\n if idx.shape[0] == masks.shape[0]: # overlap_masks=False\n image_masks = masks[idx]\n else: # overlap_masks=True\n image_masks = masks[[i]] # (1, 640, 640)\n nl = idx.sum()\n index = np.arange(nl).reshape((nl, 1, 1)) + 1\n image_masks = np.repeat(image_masks, nl, axis=0)\n image_masks = np.where(image_masks == index, 1.0, 0.0)\n\n im = np.asarray(annotator.im).copy()\n for j in range(len(image_masks)):\n if labels or conf[j] > conf_thres:\n color = colors(classes[j])\n mh, mw = image_masks[j].shape\n if mh != h or mw != w:\n mask = image_masks[j].astype(np.uint8)\n mask = cv2.resize(mask, (w, h))\n mask = mask.astype(bool)\n else:\n mask = image_masks[j].astype(bool)\n try:\n im[y : y + h, x : x + w, :][mask] = (\n im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6\n )\n except Exception:\n pass\n annotator.fromarray(im)\n if not save:\n return np.asarray(annotator.im)\n annotator.im.save(fname) # save\n if on_plot:\n on_plot(fname)",
"chunk_type": "function",
"name": "plot_images",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 680,
"end_line": 844,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Plot image grid with labels, bounding boxes, masks, and keypoints.\n\nArgs:\n labels (Dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.\n images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).\n paths (Optional[List[str]]): List of file paths for each image in the batch.\n fname (str): Output filename for the plotted image grid.\n names (Optional[Dict[int, str]]): Dictionary mapping class indices to class names.\n on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.\n max_size (int): Maximum size of the output image grid.\n max_subplots (int): Maximum number of subplots in the image grid.\n save (bool): Whether to save the plotted image grid to a file.\n conf_thres (float): Confidence threshold for displaying detections.\n\nReturns:\n (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.\n\nNote:\n This function supports both tensor and numpy array inputs. It will automatically\n convert tensor inputs to numpy arrays for processing.",
"parameters": [
"labels: Dict[str, Any]",
"images: Union[torch.Tensor, np.ndarray]",
"paths: Optional[List[str]]",
"fname: str",
"names: Optional[Dict[int, str]]",
"on_plot: Optional[Callable]",
"max_size: int",
"max_subplots: int",
"save: bool",
"conf_thres: float"
],
"return_type": "Optional[np.ndarray]",
"decorators": [
"threaded"
],
"complexity_score": 36,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "function_plot_images_4ec7aab8"
},
{
"content": "def plot_results(\n file: str = \"path/to/results.csv\",\n dir: str = \"\",\n segment: bool = False,\n pose: bool = False,\n classify: bool = False,\n on_plot: Optional[Callable] = None,\n):\n \"\"\"\n Plot training results from a results CSV file. The function supports various types of data including segmentation,\n pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.\n\n Args:\n file (str, optional): Path to the CSV file containing the training results.\n dir (str, optional): Directory where the CSV file is located if 'file' is not provided.\n segment (bool, optional): Flag to indicate if the data is for segmentation.\n pose (bool, optional): Flag to indicate if the data is for pose estimation.\n classify (bool, optional): Flag to indicate if the data is for classification.\n on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.\n\n Examples:\n >>> from ultralytics.utils.plotting import plot_results\n >>> plot_results(\"path/to/results.csv\", segment=True)\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n import pandas as pd\n from scipy.ndimage import gaussian_filter1d\n\n save_dir = Path(file).parent if file else Path(dir)\n if classify:\n fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)\n index = [2, 5, 3, 4]\n elif segment:\n fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)\n index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]\n elif pose:\n fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)\n index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]\n else:\n fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)\n index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]\n ax = ax.ravel()\n files = list(save_dir.glob(\"results*.csv\"))\n assert len(files), f\"No results.csv files found in {save_dir.resolve()}, nothing to plot.\"\n for f in files:\n try:\n data = pd.read_csv(f)\n s = [x.strip() for x in data.columns]\n x = data.values[:, 0]\n for i, j in enumerate(index):\n y = data.values[:, j].astype(\"float\")\n # y[y == 0] = np.nan # don't show zero values\n ax[i].plot(x, y, marker=\".\", label=f.stem, linewidth=2, markersize=8) # actual results\n ax[i].plot(x, gaussian_filter1d(y, sigma=3), \":\", label=\"smooth\", linewidth=2) # smoothing line\n ax[i].set_title(s[j], fontsize=12)\n # if j in {8, 9, 10}: # share train and val loss y axes\n # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])\n except Exception as e:\n LOGGER.error(f\"Plotting error for {f}: {e}\")\n ax[1].legend()\n fname = save_dir / \"results.png\"\n fig.savefig(fname, dpi=200)\n plt.close()\n if on_plot:\n on_plot(fname)",
"chunk_type": "function",
"name": "plot_results",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 848,
"end_line": 912,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Plot training results from a results CSV file. The function supports various types of data including segmentation,\npose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.\n\nArgs:\n file (str, optional): Path to the CSV file containing the training results.\n dir (str, optional): Directory where the CSV file is located if 'file' is not provided.\n segment (bool, optional): Flag to indicate if the data is for segmentation.\n pose (bool, optional): Flag to indicate if the data is for pose estimation.\n classify (bool, optional): Flag to indicate if the data is for classification.\n on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.\n\nExamples:\n >>> from ultralytics.utils.plotting import plot_results\n >>> plot_results(\"path/to/results.csv\", segment=True)",
"parameters": [
"file: str",
"dir: str",
"segment: bool",
"pose: bool",
"classify: bool",
"on_plot: Optional[Callable]"
],
"return_type": null,
"decorators": [
"plt_settings()"
],
"complexity_score": 9,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "function_plot_results_72e5b78e"
},
{
"content": "def plt_color_scatter(v, f, bins: int = 20, cmap: str = \"viridis\", alpha: float = 0.8, edgecolors: str = \"none\"):\n \"\"\"\n Plot a scatter plot with points colored based on a 2D histogram.\n\n Args:\n v (array-like): Values for the x-axis.\n f (array-like): Values for the y-axis.\n bins (int, optional): Number of bins for the histogram.\n cmap (str, optional): Colormap for the scatter plot.\n alpha (float, optional): Alpha for the scatter plot.\n edgecolors (str, optional): Edge colors for the scatter plot.\n\n Examples:\n >>> v = np.random.rand(100)\n >>> f = np.random.rand(100)\n >>> plt_color_scatter(v, f)\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n # Calculate 2D histogram and corresponding colors\n hist, xedges, yedges = np.histogram2d(v, f, bins=bins)\n colors = [\n hist[\n min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),\n min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),\n ]\n for i in range(len(v))\n ]\n\n # Scatter plot\n plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)",
"chunk_type": "function",
"name": "plt_color_scatter",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 915,
"end_line": 945,
"start_col": 0,
"end_col": 78,
"parent_name": null,
"docstring": "Plot a scatter plot with points colored based on a 2D histogram.\n\nArgs:\n v (array-like): Values for the x-axis.\n f (array-like): Values for the y-axis.\n bins (int, optional): Number of bins for the histogram.\n cmap (str, optional): Colormap for the scatter plot.\n alpha (float, optional): Alpha for the scatter plot.\n edgecolors (str, optional): Edge colors for the scatter plot.\n\nExamples:\n >>> v = np.random.rand(100)\n >>> f = np.random.rand(100)\n >>> plt_color_scatter(v, f)",
"parameters": [
"v",
"f",
"bins: int",
"cmap: str",
"alpha: float",
"edgecolors: str"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "function_plt_color_scatter_95489aff"
},
{
"content": "def plot_tune_results(csv_file: str = \"tune_results.csv\"):\n \"\"\"\n Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key\n in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.\n\n Args:\n csv_file (str, optional): Path to the CSV file containing the tuning results.\n\n Examples:\n >>> plot_tune_results(\"path/to/tune_results.csv\")\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n import pandas as pd\n from scipy.ndimage import gaussian_filter1d\n\n def _save_one_file(file):\n \"\"\"Save one matplotlib plot to 'file'.\"\"\"\n plt.savefig(file, dpi=200)\n plt.close()\n LOGGER.info(f\"Saved {file}\")\n\n # Scatter plots for each hyperparameter\n csv_file = Path(csv_file)\n data = pd.read_csv(csv_file)\n num_metrics_columns = 1\n keys = [x.strip() for x in data.columns][num_metrics_columns:]\n x = data.values\n fitness = x[:, 0] # fitness\n j = np.argmax(fitness) # max fitness index\n n = math.ceil(len(keys) ** 0.5) # columns and rows in plot\n plt.figure(figsize=(10, 10), tight_layout=True)\n for i, k in enumerate(keys):\n v = x[:, i + num_metrics_columns]\n mu = v[j] # best single result\n plt.subplot(n, n, i + 1)\n plt_color_scatter(v, fitness, cmap=\"viridis\", alpha=0.8, edgecolors=\"none\")\n plt.plot(mu, fitness.max(), \"k+\", markersize=15)\n plt.title(f\"{k} = {mu:.3g}\", fontdict={\"size\": 9}) # limit to 40 characters\n plt.tick_params(axis=\"both\", labelsize=8) # Set axis label size to 8\n if i % n != 0:\n plt.yticks([])\n _save_one_file(csv_file.with_name(\"tune_scatter_plots.png\"))\n\n # Fitness vs iteration\n x = range(1, len(fitness) + 1)\n plt.figure(figsize=(10, 6), tight_layout=True)\n plt.plot(x, fitness, marker=\"o\", linestyle=\"none\", label=\"fitness\")\n plt.plot(x, gaussian_filter1d(fitness, sigma=3), \":\", label=\"smoothed\", linewidth=2) # smoothing line\n plt.title(\"Fitness vs Iteration\")\n plt.xlabel(\"Iteration\")\n plt.ylabel(\"Fitness\")\n plt.grid(True)\n plt.legend()\n _save_one_file(csv_file.with_name(\"tune_fitness.png\"))",
"chunk_type": "function",
"name": "plot_tune_results",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 948,
"end_line": 1001,
"start_col": 0,
"end_col": 58,
"parent_name": null,
"docstring": "Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key\nin the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.\n\nArgs:\n csv_file (str, optional): Path to the CSV file containing the tuning results.\n\nExamples:\n >>> plot_tune_results(\"path/to/tune_results.csv\")",
"parameters": [
"csv_file: str"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "function_plot_tune_results_08ba113c"
},
{
"content": "def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path(\"runs/detect/exp\")):\n \"\"\"\n Visualize feature maps of a given model module during inference.\n\n Args:\n x (torch.Tensor): Features to be visualized.\n module_type (str): Module type.\n stage (int): Module stage within the model.\n n (int, optional): Maximum number of feature maps to plot.\n save_dir (Path, optional): Directory to save results.\n \"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n for m in {\"Detect\", \"Segment\", \"Pose\", \"Classify\", \"OBB\", \"RTDETRDecoder\"}: # all model heads\n if m in module_type:\n return\n if isinstance(x, torch.Tensor):\n _, channels, height, width = x.shape # batch, channels, height, width\n if height > 1 and width > 1:\n f = save_dir / f\"stage{stage}_{module_type.rsplit('.', 1)[-1]}_features.png\" # filename\n\n blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels\n n = min(n, channels) # number of plots\n _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols\n ax = ax.ravel()\n plt.subplots_adjust(wspace=0.05, hspace=0.05)\n for i in range(n):\n ax[i].imshow(blocks[i].squeeze()) # cmap='gray'\n ax[i].axis(\"off\")\n\n LOGGER.info(f\"Saving {f}... ({n}/{channels})\")\n plt.savefig(f, dpi=300, bbox_inches=\"tight\")\n plt.close()\n np.save(str(f.with_suffix(\".npy\")), x[0].cpu().numpy()) # npy save",
"chunk_type": "function",
"name": "feature_visualization",
"file_path": "ultralytics\\ultralytics\\utils\\plotting.py",
"start_line": 1004,
"end_line": 1037,
"start_col": 0,
"end_col": 67,
"parent_name": null,
"docstring": "Visualize feature maps of a given model module during inference.\n\nArgs:\n x (torch.Tensor): Features to be visualized.\n module_type (str): Module type.\n stage (int): Module stage within the model.\n n (int, optional): Maximum number of feature maps to plot.\n save_dir (Path, optional): Directory to save results.",
"parameters": [
"x",
"module_type: str",
"stage: int",
"n: int",
"save_dir: Path"
],
"return_type": null,
"decorators": [],
"complexity_score": 6,
"dependencies": [
"math",
"warnings",
"pathlib.Path",
"typing.Any",
"typing.Callable",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"cv2",
"numpy",
"torch",
"PIL.Image",
"PIL.ImageDraw",
"PIL.ImageFont",
"PIL.__version__",
"ultralytics.utils.IS_COLAB",
"ultralytics.utils.IS_KAGGLE",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TryExcept",
"ultralytics.utils.ops",
"ultralytics.utils.plt_settings",
"ultralytics.utils.threaded",
"ultralytics.utils.checks.check_font",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.checks.is_ascii",
"ultralytics.utils.files.increment_path",
"matplotlib.pyplot",
"pandas",
"matplotlib.colors.LinearSegmentedColormap",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"matplotlib.pyplot",
"pandas",
"scipy.ndimage.gaussian_filter1d",
"matplotlib.pyplot",
"seaborn"
],
"chunk_id": "function_feature_visualization_aab84be0"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_3f7c74dd"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_88e7eb7d"
},
{
"content": "from . import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_911cfcde"
},
{
"content": "from .checks import check_version",
"chunk_type": "import",
"name": "check_version",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_version_2d98f8c9"
},
{
"content": "from .metrics import bbox_iou, probiou",
"chunk_type": "import",
"name": "bbox_iou, probiou",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_bbox_iou, probiou_0f7c9b26"
},
{
"content": "from .ops import xywhr2xyxyxyxy",
"chunk_type": "import",
"name": "xywhr2xyxyxyxy",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_xywhr2xyxyxyxy_847e17a6"
},
{
"content": "TORCH_1_10 = check_version(torch.__version__, \"1.10.0\")",
"chunk_type": "variable",
"name": "TORCH_1_10",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCH_1_10_3afeeabf"
},
{
"content": "class TaskAlignedAssigner(nn.Module):\n \"\"\"\n A task-aligned assigner for object detection.\n\n This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both\n classification and localization information.\n\n Attributes:\n topk (int): The number of top candidates to consider.\n num_classes (int): The number of object classes.\n alpha (float): The alpha parameter for the classification component of the task-aligned metric.\n beta (float): The beta parameter for the localization component of the task-aligned metric.\n eps (float): A small value to prevent division by zero.\n \"\"\"\n\n def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):\n \"\"\"\n Initialize a TaskAlignedAssigner object with customizable hyperparameters.\n\n Args:\n topk (int, optional): The number of top candidates to consider.\n num_classes (int, optional): The number of object classes.\n alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.\n beta (float, optional): The beta parameter for the localization component of the task-aligned metric.\n eps (float, optional): A small value to prevent division by zero.\n \"\"\"\n super().__init__()\n self.topk = topk\n self.num_classes = num_classes\n self.alpha = alpha\n self.beta = beta\n self.eps = eps\n\n @torch.no_grad()\n def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):\n \"\"\"\n Compute the task-aligned assignment.\n\n Args:\n pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).\n pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).\n anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).\n gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).\n gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).\n mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).\n\n Returns:\n target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).\n target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).\n target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).\n fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).\n target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).\n\n References:\n https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py\n \"\"\"\n self.bs = pd_scores.shape[0]\n self.n_max_boxes = gt_bboxes.shape[1]\n device = gt_bboxes.device\n\n if self.n_max_boxes == 0:\n return (\n torch.full_like(pd_scores[..., 0], self.num_classes),\n torch.zeros_like(pd_bboxes),\n torch.zeros_like(pd_scores),\n torch.zeros_like(pd_scores[..., 0]),\n torch.zeros_like(pd_scores[..., 0]),\n )\n\n try:\n return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)\n except torch.cuda.OutOfMemoryError:\n # Move tensors to CPU, compute, then move back to original device\n LOGGER.warning(\"CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU\")\n cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]\n result = self._forward(*cpu_tensors)\n return tuple(t.to(device) for t in result)\n\n def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):\n \"\"\"\n Compute the task-aligned assignment.\n\n Args:\n pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).\n pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).\n anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).\n gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).\n gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).\n mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).\n\n Returns:\n target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).\n target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).\n target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).\n fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).\n target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).\n \"\"\"\n mask_pos, align_metric, overlaps = self.get_pos_mask(\n pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt\n )\n\n target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)\n\n # Assigned target\n target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)\n\n # Normalize\n align_metric *= mask_pos\n pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj\n pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj\n norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)\n target_scores = target_scores * norm_align_metric\n\n return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx\n\n def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):\n \"\"\"\n Get positive mask for each ground truth box.\n\n Args:\n pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).\n pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).\n gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).\n gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).\n anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).\n mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).\n\n Returns:\n mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).\n align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).\n overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).\n \"\"\"\n mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)\n # Get anchor_align metric, (b, max_num_obj, h*w)\n align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)\n # Get topk_metric mask, (b, max_num_obj, h*w)\n mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())\n # Merge all mask to a final mask, (b, max_num_obj, h*w)\n mask_pos = mask_topk * mask_in_gts * mask_gt\n\n return mask_pos, align_metric, overlaps\n\n def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):\n \"\"\"\n Compute alignment metric given predicted and ground truth bounding boxes.\n\n Args:\n pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).\n pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).\n gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).\n gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).\n mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).\n\n Returns:\n align_metric (torch.Tensor): Alignment metric combining classification and localization.\n overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.\n \"\"\"\n na = pd_bboxes.shape[-2]\n mask_gt = mask_gt.bool() # b, max_num_obj, h*w\n overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)\n bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)\n\n ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj\n ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj\n ind[1] = gt_labels.squeeze(-1) # b, max_num_obj\n # Get the scores of each grid for each gt cls\n bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w\n\n # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)\n pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]\n gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]\n overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)\n\n align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)\n return align_metric, overlaps\n\n def iou_calculation(self, gt_bboxes, pd_bboxes):\n \"\"\"\n Calculate IoU for horizontal bounding boxes.\n\n Args:\n gt_bboxes (torch.Tensor): Ground truth boxes.\n pd_bboxes (torch.Tensor): Predicted boxes.\n\n Returns:\n (torch.Tensor): IoU values between each pair of boxes.\n \"\"\"\n return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)\n\n def select_topk_candidates(self, metrics, topk_mask=None):\n \"\"\"\n Select the top-k candidates based on the given metrics.\n\n Args:\n metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is\n the maximum number of objects, and h*w represents the total number of anchor points.\n topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where\n topk is the number of top candidates to consider. If not provided, the top-k values are automatically\n computed based on the given metrics.\n\n Returns:\n (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.\n \"\"\"\n # (b, max_num_obj, topk)\n topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)\n if topk_mask is None:\n topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)\n # (b, max_num_obj, topk)\n topk_idxs.masked_fill_(~topk_mask, 0)\n\n # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)\n count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)\n ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)\n for k in range(self.topk):\n # Expand topk_idxs for each value of k and add 1 at the specified positions\n count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)\n # Filter invalid bboxes\n count_tensor.masked_fill_(count_tensor > 1, 0)\n\n return count_tensor.to(metrics.dtype)\n\n def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):\n \"\"\"\n Compute target labels, target bounding boxes, and target scores for the positive anchor points.\n\n Args:\n gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the\n batch size and max_num_obj is the maximum number of objects.\n gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).\n target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive\n anchor points, with shape (b, h*w), where h*w is the total\n number of anchor points.\n fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive\n (foreground) anchor points.\n\n Returns:\n target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).\n target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).\n target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).\n \"\"\"\n # Assigned target labels, (b, 1)\n batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]\n target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)\n target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)\n\n # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)\n target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]\n\n # Assigned target scores\n target_labels.clamp_(0)\n\n # 10x faster than F.one_hot()\n target_scores = torch.zeros(\n (target_labels.shape[0], target_labels.shape[1], self.num_classes),\n dtype=torch.int64,\n device=target_labels.device,\n ) # (b, h*w, 80)\n target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)\n\n fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)\n target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)\n\n return target_labels, target_bboxes, target_scores\n\n @staticmethod\n def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):\n \"\"\"\n Select positive anchor centers within ground truth bounding boxes.\n\n Args:\n xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).\n gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).\n eps (float, optional): Small value for numerical stability.\n\n Returns:\n (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).\n\n Note:\n b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.\n Bounding box format: [x_min, y_min, x_max, y_max].\n \"\"\"\n n_anchors = xy_centers.shape[0]\n bs, n_boxes, _ = gt_bboxes.shape\n lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom\n bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)\n return bbox_deltas.amin(3).gt_(eps)\n\n @staticmethod\n def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):\n \"\"\"\n Select anchor boxes with highest IoU when assigned to multiple ground truths.\n\n Args:\n mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).\n overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).\n n_max_boxes (int): Maximum number of ground truth boxes.\n\n Returns:\n target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).\n fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).\n mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).\n \"\"\"\n # Convert (b, n_max_boxes, h*w) -> (b, h*w)\n fg_mask = mask_pos.sum(-2)\n if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes\n mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)\n max_overlaps_idx = overlaps.argmax(1) # (b, h*w)\n\n is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)\n is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)\n\n mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)\n fg_mask = mask_pos.sum(-2)\n # Find each grid serve which gt(index)\n target_gt_idx = mask_pos.argmax(-2) # (b, h*w)\n return target_gt_idx, fg_mask, mask_pos",
"chunk_type": "class",
"name": "TaskAlignedAssigner",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 14,
"end_line": 329,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "A task-aligned assigner for object detection.\n\nThis class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both\nclassification and localization information.\n\nAttributes:\n topk (int): The number of top candidates to consider.\n num_classes (int): The number of object classes.\n alpha (float): The alpha parameter for the classification component of the task-aligned metric.\n beta (float): The beta parameter for the localization component of the task-aligned metric.\n eps (float): A small value to prevent division by zero.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"torch",
"torch.nn",
"LOGGER",
"checks.check_version",
"metrics.bbox_iou",
"metrics.probiou",
"ops.xywhr2xyxyxyxy",
"nn.Module"
],
"chunk_id": "class_TaskAlignedAssigner_d717a5d2"
},
{
"content": "class RotatedTaskAlignedAssigner(TaskAlignedAssigner):\n \"\"\"Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric.\"\"\"\n\n def iou_calculation(self, gt_bboxes, pd_bboxes):\n \"\"\"Calculate IoU for rotated bounding boxes.\"\"\"\n return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)\n\n @staticmethod\n def select_candidates_in_gts(xy_centers, gt_bboxes):\n \"\"\"\n Select the positive anchor center in gt for rotated bounding boxes.\n\n Args:\n xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).\n gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).\n\n Returns:\n (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).\n \"\"\"\n # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)\n corners = xywhr2xyxyxyxy(gt_bboxes)\n # (b, n_boxes, 1, 2)\n a, b, _, d = corners.split(1, dim=-2)\n ab = b - a\n ad = d - a\n\n # (b, n_boxes, h*w, 2)\n ap = xy_centers - a\n norm_ab = (ab * ab).sum(dim=-1)\n norm_ad = (ad * ad).sum(dim=-1)\n ap_dot_ab = (ap * ab).sum(dim=-1)\n ap_dot_ad = (ap * ad).sum(dim=-1)\n return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box",
"chunk_type": "class",
"name": "RotatedTaskAlignedAssigner",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 332,
"end_line": 364,
"start_col": 0,
"end_col": 100,
"parent_name": null,
"docstring": "Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"torch",
"torch.nn",
"LOGGER",
"checks.check_version",
"metrics.bbox_iou",
"metrics.probiou",
"ops.xywhr2xyxyxyxy",
"TaskAlignedAssigner"
],
"chunk_id": "class_RotatedTaskAlignedAssigner_483ffc95"
},
{
"content": "def make_anchors(feats, strides, grid_cell_offset=0.5):\n \"\"\"Generate anchors from features.\"\"\"\n anchor_points, stride_tensor = [], []\n assert feats is not None\n dtype, device = feats[0].dtype, feats[0].device\n for i, stride in enumerate(strides):\n h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))\n sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x\n sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y\n sy, sx = torch.meshgrid(sy, sx, indexing=\"ij\") if TORCH_1_10 else torch.meshgrid(sy, sx)\n anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))\n stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))\n return torch.cat(anchor_points), torch.cat(stride_tensor)",
"chunk_type": "function",
"name": "make_anchors",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 367,
"end_line": 379,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": "Generate anchors from features.",
"parameters": [
"feats",
"strides",
"grid_cell_offset"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"torch",
"torch.nn",
"LOGGER",
"checks.check_version",
"metrics.bbox_iou",
"metrics.probiou",
"ops.xywhr2xyxyxyxy"
],
"chunk_id": "function_make_anchors_d7e2eb74"
},
{
"content": "def dist2bbox(distance, anchor_points, xywh=True, dim=-1):\n \"\"\"Transform distance(ltrb) to box(xywh or xyxy).\"\"\"\n lt, rb = distance.chunk(2, dim)\n x1y1 = anchor_points - lt\n x2y2 = anchor_points + rb\n if xywh:\n c_xy = (x1y1 + x2y2) / 2\n wh = x2y2 - x1y1\n return torch.cat((c_xy, wh), dim) # xywh bbox\n return torch.cat((x1y1, x2y2), dim) # xyxy bbox",
"chunk_type": "function",
"name": "dist2bbox",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 382,
"end_line": 391,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": "Transform distance(ltrb) to box(xywh or xyxy).",
"parameters": [
"distance",
"anchor_points",
"xywh",
"dim"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"torch",
"torch.nn",
"LOGGER",
"checks.check_version",
"metrics.bbox_iou",
"metrics.probiou",
"ops.xywhr2xyxyxyxy"
],
"chunk_id": "function_dist2bbox_efa5c776"
},
{
"content": "def bbox2dist(anchor_points, bbox, reg_max):\n \"\"\"Transform bbox(xyxy) to dist(ltrb).\"\"\"\n x1y1, x2y2 = bbox.chunk(2, -1)\n return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)",
"chunk_type": "function",
"name": "bbox2dist",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 394,
"end_line": 397,
"start_col": 0,
"end_col": 96,
"parent_name": null,
"docstring": "Transform bbox(xyxy) to dist(ltrb).",
"parameters": [
"anchor_points",
"bbox",
"reg_max"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"torch",
"torch.nn",
"LOGGER",
"checks.check_version",
"metrics.bbox_iou",
"metrics.probiou",
"ops.xywhr2xyxyxyxy"
],
"chunk_id": "function_bbox2dist_7eed802e"
},
{
"content": "def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):\n \"\"\"\n Decode predicted rotated bounding box coordinates from anchor points and distribution.\n\n Args:\n pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).\n pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).\n anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).\n dim (int, optional): Dimension along which to split.\n\n Returns:\n (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).\n \"\"\"\n lt, rb = pred_dist.split(2, dim=dim)\n cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)\n # (bs, h*w, 1)\n xf, yf = ((rb - lt) / 2).split(1, dim=dim)\n x, y = xf * cos - yf * sin, xf * sin + yf * cos\n xy = torch.cat([x, y], dim=dim) + anchor_points\n return torch.cat([xy, lt + rb], dim=dim)",
"chunk_type": "function",
"name": "dist2rbox",
"file_path": "ultralytics\\ultralytics\\utils\\tal.py",
"start_line": 400,
"end_line": 419,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Decode predicted rotated bounding box coordinates from anchor points and distribution.\n\nArgs:\n pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).\n pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).\n anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).\n dim (int, optional): Dimension along which to split.\n\nReturns:\n (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).",
"parameters": [
"pred_dist",
"pred_angle",
"anchor_points",
"dim"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"torch",
"torch.nn",
"LOGGER",
"checks.check_version",
"metrics.bbox_iou",
"metrics.probiou",
"ops.xywhr2xyxyxyxy"
],
"chunk_id": "function_dist2rbox_40bbab6d"
},
{
"content": "import functools",
"chunk_type": "import",
"name": "functools",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_functools_1761ad8e"
},
{
"content": "import gc",
"chunk_type": "import",
"name": "gc",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_gc_79b36253"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_94b5469f"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_196299bc"
},
{
"content": "import random",
"chunk_type": "import",
"name": "random",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_random_11b50b2e"
},
{
"content": "import time",
"chunk_type": "import",
"name": "time",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_time_ba590618"
},
{
"content": "from contextlib import contextmanager",
"chunk_type": "import",
"name": "contextmanager",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_contextmanager_db6adbfc"
},
{
"content": "from copy import deepcopy",
"chunk_type": "import",
"name": "deepcopy",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_deepcopy_87deb03a"
},
{
"content": "from datetime import datetime",
"chunk_type": "import",
"name": "datetime",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_datetime_7e96701c"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_229a89f2"
},
{
"content": "from typing import Any, Dict, Union",
"chunk_type": "import",
"name": "Any, Dict, Union",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, Union_eec6484e"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_38987177"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_b805e815"
},
{
"content": "import torch.distributed as dist",
"chunk_type": "import",
"name": "torch.distributed",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.distributed_ce6f6f07"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_25306e3a"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_4d818dae"
},
{
"content": "from ultralytics import __version__",
"chunk_type": "import",
"name": "__version__",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 21,
"end_line": 21,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import___version___aca33539"
},
{
"content": "from ultralytics.utils import (\n DEFAULT_CFG_DICT,\n DEFAULT_CFG_KEYS,\n LOGGER,\n NUM_THREADS,\n PYTHON_VERSION,\n TORCHVISION_VERSION,\n WINDOWS,\n colorstr,\n)",
"chunk_type": "import",
"name": "DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, NUM_THREADS, PYTHON_VERSION, TORCHVISION_VERSION, WINDOWS, colorstr",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 22,
"end_line": 31,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, NUM_THREADS, PYTHON_VERSION, TORCHVISION_VERSION, WINDOWS, colorstr_60e35477"
},
{
"content": "from ultralytics.utils.checks import check_version",
"chunk_type": "import",
"name": "check_version",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 32,
"end_line": 32,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_version_3bb078e0"
},
{
"content": "from ultralytics.utils.patches import torch_load",
"chunk_type": "import",
"name": "torch_load",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 33,
"end_line": 33,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_load_bff04f79"
},
{
"content": "TORCH_1_9 = check_version(torch.__version__, \"1.9.0\")",
"chunk_type": "variable",
"name": "TORCH_1_9",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 36,
"end_line": 36,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCH_1_9_deaf6ac0"
},
{
"content": "TORCH_1_13 = check_version(torch.__version__, \"1.13.0\")",
"chunk_type": "variable",
"name": "TORCH_1_13",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 37,
"end_line": 37,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCH_1_13_0b85fd27"
},
{
"content": "TORCH_2_0 = check_version(torch.__version__, \"2.0.0\")",
"chunk_type": "variable",
"name": "TORCH_2_0",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 38,
"end_line": 38,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCH_2_0_d41a9c8d"
},
{
"content": "TORCH_2_4 = check_version(torch.__version__, \"2.4.0\")",
"chunk_type": "variable",
"name": "TORCH_2_4",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 39,
"end_line": 39,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCH_2_4_259b5df7"
},
{
"content": "TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, \"0.10.0\")",
"chunk_type": "variable",
"name": "TORCHVISION_0_10",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 40,
"end_line": 40,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCHVISION_0_10_63667314"
},
{
"content": "TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, \"0.11.0\")",
"chunk_type": "variable",
"name": "TORCHVISION_0_11",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 41,
"end_line": 41,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCHVISION_0_11_c6dd5441"
},
{
"content": "TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, \"0.13.0\")",
"chunk_type": "variable",
"name": "TORCHVISION_0_13",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 42,
"end_line": 42,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCHVISION_0_13_9aa804a6"
},
{
"content": "TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, \"0.18.0\")",
"chunk_type": "variable",
"name": "TORCHVISION_0_18",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 43,
"end_line": 43,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCHVISION_0_18_240f011e"
},
{
"content": "def torch_distributed_zero_first(local_rank: int):\n \"\"\"Ensure all processes in distributed training wait for the local master (rank 0) to complete a task first.\"\"\"\n initialized = dist.is_available() and dist.is_initialized()\n use_ids = initialized and dist.get_backend() == \"nccl\"\n\n if initialized and local_rank not in {-1, 0}:\n dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()\n yield\n if initialized and local_rank == 0:\n dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()",
"chunk_type": "function",
"name": "torch_distributed_zero_first",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 52,
"end_line": 61,
"start_col": 0,
"end_col": 76,
"parent_name": null,
"docstring": "Ensure all processes in distributed training wait for the local master (rank 0) to complete a task first.",
"parameters": [
"local_rank: int"
],
"return_type": null,
"decorators": [
"contextmanager"
],
"complexity_score": 3,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_torch_distributed_zero_first_f35ec42b"
},
{
"content": "def smart_inference_mode():\n \"\"\"Apply torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.\"\"\"\n\n def decorate(fn):\n \"\"\"Apply appropriate torch decorator for inference mode based on torch version.\"\"\"\n if TORCH_1_9 and torch.is_inference_mode_enabled():\n return fn # already in inference_mode, act as a pass-through\n else:\n return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)\n\n return decorate",
"chunk_type": "function",
"name": "smart_inference_mode",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 64,
"end_line": 74,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Apply torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_smart_inference_mode_a69c8fae"
},
{
"content": "def autocast(enabled: bool, device: str = \"cuda\"):\n \"\"\"\n Get the appropriate autocast context manager based on PyTorch version and AMP setting.\n\n This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both\n older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.\n\n Args:\n enabled (bool): Whether to enable automatic mixed precision.\n device (str, optional): The device to use for autocast.\n\n Returns:\n (torch.amp.autocast): The appropriate autocast context manager.\n\n Notes:\n - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.\n - For older versions, it uses `torch.cuda.autocast`.\n\n Examples:\n >>> with autocast(enabled=True):\n ... # Your mixed precision operations here\n ... pass\n \"\"\"\n if TORCH_1_13:\n return torch.amp.autocast(device, enabled=enabled)\n else:\n return torch.cuda.amp.autocast(enabled)",
"chunk_type": "function",
"name": "autocast",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 77,
"end_line": 103,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Get the appropriate autocast context manager based on PyTorch version and AMP setting.\n\nThis function returns a context manager for automatic mixed precision (AMP) training that is compatible with both\nolder and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.\n\nArgs:\n enabled (bool): Whether to enable automatic mixed precision.\n device (str, optional): The device to use for autocast.\n\nReturns:\n (torch.amp.autocast): The appropriate autocast context manager.\n\nNotes:\n - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.\n - For older versions, it uses `torch.cuda.autocast`.\n\nExamples:\n >>> with autocast(enabled=True):\n ... # Your mixed precision operations here\n ... pass",
"parameters": [
"enabled: bool",
"device: str"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_autocast_34300a42"
},
{
"content": "def get_cpu_info():\n \"\"\"Return a string with system CPU information, i.e. 'Apple M2'.\"\"\"\n from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error\n\n if \"cpu_info\" not in PERSISTENT_CACHE:\n try:\n import cpuinfo # pip install py-cpuinfo\n\n k = \"brand_raw\", \"hardware_raw\", \"arch_string_raw\" # keys sorted by preference\n info = cpuinfo.get_cpu_info() # info dict\n string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], \"unknown\")\n PERSISTENT_CACHE[\"cpu_info\"] = string.replace(\"(R)\", \"\").replace(\"CPU \", \"\").replace(\"@ \", \"\")\n except Exception:\n pass\n return PERSISTENT_CACHE.get(\"cpu_info\", \"unknown\")",
"chunk_type": "function",
"name": "get_cpu_info",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 107,
"end_line": 121,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": "Return a string with system CPU information, i.e. 'Apple M2'.",
"parameters": [],
"return_type": null,
"decorators": [
"functools.lru_cache"
],
"complexity_score": 3,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_get_cpu_info_d2bd4308"
},
{
"content": "def get_gpu_info(index):\n \"\"\"Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'.\"\"\"\n properties = torch.cuda.get_device_properties(index)\n return f\"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB\"",
"chunk_type": "function",
"name": "get_gpu_info",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 125,
"end_line": 128,
"start_col": 0,
"end_col": 77,
"parent_name": null,
"docstring": "Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'.",
"parameters": [
"index"
],
"return_type": null,
"decorators": [
"functools.lru_cache"
],
"complexity_score": 1,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_get_gpu_info_3a395543"
},
{
"content": "def select_device(device=\"\", batch=0, newline=False, verbose=True):\n \"\"\"\n Select the appropriate PyTorch device based on the provided arguments.\n\n The function takes a string specifying the device or a torch.device object and returns a torch.device object\n representing the selected device. The function also validates the number of available devices and raises an\n exception if the requested device(s) are not available.\n\n Args:\n device (str | torch.device, optional): Device string or torch.device object. Options are 'None', 'cpu', or\n 'cuda', or '0' or '0,1,2,3'. Auto-selects the first available GPU, or CPU if no GPU is available.\n batch (int, optional): Batch size being used in your model.\n newline (bool, optional): If True, adds a newline at the end of the log string.\n verbose (bool, optional): If True, logs the device information.\n\n Returns:\n (torch.device): Selected device.\n\n Raises:\n ValueError: If the specified device is not available or if the batch size is not a multiple of the number of\n devices when using multiple GPUs.\n\n Examples:\n >>> select_device(\"cuda:0\")\n device(type='cuda', index=0)\n\n >>> select_device(\"cpu\")\n device(type='cpu')\n\n Notes:\n Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.\n \"\"\"\n if isinstance(device, torch.device) or str(device).startswith((\"tpu\", \"intel\")):\n return device\n\n s = f\"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} \"\n device = str(device).lower()\n for remove in \"cuda:\", \"none\", \"(\", \")\", \"[\", \"]\", \"'\", \" \":\n device = device.replace(remove, \"\") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'\n\n # Auto-select GPUs\n if \"-1\" in device:\n from ultralytics.utils.autodevice import GPUInfo\n\n # Replace each -1 with a selected GPU or remove it\n parts = device.split(\",\")\n selected = GPUInfo().select_idle_gpu(count=parts.count(\"-1\"), min_memory_fraction=0.2)\n for i in range(len(parts)):\n if parts[i] == \"-1\":\n parts[i] = str(selected.pop(0)) if selected else \"\"\n device = \",\".join(p for p in parts if p)\n\n cpu = device == \"cpu\"\n mps = device in {\"mps\", \"mps:0\"} # Apple Metal Performance Shaders (MPS)\n if cpu or mps:\n os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\" # force torch.cuda.is_available() = False\n elif device: # non-cpu device requested\n if device == \"cuda\":\n device = \"0\"\n if \",\" in device:\n device = \",\".join([x for x in device.split(\",\") if x]) # remove sequential commas, i.e. \"0,,1\" -> \"0,1\"\n visible = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n os.environ[\"CUDA_VISIBLE_DEVICES\"] = device # set environment variable - must be before assert is_available()\n if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(\",\"))):\n LOGGER.info(s)\n install = (\n \"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no \"\n \"CUDA devices are seen by torch.\\n\"\n if torch.cuda.device_count() == 0\n else \"\"\n )\n raise ValueError(\n f\"Invalid CUDA 'device={device}' requested.\"\n f\" Use 'device=cpu' or pass valid CUDA device(s) if available,\"\n f\" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\\n\"\n f\"\\ntorch.cuda.is_available(): {torch.cuda.is_available()}\"\n f\"\\ntorch.cuda.device_count(): {torch.cuda.device_count()}\"\n f\"\\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\\n\"\n f\"{install}\"\n )\n\n if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available\n devices = device.split(\",\") if device else \"0\" # i.e. \"0,1\" -> [\"0\", \"1\"]\n n = len(devices) # device count\n if n > 1: # multi-GPU\n if batch < 1:\n raise ValueError(\n \"AutoBatch with batch<1 not supported for Multi-GPU training, \"\n f\"please specify a valid batch size multiple of GPU count {n}, i.e. batch={n * 8}.\"\n )\n if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count\n raise ValueError(\n f\"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or \"\n f\"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.\"\n )\n space = \" \" * (len(s) + 1)\n for i, d in enumerate(devices):\n s += f\"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\\n\" # bytes to MB\n arg = \"cuda:0\"\n elif mps and TORCH_2_0 and torch.backends.mps.is_available():\n # Prefer MPS if available\n s += f\"MPS ({get_cpu_info()})\\n\"\n arg = \"mps\"\n else: # revert to CPU\n s += f\"CPU ({get_cpu_info()})\\n\"\n arg = \"cpu\"\n\n if arg in {\"cpu\", \"mps\"}:\n torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training\n if verbose:\n LOGGER.info(s if newline else s.rstrip())\n return torch.device(arg)",
"chunk_type": "function",
"name": "select_device",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 131,
"end_line": 242,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Select the appropriate PyTorch device based on the provided arguments.\n\nThe function takes a string specifying the device or a torch.device object and returns a torch.device object\nrepresenting the selected device. The function also validates the number of available devices and raises an\nexception if the requested device(s) are not available.\n\nArgs:\n device (str | torch.device, optional): Device string or torch.device object. Options are 'None', 'cpu', or\n 'cuda', or '0' or '0,1,2,3'. Auto-selects the first available GPU, or CPU if no GPU is available.\n batch (int, optional): Batch size being used in your model.\n newline (bool, optional): If True, adds a newline at the end of the log string.\n verbose (bool, optional): If True, logs the device information.\n\nReturns:\n (torch.device): Selected device.\n\nRaises:\n ValueError: If the specified device is not available or if the batch size is not a multiple of the number of\n devices when using multiple GPUs.\n\nExamples:\n >>> select_device(\"cuda:0\")\n device(type='cuda', index=0)\n\n >>> select_device(\"cpu\")\n device(type='cpu')\n\nNotes:\n Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.",
"parameters": [
"device",
"batch",
"newline",
"verbose"
],
"return_type": null,
"decorators": [],
"complexity_score": 21,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_select_device_6a27eb04"
},
{
"content": "def time_sync():\n \"\"\"Return PyTorch-accurate time.\"\"\"\n if torch.cuda.is_available():\n torch.cuda.synchronize()\n return time.time()",
"chunk_type": "function",
"name": "time_sync",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 245,
"end_line": 249,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Return PyTorch-accurate time.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_time_sync_6134daef"
},
{
"content": "def fuse_conv_and_bn(conv, bn):\n \"\"\"Fuse Conv2d() and BatchNorm2d() layers.\"\"\"\n fusedconv = (\n nn.Conv2d(\n conv.in_channels,\n conv.out_channels,\n kernel_size=conv.kernel_size,\n stride=conv.stride,\n padding=conv.padding,\n dilation=conv.dilation,\n groups=conv.groups,\n bias=True,\n )\n .requires_grad_(False)\n .to(conv.weight.device)\n )\n\n # Prepare filters\n w_conv = conv.weight.view(conv.out_channels, -1)\n w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))\n fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))\n\n # Prepare spatial bias\n b_conv = (\n torch.zeros(conv.weight.shape[0], dtype=conv.weight.dtype, device=conv.weight.device)\n if conv.bias is None\n else conv.bias\n )\n b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))\n fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)\n\n return fusedconv",
"chunk_type": "function",
"name": "fuse_conv_and_bn",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 252,
"end_line": 283,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Fuse Conv2d() and BatchNorm2d() layers.",
"parameters": [
"conv",
"bn"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_fuse_conv_and_bn_1390fc68"
},
{
"content": "def fuse_deconv_and_bn(deconv, bn):\n \"\"\"Fuse ConvTranspose2d() and BatchNorm2d() layers.\"\"\"\n fuseddconv = (\n nn.ConvTranspose2d(\n deconv.in_channels,\n deconv.out_channels,\n kernel_size=deconv.kernel_size,\n stride=deconv.stride,\n padding=deconv.padding,\n output_padding=deconv.output_padding,\n dilation=deconv.dilation,\n groups=deconv.groups,\n bias=True,\n )\n .requires_grad_(False)\n .to(deconv.weight.device)\n )\n\n # Prepare filters\n w_deconv = deconv.weight.view(deconv.out_channels, -1)\n w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))\n fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))\n\n # Prepare spatial bias\n b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias\n b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))\n fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)\n\n return fuseddconv",
"chunk_type": "function",
"name": "fuse_deconv_and_bn",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 286,
"end_line": 314,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": "Fuse ConvTranspose2d() and BatchNorm2d() layers.",
"parameters": [
"deconv",
"bn"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_fuse_deconv_and_bn_2d6caedb"
},
{
"content": "def model_info(model, detailed=False, verbose=True, imgsz=640):\n \"\"\"\n Print and return detailed model information layer by layer.\n\n Args:\n model (nn.Module): Model to analyze.\n detailed (bool, optional): Whether to print detailed layer information.\n verbose (bool, optional): Whether to print model information.\n imgsz (int | list, optional): Input image size.\n\n Returns:\n n_l (int): Number of layers.\n n_p (int): Number of parameters.\n n_g (int): Number of gradients.\n flops (float): GFLOPs.\n \"\"\"\n if not verbose:\n return\n n_p = get_num_params(model) # number of parameters\n n_g = get_num_gradients(model) # number of gradients\n layers = __import__(\"collections\").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0)\n n_l = len(layers) # number of layers\n if detailed:\n h = f\"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}\"\n LOGGER.info(h)\n for i, (mn, m) in enumerate(layers.items()):\n mn = mn.replace(\"module_list.\", \"\")\n mt = m.__class__.__name__\n if len(m._parameters):\n for pn, p in m.named_parameters():\n LOGGER.info(\n f\"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}\"\n )\n else: # layers with no learnable params\n LOGGER.info(f\"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}\")\n\n flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]\n fused = \" (fused)\" if getattr(model, \"is_fused\", lambda: False)() else \"\"\n fs = f\", {flops:.1f} GFLOPs\" if flops else \"\"\n yaml_file = getattr(model, \"yaml_file\", \"\") or getattr(model, \"yaml\", {}).get(\"yaml_file\", \"\")\n model_name = Path(yaml_file).stem.replace(\"yolo\", \"YOLO\") or \"Model\"\n LOGGER.info(f\"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}\")\n return n_l, n_p, n_g, flops",
"chunk_type": "function",
"name": "model_info",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 317,
"end_line": 359,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": "Print and return detailed model information layer by layer.\n\nArgs:\n model (nn.Module): Model to analyze.\n detailed (bool, optional): Whether to print detailed layer information.\n verbose (bool, optional): Whether to print model information.\n imgsz (int | list, optional): Input image size.\n\nReturns:\n n_l (int): Number of layers.\n n_p (int): Number of parameters.\n n_g (int): Number of gradients.\n flops (float): GFLOPs.",
"parameters": [
"model",
"detailed",
"verbose",
"imgsz"
],
"return_type": null,
"decorators": [],
"complexity_score": 7,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_model_info_88b9a323"
},
{
"content": "def get_num_params(model):\n \"\"\"Return the total number of parameters in a YOLO model.\"\"\"\n return sum(x.numel() for x in model.parameters())",
"chunk_type": "function",
"name": "get_num_params",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 362,
"end_line": 364,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": "Return the total number of parameters in a YOLO model.",
"parameters": [
"model"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_get_num_params_69f2ae44"
},
{
"content": "def get_num_gradients(model):\n \"\"\"Return the total number of parameters with gradients in a YOLO model.\"\"\"\n return sum(x.numel() for x in model.parameters() if x.requires_grad)",
"chunk_type": "function",
"name": "get_num_gradients",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 367,
"end_line": 369,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": "Return the total number of parameters with gradients in a YOLO model.",
"parameters": [
"model"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_get_num_gradients_b2843e9c"
},
{
"content": "def model_info_for_loggers(trainer):\n \"\"\"\n Return model info dict with useful model information.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.\n\n Returns:\n (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.\n\n Examples:\n YOLOv8n info for loggers\n >>> results = {\n ... \"model/parameters\": 3151904,\n ... \"model/GFLOPs\": 8.746,\n ... \"model/speed_ONNX(ms)\": 41.244,\n ... \"model/speed_TensorRT(ms)\": 3.211,\n ... \"model/speed_PyTorch(ms)\": 18.755,\n ...}\n \"\"\"\n if trainer.args.profile: # profile ONNX and TensorRT times\n from ultralytics.utils.benchmarks import ProfileModels\n\n results = ProfileModels([trainer.last], device=trainer.device).run()[0]\n results.pop(\"model/name\")\n else: # only return PyTorch times from most recent validation\n results = {\n \"model/parameters\": get_num_params(trainer.model),\n \"model/GFLOPs\": round(get_flops(trainer.model), 3),\n }\n results[\"model/speed_PyTorch(ms)\"] = round(trainer.validator.speed[\"inference\"], 3)\n return results",
"chunk_type": "function",
"name": "model_info_for_loggers",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 372,
"end_line": 403,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Return model info dict with useful model information.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.\n\nReturns:\n (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.\n\nExamples:\n YOLOv8n info for loggers\n >>> results = {\n ... \"model/parameters\": 3151904,\n ... \"model/GFLOPs\": 8.746,\n ... \"model/speed_ONNX(ms)\": 41.244,\n ... \"model/speed_TensorRT(ms)\": 3.211,\n ... \"model/speed_PyTorch(ms)\": 18.755,\n ...}",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_model_info_for_loggers_613e00b0"
},
{
"content": "def get_flops(model, imgsz=640):\n \"\"\"\n Calculate FLOPs (floating point operations) for a model in billions.\n\n Attempts two calculation methods: first with a stride-based tensor for efficiency,\n then falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0\n if thop library is unavailable or calculation fails.\n\n Args:\n model (nn.Module): The model to calculate FLOPs for.\n imgsz (int | list, optional): Input image size.\n\n Returns:\n (float): The model FLOPs in billions.\n \"\"\"\n try:\n import thop\n except ImportError:\n thop = None # conda support without 'ultralytics-thop' installed\n\n if not thop:\n return 0.0 # if not installed return 0.0 GFLOPs\n\n try:\n model = de_parallel(model)\n p = next(model.parameters())\n if not isinstance(imgsz, list):\n imgsz = [imgsz, imgsz] # expand if int/float\n try:\n # Method 1: Use stride-based input tensor\n stride = max(int(model.stride.max()), 32) if hasattr(model, \"stride\") else 32 # max stride\n im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format\n flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs\n return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs\n except Exception:\n # Method 2: Use actual image size (required for RTDETR models)\n im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format\n return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs\n except Exception:\n return 0.0",
"chunk_type": "function",
"name": "get_flops",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 406,
"end_line": 445,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Calculate FLOPs (floating point operations) for a model in billions.\n\nAttempts two calculation methods: first with a stride-based tensor for efficiency,\nthen falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0\nif thop library is unavailable or calculation fails.\n\nArgs:\n model (nn.Module): The model to calculate FLOPs for.\n imgsz (int | list, optional): Input image size.\n\nReturns:\n (float): The model FLOPs in billions.",
"parameters": [
"model",
"imgsz"
],
"return_type": null,
"decorators": [],
"complexity_score": 6,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_get_flops_17bd5430"
},
{
"content": "def get_flops_with_torch_profiler(model, imgsz=640):\n \"\"\"\n Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).\n\n Args:\n model (nn.Module): The model to calculate FLOPs for.\n imgsz (int | list, optional): Input image size.\n\n Returns:\n (float): The model's FLOPs in billions.\n \"\"\"\n if not TORCH_2_0: # torch profiler implemented in torch>=2.0\n return 0.0\n model = de_parallel(model)\n p = next(model.parameters())\n if not isinstance(imgsz, list):\n imgsz = [imgsz, imgsz] # expand if int/float\n try:\n # Use stride size for input tensor\n stride = (max(int(model.stride.max()), 32) if hasattr(model, \"stride\") else 32) * 2 # max stride\n im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format\n with torch.profiler.profile(with_flops=True) as prof:\n model(im)\n flops = sum(x.flops for x in prof.key_averages()) / 1e9\n flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs\n except Exception:\n # Use actual image size for input tensor (i.e. required for RTDETR models)\n im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format\n with torch.profiler.profile(with_flops=True) as prof:\n model(im)\n flops = sum(x.flops for x in prof.key_averages()) / 1e9\n return flops",
"chunk_type": "function",
"name": "get_flops_with_torch_profiler",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 448,
"end_line": 479,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).\n\nArgs:\n model (nn.Module): The model to calculate FLOPs for.\n imgsz (int | list, optional): Input image size.\n\nReturns:\n (float): The model's FLOPs in billions.",
"parameters": [
"model",
"imgsz"
],
"return_type": null,
"decorators": [],
"complexity_score": 6,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_get_flops_with_torch_profiler_1c92f7a3"
},
{
"content": "def initialize_weights(model):\n \"\"\"Initialize model weights to random values.\"\"\"\n for m in model.modules():\n t = type(m)\n if t is nn.Conv2d:\n pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n elif t is nn.BatchNorm2d:\n m.eps = 1e-3\n m.momentum = 0.03\n elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:\n m.inplace = True",
"chunk_type": "function",
"name": "initialize_weights",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 482,
"end_line": 492,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Initialize model weights to random values.",
"parameters": [
"model"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_initialize_weights_c22e7ae4"
},
{
"content": "def scale_img(img, ratio=1.0, same_shape=False, gs=32):\n \"\"\"\n Scale and pad an image tensor, optionally maintaining aspect ratio and padding to gs multiple.\n\n Args:\n img (torch.Tensor): Input image tensor.\n ratio (float, optional): Scaling ratio.\n same_shape (bool, optional): Whether to maintain the same shape.\n gs (int, optional): Grid size for padding.\n\n Returns:\n (torch.Tensor): Scaled and padded image tensor.\n \"\"\"\n if ratio == 1.0:\n return img\n h, w = img.shape[2:]\n s = (int(h * ratio), int(w * ratio)) # new size\n img = F.interpolate(img, size=s, mode=\"bilinear\", align_corners=False) # resize\n if not same_shape: # pad/crop img\n h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))\n return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean",
"chunk_type": "function",
"name": "scale_img",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 495,
"end_line": 515,
"start_col": 0,
"end_col": 62,
"parent_name": null,
"docstring": "Scale and pad an image tensor, optionally maintaining aspect ratio and padding to gs multiple.\n\nArgs:\n img (torch.Tensor): Input image tensor.\n ratio (float, optional): Scaling ratio.\n same_shape (bool, optional): Whether to maintain the same shape.\n gs (int, optional): Grid size for padding.\n\nReturns:\n (torch.Tensor): Scaled and padded image tensor.",
"parameters": [
"img",
"ratio",
"same_shape",
"gs"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_scale_img_f870b9db"
},
{
"content": "def copy_attr(a, b, include=(), exclude=()):\n \"\"\"\n Copy attributes from object 'b' to object 'a', with options to include/exclude certain attributes.\n\n Args:\n a (Any): Destination object to copy attributes to.\n b (Any): Source object to copy attributes from.\n include (tuple, optional): Attributes to include. If empty, all attributes are included.\n exclude (tuple, optional): Attributes to exclude.\n \"\"\"\n for k, v in b.__dict__.items():\n if (len(include) and k not in include) or k.startswith(\"_\") or k in exclude:\n continue\n else:\n setattr(a, k, v)",
"chunk_type": "function",
"name": "copy_attr",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 518,
"end_line": 532,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Copy attributes from object 'b' to object 'a', with options to include/exclude certain attributes.\n\nArgs:\n a (Any): Destination object to copy attributes to.\n b (Any): Source object to copy attributes from.\n include (tuple, optional): Attributes to include. If empty, all attributes are included.\n exclude (tuple, optional): Attributes to exclude.",
"parameters": [
"a",
"b",
"include",
"exclude"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_copy_attr_7422bd0c"
},
{
"content": "def get_latest_opset():\n \"\"\"\n Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.\n\n Returns:\n (int): The ONNX opset version.\n \"\"\"\n if TORCH_1_13:\n # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'\n return max(int(k[14:]) for k in vars(torch.onnx) if \"symbolic_opset\" in k) - 1\n # Otherwise for PyTorch<=1.12 return the corresponding predefined opset\n version = torch.onnx.producer_version.rsplit(\".\", 1)[0] # i.e. '2.3'\n return {\"1.12\": 15, \"1.11\": 14, \"1.10\": 13, \"1.9\": 12, \"1.8\": 12}.get(version, 12)",
"chunk_type": "function",
"name": "get_latest_opset",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 535,
"end_line": 547,
"start_col": 0,
"end_col": 86,
"parent_name": null,
"docstring": "Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.\n\nReturns:\n (int): The ONNX opset version.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_get_latest_opset_6f3df7f2"
},
{
"content": "def intersect_dicts(da, db, exclude=()):\n \"\"\"\n Return a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.\n\n Args:\n da (dict): First dictionary.\n db (dict): Second dictionary.\n exclude (tuple, optional): Keys to exclude.\n\n Returns:\n (dict): Dictionary of intersecting keys with matching shapes.\n \"\"\"\n return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}",
"chunk_type": "function",
"name": "intersect_dicts",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 550,
"end_line": 562,
"start_col": 0,
"end_col": 115,
"parent_name": null,
"docstring": "Return a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.\n\nArgs:\n da (dict): First dictionary.\n db (dict): Second dictionary.\n exclude (tuple, optional): Keys to exclude.\n\nReturns:\n (dict): Dictionary of intersecting keys with matching shapes.",
"parameters": [
"da",
"db",
"exclude"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_intersect_dicts_de5ef473"
},
{
"content": "def is_parallel(model):\n \"\"\"\n Return True if model is of type DP or DDP.\n\n Args:\n model (nn.Module): Model to check.\n\n Returns:\n (bool): True if model is DataParallel or DistributedDataParallel.\n \"\"\"\n return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))",
"chunk_type": "function",
"name": "is_parallel",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 565,
"end_line": 575,
"start_col": 0,
"end_col": 93,
"parent_name": null,
"docstring": "Return True if model is of type DP or DDP.\n\nArgs:\n model (nn.Module): Model to check.\n\nReturns:\n (bool): True if model is DataParallel or DistributedDataParallel.",
"parameters": [
"model"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_is_parallel_a6f44a0b"
},
{
"content": "def de_parallel(model):\n \"\"\"\n De-parallelize a model: return single-GPU model if model is of type DP or DDP.\n\n Args:\n model (nn.Module): Model to de-parallelize.\n\n Returns:\n (nn.Module): De-parallelized model.\n \"\"\"\n return model.module if is_parallel(model) else model",
"chunk_type": "function",
"name": "de_parallel",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 578,
"end_line": 588,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": "De-parallelize a model: return single-GPU model if model is of type DP or DDP.\n\nArgs:\n model (nn.Module): Model to de-parallelize.\n\nReturns:\n (nn.Module): De-parallelized model.",
"parameters": [
"model"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_de_parallel_0b343ee4"
},
{
"content": "def one_cycle(y1=0.0, y2=1.0, steps=100):\n \"\"\"\n Return a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.\n\n Args:\n y1 (float, optional): Initial value.\n y2 (float, optional): Final value.\n steps (int, optional): Number of steps.\n\n Returns:\n (function): Lambda function for computing the sinusoidal ramp.\n \"\"\"\n return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1",
"chunk_type": "function",
"name": "one_cycle",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 591,
"end_line": 603,
"start_col": 0,
"end_col": 85,
"parent_name": null,
"docstring": "Return a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.\n\nArgs:\n y1 (float, optional): Initial value.\n y2 (float, optional): Final value.\n steps (int, optional): Number of steps.\n\nReturns:\n (function): Lambda function for computing the sinusoidal ramp.",
"parameters": [
"y1",
"y2",
"steps"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_one_cycle_69bc49bf"
},
{
"content": "def init_seeds(seed=0, deterministic=False):\n \"\"\"\n Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.\n\n Args:\n seed (int, optional): Random seed.\n deterministic (bool, optional): Whether to set deterministic algorithms.\n \"\"\"\n random.seed(seed)\n np.random.seed(seed)\n torch.manual_seed(seed)\n torch.cuda.manual_seed(seed)\n torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe\n # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287\n if deterministic:\n if TORCH_2_0:\n torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible\n torch.backends.cudnn.deterministic = True\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n os.environ[\"PYTHONHASHSEED\"] = str(seed)\n else:\n LOGGER.warning(\"Upgrade to torch>=2.0.0 for deterministic training.\")\n else:\n unset_deterministic()",
"chunk_type": "function",
"name": "init_seeds",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 606,
"end_line": 629,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": "Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.\n\nArgs:\n seed (int, optional): Random seed.\n deterministic (bool, optional): Whether to set deterministic algorithms.",
"parameters": [
"seed",
"deterministic"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_init_seeds_12265020"
},
{
"content": "def unset_deterministic():\n \"\"\"Unset all the configurations applied for deterministic training.\"\"\"\n torch.use_deterministic_algorithms(False)\n torch.backends.cudnn.deterministic = False\n os.environ.pop(\"CUBLAS_WORKSPACE_CONFIG\", None)\n os.environ.pop(\"PYTHONHASHSEED\", None)",
"chunk_type": "function",
"name": "unset_deterministic",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 632,
"end_line": 637,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": "Unset all the configurations applied for deterministic training.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_unset_deterministic_e2ce1747"
},
{
"content": "class ModelEMA:\n \"\"\"\n Updated Exponential Moving Average (EMA) implementation.\n\n Keeps a moving average of everything in the model state_dict (parameters and buffers).\n For EMA details see References.\n\n To disable EMA set the `enabled` attribute to `False`.\n\n Attributes:\n ema (nn.Module): Copy of the model in evaluation mode.\n updates (int): Number of EMA updates.\n decay (function): Decay function that determines the EMA weight.\n enabled (bool): Whether EMA is enabled.\n\n References:\n - https://github.com/rwightman/pytorch-image-models\n - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage\n \"\"\"\n\n def __init__(self, model, decay=0.9999, tau=2000, updates=0):\n \"\"\"\n Initialize EMA for 'model' with given arguments.\n\n Args:\n model (nn.Module): Model to create EMA for.\n decay (float, optional): Maximum EMA decay rate.\n tau (int, optional): EMA decay time constant.\n updates (int, optional): Initial number of updates.\n \"\"\"\n self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA\n self.updates = updates # number of EMA updates\n self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)\n for p in self.ema.parameters():\n p.requires_grad_(False)\n self.enabled = True\n\n def update(self, model):\n \"\"\"\n Update EMA parameters.\n\n Args:\n model (nn.Module): Model to update EMA from.\n \"\"\"\n if self.enabled:\n self.updates += 1\n d = self.decay(self.updates)\n\n msd = de_parallel(model).state_dict() # model state_dict\n for k, v in self.ema.state_dict().items():\n if v.dtype.is_floating_point: # true for FP16 and FP32\n v *= d\n v += (1 - d) * msd[k].detach()\n # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'\n\n def update_attr(self, model, include=(), exclude=(\"process_group\", \"reducer\")):\n \"\"\"\n Update attributes and save stripped model with optimizer removed.\n\n Args:\n model (nn.Module): Model to update attributes from.\n include (tuple, optional): Attributes to include.\n exclude (tuple, optional): Attributes to exclude.\n \"\"\"\n if self.enabled:\n copy_attr(self.ema, model, include, exclude)",
"chunk_type": "class",
"name": "ModelEMA",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 640,
"end_line": 705,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": "Updated Exponential Moving Average (EMA) implementation.\n\nKeeps a moving average of everything in the model state_dict (parameters and buffers).\nFor EMA details see References.\n\nTo disable EMA set the `enabled` attribute to `False`.\n\nAttributes:\n ema (nn.Module): Copy of the model in evaluation mode.\n updates (int): Number of EMA updates.\n decay (function): Decay function that determines the EMA weight.\n enabled (bool): Whether EMA is enabled.\n\nReferences:\n - https://github.com/rwightman/pytorch-image-models\n - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "class_ModelEMA_df773054"
},
{
"content": "def strip_optimizer(f: Union[str, Path] = \"best.pt\", s: str = \"\", updates: Dict[str, Any] = None) -> Dict[str, Any]:\n \"\"\"\n Strip optimizer from 'f' to finalize training, optionally save as 's'.\n\n Args:\n f (str | Path): File path to model to strip the optimizer from.\n s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be\n overwritten.\n updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.\n\n Returns:\n (dict): The combined checkpoint dictionary.\n\n Examples:\n >>> from pathlib import Path\n >>> from ultralytics.utils.torch_utils import strip_optimizer\n >>> for f in Path(\"path/to/model/checkpoints\").rglob(\"*.pt\"):\n >>> strip_optimizer(f)\n \"\"\"\n try:\n x = torch_load(f, map_location=torch.device(\"cpu\"))\n assert isinstance(x, dict), \"checkpoint is not a Python dictionary\"\n assert \"model\" in x, \"'model' missing from checkpoint\"\n except Exception as e:\n LOGGER.warning(f\"Skipping {f}, not a valid Ultralytics model: {e}\")\n return {}\n\n metadata = {\n \"date\": datetime.now().isoformat(),\n \"version\": __version__,\n \"license\": \"AGPL-3.0 License (https://ultralytics.com/license)\",\n \"docs\": \"https://docs.ultralytics.com\",\n }\n\n # Update model\n if x.get(\"ema\"):\n x[\"model\"] = x[\"ema\"] # replace model with EMA\n if hasattr(x[\"model\"], \"args\"):\n x[\"model\"].args = dict(x[\"model\"].args) # convert from IterableSimpleNamespace to dict\n if hasattr(x[\"model\"], \"criterion\"):\n x[\"model\"].criterion = None # strip loss criterion\n x[\"model\"].half() # to FP16\n for p in x[\"model\"].parameters():\n p.requires_grad = False\n\n # Update other keys\n args = {**DEFAULT_CFG_DICT, **x.get(\"train_args\", {})} # combine args\n for k in \"optimizer\", \"best_fitness\", \"ema\", \"updates\": # keys\n x[k] = None\n x[\"epoch\"] = -1\n x[\"train_args\"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys\n # x['model'].args = x['train_args']\n\n # Save\n combined = {**metadata, **x, **(updates or {})}\n torch.save(combined, s or f) # combine dicts (prefer to the right)\n mb = os.path.getsize(s or f) / 1e6 # file size\n LOGGER.info(f\"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB\")\n return combined",
"chunk_type": "function",
"name": "strip_optimizer",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 708,
"end_line": 766,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Strip optimizer from 'f' to finalize training, optionally save as 's'.\n\nArgs:\n f (str | Path): File path to model to strip the optimizer from.\n s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be\n overwritten.\n updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.\n\nReturns:\n (dict): The combined checkpoint dictionary.\n\nExamples:\n >>> from pathlib import Path\n >>> from ultralytics.utils.torch_utils import strip_optimizer\n >>> for f in Path(\"path/to/model/checkpoints\").rglob(\"*.pt\"):\n >>> strip_optimizer(f)",
"parameters": [
"f: Union[str, Path]",
"s: str",
"updates: Dict[str, Any]"
],
"return_type": "Dict[str, Any]",
"decorators": [],
"complexity_score": 8,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_strip_optimizer_8410572a"
},
{
"content": "def convert_optimizer_state_dict_to_fp16(state_dict):\n \"\"\"\n Convert the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.\n\n Args:\n state_dict (dict): Optimizer state dictionary.\n\n Returns:\n (dict): Converted optimizer state dictionary with FP16 tensors.\n \"\"\"\n for state in state_dict[\"state\"].values():\n for k, v in state.items():\n if k != \"step\" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:\n state[k] = v.half()\n\n return state_dict",
"chunk_type": "function",
"name": "convert_optimizer_state_dict_to_fp16",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 769,
"end_line": 784,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": "Convert the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.\n\nArgs:\n state_dict (dict): Optimizer state dictionary.\n\nReturns:\n (dict): Converted optimizer state dictionary with FP16 tensors.",
"parameters": [
"state_dict"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_convert_optimizer_state_dict_to_fp16_a80c83e7"
},
{
"content": "def cuda_memory_usage(device=None):\n \"\"\"\n Monitor and manage CUDA memory usage.\n\n This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.\n It then yields a dictionary containing memory usage information, which can be updated by the caller.\n Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.\n\n Args:\n device (torch.device, optional): The CUDA device to query memory usage for.\n\n Yields:\n (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.\n \"\"\"\n cuda_info = dict(memory=0)\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n try:\n yield cuda_info\n finally:\n cuda_info[\"memory\"] = torch.cuda.memory_reserved(device)\n else:\n yield cuda_info",
"chunk_type": "function",
"name": "cuda_memory_usage",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 788,
"end_line": 810,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Monitor and manage CUDA memory usage.\n\nThis function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.\nIt then yields a dictionary containing memory usage information, which can be updated by the caller.\nFinally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.\n\nArgs:\n device (torch.device, optional): The CUDA device to query memory usage for.\n\nYields:\n (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.",
"parameters": [
"device"
],
"return_type": null,
"decorators": [
"contextmanager"
],
"complexity_score": 2,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_cuda_memory_usage_8815d84d"
},
{
"content": "def profile_ops(input, ops, n=10, device=None, max_num_obj=0):\n \"\"\"\n Ultralytics speed, memory and FLOPs profiler.\n\n Args:\n input (torch.Tensor | list): Input tensor(s) to profile.\n ops (nn.Module | list): Model or list of operations to profile.\n n (int, optional): Number of iterations to average.\n device (str | torch.device, optional): Device to profile on.\n max_num_obj (int, optional): Maximum number of objects for simulation.\n\n Returns:\n (list): Profile results for each operation.\n\n Examples:\n >>> from ultralytics.utils.torch_utils import profile_ops\n >>> input = torch.randn(16, 3, 640, 640)\n >>> m1 = lambda x: x * torch.sigmoid(x)\n >>> m2 = nn.SiLU()\n >>> profile_ops(input, [m1, m2], n=100) # profile over 100 iterations\n \"\"\"\n try:\n import thop\n except ImportError:\n thop = None # conda support without 'ultralytics-thop' installed\n\n results = []\n if not isinstance(device, torch.device):\n device = select_device(device)\n LOGGER.info(\n f\"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}\"\n f\"{'input':>24s}{'output':>24s}\"\n )\n gc.collect() # attempt to free unused memory\n torch.cuda.empty_cache()\n for x in input if isinstance(input, list) else [input]:\n x = x.to(device)\n x.requires_grad = True\n for m in ops if isinstance(ops, list) else [ops]:\n m = m.to(device) if hasattr(m, \"to\") else m # device\n m = m.half() if hasattr(m, \"half\") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m\n tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward\n try:\n flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs\n except Exception:\n flops = 0\n\n try:\n mem = 0\n for _ in range(n):\n with cuda_memory_usage(device) as cuda_info:\n t[0] = time_sync()\n y = m(x)\n t[1] = time_sync()\n try:\n (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()\n t[2] = time_sync()\n except Exception: # no backward method\n # print(e) # for debug\n t[2] = float(\"nan\")\n mem += cuda_info[\"memory\"] / 1e9 # (GB)\n tf += (t[1] - t[0]) * 1000 / n # ms per op forward\n tb += (t[2] - t[1]) * 1000 / n # ms per op backward\n if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)\n with cuda_memory_usage(device) as cuda_info:\n torch.randn(\n x.shape[0],\n max_num_obj,\n int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),\n device=device,\n dtype=torch.float32,\n )\n mem += cuda_info[\"memory\"] / 1e9 # (GB)\n s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else \"list\" for x in (x, y)) # shapes\n p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters\n LOGGER.info(f\"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}\")\n results.append([p, flops, mem, tf, tb, s_in, s_out])\n except Exception as e:\n LOGGER.info(e)\n results.append(None)\n finally:\n gc.collect() # attempt to free unused memory\n torch.cuda.empty_cache()\n return results",
"chunk_type": "function",
"name": "profile_ops",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 813,
"end_line": 896,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Ultralytics speed, memory and FLOPs profiler.\n\nArgs:\n input (torch.Tensor | list): Input tensor(s) to profile.\n ops (nn.Module | list): Model or list of operations to profile.\n n (int, optional): Number of iterations to average.\n device (str | torch.device, optional): Device to profile on.\n max_num_obj (int, optional): Maximum number of objects for simulation.\n\nReturns:\n (list): Profile results for each operation.\n\nExamples:\n >>> from ultralytics.utils.torch_utils import profile_ops\n >>> input = torch.randn(16, 3, 640, 640)\n >>> m1 = lambda x: x * torch.sigmoid(x)\n >>> m2 = nn.SiLU()\n >>> profile_ops(input, [m1, m2], n=100) # profile over 100 iterations",
"parameters": [
"input",
"ops",
"n",
"device",
"max_num_obj"
],
"return_type": null,
"decorators": [],
"complexity_score": 14,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "function_profile_ops_d9138367"
},
{
"content": "class EarlyStopping:\n \"\"\"\n Early stopping class that stops training when a specified number of epochs have passed without improvement.\n\n Attributes:\n best_fitness (float): Best fitness value observed.\n best_epoch (int): Epoch where best fitness was observed.\n patience (int): Number of epochs to wait after fitness stops improving before stopping.\n possible_stop (bool): Flag indicating if stopping may occur next epoch.\n \"\"\"\n\n def __init__(self, patience=50):\n \"\"\"\n Initialize early stopping object.\n\n Args:\n patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.\n \"\"\"\n self.best_fitness = 0.0 # i.e. mAP\n self.best_epoch = 0\n self.patience = patience or float(\"inf\") # epochs to wait after fitness stops improving to stop\n self.possible_stop = False # possible stop may occur next epoch\n\n def __call__(self, epoch, fitness):\n \"\"\"\n Check whether to stop training.\n\n Args:\n epoch (int): Current epoch of training\n fitness (float): Fitness value of current epoch\n\n Returns:\n (bool): True if training should stop, False otherwise\n \"\"\"\n if fitness is None: # check if fitness=None (happens when val=False)\n return False\n\n if fitness > self.best_fitness or self.best_fitness == 0: # allow for early zero-fitness stage of training\n self.best_epoch = epoch\n self.best_fitness = fitness\n delta = epoch - self.best_epoch # epochs without improvement\n self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch\n stop = delta >= self.patience # stop training if patience exceeded\n if stop:\n prefix = colorstr(\"EarlyStopping: \")\n LOGGER.info(\n f\"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. \"\n f\"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\\n\"\n f\"To update EarlyStopping(patience={self.patience}) pass a new patience value, \"\n f\"i.e. `patience=300` or use `patience=0` to disable EarlyStopping.\"\n )\n return stop",
"chunk_type": "class",
"name": "EarlyStopping",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 899,
"end_line": 950,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Early stopping class that stops training when a specified number of epochs have passed without improvement.\n\nAttributes:\n best_fitness (float): Best fitness value observed.\n best_epoch (int): Epoch where best fitness was observed.\n patience (int): Number of epochs to wait after fitness stops improving before stopping.\n possible_stop (bool): Flag indicating if stopping may occur next epoch.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo"
],
"chunk_id": "class_EarlyStopping_c8b60078"
},
{
"content": "class FXModel(nn.Module):\n \"\"\"\n A custom model class for torch.fx compatibility.\n\n This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph\n manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper\n copying.\n\n Attributes:\n model (nn.Module): The original model's layers.\n \"\"\"\n\n def __init__(self, model):\n \"\"\"\n Initialize the FXModel.\n\n Args:\n model (nn.Module): The original model to wrap for torch.fx compatibility.\n \"\"\"\n super().__init__()\n copy_attr(self, model)\n # Explicitly set `model` since `copy_attr` somehow does not copy it.\n self.model = model.model\n\n def forward(self, x):\n \"\"\"\n Forward pass through the model.\n\n This method performs the forward pass through the model, handling the dependencies between layers and saving\n intermediate outputs.\n\n Args:\n x (torch.Tensor): The input tensor to the model.\n\n Returns:\n (torch.Tensor): The output tensor from the model.\n \"\"\"\n y = [] # outputs\n for m in self.model:\n if m.f != -1: # if not from previous layer\n # from earlier layers\n x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]\n x = m(x) # run\n y.append(x) # save output\n return x",
"chunk_type": "class",
"name": "FXModel",
"file_path": "ultralytics\\ultralytics\\utils\\torch_utils.py",
"start_line": 953,
"end_line": 997,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "A custom model class for torch.fx compatibility.\n\nThis class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph\nmanipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper\ncopying.\n\nAttributes:\n model (nn.Module): The original model's layers.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"functools",
"gc",
"math",
"os",
"random",
"time",
"contextlib.contextmanager",
"copy.deepcopy",
"datetime.datetime",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Union",
"numpy",
"torch",
"torch.distributed",
"torch.nn",
"torch.nn.functional",
"ultralytics.__version__",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.DEFAULT_CFG_KEYS",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.PYTHON_VERSION",
"ultralytics.utils.TORCHVISION_VERSION",
"ultralytics.utils.WINDOWS",
"ultralytics.utils.colorstr",
"ultralytics.utils.checks.check_version",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.PERSISTENT_CACHE",
"ultralytics.utils.autodevice.GPUInfo",
"ultralytics.utils.benchmarks.ProfileModels",
"thop",
"thop",
"cpuinfo",
"nn.Module"
],
"chunk_id": "class_FXModel_592a07c2"
},
{
"content": "from typing import List",
"chunk_type": "import",
"name": "List",
"file_path": "ultralytics\\ultralytics\\utils\\triton.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List_8c3ed700"
},
{
"content": "from urllib.parse import urlsplit",
"chunk_type": "import",
"name": "urlsplit",
"file_path": "ultralytics\\ultralytics\\utils\\triton.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_urlsplit_621806e4"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\triton.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_727d2609"
},
{
"content": "class TritonRemoteModel:\n \"\"\"\n Client for interacting with a remote Triton Inference Server model.\n\n This class provides a convenient interface for sending inference requests to a Triton Inference Server\n and processing the responses. Supports both HTTP and gRPC communication protocols.\n\n Attributes:\n endpoint (str): The name of the model on the Triton server.\n url (str): The URL of the Triton server.\n triton_client: The Triton client (either HTTP or gRPC).\n InferInput: The input class for the Triton client.\n InferRequestedOutput: The output request class for the Triton client.\n input_formats (List[str]): The data types of the model inputs.\n np_input_formats (List[type]): The numpy data types of the model inputs.\n input_names (List[str]): The names of the model inputs.\n output_names (List[str]): The names of the model outputs.\n metadata: The metadata associated with the model.\n\n Methods:\n __call__: Call the model with the given inputs and return the outputs.\n\n Examples:\n Initialize a Triton client with HTTP\n >>> model = TritonRemoteModel(url=\"localhost:8000\", endpoint=\"yolov8\", scheme=\"http\")\n\n Make inference with numpy arrays\n >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))\n \"\"\"\n\n def __init__(self, url: str, endpoint: str = \"\", scheme: str = \"\"):\n \"\"\"\n Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.\n\n Arguments may be provided individually or parsed from a collective 'url' argument of the form\n :////\n\n Args:\n url (str): The URL of the Triton server.\n endpoint (str, optional): The name of the model on the Triton server.\n scheme (str, optional): The communication scheme ('http' or 'grpc').\n\n Examples:\n >>> model = TritonRemoteModel(url=\"localhost:8000\", endpoint=\"yolov8\", scheme=\"http\")\n >>> model = TritonRemoteModel(url=\"http://localhost:8000/yolov8\")\n \"\"\"\n if not endpoint and not scheme: # Parse all args from URL string\n splits = urlsplit(url)\n endpoint = splits.path.strip(\"/\").split(\"/\", 1)[0]\n scheme = splits.scheme\n url = splits.netloc\n\n self.endpoint = endpoint\n self.url = url\n\n # Choose the Triton client based on the communication scheme\n if scheme == \"http\":\n import tritonclient.http as client # noqa\n\n self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)\n config = self.triton_client.get_model_config(endpoint)\n else:\n import tritonclient.grpc as client # noqa\n\n self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)\n config = self.triton_client.get_model_config(endpoint, as_json=True)[\"config\"]\n\n # Sort output names alphabetically, i.e. 'output0', 'output1', etc.\n config[\"output\"] = sorted(config[\"output\"], key=lambda x: x.get(\"name\"))\n\n # Define model attributes\n type_map = {\"TYPE_FP32\": np.float32, \"TYPE_FP16\": np.float16, \"TYPE_UINT8\": np.uint8}\n self.InferRequestedOutput = client.InferRequestedOutput\n self.InferInput = client.InferInput\n self.input_formats = [x[\"data_type\"] for x in config[\"input\"]]\n self.np_input_formats = [type_map[x] for x in self.input_formats]\n self.input_names = [x[\"name\"] for x in config[\"input\"]]\n self.output_names = [x[\"name\"] for x in config[\"output\"]]\n self.metadata = eval(config.get(\"parameters\", {}).get(\"metadata\", {}).get(\"string_value\", \"None\"))\n\n def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:\n \"\"\"\n Call the model with the given inputs and return inference results.\n\n Args:\n *inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type\n for the corresponding model input.\n\n Returns:\n (List[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list\n corresponds to one of the model's output tensors.\n\n Examples:\n >>> model = TritonRemoteModel(url=\"localhost:8000\", endpoint=\"yolov8\", scheme=\"http\")\n >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))\n \"\"\"\n infer_inputs = []\n input_format = inputs[0].dtype\n for i, x in enumerate(inputs):\n if x.dtype != self.np_input_formats[i]:\n x = x.astype(self.np_input_formats[i])\n infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace(\"TYPE_\", \"\"))\n infer_input.set_data_from_numpy(x)\n infer_inputs.append(infer_input)\n\n infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]\n outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)\n\n return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]",
"chunk_type": "class",
"name": "TritonRemoteModel",
"file_path": "ultralytics\\ultralytics\\utils\\triton.py",
"start_line": 9,
"end_line": 117,
"start_col": 0,
"end_col": 104,
"parent_name": null,
"docstring": "Client for interacting with a remote Triton Inference Server model.\n\nThis class provides a convenient interface for sending inference requests to a Triton Inference Server\nand processing the responses. Supports both HTTP and gRPC communication protocols.\n\nAttributes:\n endpoint (str): The name of the model on the Triton server.\n url (str): The URL of the Triton server.\n triton_client: The Triton client (either HTTP or gRPC).\n InferInput: The input class for the Triton client.\n InferRequestedOutput: The output request class for the Triton client.\n input_formats (List[str]): The data types of the model inputs.\n np_input_formats (List[type]): The numpy data types of the model inputs.\n input_names (List[str]): The names of the model inputs.\n output_names (List[str]): The names of the model outputs.\n metadata: The metadata associated with the model.\n\nMethods:\n __call__: Call the model with the given inputs and return the outputs.\n\nExamples:\n Initialize a Triton client with HTTP\n >>> model = TritonRemoteModel(url=\"localhost:8000\", endpoint=\"yolov8\", scheme=\"http\")\n\n Make inference with numpy arrays\n >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"urllib.parse.urlsplit",
"numpy",
"tritonclient.http",
"tritonclient.grpc"
],
"chunk_id": "class_TritonRemoteModel_b8e84f73"
},
{
"content": "from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir",
"chunk_type": "import",
"name": "TASK2DATA, TASK2METRIC, get_cfg, get_save_dir",
"file_path": "ultralytics\\ultralytics\\utils\\tuner.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 73,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TASK2DATA, TASK2METRIC, get_cfg, get_save_dir_7c2ec1a2"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr",
"chunk_type": "import",
"name": "DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr",
"file_path": "ultralytics\\ultralytics\\utils\\tuner.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 98,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr_e9eb42a8"
},
{
"content": "def run_ray_tune(\n model,\n space: dict = None,\n grace_period: int = 10,\n gpu_per_trial: int = None,\n max_samples: int = 10,\n **train_args,\n):\n \"\"\"\n Run hyperparameter tuning using Ray Tune.\n\n Args:\n model (YOLO): Model to run the tuner on.\n space (dict, optional): The hyperparameter search space. If not provided, uses default space.\n grace_period (int, optional): The grace period in epochs of the ASHA scheduler.\n gpu_per_trial (int, optional): The number of GPUs to allocate per trial.\n max_samples (int, optional): The maximum number of trials to run.\n **train_args (Any): Additional arguments to pass to the `train()` method.\n\n Returns:\n (ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.\n\n Examples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\") # Load a YOLO11n model\n\n Start tuning hyperparameters for YOLO11n training on the COCO8 dataset\n >>> result_grid = model.tune(data=\"coco8.yaml\", use_ray=True)\n \"\"\"\n LOGGER.info(\"💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune\")\n if train_args is None:\n train_args = {}\n\n try:\n checks.check_requirements(\"ray[tune]\")\n\n import ray\n from ray import tune\n from ray.air import RunConfig\n from ray.air.integrations.wandb import WandbLoggerCallback\n from ray.tune.schedulers import ASHAScheduler\n except ImportError:\n raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install \"ray[tune]\"')\n\n try:\n import wandb\n\n assert hasattr(wandb, \"__version__\")\n except (ImportError, AssertionError):\n wandb = False\n\n checks.check_version(ray.__version__, \">=2.0.0\", \"ray\")\n default_space = {\n # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),\n \"lr0\": tune.uniform(1e-5, 1e-1),\n \"lrf\": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)\n \"momentum\": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1\n \"weight_decay\": tune.uniform(0.0, 0.001), # optimizer weight decay\n \"warmup_epochs\": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)\n \"warmup_momentum\": tune.uniform(0.0, 0.95), # warmup initial momentum\n \"box\": tune.uniform(0.02, 0.2), # box loss gain\n \"cls\": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)\n \"hsv_h\": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)\n \"hsv_s\": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)\n \"hsv_v\": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)\n \"degrees\": tune.uniform(0.0, 45.0), # image rotation (+/- deg)\n \"translate\": tune.uniform(0.0, 0.9), # image translation (+/- fraction)\n \"scale\": tune.uniform(0.0, 0.9), # image scale (+/- gain)\n \"shear\": tune.uniform(0.0, 10.0), # image shear (+/- deg)\n \"perspective\": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001\n \"flipud\": tune.uniform(0.0, 1.0), # image flip up-down (probability)\n \"fliplr\": tune.uniform(0.0, 1.0), # image flip left-right (probability)\n \"bgr\": tune.uniform(0.0, 1.0), # image channel BGR (probability)\n \"mosaic\": tune.uniform(0.0, 1.0), # image mosaic (probability)\n \"mixup\": tune.uniform(0.0, 1.0), # image mixup (probability)\n \"cutmix\": tune.uniform(0.0, 1.0), # image cutmix (probability)\n \"copy_paste\": tune.uniform(0.0, 1.0), # segment copy-paste (probability)\n }\n\n # Put the model in ray store\n task = model.task\n model_in_store = ray.put(model)\n\n def _tune(config):\n \"\"\"Train the YOLO model with the specified hyperparameters and return results.\"\"\"\n model_to_train = ray.get(model_in_store) # get the model from ray store for tuning\n model_to_train.reset_callbacks()\n config.update(train_args)\n results = model_to_train.train(**config)\n return results.results_dict\n\n # Get search space\n if not space and not train_args.get(\"resume\"):\n space = default_space\n LOGGER.warning(\"Search space not provided, using default search space.\")\n\n # Get dataset\n data = train_args.get(\"data\", TASK2DATA[task])\n space[\"data\"] = data\n if \"data\" not in train_args:\n LOGGER.warning(f'Data not provided, using default \"data={data}\".')\n\n # Define the trainable function with allocated resources\n trainable_with_resources = tune.with_resources(_tune, {\"cpu\": NUM_THREADS, \"gpu\": gpu_per_trial or 0})\n\n # Define the ASHA scheduler for hyperparameter search\n asha_scheduler = ASHAScheduler(\n time_attr=\"epoch\",\n metric=TASK2METRIC[task],\n mode=\"max\",\n max_t=train_args.get(\"epochs\") or DEFAULT_CFG_DICT[\"epochs\"] or 100,\n grace_period=grace_period,\n reduction_factor=3,\n )\n\n # Define the callbacks for the hyperparameter search\n tuner_callbacks = [WandbLoggerCallback(project=\"YOLOv8-tune\")] if wandb else []\n\n # Create the Ray Tune hyperparameter search tuner\n tune_dir = get_save_dir(\n get_cfg(\n DEFAULT_CFG,\n {**train_args, **{\"exist_ok\": train_args.pop(\"resume\", False)}}, # resume w/ same tune_dir\n ),\n name=train_args.pop(\"name\", \"tune\"), # runs/{task}/{tune_dir}\n ).resolve() # must be absolute dir\n tune_dir.mkdir(parents=True, exist_ok=True)\n if tune.Tuner.can_restore(tune_dir):\n LOGGER.info(f\"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...\")\n tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)\n else:\n tuner = tune.Tuner(\n trainable_with_resources,\n param_space=space,\n tune_config=tune.TuneConfig(\n scheduler=asha_scheduler,\n num_samples=max_samples,\n trial_name_creator=lambda trial: f\"{trial.trainable_name}_{trial.trial_id}\",\n trial_dirname_creator=lambda trial: f\"{trial.trainable_name}_{trial.trial_id}\",\n ),\n run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name),\n )\n\n # Run the hyperparameter search\n tuner.fit()\n\n # Get the results of the hyperparameter search\n results = tuner.get_results()\n\n # Shut down Ray to clean up workers\n ray.shutdown()\n\n return results",
"chunk_type": "function",
"name": "run_ray_tune",
"file_path": "ultralytics\\ultralytics\\utils\\tuner.py",
"start_line": 7,
"end_line": 159,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Run hyperparameter tuning using Ray Tune.\n\nArgs:\n model (YOLO): Model to run the tuner on.\n space (dict, optional): The hyperparameter search space. If not provided, uses default space.\n grace_period (int, optional): The grace period in epochs of the ASHA scheduler.\n gpu_per_trial (int, optional): The number of GPUs to allocate per trial.\n max_samples (int, optional): The maximum number of trials to run.\n **train_args (Any): Additional arguments to pass to the `train()` method.\n\nReturns:\n (ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.\n\nExamples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\") # Load a YOLO11n model\n\n Start tuning hyperparameters for YOLO11n training on the COCO8 dataset\n >>> result_grid = model.tune(data=\"coco8.yaml\", use_ray=True)",
"parameters": [
"model",
"space: dict",
"grace_period: int",
"gpu_per_trial: int",
"max_samples: int"
],
"return_type": null,
"decorators": [],
"complexity_score": 7,
"dependencies": [
"ultralytics.cfg.TASK2DATA",
"ultralytics.cfg.TASK2METRIC",
"ultralytics.cfg.get_cfg",
"ultralytics.cfg.get_save_dir",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.checks",
"ultralytics.utils.colorstr",
"ray",
"ray.tune",
"ray.air.RunConfig",
"ray.air.integrations.wandb.WandbLoggerCallback",
"ray.tune.schedulers.ASHAScheduler",
"wandb"
],
"chunk_id": "function_run_ray_tune_8bb9d330"
},
{
"content": "import contextlib",
"chunk_type": "import",
"name": "contextlib",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_contextlib_4d4c47c8"
},
{
"content": "import importlib.metadata",
"chunk_type": "import",
"name": "importlib.metadata",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_importlib.metadata_1f56d5eb"
},
{
"content": "import inspect",
"chunk_type": "import",
"name": "inspect",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 14,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_inspect_8238d903"
},
{
"content": "import json",
"chunk_type": "import",
"name": "json",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_json_a7821fa1"
},
{
"content": "import logging",
"chunk_type": "import",
"name": "logging",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 14,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_logging_2ff27b9b"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_a474651e"
},
{
"content": "import platform",
"chunk_type": "import",
"name": "platform",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_platform_aaf3ae19"
},
{
"content": "import re",
"chunk_type": "import",
"name": "re",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_re_1aa51818"
},
{
"content": "import subprocess",
"chunk_type": "import",
"name": "subprocess",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_subprocess_c14a63da"
},
{
"content": "import sys",
"chunk_type": "import",
"name": "sys",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_sys_7e23cba0"
},
{
"content": "import threading",
"chunk_type": "import",
"name": "threading",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_threading_defa8e84"
},
{
"content": "import time",
"chunk_type": "import",
"name": "time",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_time_5b823329"
},
{
"content": "import warnings",
"chunk_type": "import",
"name": "warnings",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_warnings_22e1693e"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_24ab0945"
},
{
"content": "from threading import Lock",
"chunk_type": "import",
"name": "Lock",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Lock_b62ee358"
},
{
"content": "from types import SimpleNamespace",
"chunk_type": "import",
"name": "SimpleNamespace",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SimpleNamespace_c9f5c5d7"
},
{
"content": "from typing import Union",
"chunk_type": "import",
"name": "Union",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Union_7b3abf18"
},
{
"content": "from urllib.parse import unquote",
"chunk_type": "import",
"name": "unquote",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 20,
"end_line": 20,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_unquote_b5450cb4"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 22,
"end_line": 22,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_377856b6"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 23,
"end_line": 23,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_615d6760"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 24,
"end_line": 24,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_6e178de0"
},
{
"content": "import tqdm",
"chunk_type": "import",
"name": "tqdm",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 25,
"end_line": 25,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_tqdm_fc1d95d4"
},
{
"content": "from ultralytics import __version__",
"chunk_type": "import",
"name": "__version__",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 27,
"end_line": 27,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import___version___ea70b14d"
},
{
"content": "from ultralytics.utils.patches import imread, imshow, imwrite, torch_save # for patches",
"chunk_type": "import",
"name": "imread, imshow, imwrite, torch_save",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 28,
"end_line": 28,
"start_col": 0,
"end_col": 73,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_imread, imshow, imwrite, torch_save_f1dceae8"
},
{
"content": "RANK = int(os.getenv(\"RANK\", -1))",
"chunk_type": "variable",
"name": "RANK",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 31,
"end_line": 31,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_RANK_9e3ef401"
},
{
"content": "LOCAL_RANK = int(os.getenv(\"LOCAL_RANK\", -1)) # https://pytorch.org/docs/stable/elastic/run.html",
"chunk_type": "variable",
"name": "LOCAL_RANK",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 32,
"end_line": 32,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_LOCAL_RANK_90f2524a"
},
{
"content": "ARGV = sys.argv or [\"\", \"\"] # sometimes sys.argv = []",
"chunk_type": "variable",
"name": "ARGV",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 35,
"end_line": 35,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_ARGV_8d0e7e0e"
},
{
"content": "FILE = Path(__file__).resolve()",
"chunk_type": "variable",
"name": "FILE",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 36,
"end_line": 36,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_FILE_835ae00c"
},
{
"content": "ROOT = FILE.parents[1] # YOLO",
"chunk_type": "variable",
"name": "ROOT",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 37,
"end_line": 37,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_ROOT_ba9cd46f"
},
{
"content": "ASSETS = ROOT / \"assets\" # default images",
"chunk_type": "variable",
"name": "ASSETS",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 38,
"end_line": 38,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_ASSETS_dc99f744"
},
{
"content": "ASSETS_URL = \"https://github.com/ultralytics/assets/releases/download/v0.0.0\" # assets GitHub URL",
"chunk_type": "variable",
"name": "ASSETS_URL",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 39,
"end_line": 39,
"start_col": 0,
"end_col": 77,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_ASSETS_URL_7853b028"
},
{
"content": "DEFAULT_CFG_PATH = ROOT / \"cfg/default.yaml\"",
"chunk_type": "variable",
"name": "DEFAULT_CFG_PATH",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 40,
"end_line": 40,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_DEFAULT_CFG_PATH_2e90f729"
},
{
"content": "NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLO multiprocessing threads",
"chunk_type": "variable",
"name": "NUM_THREADS",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 41,
"end_line": 41,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_NUM_THREADS_640ac97b"
},
{
"content": "AUTOINSTALL = str(os.getenv(\"YOLO_AUTOINSTALL\", True)).lower() == \"true\" # global auto-install mode",
"chunk_type": "variable",
"name": "AUTOINSTALL",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 42,
"end_line": 42,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_AUTOINSTALL_66403f21"
},
{
"content": "VERBOSE = str(os.getenv(\"YOLO_VERBOSE\", True)).lower() == \"true\" # global verbose mode",
"chunk_type": "variable",
"name": "VERBOSE",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 43,
"end_line": 43,
"start_col": 0,
"end_col": 64,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_VERBOSE_a14f86bb"
},
{
"content": "TQDM_BAR_FORMAT = \"{l_bar}{bar:10}{r_bar}\" if VERBOSE else None # tqdm bar format",
"chunk_type": "variable",
"name": "TQDM_BAR_FORMAT",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 44,
"end_line": 44,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TQDM_BAR_FORMAT_16e9892d"
},
{
"content": "LOGGING_NAME = \"ultralytics\"",
"chunk_type": "variable",
"name": "LOGGING_NAME",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 45,
"end_line": 45,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_LOGGING_NAME_c6583563"
},
{
"content": "MACOS_VERSION = platform.mac_ver()[0] if MACOS else None",
"chunk_type": "variable",
"name": "MACOS_VERSION",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 47,
"end_line": 47,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_MACOS_VERSION_72e4c96b"
},
{
"content": "ARM64 = platform.machine() in {\"arm64\", \"aarch64\"} # ARM64 booleans",
"chunk_type": "variable",
"name": "ARM64",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 48,
"end_line": 48,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_ARM64_f4087415"
},
{
"content": "PYTHON_VERSION = platform.python_version()",
"chunk_type": "variable",
"name": "PYTHON_VERSION",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 49,
"end_line": 49,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_PYTHON_VERSION_229c3f66"
},
{
"content": "TORCH_VERSION = torch.__version__",
"chunk_type": "variable",
"name": "TORCH_VERSION",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 50,
"end_line": 50,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCH_VERSION_f77f9b49"
},
{
"content": "TORCHVISION_VERSION = importlib.metadata.version(\"torchvision\") # faster than importing torchvision",
"chunk_type": "variable",
"name": "TORCHVISION_VERSION",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 51,
"end_line": 51,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TORCHVISION_VERSION_6646c490"
},
{
"content": "IS_VSCODE = os.environ.get(\"TERM_PROGRAM\", False) == \"vscode\"",
"chunk_type": "variable",
"name": "IS_VSCODE",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 52,
"end_line": 52,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_VSCODE_78a0d751"
},
{
"content": "RKNN_CHIPS = frozenset(\n {\n \"rk3588\",\n \"rk3576\",\n \"rk3566\",\n \"rk3568\",\n \"rk3562\",\n \"rv1103\",\n \"rv1106\",\n \"rv1103b\",\n \"rv1106b\",\n \"rk2118\",\n }\n) # Rockchip processors available for export",
"chunk_type": "variable",
"name": "RKNN_CHIPS",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 53,
"end_line": 66,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_RKNN_CHIPS_7a7f5d56"
},
{
"content": "HELP_MSG = \"\"\"\n Examples for running Ultralytics:\n\n 1. Install the ultralytics package:\n\n pip install ultralytics\n\n 2. Use the Python SDK:\n\n from ultralytics import YOLO\n\n # Load a model\n model = YOLO(\"yolo11n.yaml\") # build a new model from scratch\n model = YOLO(\"yolo11n.pt\") # load a pretrained model (recommended for training)\n\n # Use the model\n results = model.train(data=\"coco8.yaml\", epochs=3) # train the model\n results = model.val() # evaluate model performance on the validation set\n results = model(\"https://ultralytics.com/images/bus.jpg\") # predict on an image\n success = model.export(format=\"onnx\") # export the model to ONNX format\n\n 3. Use the command line interface (CLI):\n\n Ultralytics 'yolo' CLI commands use the following syntax:\n\n yolo TASK MODE ARGS\n\n Where TASK (optional) is one of [detect, segment, classify, pose, obb]\n MODE (required) is one of [train, val, predict, export, track, benchmark]\n ARGS (optional) are any number of custom \"arg=value\" pairs like \"imgsz=320\" that override defaults.\n See all ARGS at https://docs.ultralytics.com/usage/cfg or with \"yolo cfg\"\n\n - Train a detection model for 10 epochs with an initial learning_rate of 0.01\n yolo detect train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01\n\n - Predict a YouTube video using a pretrained segmentation model at image size 320:\n yolo segment predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320\n\n - Val a pretrained detection model at batch-size 1 and image size 640:\n yolo detect val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640\n\n - Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)\n yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128\n\n - Run special commands:\n yolo help\n yolo checks\n yolo version\n yolo settings\n yolo copy-cfg\n yolo cfg\n\n Docs: https://docs.ultralytics.com\n Community: https://community.ultralytics.com\n GitHub: https://github.com/ultralytics/ultralytics\n \"\"\"",
"chunk_type": "variable",
"name": "HELP_MSG",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 67,
"end_line": 122,
"start_col": 0,
"end_col": 7,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_HELP_MSG_5239f3ea"
},
{
"content": "class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm):\n \"\"\"\n A custom TQDM progress bar class that extends the original tqdm functionality.\n\n This class modifies the behavior of the original tqdm progress bar based on global settings and provides\n additional customization options for Ultralytics projects. The progress bar is automatically disabled when\n VERBOSE is False or when explicitly disabled.\n\n Attributes:\n disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and\n any passed 'disable' argument.\n bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not\n explicitly set.\n\n Methods:\n __init__: Initialize the TQDM object with custom settings.\n __iter__: Return self as iterator to satisfy Iterable interface.\n\n Examples:\n >>> from ultralytics.utils import TQDM\n >>> for i in TQDM(range(100)):\n ... # Your processing code here\n ... pass\n \"\"\"\n\n def __init__(self, *args, **kwargs):\n \"\"\"\n Initialize a custom TQDM progress bar with Ultralytics-specific settings.\n\n Args:\n *args (Any): Variable length argument list to be passed to the original tqdm constructor.\n **kwargs (Any): Arbitrary keyword arguments to be passed to the original tqdm constructor.\n\n Notes:\n - The progress bar is disabled if VERBOSE is False or if 'disable' is explicitly set to True in kwargs.\n - The default bar format is set to TQDM_BAR_FORMAT unless overridden in kwargs.\n\n Examples:\n >>> from ultralytics.utils import TQDM\n >>> for i in TQDM(range(100)):\n ... # Your code here\n ... pass\n \"\"\"\n warnings.filterwarnings(\"ignore\", category=tqdm.TqdmExperimentalWarning) # suppress tqdm.rich warning\n kwargs[\"disable\"] = not VERBOSE or kwargs.get(\"disable\", False)\n kwargs.setdefault(\"bar_format\", TQDM_BAR_FORMAT) # override default value if passed\n super().__init__(*args, **kwargs)\n\n def __iter__(self):\n \"\"\"Return self as iterator to satisfy Iterable interface.\"\"\"\n return super().__iter__()",
"chunk_type": "class",
"name": "TQDM",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 137,
"end_line": 187,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "A custom TQDM progress bar class that extends the original tqdm functionality.\n\nThis class modifies the behavior of the original tqdm progress bar based on global settings and provides\nadditional customization options for Ultralytics projects. The progress bar is automatically disabled when\nVERBOSE is False or when explicitly disabled.\n\nAttributes:\n disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and\n any passed 'disable' argument.\n bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not\n explicitly set.\n\nMethods:\n __init__: Initialize the TQDM object with custom settings.\n __iter__: Return self as iterator to satisfy Iterable interface.\n\nExamples:\n >>> from ultralytics.utils import TQDM\n >>> for i in TQDM(range(100)):\n ... # Your processing code here\n ... pass",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io",
"rich.tqdm if TQDM_RICH else tqdm.tqdm"
],
"chunk_id": "class_TQDM_87fbee36"
},
{
"content": "class DataExportMixin:\n \"\"\"\n Mixin class for exporting validation metrics or prediction results in various formats.\n\n This class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results\n from classification, object detection, segmentation, or pose estimation tasks into various formats: Pandas\n DataFrame, CSV, XML, HTML, JSON and SQLite (SQL).\n\n Methods:\n to_df: Convert summary to a Pandas DataFrame.\n to_csv: Export results as a CSV string.\n to_xml: Export results as an XML string (requires `lxml`).\n to_html: Export results as an HTML table.\n to_json: Export results as a JSON string.\n tojson: Deprecated alias for `to_json()`.\n to_sql: Export results to an SQLite database.\n\n Examples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model(\"image.jpg\")\n >>> df = results.to_df()\n >>> print(df)\n >>> csv_data = results.to_csv()\n >>> results.to_sql(table_name=\"yolo_results\")\n \"\"\"\n\n def to_df(self, normalize=False, decimals=5):\n \"\"\"\n Create a pandas DataFrame from the prediction results summary or validation metrics.\n\n Args:\n normalize (bool, optional): Normalize numerical values for easier comparison.\n decimals (int, optional): Decimal places to round floats.\n\n Returns:\n (DataFrame): DataFrame containing the summary data.\n \"\"\"\n import pandas as pd # scope for faster 'import ultralytics'\n\n return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals))\n\n def to_csv(self, normalize=False, decimals=5):\n \"\"\"\n Export results to CSV string format.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n\n Returns:\n (str): CSV content as string.\n \"\"\"\n return self.to_df(normalize=normalize, decimals=decimals).to_csv()\n\n def to_xml(self, normalize=False, decimals=5):\n \"\"\"\n Export results to XML format.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n\n Returns:\n (str): XML string.\n\n Notes:\n Requires `lxml` package to be installed.\n \"\"\"\n df = self.to_df(normalize=normalize, decimals=decimals)\n return '\\n' if df.empty else df.to_xml(parser=\"etree\")\n\n def to_html(self, normalize=False, decimals=5, index=False):\n \"\"\"\n Export results to HTML table format.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n index (bool, optional): Whether to include index column in the HTML table.\n\n Returns:\n (str): HTML representation of the results.\n \"\"\"\n df = self.to_df(normalize=normalize, decimals=decimals)\n return \"\" if df.empty else df.to_html(index=index)\n\n def tojson(self, normalize=False, decimals=5):\n \"\"\"Deprecated version of to_json().\"\"\"\n LOGGER.warning(\"'result.tojson()' is deprecated, replace with 'result.to_json()'.\")\n return self.to_json(normalize, decimals)\n\n def to_json(self, normalize=False, decimals=5):\n \"\"\"\n Export results to JSON format.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n\n Returns:\n (str): JSON-formatted string of the results.\n \"\"\"\n return self.to_df(normalize=normalize, decimals=decimals).to_json(orient=\"records\", indent=2)\n\n def to_sql(self, normalize=False, decimals=5, table_name=\"results\", db_path=\"results.db\"):\n \"\"\"\n Save results to an SQLite database.\n\n Args:\n normalize (bool, optional): Normalize numeric values.\n decimals (int, optional): Decimal precision.\n table_name (str, optional): Name of the SQL table.\n db_path (str, optional): SQLite database file path.\n \"\"\"\n df = self.to_df(normalize, decimals)\n if df.empty or df.columns.empty: # Exit if df is None or has no columns (i.e., no schema)\n return\n\n import sqlite3\n\n conn = sqlite3.connect(db_path)\n cursor = conn.cursor()\n\n # Dynamically create table schema based on summary to support prediction and validation results export\n columns = []\n for col in df.columns:\n sample_val = df[col].dropna().iloc[0] if not df[col].dropna().empty else \"\"\n if isinstance(sample_val, dict):\n col_type = \"TEXT\"\n elif isinstance(sample_val, (float, int)):\n col_type = \"REAL\"\n else:\n col_type = \"TEXT\"\n columns.append(f'\"{col}\" {col_type}') # Quote column names to handle special characters like hyphens\n\n # Create table (Drop table from db if it's already exist)\n cursor.execute(f'DROP TABLE IF EXISTS \"{table_name}\"')\n cursor.execute(f'CREATE TABLE \"{table_name}\" (id INTEGER PRIMARY KEY AUTOINCREMENT, {\", \".join(columns)})')\n\n for _, row in df.iterrows():\n values = [json.dumps(v) if isinstance(v, dict) else v for v in row]\n column_names = \", \".join(f'\"{col}\"' for col in df.columns)\n placeholders = \", \".join(\"?\" for _ in df.columns)\n cursor.execute(f'INSERT INTO \"{table_name}\" ({column_names}) VALUES ({placeholders})', values)\n\n conn.commit()\n conn.close()\n LOGGER.info(f\"Results saved to SQL table '{table_name}' in '{db_path}'.\")",
"chunk_type": "class",
"name": "DataExportMixin",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 190,
"end_line": 337,
"start_col": 0,
"end_col": 81,
"parent_name": null,
"docstring": "Mixin class for exporting validation metrics or prediction results in various formats.\n\nThis class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results\nfrom classification, object detection, segmentation, or pose estimation tasks into various formats: Pandas\nDataFrame, CSV, XML, HTML, JSON and SQLite (SQL).\n\nMethods:\n to_df: Convert summary to a Pandas DataFrame.\n to_csv: Export results as a CSV string.\n to_xml: Export results as an XML string (requires `lxml`).\n to_html: Export results as an HTML table.\n to_json: Export results as a JSON string.\n tojson: Deprecated alias for `to_json()`.\n to_sql: Export results to an SQLite database.\n\nExamples:\n >>> model = YOLO(\"yolo11n.pt\")\n >>> results = model(\"image.jpg\")\n >>> df = results.to_df()\n >>> print(df)\n >>> csv_data = results.to_csv()\n >>> results.to_sql(table_name=\"yolo_results\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "class_DataExportMixin_9a9b7f10"
},
{
"content": "class SimpleClass:\n \"\"\"\n A simple base class for creating objects with string representations of their attributes.\n\n This class provides a foundation for creating objects that can be easily printed or represented as strings,\n showing all their non-callable attributes. It's useful for debugging and introspection of object states.\n\n Methods:\n __str__: Return a human-readable string representation of the object.\n __repr__: Return a machine-readable string representation of the object.\n __getattr__: Provide a custom attribute access error message with helpful information.\n\n Examples:\n >>> class MyClass(SimpleClass):\n ... def __init__(self):\n ... self.x = 10\n ... self.y = \"hello\"\n >>> obj = MyClass()\n >>> print(obj)\n __main__.MyClass object with attributes:\n\n x: 10\n y: 'hello'\n\n Notes:\n - This class is designed to be subclassed. It provides a convenient way to inspect object attributes.\n - The string representation includes the module and class name of the object.\n - Callable attributes and attributes starting with an underscore are excluded from the string representation.\n \"\"\"\n\n def __str__(self):\n \"\"\"Return a human-readable string representation of the object.\"\"\"\n attr = []\n for a in dir(self):\n v = getattr(self, a)\n if not callable(v) and not a.startswith(\"_\"):\n if isinstance(v, SimpleClass):\n # Display only the module and class name for subclasses\n s = f\"{a}: {v.__module__}.{v.__class__.__name__} object\"\n else:\n s = f\"{a}: {repr(v)}\"\n attr.append(s)\n return f\"{self.__module__}.{self.__class__.__name__} object with attributes:\\n\\n\" + \"\\n\".join(attr)\n\n def __repr__(self):\n \"\"\"Return a machine-readable string representation of the object.\"\"\"\n return self.__str__()\n\n def __getattr__(self, attr):\n \"\"\"Provide a custom attribute access error message with helpful information.\"\"\"\n name = self.__class__.__name__\n raise AttributeError(f\"'{name}' object has no attribute '{attr}'. See valid attributes below.\\n{self.__doc__}\")",
"chunk_type": "class",
"name": "SimpleClass",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 340,
"end_line": 391,
"start_col": 0,
"end_col": 119,
"parent_name": null,
"docstring": "A simple base class for creating objects with string representations of their attributes.\n\nThis class provides a foundation for creating objects that can be easily printed or represented as strings,\nshowing all their non-callable attributes. It's useful for debugging and introspection of object states.\n\nMethods:\n __str__: Return a human-readable string representation of the object.\n __repr__: Return a machine-readable string representation of the object.\n __getattr__: Provide a custom attribute access error message with helpful information.\n\nExamples:\n >>> class MyClass(SimpleClass):\n ... def __init__(self):\n ... self.x = 10\n ... self.y = \"hello\"\n >>> obj = MyClass()\n >>> print(obj)\n __main__.MyClass object with attributes:\n\n x: 10\n y: 'hello'\n\nNotes:\n - This class is designed to be subclassed. It provides a convenient way to inspect object attributes.\n - The string representation includes the module and class name of the object.\n - Callable attributes and attributes starting with an underscore are excluded from the string representation.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "class_SimpleClass_0d66c6b0"
},
{
"content": "class IterableSimpleNamespace(SimpleNamespace):\n \"\"\"\n An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration.\n\n This class extends the SimpleNamespace class with additional methods for iteration, string representation,\n and attribute access. It is designed to be used as a convenient container for storing and accessing\n configuration parameters.\n\n Methods:\n __iter__: Return an iterator of key-value pairs from the namespace's attributes.\n __str__: Return a human-readable string representation of the object.\n __getattr__: Provide a custom attribute access error message with helpful information.\n get: Retrieve the value of a specified key, or a default value if the key doesn't exist.\n\n Examples:\n >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)\n >>> for k, v in cfg:\n ... print(f\"{k}: {v}\")\n a: 1\n b: 2\n c: 3\n >>> print(cfg)\n a=1\n b=2\n c=3\n >>> cfg.get(\"b\")\n 2\n >>> cfg.get(\"d\", \"default\")\n 'default'\n\n Notes:\n This class is particularly useful for storing configuration parameters in a more accessible\n and iterable format compared to a standard dictionary.\n \"\"\"\n\n def __iter__(self):\n \"\"\"Return an iterator of key-value pairs from the namespace's attributes.\"\"\"\n return iter(vars(self).items())\n\n def __str__(self):\n \"\"\"Return a human-readable string representation of the object.\"\"\"\n return \"\\n\".join(f\"{k}={v}\" for k, v in vars(self).items())\n\n def __getattr__(self, attr):\n \"\"\"Provide a custom attribute access error message with helpful information.\"\"\"\n name = self.__class__.__name__\n raise AttributeError(\n f\"\"\"\n '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics\n 'default.yaml' file.\\nPlease update your code with 'pip install -U ultralytics' and if necessary replace\n {DEFAULT_CFG_PATH} with the latest version from\n https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml\n \"\"\"\n )\n\n def get(self, key, default=None):\n \"\"\"Return the value of the specified key if it exists; otherwise, return the default value.\"\"\"\n return getattr(self, key, default)",
"chunk_type": "class",
"name": "IterableSimpleNamespace",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 394,
"end_line": 451,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": "An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration.\n\nThis class extends the SimpleNamespace class with additional methods for iteration, string representation,\nand attribute access. It is designed to be used as a convenient container for storing and accessing\nconfiguration parameters.\n\nMethods:\n __iter__: Return an iterator of key-value pairs from the namespace's attributes.\n __str__: Return a human-readable string representation of the object.\n __getattr__: Provide a custom attribute access error message with helpful information.\n get: Retrieve the value of a specified key, or a default value if the key doesn't exist.\n\nExamples:\n >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)\n >>> for k, v in cfg:\n ... print(f\"{k}: {v}\")\n a: 1\n b: 2\n c: 3\n >>> print(cfg)\n a=1\n b=2\n c=3\n >>> cfg.get(\"b\")\n 2\n >>> cfg.get(\"d\", \"default\")\n 'default'\n\nNotes:\n This class is particularly useful for storing configuration parameters in a more accessible\n and iterable format compared to a standard dictionary.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io",
"SimpleNamespace"
],
"chunk_id": "class_IterableSimpleNamespace_68d9d1c1"
},
{
"content": "def plt_settings(rcparams=None, backend=\"Agg\"):\n \"\"\"\n Decorator to temporarily set rc parameters and the backend for a plotting function.\n\n Args:\n rcparams (dict, optional): Dictionary of rc parameters to set.\n backend (str, optional): Name of the backend to use.\n\n Returns:\n (Callable): Decorated function with temporarily set rc parameters and backend.\n\n Examples:\n >>> @plt_settings({\"font.size\": 12})\n >>> def plot_function():\n ... plt.figure()\n ... plt.plot([1, 2, 3])\n ... plt.show()\n\n >>> with plt_settings({\"font.size\": 12}):\n ... plt.figure()\n ... plt.plot([1, 2, 3])\n ... plt.show()\n \"\"\"\n if rcparams is None:\n rcparams = {\"font.size\": 11}\n\n def decorator(func):\n \"\"\"Decorator to apply temporary rc parameters and backend to a function.\"\"\"\n\n def wrapper(*args, **kwargs):\n \"\"\"Set rc parameters and backend, call the original function, and restore the settings.\"\"\"\n import matplotlib.pyplot as plt # scope for faster 'import ultralytics'\n\n original_backend = plt.get_backend()\n switch = backend.lower() != original_backend.lower()\n if switch:\n plt.close(\"all\") # auto-close()ing of figures upon backend switching is deprecated since 3.8\n plt.switch_backend(backend)\n\n # Plot with backend and always revert to original backend\n try:\n with plt.rc_context(rcparams):\n result = func(*args, **kwargs)\n finally:\n if switch:\n plt.close(\"all\")\n plt.switch_backend(original_backend)\n return result\n\n return wrapper\n\n return decorator",
"chunk_type": "function",
"name": "plt_settings",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 454,
"end_line": 505,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Decorator to temporarily set rc parameters and the backend for a plotting function.\n\nArgs:\n rcparams (dict, optional): Dictionary of rc parameters to set.\n backend (str, optional): Name of the backend to use.\n\nReturns:\n (Callable): Decorated function with temporarily set rc parameters and backend.\n\nExamples:\n >>> @plt_settings({\"font.size\": 12})\n >>> def plot_function():\n ... plt.figure()\n ... plt.plot([1, 2, 3])\n ... plt.show()\n\n >>> with plt_settings({\"font.size\": 12}):\n ... plt.figure()\n ... plt.plot([1, 2, 3])\n ... plt.show()",
"parameters": [
"rcparams",
"backend"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_plt_settings_6860eb76"
},
{
"content": "def set_logging(name=\"LOGGING_NAME\", verbose=True):\n \"\"\"\n Set up logging with UTF-8 encoding and configurable verbosity.\n\n This function configures logging for the Ultralytics library, setting the appropriate logging level and\n formatter based on the verbosity flag and the current process rank. It handles special cases for Windows\n environments where UTF-8 encoding might not be the default.\n\n Args:\n name (str): Name of the logger.\n verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise.\n\n Returns:\n (logging.Logger): Configured logger object.\n\n Examples:\n >>> set_logging(name=\"ultralytics\", verbose=True)\n >>> logger = logging.getLogger(\"ultralytics\")\n >>> logger.info(\"This is an info message\")\n\n Notes:\n - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible.\n - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments.\n - The function sets up a StreamHandler with the appropriate formatter and level.\n - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers.\n \"\"\"\n level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings\n\n class PrefixFormatter(logging.Formatter):\n def format(self, record):\n \"\"\"Format log records with prefixes based on level.\"\"\"\n # Apply prefixes based on log level\n if record.levelno == logging.WARNING:\n prefix = \"WARNING ⚠️\" if not WINDOWS else \"WARNING\"\n record.msg = f\"{prefix} {record.msg}\"\n elif record.levelno == logging.ERROR:\n prefix = \"ERROR ❌\" if not WINDOWS else \"ERROR\"\n record.msg = f\"{prefix} {record.msg}\"\n\n # Handle emojis in message based on platform\n formatted_message = super().format(record)\n return emojis(formatted_message)\n\n formatter = PrefixFormatter(\"%(message)s\")\n\n # Handle Windows UTF-8 encoding issues\n if WINDOWS and hasattr(sys.stdout, \"encoding\") and sys.stdout.encoding != \"utf-8\":\n try:\n # Attempt to reconfigure stdout to use UTF-8 encoding if possible\n if hasattr(sys.stdout, \"reconfigure\"):\n sys.stdout.reconfigure(encoding=\"utf-8\")\n # For environments where reconfigure is not available, wrap stdout in a TextIOWrapper\n elif hasattr(sys.stdout, \"buffer\"):\n import io\n\n sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding=\"utf-8\")\n except Exception:\n pass\n\n # Create and configure the StreamHandler with the appropriate formatter and level\n stream_handler = logging.StreamHandler(sys.stdout)\n stream_handler.setFormatter(formatter)\n stream_handler.setLevel(level)\n\n # Set up the logger\n logger = logging.getLogger(name)\n logger.setLevel(level)\n logger.addHandler(stream_handler)\n logger.propagate = False\n return logger",
"chunk_type": "function",
"name": "set_logging",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 508,
"end_line": 577,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": "Set up logging with UTF-8 encoding and configurable verbosity.\n\nThis function configures logging for the Ultralytics library, setting the appropriate logging level and\nformatter based on the verbosity flag and the current process rank. It handles special cases for Windows\nenvironments where UTF-8 encoding might not be the default.\n\nArgs:\n name (str): Name of the logger.\n verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise.\n\nReturns:\n (logging.Logger): Configured logger object.\n\nExamples:\n >>> set_logging(name=\"ultralytics\", verbose=True)\n >>> logger = logging.getLogger(\"ultralytics\")\n >>> logger.info(\"This is an info message\")\n\nNotes:\n - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible.\n - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments.\n - The function sets up a StreamHandler with the appropriate formatter and level.\n - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers.",
"parameters": [
"name",
"verbose"
],
"return_type": null,
"decorators": [],
"complexity_score": 7,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_set_logging_e62f00e3"
},
{
"content": "LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)",
"chunk_type": "variable",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 581,
"end_line": 581,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_LOGGER_d3f1010c"
},
{
"content": "def emojis(string=\"\"):\n \"\"\"Return platform-dependent emoji-safe version of string.\"\"\"\n return string.encode().decode(\"ascii\", \"ignore\") if WINDOWS else string",
"chunk_type": "function",
"name": "emojis",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 586,
"end_line": 588,
"start_col": 0,
"end_col": 75,
"parent_name": null,
"docstring": "Return platform-dependent emoji-safe version of string.",
"parameters": [
"string"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_emojis_37a9f4c6"
},
{
"content": "class ThreadingLocked:\n \"\"\"\n A decorator class for ensuring thread-safe execution of a function or method.\n\n This class can be used as a decorator to make sure that if the decorated function is called from multiple threads,\n only one thread at a time will be able to execute the function.\n\n Attributes:\n lock (threading.Lock): A lock object used to manage access to the decorated function.\n\n Examples:\n >>> from ultralytics.utils import ThreadingLocked\n >>> @ThreadingLocked()\n >>> def my_function():\n ... # Your code here\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize the decorator class with a threading lock.\"\"\"\n self.lock = threading.Lock()\n\n def __call__(self, f):\n \"\"\"Run thread-safe execution of function or method.\"\"\"\n from functools import wraps\n\n @wraps(f)\n def decorated(*args, **kwargs):\n \"\"\"Apply thread-safety to the decorated function or method.\"\"\"\n with self.lock:\n return f(*args, **kwargs)\n\n return decorated",
"chunk_type": "class",
"name": "ThreadingLocked",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 591,
"end_line": 622,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": "A decorator class for ensuring thread-safe execution of a function or method.\n\nThis class can be used as a decorator to make sure that if the decorated function is called from multiple threads,\nonly one thread at a time will be able to execute the function.\n\nAttributes:\n lock (threading.Lock): A lock object used to manage access to the decorated function.\n\nExamples:\n >>> from ultralytics.utils import ThreadingLocked\n >>> @ThreadingLocked()\n >>> def my_function():\n ... # Your code here",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "class_ThreadingLocked_a0571b88"
},
{
"content": "class YAML:\n \"\"\"\n YAML utility class for efficient file operations with automatic C-implementation detection.\n\n This class provides optimized YAML loading and saving operations using PyYAML's fastest available implementation\n (C-based when possible). It implements a singleton pattern with lazy initialization, allowing direct class method\n usage without explicit instantiation. The class handles file path creation, validation, and character encoding\n issues automatically.\n\n The implementation prioritizes performance through:\n - Automatic C-based loader/dumper selection when available\n - Singleton pattern to reuse the same instance\n - Lazy initialization to defer import costs until needed\n - Fallback mechanisms for handling problematic YAML content\n\n Attributes:\n _instance: Internal singleton instance storage.\n yaml: Reference to the PyYAML module.\n SafeLoader: Best available YAML loader (CSafeLoader if available).\n SafeDumper: Best available YAML dumper (CSafeDumper if available).\n\n Examples:\n >>> data = YAML.load(\"config.yaml\")\n >>> data[\"new_value\"] = 123\n >>> YAML.save(\"updated_config.yaml\", data)\n >>> YAML.print(data)\n \"\"\"\n\n _instance = None\n\n @classmethod\n def _get_instance(cls):\n \"\"\"Initialize singleton instance on first use.\"\"\"\n if cls._instance is None:\n cls._instance = cls()\n return cls._instance\n\n def __init__(self):\n \"\"\"Initialize with optimal YAML implementation (C-based when available).\"\"\"\n import yaml\n\n self.yaml = yaml\n # Use C-based implementation if available for better performance\n try:\n self.SafeLoader = yaml.CSafeLoader\n self.SafeDumper = yaml.CSafeDumper\n except (AttributeError, ImportError):\n self.SafeLoader = yaml.SafeLoader\n self.SafeDumper = yaml.SafeDumper\n\n @classmethod\n def save(cls, file=\"data.yaml\", data=None, header=\"\"):\n \"\"\"\n Save Python object as YAML file.\n\n Args:\n file (str | Path): Path to save YAML file.\n data (dict | None): Dict or compatible object to save.\n header (str): Optional string to add at file beginning.\n \"\"\"\n instance = cls._get_instance()\n if data is None:\n data = {}\n\n # Create parent directories if needed\n file = Path(file)\n file.parent.mkdir(parents=True, exist_ok=True)\n\n # Convert non-serializable objects to strings\n valid_types = int, float, str, bool, list, tuple, dict, type(None)\n for k, v in data.items():\n if not isinstance(v, valid_types):\n data[k] = str(v)\n\n # Write YAML file\n with open(file, \"w\", errors=\"ignore\", encoding=\"utf-8\") as f:\n if header:\n f.write(header)\n instance.yaml.dump(data, f, sort_keys=False, allow_unicode=True, Dumper=instance.SafeDumper)\n\n @classmethod\n def load(cls, file=\"data.yaml\", append_filename=False):\n \"\"\"\n Load YAML file to Python object with robust error handling.\n\n Args:\n file (str | Path): Path to YAML file.\n append_filename (bool): Whether to add filename to returned dict.\n\n Returns:\n (dict): Loaded YAML content.\n \"\"\"\n instance = cls._get_instance()\n assert str(file).endswith((\".yaml\", \".yml\")), f\"Not a YAML file: {file}\"\n\n # Read file content\n with open(file, errors=\"ignore\", encoding=\"utf-8\") as f:\n s = f.read()\n\n # Try loading YAML with fallback for problematic characters\n try:\n data = instance.yaml.load(s, Loader=instance.SafeLoader) or {}\n except Exception:\n # Remove problematic characters and retry\n s = re.sub(r\"[^\\x09\\x0A\\x0D\\x20-\\x7E\\x85\\xA0-\\uD7FF\\uE000-\\uFFFD\\U00010000-\\U0010ffff]+\", \"\", s)\n data = instance.yaml.load(s, Loader=instance.SafeLoader) or {}\n\n # Check for accidental user-error None strings (should be 'null' in YAML)\n if \"None\" in data.values():\n data = {k: None if v == \"None\" else v for k, v in data.items()}\n\n if append_filename:\n data[\"yaml_file\"] = str(file)\n return data\n\n @classmethod\n def print(cls, yaml_file):\n \"\"\"\n Pretty print YAML file or object to console.\n\n Args:\n yaml_file (str | Path | dict): Path to YAML file or dict to print.\n \"\"\"\n instance = cls._get_instance()\n\n # Load file if path provided\n yaml_dict = cls.load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file\n\n # Use -1 for unlimited width in C implementation\n dump = instance.yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=-1, Dumper=instance.SafeDumper)\n\n LOGGER.info(f\"Printing '{colorstr('bold', 'black', yaml_file)}'\\n\\n{dump}\")",
"chunk_type": "class",
"name": "YAML",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 625,
"end_line": 756,
"start_col": 0,
"end_col": 83,
"parent_name": null,
"docstring": "YAML utility class for efficient file operations with automatic C-implementation detection.\n\nThis class provides optimized YAML loading and saving operations using PyYAML's fastest available implementation\n(C-based when possible). It implements a singleton pattern with lazy initialization, allowing direct class method\nusage without explicit instantiation. The class handles file path creation, validation, and character encoding\nissues automatically.\n\nThe implementation prioritizes performance through:\n - Automatic C-based loader/dumper selection when available\n - Singleton pattern to reuse the same instance\n - Lazy initialization to defer import costs until needed\n - Fallback mechanisms for handling problematic YAML content\n\nAttributes:\n _instance: Internal singleton instance storage.\n yaml: Reference to the PyYAML module.\n SafeLoader: Best available YAML loader (CSafeLoader if available).\n SafeDumper: Best available YAML dumper (CSafeDumper if available).\n\nExamples:\n >>> data = YAML.load(\"config.yaml\")\n >>> data[\"new_value\"] = 123\n >>> YAML.save(\"updated_config.yaml\", data)\n >>> YAML.print(data)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "class_YAML_8935df41"
},
{
"content": "DEFAULT_CFG_DICT = YAML.load(DEFAULT_CFG_PATH)",
"chunk_type": "variable",
"name": "DEFAULT_CFG_DICT",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 760,
"end_line": 760,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_DEFAULT_CFG_DICT_07c66c33"
},
{
"content": "DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()",
"chunk_type": "variable",
"name": "DEFAULT_CFG_KEYS",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 761,
"end_line": 761,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_DEFAULT_CFG_KEYS_971a588d"
},
{
"content": "DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)",
"chunk_type": "variable",
"name": "DEFAULT_CFG",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 762,
"end_line": 762,
"start_col": 0,
"end_col": 57,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_DEFAULT_CFG_2b455215"
},
{
"content": "def read_device_model() -> str:\n \"\"\"\n Read the device model information from the system and cache it for quick access.\n\n Returns:\n (str): Kernel release information.\n \"\"\"\n return platform.release().lower()",
"chunk_type": "function",
"name": "read_device_model",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 765,
"end_line": 772,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": "Read the device model information from the system and cache it for quick access.\n\nReturns:\n (str): Kernel release information.",
"parameters": [],
"return_type": "str",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_read_device_model_958dbac8"
},
{
"content": "def is_ubuntu() -> bool:\n \"\"\"\n Check if the OS is Ubuntu.\n\n Returns:\n (bool): True if OS is Ubuntu, False otherwise.\n \"\"\"\n try:\n with open(\"/etc/os-release\") as f:\n return \"ID=ubuntu\" in f.read()\n except FileNotFoundError:\n return False",
"chunk_type": "function",
"name": "is_ubuntu",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 775,
"end_line": 786,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Check if the OS is Ubuntu.\n\nReturns:\n (bool): True if OS is Ubuntu, False otherwise.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_ubuntu_49c5d0a0"
},
{
"content": "def is_colab():\n \"\"\"\n Check if the current script is running inside a Google Colab notebook.\n\n Returns:\n (bool): True if running inside a Colab notebook, False otherwise.\n \"\"\"\n return \"COLAB_RELEASE_TAG\" in os.environ or \"COLAB_BACKEND_VERSION\" in os.environ",
"chunk_type": "function",
"name": "is_colab",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 789,
"end_line": 796,
"start_col": 0,
"end_col": 85,
"parent_name": null,
"docstring": "Check if the current script is running inside a Google Colab notebook.\n\nReturns:\n (bool): True if running inside a Colab notebook, False otherwise.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_colab_625ca7e5"
},
{
"content": "def is_kaggle():\n \"\"\"\n Check if the current script is running inside a Kaggle kernel.\n\n Returns:\n (bool): True if running inside a Kaggle kernel, False otherwise.\n \"\"\"\n return os.environ.get(\"PWD\") == \"/kaggle/working\" and os.environ.get(\"KAGGLE_URL_BASE\") == \"https://www.kaggle.com\"",
"chunk_type": "function",
"name": "is_kaggle",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 799,
"end_line": 806,
"start_col": 0,
"end_col": 119,
"parent_name": null,
"docstring": "Check if the current script is running inside a Kaggle kernel.\n\nReturns:\n (bool): True if running inside a Kaggle kernel, False otherwise.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_kaggle_57045760"
},
{
"content": "def is_jupyter():\n \"\"\"\n Check if the current script is running inside a Jupyter Notebook.\n\n Returns:\n (bool): True if running inside a Jupyter Notebook, False otherwise.\n\n Notes:\n - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable.\n - \"get_ipython\" in globals() method suffers false positives when IPython package installed manually.\n \"\"\"\n return IS_COLAB or IS_KAGGLE",
"chunk_type": "function",
"name": "is_jupyter",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 809,
"end_line": 820,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": "Check if the current script is running inside a Jupyter Notebook.\n\nReturns:\n (bool): True if running inside a Jupyter Notebook, False otherwise.\n\nNotes:\n - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable.\n - \"get_ipython\" in globals() method suffers false positives when IPython package installed manually.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_jupyter_2cf61eed"
},
{
"content": "def is_runpod():\n \"\"\"\n Check if the current script is running inside a RunPod container.\n\n Returns:\n (bool): True if running in RunPod, False otherwise.\n \"\"\"\n return \"RUNPOD_POD_ID\" in os.environ",
"chunk_type": "function",
"name": "is_runpod",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 823,
"end_line": 830,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Check if the current script is running inside a RunPod container.\n\nReturns:\n (bool): True if running in RunPod, False otherwise.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_runpod_a92259b9"
},
{
"content": "def is_docker() -> bool:\n \"\"\"\n Determine if the script is running inside a Docker container.\n\n Returns:\n (bool): True if the script is running inside a Docker container, False otherwise.\n \"\"\"\n try:\n return os.path.exists(\"/.dockerenv\")\n except Exception:\n return False",
"chunk_type": "function",
"name": "is_docker",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 833,
"end_line": 843,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Determine if the script is running inside a Docker container.\n\nReturns:\n (bool): True if the script is running inside a Docker container, False otherwise.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_docker_dcc1b9ae"
},
{
"content": "def is_raspberrypi() -> bool:\n \"\"\"\n Determine if the Python environment is running on a Raspberry Pi.\n\n Returns:\n (bool): True if running on a Raspberry Pi, False otherwise.\n \"\"\"\n return \"rpi\" in DEVICE_MODEL",
"chunk_type": "function",
"name": "is_raspberrypi",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 846,
"end_line": 853,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": "Determine if the Python environment is running on a Raspberry Pi.\n\nReturns:\n (bool): True if running on a Raspberry Pi, False otherwise.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_raspberrypi_2a530ff0"
},
{
"content": "def is_jetson() -> bool:\n \"\"\"\n Determine if the Python environment is running on an NVIDIA Jetson device.\n\n Returns:\n (bool): True if running on an NVIDIA Jetson device, False otherwise.\n \"\"\"\n return \"tegra\" in DEVICE_MODEL",
"chunk_type": "function",
"name": "is_jetson",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 856,
"end_line": 863,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": "Determine if the Python environment is running on an NVIDIA Jetson device.\n\nReturns:\n (bool): True if running on an NVIDIA Jetson device, False otherwise.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_jetson_28a84d17"
},
{
"content": "def is_online() -> bool:\n \"\"\"\n Check internet connectivity by attempting to connect to a known online host.\n\n Returns:\n (bool): True if connection is successful, False otherwise.\n \"\"\"\n try:\n assert str(os.getenv(\"YOLO_OFFLINE\", \"\")).lower() != \"true\" # check if ENV var YOLO_OFFLINE=\"True\"\n import socket\n\n for dns in (\"1.1.1.1\", \"8.8.8.8\"): # check Cloudflare and Google DNS\n socket.create_connection(address=(dns, 80), timeout=2.0).close()\n return True\n except Exception:\n return False",
"chunk_type": "function",
"name": "is_online",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 866,
"end_line": 881,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Check internet connectivity by attempting to connect to a known online host.\n\nReturns:\n (bool): True if connection is successful, False otherwise.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_online_9e65e651"
},
{
"content": "def is_pip_package(filepath: str = __name__) -> bool:\n \"\"\"\n Determine if the file at the given filepath is part of a pip package.\n\n Args:\n filepath (str): The filepath to check.\n\n Returns:\n (bool): True if the file is part of a pip package, False otherwise.\n \"\"\"\n import importlib.util\n\n # Get the spec for the module\n spec = importlib.util.find_spec(filepath)\n\n # Return whether the spec is not None and the origin is not None (indicating it is a package)\n return spec is not None and spec.origin is not None",
"chunk_type": "function",
"name": "is_pip_package",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 884,
"end_line": 900,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": "Determine if the file at the given filepath is part of a pip package.\n\nArgs:\n filepath (str): The filepath to check.\n\nReturns:\n (bool): True if the file is part of a pip package, False otherwise.",
"parameters": [
"filepath: str"
],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_pip_package_fc1f9870"
},
{
"content": "def is_dir_writeable(dir_path: Union[str, Path]) -> bool:\n \"\"\"\n Check if a directory is writeable.\n\n Args:\n dir_path (str | Path): The path to the directory.\n\n Returns:\n (bool): True if the directory is writeable, False otherwise.\n \"\"\"\n return os.access(str(dir_path), os.W_OK)",
"chunk_type": "function",
"name": "is_dir_writeable",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 903,
"end_line": 913,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Check if a directory is writeable.\n\nArgs:\n dir_path (str | Path): The path to the directory.\n\nReturns:\n (bool): True if the directory is writeable, False otherwise.",
"parameters": [
"dir_path: Union[str, Path]"
],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_dir_writeable_1802308c"
},
{
"content": "def is_pytest_running():\n \"\"\"\n Determine whether pytest is currently running or not.\n\n Returns:\n (bool): True if pytest is running, False otherwise.\n \"\"\"\n return (\"PYTEST_CURRENT_TEST\" in os.environ) or (\"pytest\" in sys.modules) or (\"pytest\" in Path(ARGV[0]).stem)",
"chunk_type": "function",
"name": "is_pytest_running",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 916,
"end_line": 923,
"start_col": 0,
"end_col": 113,
"parent_name": null,
"docstring": "Determine whether pytest is currently running or not.\n\nReturns:\n (bool): True if pytest is running, False otherwise.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_pytest_running_b6b77f74"
},
{
"content": "def is_github_action_running() -> bool:\n \"\"\"\n Determine if the current environment is a GitHub Actions runner.\n\n Returns:\n (bool): True if the current environment is a GitHub Actions runner, False otherwise.\n \"\"\"\n return \"GITHUB_ACTIONS\" in os.environ and \"GITHUB_WORKFLOW\" in os.environ and \"RUNNER_OS\" in os.environ",
"chunk_type": "function",
"name": "is_github_action_running",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 926,
"end_line": 933,
"start_col": 0,
"end_col": 107,
"parent_name": null,
"docstring": "Determine if the current environment is a GitHub Actions runner.\n\nReturns:\n (bool): True if the current environment is a GitHub Actions runner, False otherwise.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_github_action_running_689e73f5"
},
{
"content": "def get_git_dir():\n \"\"\"\n Determine whether the current file is part of a git repository and if so, return the repository root directory.\n\n Returns:\n (Path | None): Git root directory if found or None if not found.\n \"\"\"\n for d in Path(__file__).parents:\n if (d / \".git\").is_dir():\n return d",
"chunk_type": "function",
"name": "get_git_dir",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 936,
"end_line": 945,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Determine whether the current file is part of a git repository and if so, return the repository root directory.\n\nReturns:\n (Path | None): Git root directory if found or None if not found.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_get_git_dir_e37bf49c"
},
{
"content": "def is_git_dir():\n \"\"\"\n Determine whether the current file is part of a git repository.\n\n Returns:\n (bool): True if current file is part of a git repository.\n \"\"\"\n return GIT_DIR is not None",
"chunk_type": "function",
"name": "is_git_dir",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 948,
"end_line": 955,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": "Determine whether the current file is part of a git repository.\n\nReturns:\n (bool): True if current file is part of a git repository.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_is_git_dir_b10e45b8"
},
{
"content": "def get_git_origin_url():\n \"\"\"\n Retrieve the origin URL of a git repository.\n\n Returns:\n (str | None): The origin URL of the git repository or None if not git directory.\n \"\"\"\n if IS_GIT_DIR:\n try:\n origin = subprocess.check_output([\"git\", \"config\", \"--get\", \"remote.origin.url\"])\n return origin.decode().strip()\n except subprocess.CalledProcessError:\n return None",
"chunk_type": "function",
"name": "get_git_origin_url",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 958,
"end_line": 970,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Retrieve the origin URL of a git repository.\n\nReturns:\n (str | None): The origin URL of the git repository or None if not git directory.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_get_git_origin_url_75508ee5"
},
{
"content": "def get_git_branch():\n \"\"\"\n Return the current git branch name. If not in a git repository, return None.\n\n Returns:\n (str | None): The current git branch name or None if not a git directory.\n \"\"\"\n if IS_GIT_DIR:\n try:\n origin = subprocess.check_output([\"git\", \"rev-parse\", \"--abbrev-ref\", \"HEAD\"])\n return origin.decode().strip()\n except subprocess.CalledProcessError:\n return None",
"chunk_type": "function",
"name": "get_git_branch",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 973,
"end_line": 985,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Return the current git branch name. If not in a git repository, return None.\n\nReturns:\n (str | None): The current git branch name or None if not a git directory.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_get_git_branch_63f076d6"
},
{
"content": "def get_default_args(func):\n \"\"\"\n Return a dictionary of default arguments for a function.\n\n Args:\n func (callable): The function to inspect.\n\n Returns:\n (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.\n \"\"\"\n signature = inspect.signature(func)\n return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}",
"chunk_type": "function",
"name": "get_default_args",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 988,
"end_line": 999,
"start_col": 0,
"end_col": 110,
"parent_name": null,
"docstring": "Return a dictionary of default arguments for a function.\n\nArgs:\n func (callable): The function to inspect.\n\nReturns:\n (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.",
"parameters": [
"func"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_get_default_args_22819343"
},
{
"content": "def get_ubuntu_version():\n \"\"\"\n Retrieve the Ubuntu version if the OS is Ubuntu.\n\n Returns:\n (str): Ubuntu version or None if not an Ubuntu OS.\n \"\"\"\n if is_ubuntu():\n try:\n with open(\"/etc/os-release\") as f:\n return re.search(r'VERSION_ID=\"(\\d+\\.\\d+)\"', f.read())[1]\n except (FileNotFoundError, AttributeError):\n return None",
"chunk_type": "function",
"name": "get_ubuntu_version",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1002,
"end_line": 1014,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Retrieve the Ubuntu version if the OS is Ubuntu.\n\nReturns:\n (str): Ubuntu version or None if not an Ubuntu OS.",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_get_ubuntu_version_11ddee4f"
},
{
"content": "def get_user_config_dir(sub_dir=\"Ultralytics\"):\n \"\"\"\n Return the appropriate config directory based on the environment operating system.\n\n Args:\n sub_dir (str): The name of the subdirectory to create.\n\n Returns:\n (Path): The path to the user config directory.\n \"\"\"\n if WINDOWS:\n path = Path.home() / \"AppData\" / \"Roaming\" / sub_dir\n elif MACOS: # macOS\n path = Path.home() / \"Library\" / \"Application Support\" / sub_dir\n elif LINUX:\n path = Path.home() / \".config\" / sub_dir\n else:\n raise ValueError(f\"Unsupported operating system: {platform.system()}\")\n\n # GCP and AWS lambda fix, only /tmp is writeable\n if not is_dir_writeable(path.parent):\n LOGGER.warning(\n f\"user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD.\"\n \"Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.\"\n )\n path = Path(\"/tmp\") / sub_dir if is_dir_writeable(\"/tmp\") else Path().cwd() / sub_dir\n\n # Create the subdirectory if it does not exist\n path.mkdir(parents=True, exist_ok=True)\n\n return path",
"chunk_type": "function",
"name": "get_user_config_dir",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1017,
"end_line": 1047,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Return the appropriate config directory based on the environment operating system.\n\nArgs:\n sub_dir (str): The name of the subdirectory to create.\n\nReturns:\n (Path): The path to the user config directory.",
"parameters": [
"sub_dir"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_get_user_config_dir_f324c308"
},
{
"content": "DEVICE_MODEL = read_device_model() # is_jetson() and is_raspberrypi() depend on this constant",
"chunk_type": "variable",
"name": "DEVICE_MODEL",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1051,
"end_line": 1051,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_DEVICE_MODEL_c1957c1a"
},
{
"content": "ONLINE = is_online()",
"chunk_type": "variable",
"name": "ONLINE",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1052,
"end_line": 1052,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_ONLINE_bb1eb81e"
},
{
"content": "IS_COLAB = is_colab()",
"chunk_type": "variable",
"name": "IS_COLAB",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1053,
"end_line": 1053,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_COLAB_fc125f59"
},
{
"content": "IS_KAGGLE = is_kaggle()",
"chunk_type": "variable",
"name": "IS_KAGGLE",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1054,
"end_line": 1054,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_KAGGLE_9bd55599"
},
{
"content": "IS_DOCKER = is_docker()",
"chunk_type": "variable",
"name": "IS_DOCKER",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1055,
"end_line": 1055,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_DOCKER_2f339b0c"
},
{
"content": "IS_JETSON = is_jetson()",
"chunk_type": "variable",
"name": "IS_JETSON",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1056,
"end_line": 1056,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_JETSON_5d927d6c"
},
{
"content": "IS_JUPYTER = is_jupyter()",
"chunk_type": "variable",
"name": "IS_JUPYTER",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1057,
"end_line": 1057,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_JUPYTER_563ad653"
},
{
"content": "IS_PIP_PACKAGE = is_pip_package()",
"chunk_type": "variable",
"name": "IS_PIP_PACKAGE",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1058,
"end_line": 1058,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_PIP_PACKAGE_4deb7236"
},
{
"content": "IS_RASPBERRYPI = is_raspberrypi()",
"chunk_type": "variable",
"name": "IS_RASPBERRYPI",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1059,
"end_line": 1059,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_RASPBERRYPI_1c4fce70"
},
{
"content": "GIT_DIR = get_git_dir()",
"chunk_type": "variable",
"name": "GIT_DIR",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1060,
"end_line": 1060,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_GIT_DIR_041c7857"
},
{
"content": "IS_GIT_DIR = is_git_dir()",
"chunk_type": "variable",
"name": "IS_GIT_DIR",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1061,
"end_line": 1061,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_IS_GIT_DIR_53796298"
},
{
"content": "USER_CONFIG_DIR = Path(os.getenv(\"YOLO_CONFIG_DIR\") or get_user_config_dir()) # Ultralytics settings dir",
"chunk_type": "variable",
"name": "USER_CONFIG_DIR",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1062,
"end_line": 1062,
"start_col": 0,
"end_col": 77,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_USER_CONFIG_DIR_fe8daf7c"
},
{
"content": "SETTINGS_FILE = USER_CONFIG_DIR / \"settings.json\"",
"chunk_type": "variable",
"name": "SETTINGS_FILE",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1063,
"end_line": 1063,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_SETTINGS_FILE_0070b1d4"
},
{
"content": "def colorstr(*input):\n r\"\"\"\n Color a string based on the provided color and style arguments using ANSI escape codes.\n\n This function can be called in two ways:\n - colorstr('color', 'style', 'your string')\n - colorstr('your string')\n\n In the second form, 'blue' and 'bold' will be applied by default.\n\n Args:\n *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,\n and the last string is the one to be colored.\n\n Returns:\n (str): The input string wrapped with ANSI escape codes for the specified color and style.\n\n Notes:\n Supported Colors and Styles:\n - Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'\n - Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',\n 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'\n - Misc: 'end', 'bold', 'underline'\n\n Examples:\n >>> colorstr(\"blue\", \"bold\", \"hello world\")\n >>> \"\\033[34m\\033[1mhello world\\033[0m\"\n\n References:\n https://en.wikipedia.org/wiki/ANSI_escape_code\n \"\"\"\n *args, string = input if len(input) > 1 else (\"blue\", \"bold\", input[0]) # color arguments, string\n colors = {\n \"black\": \"\\033[30m\", # basic colors\n \"red\": \"\\033[31m\",\n \"green\": \"\\033[32m\",\n \"yellow\": \"\\033[33m\",\n \"blue\": \"\\033[34m\",\n \"magenta\": \"\\033[35m\",\n \"cyan\": \"\\033[36m\",\n \"white\": \"\\033[37m\",\n \"bright_black\": \"\\033[90m\", # bright colors\n \"bright_red\": \"\\033[91m\",\n \"bright_green\": \"\\033[92m\",\n \"bright_yellow\": \"\\033[93m\",\n \"bright_blue\": \"\\033[94m\",\n \"bright_magenta\": \"\\033[95m\",\n \"bright_cyan\": \"\\033[96m\",\n \"bright_white\": \"\\033[97m\",\n \"end\": \"\\033[0m\", # misc\n \"bold\": \"\\033[1m\",\n \"underline\": \"\\033[4m\",\n }\n return \"\".join(colors[x] for x in args) + f\"{string}\" + colors[\"end\"]",
"chunk_type": "function",
"name": "colorstr",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1066,
"end_line": 1119,
"start_col": 0,
"end_col": 73,
"parent_name": null,
"docstring": "Color a string based on the provided color and style arguments using ANSI escape codes.\n\nThis function can be called in two ways:\n - colorstr('color', 'style', 'your string')\n - colorstr('your string')\n\nIn the second form, 'blue' and 'bold' will be applied by default.\n\nArgs:\n *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,\n and the last string is the one to be colored.\n\nReturns:\n (str): The input string wrapped with ANSI escape codes for the specified color and style.\n\nNotes:\n Supported Colors and Styles:\n - Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'\n - Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',\n 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'\n - Misc: 'end', 'bold', 'underline'\n\nExamples:\n >>> colorstr(\"blue\", \"bold\", \"hello world\")\n >>> \"\\033[34m\\033[1mhello world\\033[0m\"\n\nReferences:\n https://en.wikipedia.org/wiki/ANSI_escape_code",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_colorstr_64f5b536"
},
{
"content": "def remove_colorstr(input_string):\n \"\"\"\n Remove ANSI escape codes from a string, effectively un-coloring it.\n\n Args:\n input_string (str): The string to remove color and style from.\n\n Returns:\n (str): A new string with all ANSI escape codes removed.\n\n Examples:\n >>> remove_colorstr(colorstr(\"blue\", \"bold\", \"hello world\"))\n >>> \"hello world\"\n \"\"\"\n ansi_escape = re.compile(r\"\\x1B\\[[0-9;]*[A-Za-z]\")\n return ansi_escape.sub(\"\", input_string)",
"chunk_type": "function",
"name": "remove_colorstr",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1122,
"end_line": 1137,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Remove ANSI escape codes from a string, effectively un-coloring it.\n\nArgs:\n input_string (str): The string to remove color and style from.\n\nReturns:\n (str): A new string with all ANSI escape codes removed.\n\nExamples:\n >>> remove_colorstr(colorstr(\"blue\", \"bold\", \"hello world\"))\n >>> \"hello world\"",
"parameters": [
"input_string"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_remove_colorstr_c31acf20"
},
{
"content": "class TryExcept(contextlib.ContextDecorator):\n \"\"\"\n Ultralytics TryExcept class for handling exceptions gracefully.\n\n This class can be used as a decorator or context manager to catch exceptions and optionally print warning messages.\n It allows code to continue execution even when exceptions occur, which is useful for non-critical operations.\n\n Attributes:\n msg (str): Optional message to display when an exception occurs.\n verbose (bool): Whether to print the exception message.\n\n Examples:\n As a decorator:\n >>> @TryExcept(msg=\"Error occurred in func\", verbose=True)\n >>> def func():\n >>> # Function logic here\n >>> pass\n\n As a context manager:\n >>> with TryExcept(msg=\"Error occurred in block\", verbose=True):\n >>> # Code block here\n >>> pass\n \"\"\"\n\n def __init__(self, msg=\"\", verbose=True):\n \"\"\"Initialize TryExcept class with optional message and verbosity settings.\"\"\"\n self.msg = msg\n self.verbose = verbose\n\n def __enter__(self):\n \"\"\"Execute when entering TryExcept context, initialize instance.\"\"\"\n pass\n\n def __exit__(self, exc_type, value, traceback):\n \"\"\"Define behavior when exiting a 'with' block, print error message if necessary.\"\"\"\n if self.verbose and value:\n LOGGER.warning(f\"{self.msg}{': ' if self.msg else ''}{value}\")\n return True",
"chunk_type": "class",
"name": "TryExcept",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1140,
"end_line": 1177,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Ultralytics TryExcept class for handling exceptions gracefully.\n\nThis class can be used as a decorator or context manager to catch exceptions and optionally print warning messages.\nIt allows code to continue execution even when exceptions occur, which is useful for non-critical operations.\n\nAttributes:\n msg (str): Optional message to display when an exception occurs.\n verbose (bool): Whether to print the exception message.\n\nExamples:\n As a decorator:\n >>> @TryExcept(msg=\"Error occurred in func\", verbose=True)\n >>> def func():\n >>> # Function logic here\n >>> pass\n\n As a context manager:\n >>> with TryExcept(msg=\"Error occurred in block\", verbose=True):\n >>> # Code block here\n >>> pass",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io",
"contextlib.ContextDecorator"
],
"chunk_id": "class_TryExcept_fe08fe13"
},
{
"content": "class Retry(contextlib.ContextDecorator):\n \"\"\"\n Retry class for function execution with exponential backoff.\n\n This decorator can be used to retry a function on exceptions, up to a specified number of times with an\n exponentially increasing delay between retries. It's useful for handling transient failures in network\n operations or other unreliable processes.\n\n Attributes:\n times (int): Maximum number of retry attempts.\n delay (int): Initial delay between retries in seconds.\n\n Examples:\n Example usage as a decorator:\n >>> @Retry(times=3, delay=2)\n >>> def test_func():\n >>> # Replace with function logic that may raise exceptions\n >>> return True\n \"\"\"\n\n def __init__(self, times=3, delay=2):\n \"\"\"Initialize Retry class with specified number of retries and delay.\"\"\"\n self.times = times\n self.delay = delay\n self._attempts = 0\n\n def __call__(self, func):\n \"\"\"Decorator implementation for Retry with exponential backoff.\"\"\"\n\n def wrapped_func(*args, **kwargs):\n \"\"\"Apply retries to the decorated function or method.\"\"\"\n self._attempts = 0\n while self._attempts < self.times:\n try:\n return func(*args, **kwargs)\n except Exception as e:\n self._attempts += 1\n LOGGER.warning(f\"Retry {self._attempts}/{self.times} failed: {e}\")\n if self._attempts >= self.times:\n raise e\n time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay\n\n return wrapped_func",
"chunk_type": "class",
"name": "Retry",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1180,
"end_line": 1222,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": "Retry class for function execution with exponential backoff.\n\nThis decorator can be used to retry a function on exceptions, up to a specified number of times with an\nexponentially increasing delay between retries. It's useful for handling transient failures in network\noperations or other unreliable processes.\n\nAttributes:\n times (int): Maximum number of retry attempts.\n delay (int): Initial delay between retries in seconds.\n\nExamples:\n Example usage as a decorator:\n >>> @Retry(times=3, delay=2)\n >>> def test_func():\n >>> # Replace with function logic that may raise exceptions\n >>> return True",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io",
"contextlib.ContextDecorator"
],
"chunk_id": "class_Retry_abfaf7c4"
},
{
"content": "def threaded(func):\n \"\"\"\n Multi-thread a target function by default and return the thread or function result.\n\n This decorator provides flexible execution of the target function, either in a separate thread or synchronously.\n By default, the function runs in a thread, but this can be controlled via the 'threaded=False' keyword argument\n which is removed from kwargs before calling the function.\n\n Args:\n func (callable): The function to be potentially executed in a separate thread.\n\n Returns:\n (callable): A wrapper function that either returns a daemon thread or the direct function result.\n\n Examples:\n >>> @threaded\n ... def process_data(data):\n ... return data\n >>>\n >>> thread = process_data(my_data) # Runs in background thread\n >>> result = process_data(my_data, threaded=False) # Runs synchronously, returns function result\n \"\"\"\n\n def wrapper(*args, **kwargs):\n \"\"\"Multi-thread a given function based on 'threaded' kwarg and return the thread or function result.\"\"\"\n if kwargs.pop(\"threaded\", True): # run in thread\n thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)\n thread.start()\n return thread\n else:\n return func(*args, **kwargs)\n\n return wrapper",
"chunk_type": "function",
"name": "threaded",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1225,
"end_line": 1257,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Multi-thread a target function by default and return the thread or function result.\n\nThis decorator provides flexible execution of the target function, either in a separate thread or synchronously.\nBy default, the function runs in a thread, but this can be controlled via the 'threaded=False' keyword argument\nwhich is removed from kwargs before calling the function.\n\nArgs:\n func (callable): The function to be potentially executed in a separate thread.\n\nReturns:\n (callable): A wrapper function that either returns a daemon thread or the direct function result.\n\nExamples:\n >>> @threaded\n ... def process_data(data):\n ... return data\n >>>\n >>> thread = process_data(my_data) # Runs in background thread\n >>> result = process_data(my_data, threaded=False) # Runs synchronously, returns function result",
"parameters": [
"func"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_threaded_82043875"
},
{
"content": "def set_sentry():\n \"\"\"\n Initialize the Sentry SDK for error tracking and reporting.\n\n Only used if sentry_sdk package is installed and sync=True in settings. Run 'yolo settings' to see and update\n settings.\n\n Conditions required to send errors (ALL conditions must be met or no errors will be reported):\n - sentry_sdk package is installed\n - sync=True in YOLO settings\n - pytest is not running\n - running in a pip package installation\n - running in a non-git directory\n - running with rank -1 or 0\n - online environment\n - CLI used to run package (checked with 'yolo' as the name of the main CLI command)\n \"\"\"\n if (\n not SETTINGS[\"sync\"]\n or RANK not in {-1, 0}\n or Path(ARGV[0]).name != \"yolo\"\n or TESTS_RUNNING\n or not ONLINE\n or not IS_PIP_PACKAGE\n or IS_GIT_DIR\n ):\n return\n # If sentry_sdk package is not installed then return and do not use Sentry\n try:\n import sentry_sdk # noqa\n except ImportError:\n return\n\n def before_send(event, hint):\n \"\"\"\n Modify the event before sending it to Sentry based on specific exception types and messages.\n\n Args:\n event (dict): The event dictionary containing information about the error.\n hint (dict): A dictionary containing additional information about the error.\n\n Returns:\n (dict | None): The modified event or None if the event should not be sent to Sentry.\n \"\"\"\n if \"exc_info\" in hint:\n exc_type, exc_value, _ = hint[\"exc_info\"]\n if exc_type in {KeyboardInterrupt, FileNotFoundError} or \"out of memory\" in str(exc_value):\n return None # do not send event\n\n event[\"tags\"] = {\n \"sys_argv\": ARGV[0],\n \"sys_argv_name\": Path(ARGV[0]).name,\n \"install\": \"git\" if IS_GIT_DIR else \"pip\" if IS_PIP_PACKAGE else \"other\",\n \"os\": ENVIRONMENT,\n }\n return event\n\n sentry_sdk.init(\n dsn=\"https://888e5a0778212e1d0314c37d4b9aae5d@o4504521589325824.ingest.us.sentry.io/4504521592406016\",\n debug=False,\n auto_enabling_integrations=False,\n traces_sample_rate=1.0,\n release=__version__,\n environment=\"runpod\" if is_runpod() else \"production\",\n before_send=before_send,\n ignore_errors=[KeyboardInterrupt, FileNotFoundError],\n )\n sentry_sdk.set_user({\"id\": SETTINGS[\"uuid\"]}) # SHA-256 anonymized UUID hash",
"chunk_type": "function",
"name": "set_sentry",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1260,
"end_line": 1327,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": "Initialize the Sentry SDK for error tracking and reporting.\n\nOnly used if sentry_sdk package is installed and sync=True in settings. Run 'yolo settings' to see and update\nsettings.\n\nConditions required to send errors (ALL conditions must be met or no errors will be reported):\n - sentry_sdk package is installed\n - sync=True in YOLO settings\n - pytest is not running\n - running in a pip package installation\n - running in a non-git directory\n - running with rank -1 or 0\n - online environment\n - CLI used to run package (checked with 'yolo' as the name of the main CLI command)",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_set_sentry_9d4803f9"
},
{
"content": "class JSONDict(dict):\n \"\"\"\n A dictionary-like class that provides JSON persistence for its contents.\n\n This class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are\n modified. It ensures thread-safe operations using a lock and handles JSON serialization of Path objects.\n\n Attributes:\n file_path (Path): The path to the JSON file used for persistence.\n lock (threading.Lock): A lock object to ensure thread-safe operations.\n\n Methods:\n _load: Load the data from the JSON file into the dictionary.\n _save: Save the current state of the dictionary to the JSON file.\n __setitem__: Store a key-value pair and persist it to disk.\n __delitem__: Remove an item and update the persistent storage.\n update: Update the dictionary and persist changes.\n clear: Clear all entries and update the persistent storage.\n\n Examples:\n >>> json_dict = JSONDict(\"data.json\")\n >>> json_dict[\"key\"] = \"value\"\n >>> print(json_dict[\"key\"])\n value\n >>> del json_dict[\"key\"]\n >>> json_dict.update({\"new_key\": \"new_value\"})\n >>> json_dict.clear()\n \"\"\"\n\n def __init__(self, file_path: Union[str, Path] = \"data.json\"):\n \"\"\"Initialize a JSONDict object with a specified file path for JSON persistence.\"\"\"\n super().__init__()\n self.file_path = Path(file_path)\n self.lock = Lock()\n self._load()\n\n def _load(self):\n \"\"\"Load the data from the JSON file into the dictionary.\"\"\"\n try:\n if self.file_path.exists():\n with open(self.file_path) as f:\n self.update(json.load(f))\n except json.JSONDecodeError:\n LOGGER.warning(f\"Error decoding JSON from {self.file_path}. Starting with an empty dictionary.\")\n except Exception as e:\n LOGGER.error(f\"Error reading from {self.file_path}: {e}\")\n\n def _save(self):\n \"\"\"Save the current state of the dictionary to the JSON file.\"\"\"\n try:\n self.file_path.parent.mkdir(parents=True, exist_ok=True)\n with open(self.file_path, \"w\", encoding=\"utf-8\") as f:\n json.dump(dict(self), f, indent=2, default=self._json_default)\n except Exception as e:\n LOGGER.error(f\"Error writing to {self.file_path}: {e}\")\n\n @staticmethod\n def _json_default(obj):\n \"\"\"Handle JSON serialization of Path objects.\"\"\"\n if isinstance(obj, Path):\n return str(obj)\n raise TypeError(f\"Object of type {type(obj).__name__} is not JSON serializable\")\n\n def __setitem__(self, key, value):\n \"\"\"Store a key-value pair and persist to disk.\"\"\"\n with self.lock:\n super().__setitem__(key, value)\n self._save()\n\n def __delitem__(self, key):\n \"\"\"Remove an item and update the persistent storage.\"\"\"\n with self.lock:\n super().__delitem__(key)\n self._save()\n\n def __str__(self):\n \"\"\"Return a pretty-printed JSON string representation of the dictionary.\"\"\"\n contents = json.dumps(dict(self), indent=2, ensure_ascii=False, default=self._json_default)\n return f'JSONDict(\"{self.file_path}\"):\\n{contents}'\n\n def update(self, *args, **kwargs):\n \"\"\"Update the dictionary and persist changes.\"\"\"\n with self.lock:\n super().update(*args, **kwargs)\n self._save()\n\n def clear(self):\n \"\"\"Clear all entries and update the persistent storage.\"\"\"\n with self.lock:\n super().clear()\n self._save()",
"chunk_type": "class",
"name": "JSONDict",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1330,
"end_line": 1420,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": "A dictionary-like class that provides JSON persistence for its contents.\n\nThis class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are\nmodified. It ensures thread-safe operations using a lock and handles JSON serialization of Path objects.\n\nAttributes:\n file_path (Path): The path to the JSON file used for persistence.\n lock (threading.Lock): A lock object to ensure thread-safe operations.\n\nMethods:\n _load: Load the data from the JSON file into the dictionary.\n _save: Save the current state of the dictionary to the JSON file.\n __setitem__: Store a key-value pair and persist it to disk.\n __delitem__: Remove an item and update the persistent storage.\n update: Update the dictionary and persist changes.\n clear: Clear all entries and update the persistent storage.\n\nExamples:\n >>> json_dict = JSONDict(\"data.json\")\n >>> json_dict[\"key\"] = \"value\"\n >>> print(json_dict[\"key\"])\n value\n >>> del json_dict[\"key\"]\n >>> json_dict.update({\"new_key\": \"new_value\"})\n >>> json_dict.clear()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io",
"dict"
],
"chunk_id": "class_JSONDict_c1d7c6d5"
},
{
"content": "class SettingsManager(JSONDict):\n \"\"\"\n SettingsManager class for managing and persisting Ultralytics settings.\n\n This class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default\n values. It validates settings on initialization and provides methods to update or reset settings. The settings\n include directories for datasets, weights, and runs, as well as various integration flags.\n\n Attributes:\n file (Path): The path to the JSON file used for persistence.\n version (str): The version of the settings schema.\n defaults (dict): A dictionary containing default settings.\n help_msg (str): A help message for users on how to view and update settings.\n\n Methods:\n _validate_settings: Validate the current settings and reset if necessary.\n update: Update settings, validating keys and types.\n reset: Reset the settings to default and save them.\n\n Examples:\n Initialize and update settings:\n >>> settings = SettingsManager()\n >>> settings.update(runs_dir=\"/new/runs/dir\")\n >>> print(settings[\"runs_dir\"])\n /new/runs/dir\n \"\"\"\n\n def __init__(self, file=SETTINGS_FILE, version=\"0.0.6\"):\n \"\"\"Initialize the SettingsManager with default settings and load user settings.\"\"\"\n import hashlib\n import uuid\n\n from ultralytics.utils.torch_utils import torch_distributed_zero_first\n\n root = GIT_DIR or Path()\n datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve()\n\n self.file = Path(file)\n self.version = version\n self.defaults = {\n \"settings_version\": version, # Settings schema version\n \"datasets_dir\": str(datasets_root / \"datasets\"), # Datasets directory\n \"weights_dir\": str(root / \"weights\"), # Model weights directory\n \"runs_dir\": str(root / \"runs\"), # Experiment runs directory\n \"uuid\": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # SHA-256 anonymized UUID hash\n \"sync\": True, # Enable synchronization\n \"api_key\": \"\", # Ultralytics API Key\n \"openai_api_key\": \"\", # OpenAI API Key\n \"clearml\": True, # ClearML integration\n \"comet\": True, # Comet integration\n \"dvc\": True, # DVC integration\n \"hub\": True, # Ultralytics HUB integration\n \"mlflow\": True, # MLflow integration\n \"neptune\": True, # Neptune integration\n \"raytune\": True, # Ray Tune integration\n \"tensorboard\": False, # TensorBoard logging\n \"wandb\": False, # Weights & Biases logging\n \"vscode_msg\": True, # VSCode message\n \"openvino_msg\": True, # OpenVINO export on Intel CPU message\n }\n\n self.help_msg = (\n f\"\\nView Ultralytics Settings with 'yolo settings' or at '{self.file}'\"\n \"\\nUpdate Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. \"\n \"For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.\"\n )\n\n with torch_distributed_zero_first(LOCAL_RANK):\n super().__init__(self.file)\n\n if not self.file.exists() or not self: # Check if file doesn't exist or is empty\n LOGGER.info(f\"Creating new Ultralytics Settings v{version} file ✅ {self.help_msg}\")\n self.reset()\n\n self._validate_settings()\n\n def _validate_settings(self):\n \"\"\"Validate the current settings and reset if necessary.\"\"\"\n correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys())\n correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())\n correct_version = self.get(\"settings_version\", \"\") == self.version\n\n if not (correct_keys and correct_types and correct_version):\n LOGGER.warning(\n \"Ultralytics settings reset to default values. This may be due to a possible problem \"\n f\"with your settings or a recent ultralytics package update. {self.help_msg}\"\n )\n self.reset()\n\n if self.get(\"datasets_dir\") == self.get(\"runs_dir\"):\n LOGGER.warning(\n f\"Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' \"\n f\"must be different than 'runs_dir: {self.get('runs_dir')}'. \"\n f\"Please change one to avoid possible issues during training. {self.help_msg}\"\n )\n\n def __setitem__(self, key, value):\n \"\"\"Update one key: value pair.\"\"\"\n self.update({key: value})\n\n def update(self, *args, **kwargs):\n \"\"\"Update settings, validating keys and types.\"\"\"\n for arg in args:\n if isinstance(arg, dict):\n kwargs.update(arg)\n for k, v in kwargs.items():\n if k not in self.defaults:\n raise KeyError(f\"No Ultralytics setting '{k}'. {self.help_msg}\")\n t = type(self.defaults[k])\n if not isinstance(v, t):\n raise TypeError(\n f\"Ultralytics setting '{k}' must be '{t.__name__}' type, not '{type(v).__name__}'. {self.help_msg}\"\n )\n super().update(*args, **kwargs)\n\n def reset(self):\n \"\"\"Reset the settings to default and save them.\"\"\"\n self.clear()\n self.update(self.defaults)",
"chunk_type": "class",
"name": "SettingsManager",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1423,
"end_line": 1541,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": "SettingsManager class for managing and persisting Ultralytics settings.\n\nThis class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default\nvalues. It validates settings on initialization and provides methods to update or reset settings. The settings\ninclude directories for datasets, weights, and runs, as well as various integration flags.\n\nAttributes:\n file (Path): The path to the JSON file used for persistence.\n version (str): The version of the settings schema.\n defaults (dict): A dictionary containing default settings.\n help_msg (str): A help message for users on how to view and update settings.\n\nMethods:\n _validate_settings: Validate the current settings and reset if necessary.\n update: Update settings, validating keys and types.\n reset: Reset the settings to default and save them.\n\nExamples:\n Initialize and update settings:\n >>> settings = SettingsManager()\n >>> settings.update(runs_dir=\"/new/runs/dir\")\n >>> print(settings[\"runs_dir\"])\n /new/runs/dir",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io",
"JSONDict"
],
"chunk_id": "class_SettingsManager_4b018472"
},
{
"content": "def deprecation_warn(arg, new_arg=None):\n \"\"\"Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.\"\"\"\n msg = f\"'{arg}' is deprecated and will be removed in in the future.\"\n if new_arg is not None:\n msg += f\" Use '{new_arg}' instead.\"\n LOGGER.warning(msg)",
"chunk_type": "function",
"name": "deprecation_warn",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1544,
"end_line": 1549,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.",
"parameters": [
"arg",
"new_arg"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_deprecation_warn_6d27243b"
},
{
"content": "def clean_url(url):\n \"\"\"Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.\"\"\"\n url = Path(url).as_posix().replace(\":/\", \"://\") # Pathlib turns :// -> :/, as_posix() for Windows\n return unquote(url).split(\"?\", 1)[0] # '%2F' to '/', split https://url.com/file.txt?auth",
"chunk_type": "function",
"name": "clean_url",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1552,
"end_line": 1555,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.",
"parameters": [
"url"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_clean_url_c74078e5"
},
{
"content": "def url2file(url):\n \"\"\"Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.\"\"\"\n return Path(clean_url(url)).name",
"chunk_type": "function",
"name": "url2file",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1558,
"end_line": 1560,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": "Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.",
"parameters": [
"url"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_url2file_fafbdcda"
},
{
"content": "def vscode_msg(ext=\"ultralytics.ultralytics-snippets\") -> str:\n \"\"\"Display a message to install Ultralytics-Snippets for VS Code if not already installed.\"\"\"\n path = (USER_CONFIG_DIR.parents[2] if WINDOWS else USER_CONFIG_DIR.parents[1]) / \".vscode/extensions\"\n obs_file = path / \".obsolete\" # file tracks uninstalled extensions, while source directory remains\n installed = any(path.glob(f\"{ext}*\")) and ext not in (obs_file.read_text(\"utf-8\") if obs_file.exists() else \"\")\n url = \"https://docs.ultralytics.com/integrations/vscode\"\n return \"\" if installed else f\"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at {url}\"",
"chunk_type": "function",
"name": "vscode_msg",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1563,
"end_line": 1569,
"start_col": 0,
"end_col": 105,
"parent_name": null,
"docstring": "Display a message to install Ultralytics-Snippets for VS Code if not already installed.",
"parameters": [
"ext"
],
"return_type": "str",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"contextlib",
"importlib.metadata",
"inspect",
"json",
"logging",
"os",
"platform",
"re",
"subprocess",
"sys",
"threading",
"time",
"warnings",
"pathlib.Path",
"threading.Lock",
"types.SimpleNamespace",
"typing.Union",
"urllib.parse.unquote",
"cv2",
"numpy",
"torch",
"tqdm",
"ultralytics.__version__",
"ultralytics.utils.patches.imread",
"ultralytics.utils.patches.imshow",
"ultralytics.utils.patches.imwrite",
"ultralytics.utils.patches.torch_save",
"tqdm.rich",
"importlib.util",
"pandas",
"sqlite3",
"functools.wraps",
"yaml",
"socket",
"sentry_sdk",
"hashlib",
"uuid",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"matplotlib.pyplot",
"io"
],
"chunk_id": "function_vscode_msg_a4abd200"
},
{
"content": "PREFIX = colorstr(\"Ultralytics: \")",
"chunk_type": "variable",
"name": "PREFIX",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1575,
"end_line": 1575,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_PREFIX_edda2740"
},
{
"content": "SETTINGS = SettingsManager() # initialize settings",
"chunk_type": "variable",
"name": "SETTINGS",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1576,
"end_line": 1576,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_SETTINGS_50e0f203"
},
{
"content": "PERSISTENT_CACHE = JSONDict(USER_CONFIG_DIR / \"persistent_cache.json\") # initialize persistent cache",
"chunk_type": "variable",
"name": "PERSISTENT_CACHE",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1577,
"end_line": 1577,
"start_col": 0,
"end_col": 70,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_PERSISTENT_CACHE_48cfe187"
},
{
"content": "DATASETS_DIR = Path(SETTINGS[\"datasets_dir\"]) # global datasets directory",
"chunk_type": "variable",
"name": "DATASETS_DIR",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1578,
"end_line": 1578,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_DATASETS_DIR_c43da5f2"
},
{
"content": "WEIGHTS_DIR = Path(SETTINGS[\"weights_dir\"]) # global weights directory",
"chunk_type": "variable",
"name": "WEIGHTS_DIR",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1579,
"end_line": 1579,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_WEIGHTS_DIR_f0a0ecc2"
},
{
"content": "RUNS_DIR = Path(SETTINGS[\"runs_dir\"]) # global runs directory",
"chunk_type": "variable",
"name": "RUNS_DIR",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1580,
"end_line": 1580,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_RUNS_DIR_fe066ddb"
},
{
"content": "ENVIRONMENT = (\n \"Colab\"\n if IS_COLAB\n else \"Kaggle\"\n if IS_KAGGLE\n else \"Jupyter\"\n if IS_JUPYTER\n else \"Docker\"\n if IS_DOCKER\n else platform.system()\n)",
"chunk_type": "variable",
"name": "ENVIRONMENT",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1581,
"end_line": 1591,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_ENVIRONMENT_ff00e629"
},
{
"content": "TESTS_RUNNING = is_pytest_running() or is_github_action_running()",
"chunk_type": "variable",
"name": "TESTS_RUNNING",
"file_path": "ultralytics\\ultralytics\\utils\\__init__.py",
"start_line": 1592,
"end_line": 1592,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_TESTS_RUNNING_2f8be4b0"
},
{
"content": "import concurrent.futures",
"chunk_type": "import",
"name": "concurrent.futures",
"file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_concurrent.futures_b27c7ccd"
},
{
"content": "import statistics",
"chunk_type": "import",
"name": "statistics",
"file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_statistics_2b7efb0d"
},
{
"content": "import time",
"chunk_type": "import",
"name": "time",
"file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_time_cacbb5ec"
},
{
"content": "from typing import List, Optional, Tuple",
"chunk_type": "import",
"name": "List, Optional, Tuple",
"file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional, Tuple_ceea995b"
},
{
"content": "import requests",
"chunk_type": "import",
"name": "requests",
"file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_requests_c171e7cd"
},
{
"content": "class GCPRegions:\n \"\"\"\n A class for managing and analyzing Google Cloud Platform (GCP) regions.\n\n This class provides functionality to initialize, categorize, and analyze GCP regions based on their\n geographical location, tier classification, and network latency.\n\n Attributes:\n regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.\n\n Methods:\n tier1: Returns a list of tier 1 GCP regions.\n tier2: Returns a list of tier 2 GCP regions.\n lowest_latency: Determines the GCP region(s) with the lowest network latency.\n\n Examples:\n >>> from ultralytics.hub.google import GCPRegions\n >>> regions = GCPRegions()\n >>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3)\n >>> print(f\"Lowest latency region: {lowest_latency_region[0][0]}\")\n \"\"\"\n\n def __init__(self):\n \"\"\"Initialize the GCPRegions class with predefined Google Cloud Platform regions and their details.\"\"\"\n self.regions = {\n \"asia-east1\": (1, \"Taiwan\", \"China\"),\n \"asia-east2\": (2, \"Hong Kong\", \"China\"),\n \"asia-northeast1\": (1, \"Tokyo\", \"Japan\"),\n \"asia-northeast2\": (1, \"Osaka\", \"Japan\"),\n \"asia-northeast3\": (2, \"Seoul\", \"South Korea\"),\n \"asia-south1\": (2, \"Mumbai\", \"India\"),\n \"asia-south2\": (2, \"Delhi\", \"India\"),\n \"asia-southeast1\": (2, \"Jurong West\", \"Singapore\"),\n \"asia-southeast2\": (2, \"Jakarta\", \"Indonesia\"),\n \"australia-southeast1\": (2, \"Sydney\", \"Australia\"),\n \"australia-southeast2\": (2, \"Melbourne\", \"Australia\"),\n \"europe-central2\": (2, \"Warsaw\", \"Poland\"),\n \"europe-north1\": (1, \"Hamina\", \"Finland\"),\n \"europe-southwest1\": (1, \"Madrid\", \"Spain\"),\n \"europe-west1\": (1, \"St. Ghislain\", \"Belgium\"),\n \"europe-west10\": (2, \"Berlin\", \"Germany\"),\n \"europe-west12\": (2, \"Turin\", \"Italy\"),\n \"europe-west2\": (2, \"London\", \"United Kingdom\"),\n \"europe-west3\": (2, \"Frankfurt\", \"Germany\"),\n \"europe-west4\": (1, \"Eemshaven\", \"Netherlands\"),\n \"europe-west6\": (2, \"Zurich\", \"Switzerland\"),\n \"europe-west8\": (1, \"Milan\", \"Italy\"),\n \"europe-west9\": (1, \"Paris\", \"France\"),\n \"me-central1\": (2, \"Doha\", \"Qatar\"),\n \"me-west1\": (1, \"Tel Aviv\", \"Israel\"),\n \"northamerica-northeast1\": (2, \"Montreal\", \"Canada\"),\n \"northamerica-northeast2\": (2, \"Toronto\", \"Canada\"),\n \"southamerica-east1\": (2, \"São Paulo\", \"Brazil\"),\n \"southamerica-west1\": (2, \"Santiago\", \"Chile\"),\n \"us-central1\": (1, \"Iowa\", \"United States\"),\n \"us-east1\": (1, \"South Carolina\", \"United States\"),\n \"us-east4\": (1, \"Northern Virginia\", \"United States\"),\n \"us-east5\": (1, \"Columbus\", \"United States\"),\n \"us-south1\": (1, \"Dallas\", \"United States\"),\n \"us-west1\": (1, \"Oregon\", \"United States\"),\n \"us-west2\": (2, \"Los Angeles\", \"United States\"),\n \"us-west3\": (2, \"Salt Lake City\", \"United States\"),\n \"us-west4\": (2, \"Las Vegas\", \"United States\"),\n }\n\n def tier1(self) -> List[str]:\n \"\"\"Return a list of GCP regions classified as tier 1 based on predefined criteria.\"\"\"\n return [region for region, info in self.regions.items() if info[0] == 1]\n\n def tier2(self) -> List[str]:\n \"\"\"Return a list of GCP regions classified as tier 2 based on predefined criteria.\"\"\"\n return [region for region, info in self.regions.items() if info[0] == 2]\n\n @staticmethod\n def _ping_region(region: str, attempts: int = 1) -> Tuple[str, float, float, float, float]:\n \"\"\"\n Ping a specified GCP region and measure network latency statistics.\n\n Args:\n region (str): The GCP region identifier to ping (e.g., 'us-central1').\n attempts (int, optional): Number of ping attempts to make for calculating statistics.\n\n Returns:\n region (str): The GCP region identifier that was pinged.\n mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.\n std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.\n min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.\n max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.\n\n Examples:\n >>> region, mean, std, min_lat, max_lat = GCPRegions._ping_region(\"us-central1\", attempts=3)\n >>> print(f\"Region {region} has mean latency: {mean:.2f}ms\")\n \"\"\"\n url = f\"https://{region}-docker.pkg.dev\"\n latencies = []\n for _ in range(attempts):\n try:\n start_time = time.time()\n _ = requests.head(url, timeout=5)\n latency = (time.time() - start_time) * 1000 # Convert latency to milliseconds\n if latency != float(\"inf\"):\n latencies.append(latency)\n except requests.RequestException:\n pass\n if not latencies:\n return region, float(\"inf\"), float(\"inf\"), float(\"inf\"), float(\"inf\")\n\n std_dev = statistics.stdev(latencies) if len(latencies) > 1 else 0\n return region, statistics.mean(latencies), std_dev, min(latencies), max(latencies)\n\n def lowest_latency(\n self,\n top: int = 1,\n verbose: bool = False,\n tier: Optional[int] = None,\n attempts: int = 1,\n ) -> List[Tuple[str, float, float, float, float]]:\n \"\"\"\n Determine the GCP regions with the lowest latency based on ping tests.\n\n Args:\n top (int, optional): Number of top regions to return.\n verbose (bool, optional): If True, prints detailed latency information for all tested regions.\n tier (int | None, optional): Filter regions by tier (1 or 2). If None, all regions are tested.\n attempts (int, optional): Number of ping attempts per region.\n\n Returns:\n (List[Tuple[str, float, float, float, float]]): List of tuples containing region information and\n latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).\n\n Examples:\n >>> regions = GCPRegions()\n >>> results = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=2)\n >>> print(results[0][0]) # Print the name of the lowest latency region\n \"\"\"\n if verbose:\n print(f\"Testing GCP regions for latency (with {attempts} {'retry' if attempts == 1 else 'attempts'})...\")\n\n regions_to_test = [k for k, v in self.regions.items() if v[0] == tier] if tier else list(self.regions.keys())\n with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:\n results = list(executor.map(lambda r: self._ping_region(r, attempts), regions_to_test))\n\n sorted_results = sorted(results, key=lambda x: x[1])\n\n if verbose:\n print(f\"{'Region':<25} {'Location':<35} {'Tier':<5} Latency (ms)\")\n for region, mean, std, min_, max_ in sorted_results:\n tier, city, country = self.regions[region]\n location = f\"{city}, {country}\"\n if mean == float(\"inf\"):\n print(f\"{region:<25} {location:<35} {tier:<5} Timeout\")\n else:\n print(f\"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})\")\n print(f\"\\nLowest latency region{'s' if top > 1 else ''}:\")\n for region, mean, std, min_, max_ in sorted_results[:top]:\n tier, city, country = self.regions[region]\n location = f\"{city}, {country}\"\n print(f\"{region} ({location}, {mean:.0f} ± {std:.0f} ms ({min_:.0f} - {max_:.0f}))\")\n\n return sorted_results[:top]",
"chunk_type": "class",
"name": "GCPRegions",
"file_path": "ultralytics\\ultralytics\\hub\\google\\__init__.py",
"start_line": 11,
"end_line": 170,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": "A class for managing and analyzing Google Cloud Platform (GCP) regions.\n\nThis class provides functionality to initialize, categorize, and analyze GCP regions based on their\ngeographical location, tier classification, and network latency.\n\nAttributes:\n regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.\n\nMethods:\n tier1: Returns a list of tier 1 GCP regions.\n tier2: Returns a list of tier 2 GCP regions.\n lowest_latency: Determines the GCP region(s) with the lowest network latency.\n\nExamples:\n >>> from ultralytics.hub.google import GCPRegions\n >>> regions = GCPRegions()\n >>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3)\n >>> print(f\"Lowest latency region: {lowest_latency_region[0][0]}\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"concurrent.futures",
"statistics",
"time",
"typing.List",
"typing.Optional",
"typing.Tuple",
"requests"
],
"chunk_id": "class_GCPRegions_792880b7"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_88442790"
},
{
"content": "from typing import Any, Dict, List, Optional",
"chunk_type": "import",
"name": "Any, Dict, List, Optional",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional_ffa21a54"
},
{
"content": "from ultralytics.engine.model import Model",
"chunk_type": "import",
"name": "Model",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Model_932838d8"
},
{
"content": "from .predict import FastSAMPredictor",
"chunk_type": "import",
"name": "FastSAMPredictor",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_FastSAMPredictor_26ba58f9"
},
{
"content": "from .val import FastSAMValidator",
"chunk_type": "import",
"name": "FastSAMValidator",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_FastSAMValidator_ab2f29b8"
},
{
"content": "class FastSAM(Model):\n \"\"\"\n FastSAM model interface for segment anything tasks.\n\n This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything\n Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.\n\n Attributes:\n model (str): Path to the pre-trained FastSAM model file.\n task (str): The task type, set to \"segment\" for FastSAM models.\n\n Methods:\n predict: Perform segmentation prediction on image or video source with optional prompts.\n task_map: Returns mapping of segment task to predictor and validator classes.\n\n Examples:\n Initialize FastSAM model and run prediction\n >>> from ultralytics import FastSAM\n >>> model = FastSAM(\"FastSAM-x.pt\")\n >>> results = model.predict(\"ultralytics/assets/bus.jpg\")\n\n Run prediction with bounding box prompts\n >>> results = model.predict(\"image.jpg\", bboxes=[[100, 100, 200, 200]])\n \"\"\"\n\n def __init__(self, model: str = \"FastSAM-x.pt\"):\n \"\"\"Initialize the FastSAM model with the specified pre-trained weights.\"\"\"\n if str(model) == \"FastSAM.pt\":\n model = \"FastSAM-x.pt\"\n assert Path(model).suffix not in {\".yaml\", \".yml\"}, \"FastSAM models only support pre-trained models.\"\n super().__init__(model=model, task=\"segment\")\n\n def predict(\n self,\n source,\n stream: bool = False,\n bboxes: Optional[List] = None,\n points: Optional[List] = None,\n labels: Optional[List] = None,\n texts: Optional[List] = None,\n **kwargs: Any,\n ):\n \"\"\"\n Perform segmentation prediction on image or video source.\n\n Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these\n prompts and passes them to the parent class predict method for processing.\n\n Args:\n source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,\n or numpy array.\n stream (bool): Whether to enable real-time streaming mode for video inputs.\n bboxes (List, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].\n points (List, optional): Point coordinates for prompted segmentation in format [[x, y]].\n labels (List, optional): Class labels for prompted segmentation.\n texts (List, optional): Text prompts for segmentation guidance.\n **kwargs (Any): Additional keyword arguments passed to the predictor.\n\n Returns:\n (List): List of Results objects containing the prediction results.\n \"\"\"\n prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)\n return super().predict(source, stream, prompts=prompts, **kwargs)\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Returns a dictionary mapping segment task to corresponding predictor and validator classes.\"\"\"\n return {\"segment\": {\"predictor\": FastSAMPredictor, \"validator\": FastSAMValidator}}",
"chunk_type": "class",
"name": "FastSAM",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\model.py",
"start_line": 12,
"end_line": 79,
"start_col": 0,
"end_col": 90,
"parent_name": null,
"docstring": "FastSAM model interface for segment anything tasks.\n\nThis class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything\nModel) implementation, allowing for efficient and accurate image segmentation with optional prompting support.\n\nAttributes:\n model (str): Path to the pre-trained FastSAM model file.\n task (str): The task type, set to \"segment\" for FastSAM models.\n\nMethods:\n predict: Perform segmentation prediction on image or video source with optional prompts.\n task_map: Returns mapping of segment task to predictor and validator classes.\n\nExamples:\n Initialize FastSAM model and run prediction\n >>> from ultralytics import FastSAM\n >>> model = FastSAM(\"FastSAM-x.pt\")\n >>> results = model.predict(\"ultralytics/assets/bus.jpg\")\n\n Run prediction with bounding box prompts\n >>> results = model.predict(\"image.jpg\", bboxes=[[100, 100, 200, 200]])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"ultralytics.engine.model.Model",
"predict.FastSAMPredictor",
"val.FastSAMValidator",
"Model"
],
"chunk_id": "class_FastSAM_eabbd80a"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_36591e18"
},
{
"content": "from PIL import Image",
"chunk_type": "import",
"name": "Image",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Image_40761ca2"
},
{
"content": "from ultralytics.models.yolo.segment import SegmentationPredictor",
"chunk_type": "import",
"name": "SegmentationPredictor",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationPredictor_1b76dc07"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, checks",
"chunk_type": "import",
"name": "DEFAULT_CFG, checks",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, checks_7e1856ec"
},
{
"content": "from ultralytics.utils.metrics import box_iou",
"chunk_type": "import",
"name": "box_iou",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_box_iou_3c40557b"
},
{
"content": "from ultralytics.utils.ops import scale_masks",
"chunk_type": "import",
"name": "scale_masks",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_scale_masks_73ead197"
},
{
"content": "from .utils import adjust_bboxes_to_image_border",
"chunk_type": "import",
"name": "adjust_bboxes_to_image_border",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_adjust_bboxes_to_image_border_8833e2ef"
},
{
"content": "class FastSAMPredictor(SegmentationPredictor):\n \"\"\"\n FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.\n\n This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It\n adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for\n single-class segmentation.\n\n Attributes:\n prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).\n device (torch.device): Device on which model and tensors are processed.\n clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.\n clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.\n\n Methods:\n postprocess: Apply postprocessing to FastSAM predictions and handle prompts.\n prompt: Perform image segmentation inference based on various prompt types.\n set_prompts: Set prompts to be used during inference.\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the FastSAMPredictor with configuration and callbacks.\n\n This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor\n extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression\n optimized for single-class segmentation.\n\n Args:\n cfg (dict): Configuration for the predictor.\n overrides (dict, optional): Configuration overrides.\n _callbacks (list, optional): List of callback functions.\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.prompts = {}\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Apply postprocessing to FastSAM predictions and handle prompts.\n\n Args:\n preds (List[torch.Tensor]): Raw predictions from the model.\n img (torch.Tensor): Input image tensor that was fed to the model.\n orig_imgs (List[np.ndarray]): Original images before preprocessing.\n\n Returns:\n (List[Results]): Processed results with prompts applied.\n \"\"\"\n bboxes = self.prompts.pop(\"bboxes\", None)\n points = self.prompts.pop(\"points\", None)\n labels = self.prompts.pop(\"labels\", None)\n texts = self.prompts.pop(\"texts\", None)\n results = super().postprocess(preds, img, orig_imgs)\n for result in results:\n full_box = torch.tensor(\n [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32\n )\n boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)\n idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()\n if idx.numel() != 0:\n result.boxes.xyxy[idx] = full_box\n\n return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)\n\n def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):\n \"\"\"\n Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.\n\n Args:\n results (Results | List[Results]): Original inference results from FastSAM models without any prompts.\n bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.\n points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.\n labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.\n texts (str | List[str], optional): Textual prompts, a list containing string objects.\n\n Returns:\n (List[Results]): Output results filtered and determined by the provided prompts.\n \"\"\"\n if bboxes is None and points is None and texts is None:\n return results\n prompt_results = []\n if not isinstance(results, list):\n results = [results]\n for result in results:\n if len(result) == 0:\n prompt_results.append(result)\n continue\n masks = result.masks.data\n if masks.shape[1:] != result.orig_shape:\n masks = scale_masks(masks[None], result.orig_shape)[0]\n # bboxes prompt\n idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)\n if bboxes is not None:\n bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)\n bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes\n bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])\n mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])\n full_mask_areas = torch.sum(masks, dim=(1, 2))\n\n union = bbox_areas[:, None] + full_mask_areas - mask_areas\n idx[torch.argmax(mask_areas / union, dim=1)] = True\n if points is not None:\n points = torch.as_tensor(points, dtype=torch.int32, device=self.device)\n points = points[None] if points.ndim == 1 else points\n if labels is None:\n labels = torch.ones(points.shape[0])\n labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)\n assert len(labels) == len(points), (\n f\"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}\"\n )\n point_idx = (\n torch.ones(len(result), dtype=torch.bool, device=self.device)\n if labels.sum() == 0 # all negative points\n else torch.zeros(len(result), dtype=torch.bool, device=self.device)\n )\n for point, label in zip(points, labels):\n point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)\n idx |= point_idx\n if texts is not None:\n if isinstance(texts, str):\n texts = [texts]\n crop_ims, filter_idx = [], []\n for i, b in enumerate(result.boxes.xyxy.tolist()):\n x1, y1, x2, y2 = (int(x) for x in b)\n if masks[i].sum() <= 100:\n filter_idx.append(i)\n continue\n crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))\n similarity = self._clip_inference(crop_ims, texts)\n text_idx = torch.argmax(similarity, dim=-1) # (M, )\n if len(filter_idx):\n text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)\n idx[text_idx] = True\n\n prompt_results.append(result[idx])\n\n return prompt_results\n\n def _clip_inference(self, images, texts):\n \"\"\"\n Perform CLIP inference to calculate similarity between images and text prompts.\n\n Args:\n images (List[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.\n texts (List[str]): List of prompt texts, each should be a string object.\n\n Returns:\n (torch.Tensor): Similarity matrix between given images and texts with shape (M, N).\n \"\"\"\n try:\n import clip\n except ImportError:\n checks.check_requirements(\"git+https://github.com/ultralytics/CLIP.git\")\n import clip\n if (not hasattr(self, \"clip_model\")) or (not hasattr(self, \"clip_preprocess\")):\n self.clip_model, self.clip_preprocess = clip.load(\"ViT-B/32\", device=self.device)\n images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])\n tokenized_text = clip.tokenize(texts).to(self.device)\n image_features = self.clip_model.encode_image(images)\n text_features = self.clip_model.encode_text(tokenized_text)\n image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)\n text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)\n return (image_features * text_features[:, None]).sum(-1) # (M, N)\n\n def set_prompts(self, prompts):\n \"\"\"Set prompts to be used during inference.\"\"\"\n self.prompts = prompts",
"chunk_type": "class",
"name": "FastSAMPredictor",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\predict.py",
"start_line": 14,
"end_line": 180,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": "FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.\n\nThis class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It\nadjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for\nsingle-class segmentation.\n\nAttributes:\n prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).\n device (torch.device): Device on which model and tensors are processed.\n clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.\n clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.\n\nMethods:\n postprocess: Apply postprocessing to FastSAM predictions and handle prompts.\n prompt: Perform image segmentation inference based on various prompt types.\n set_prompts: Set prompts to be used during inference.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"torch",
"PIL.Image",
"ultralytics.models.yolo.segment.SegmentationPredictor",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.checks",
"ultralytics.utils.metrics.box_iou",
"ultralytics.utils.ops.scale_masks",
"utils.adjust_bboxes_to_image_border",
"clip",
"clip",
"SegmentationPredictor"
],
"chunk_id": "class_FastSAMPredictor_16f795d4"
},
{
"content": "def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):\n \"\"\"\n Adjust bounding boxes to stick to image border if they are within a certain threshold.\n\n Args:\n boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.\n image_shape (tuple): Image dimensions as (height, width).\n threshold (int): Pixel threshold for considering a box close to the border.\n\n Returns:\n (torch.Tensor): Adjusted bounding boxes with shape (N, 4).\n \"\"\"\n # Image dimensions\n h, w = image_shape\n\n # Adjust boxes that are close to image borders\n boxes[boxes[:, 0] < threshold, 0] = 0 # x1\n boxes[boxes[:, 1] < threshold, 1] = 0 # y1\n boxes[boxes[:, 2] > w - threshold, 2] = w # x2\n boxes[boxes[:, 3] > h - threshold, 3] = h # y2\n return boxes",
"chunk_type": "function",
"name": "adjust_bboxes_to_image_border",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\utils.py",
"start_line": 4,
"end_line": 24,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Adjust bounding boxes to stick to image border if they are within a certain threshold.\n\nArgs:\n boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.\n image_shape (tuple): Image dimensions as (height, width).\n threshold (int): Pixel threshold for considering a box close to the border.\n\nReturns:\n (torch.Tensor): Adjusted bounding boxes with shape (N, 4).",
"parameters": [
"boxes",
"image_shape",
"threshold"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [],
"chunk_id": "function_adjust_bboxes_to_image_border_775308ef"
},
{
"content": "from ultralytics.models.yolo.segment import SegmentationValidator",
"chunk_type": "import",
"name": "SegmentationValidator",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationValidator_3adc7638"
},
{
"content": "class FastSAMValidator(SegmentationValidator):\n \"\"\"\n Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.\n\n Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class\n sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled\n to avoid errors during validation.\n\n Attributes:\n dataloader (torch.utils.data.DataLoader): The data loader object used for validation.\n save_dir (Path): The directory where validation results will be saved.\n args (SimpleNamespace): Additional arguments for customization of the validation process.\n _callbacks (list): List of callback functions to be invoked during validation.\n metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.\n\n Methods:\n __init__: Initialize the FastSAMValidator with custom settings for Fast SAM.\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):\n \"\"\"\n Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.\n save_dir (Path, optional): Directory to save results.\n args (SimpleNamespace, optional): Configuration for the validator.\n _callbacks (list, optional): List of callback functions to be invoked during validation.\n\n Notes:\n Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.args.task = \"segment\"\n self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors",
"chunk_type": "class",
"name": "FastSAMValidator",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\val.py",
"start_line": 6,
"end_line": 40,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": "Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.\n\nExtends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class\nsets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled\nto avoid errors during validation.\n\nAttributes:\n dataloader (torch.utils.data.DataLoader): The data loader object used for validation.\n save_dir (Path): The directory where validation results will be saved.\n args (SimpleNamespace): Additional arguments for customization of the validation process.\n _callbacks (list): List of callback functions to be invoked during validation.\n metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.\n\nMethods:\n __init__: Initialize the FastSAMValidator with custom settings for Fast SAM.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"ultralytics.models.yolo.segment.SegmentationValidator",
"SegmentationValidator"
],
"chunk_id": "class_FastSAMValidator_f79634ac"
},
{
"content": "from .model import FastSAM",
"chunk_type": "import",
"name": "FastSAM",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_FastSAM_596b8d4c"
},
{
"content": "from .predict import FastSAMPredictor",
"chunk_type": "import",
"name": "FastSAMPredictor",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_FastSAMPredictor_a37ddd48"
},
{
"content": "from .val import FastSAMValidator",
"chunk_type": "import",
"name": "FastSAMValidator",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_FastSAMValidator_dce83088"
},
{
"content": "__all__ = \"FastSAMPredictor\", \"FastSAM\", \"FastSAMValidator\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\fastsam\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___caf38e02"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_63a1b45c"
},
{
"content": "from typing import Any, Dict",
"chunk_type": "import",
"name": "Any, Dict",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict_867260ee"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_d26b0c34"
},
{
"content": "from ultralytics.engine.model import Model",
"chunk_type": "import",
"name": "Model",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Model_16a03cde"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG_DICT",
"chunk_type": "import",
"name": "DEFAULT_CFG_DICT",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG_DICT_d978629b"
},
{
"content": "from ultralytics.utils.downloads import attempt_download_asset",
"chunk_type": "import",
"name": "attempt_download_asset",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 62,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_attempt_download_asset_cfb43222"
},
{
"content": "from ultralytics.utils.patches import torch_load",
"chunk_type": "import",
"name": "torch_load",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_load_3c9fda17"
},
{
"content": "from ultralytics.utils.torch_utils import model_info",
"chunk_type": "import",
"name": "model_info",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_model_info_c8b686cd"
},
{
"content": "from .predict import NASPredictor",
"chunk_type": "import",
"name": "NASPredictor",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_NASPredictor_30757246"
},
{
"content": "from .val import NASValidator",
"chunk_type": "import",
"name": "NASValidator",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_NASValidator_60ddb312"
},
{
"content": "class NAS(Model):\n \"\"\"\n YOLO-NAS model for object detection.\n\n This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.\n It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.\n\n Attributes:\n model (torch.nn.Module): The loaded YOLO-NAS model.\n task (str): The task type for the model, defaults to 'detect'.\n predictor (NASPredictor): The predictor instance for making predictions.\n validator (NASValidator): The validator instance for model validation.\n\n Methods:\n info: Log model information and return model details.\n\n Examples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> results = model.predict(\"ultralytics/assets/bus.jpg\")\n\n Notes:\n YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.\n \"\"\"\n\n def __init__(self, model: str = \"yolo_nas_s.pt\") -> None:\n \"\"\"Initialize the NAS model with the provided or default model.\"\"\"\n assert Path(model).suffix not in {\".yaml\", \".yml\"}, \"YOLO-NAS models only support pre-trained models.\"\n super().__init__(model, task=\"detect\")\n\n def _load(self, weights: str, task=None) -> None:\n \"\"\"\n Load an existing NAS model weights or create a new NAS model with pretrained weights.\n\n Args:\n weights (str): Path to the model weights file or model name.\n task (str, optional): Task type for the model.\n \"\"\"\n import super_gradients\n\n suffix = Path(weights).suffix\n if suffix == \".pt\":\n self.model = torch_load(attempt_download_asset(weights))\n elif suffix == \"\":\n self.model = super_gradients.training.models.get(weights, pretrained_weights=\"coco\")\n\n # Override the forward method to ignore additional arguments\n def new_forward(x, *args, **kwargs):\n \"\"\"Ignore additional __call__ arguments.\"\"\"\n return self.model._original_forward(x)\n\n self.model._original_forward = self.model.forward\n self.model.forward = new_forward\n\n # Standardize model attributes for compatibility\n self.model.fuse = lambda verbose=True: self.model\n self.model.stride = torch.tensor([32])\n self.model.names = dict(enumerate(self.model._class_names))\n self.model.is_fused = lambda: False # for info()\n self.model.yaml = {} # for info()\n self.model.pt_path = weights # for export()\n self.model.task = \"detect\" # for export()\n self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()\n self.model.eval()\n\n def info(self, detailed: bool = False, verbose: bool = True) -> Dict[str, Any]:\n \"\"\"\n Log model information.\n\n Args:\n detailed (bool): Show detailed information about model.\n verbose (bool): Controls verbosity.\n\n Returns:\n (Dict[str, Any]): Model information dictionary.\n \"\"\"\n return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Return a dictionary mapping tasks to respective predictor and validator classes.\"\"\"\n return {\"detect\": {\"predictor\": NASPredictor, \"validator\": NASValidator}}",
"chunk_type": "class",
"name": "NAS",
"file_path": "ultralytics\\ultralytics\\models\\nas\\model.py",
"start_line": 18,
"end_line": 99,
"start_col": 0,
"end_col": 81,
"parent_name": null,
"docstring": "YOLO-NAS model for object detection.\n\nThis class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.\nIt is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.\n\nAttributes:\n model (torch.nn.Module): The loaded YOLO-NAS model.\n task (str): The task type for the model, defaults to 'detect'.\n predictor (NASPredictor): The predictor instance for making predictions.\n validator (NASValidator): The validator instance for model validation.\n\nMethods:\n info: Log model information and return model details.\n\nExamples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> results = model.predict(\"ultralytics/assets/bus.jpg\")\n\nNotes:\n YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Any",
"typing.Dict",
"torch",
"ultralytics.engine.model.Model",
"ultralytics.utils.DEFAULT_CFG_DICT",
"ultralytics.utils.downloads.attempt_download_asset",
"ultralytics.utils.patches.torch_load",
"ultralytics.utils.torch_utils.model_info",
"predict.NASPredictor",
"val.NASValidator",
"super_gradients",
"Model"
],
"chunk_id": "class_NAS_ce1430e4"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\nas\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_cb8909b1"
},
{
"content": "from ultralytics.models.yolo.detect.predict import DetectionPredictor",
"chunk_type": "import",
"name": "DetectionPredictor",
"file_path": "ultralytics\\ultralytics\\models\\nas\\predict.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 69,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionPredictor_8fa7d599"
},
{
"content": "from ultralytics.utils import ops",
"chunk_type": "import",
"name": "ops",
"file_path": "ultralytics\\ultralytics\\models\\nas\\predict.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ops_7a866dcc"
},
{
"content": "class NASPredictor(DetectionPredictor):\n \"\"\"\n Ultralytics YOLO NAS Predictor for object detection.\n\n This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the\n raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and\n scaling the bounding boxes to fit the original image dimensions.\n\n Attributes:\n args (Namespace): Namespace containing various configurations for post-processing including confidence\n threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.\n model (torch.nn.Module): The YOLO NAS model used for inference.\n batch (list): Batch of inputs for processing.\n\n Examples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> predictor = model.predictor\n\n Assume that raw_preds, img, orig_imgs are available\n >>> results = predictor.postprocess(raw_preds, img, orig_imgs)\n\n Notes:\n Typically, this class is not instantiated directly. It is used internally within the NAS class.\n \"\"\"\n\n def postprocess(self, preds_in, img, orig_imgs):\n \"\"\"\n Postprocess NAS model predictions to generate final detection results.\n\n This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies\n post-processing operations to generate the final detection results compatible with Ultralytics\n result visualization and analysis tools.\n\n Args:\n preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.\n img (torch.Tensor): Input image tensor that was fed to the model, with shape (B, C, H, W).\n orig_imgs (list | torch.Tensor | np.ndarray): Original images before preprocessing, used for scaling\n coordinates back to original dimensions.\n\n Returns:\n (list): List of Results objects containing the processed predictions for each image in the batch.\n\n Examples:\n >>> predictor = NAS(\"yolo_nas_s\").predictor\n >>> results = predictor.postprocess(raw_preds, img, orig_imgs)\n \"\"\"\n boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format\n preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores\n return super().postprocess(preds, img, orig_imgs)",
"chunk_type": "class",
"name": "NASPredictor",
"file_path": "ultralytics\\ultralytics\\models\\nas\\predict.py",
"start_line": 9,
"end_line": 58,
"start_col": 0,
"end_col": 57,
"parent_name": null,
"docstring": "Ultralytics YOLO NAS Predictor for object detection.\n\nThis class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the\nraw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and\nscaling the bounding boxes to fit the original image dimensions.\n\nAttributes:\n args (Namespace): Namespace containing various configurations for post-processing including confidence\n threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.\n model (torch.nn.Module): The YOLO NAS model used for inference.\n batch (list): Batch of inputs for processing.\n\nExamples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> predictor = model.predictor\n\n Assume that raw_preds, img, orig_imgs are available\n >>> results = predictor.postprocess(raw_preds, img, orig_imgs)\n\nNotes:\n Typically, this class is not instantiated directly. It is used internally within the NAS class.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"torch",
"ultralytics.models.yolo.detect.predict.DetectionPredictor",
"ultralytics.utils.ops",
"DetectionPredictor"
],
"chunk_id": "class_NASPredictor_68ba2ed5"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\nas\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_e167895d"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionValidator",
"chunk_type": "import",
"name": "DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\nas\\val.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionValidator_3c78c9ab"
},
{
"content": "from ultralytics.utils import ops",
"chunk_type": "import",
"name": "ops",
"file_path": "ultralytics\\ultralytics\\models\\nas\\val.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ops_1d91d0fa"
},
{
"content": "__all__ = [\"NASValidator\"]",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\nas\\val.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___72cf61b7"
},
{
"content": "class NASValidator(DetectionValidator):\n \"\"\"\n Ultralytics YOLO NAS Validator for object detection.\n\n Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions\n generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,\n ultimately producing the final detections.\n\n Attributes:\n args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU\n thresholds.\n lb (torch.Tensor): Optional tensor for multilabel NMS.\n\n Examples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> validator = model.validator\n >>> # Assumes that raw_preds are available\n >>> final_preds = validator.postprocess(raw_preds)\n\n Notes:\n This class is generally not instantiated directly but is used internally within the NAS class.\n \"\"\"\n\n def postprocess(self, preds_in):\n \"\"\"Apply Non-maximum suppression to prediction outputs.\"\"\"\n boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh\n preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute\n return super().postprocess(preds)",
"chunk_type": "class",
"name": "NASValidator",
"file_path": "ultralytics\\ultralytics\\models\\nas\\val.py",
"start_line": 11,
"end_line": 39,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": "Ultralytics YOLO NAS Validator for object detection.\n\nExtends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions\ngenerated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,\nultimately producing the final detections.\n\nAttributes:\n args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU\n thresholds.\n lb (torch.Tensor): Optional tensor for multilabel NMS.\n\nExamples:\n >>> from ultralytics import NAS\n >>> model = NAS(\"yolo_nas_s\")\n >>> validator = model.validator\n >>> # Assumes that raw_preds are available\n >>> final_preds = validator.postprocess(raw_preds)\n\nNotes:\n This class is generally not instantiated directly but is used internally within the NAS class.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"torch",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.utils.ops",
"DetectionValidator"
],
"chunk_id": "class_NASValidator_c323d399"
},
{
"content": "from .model import NAS",
"chunk_type": "import",
"name": "NAS",
"file_path": "ultralytics\\ultralytics\\models\\nas\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_NAS_f3103ce8"
},
{
"content": "from .predict import NASPredictor",
"chunk_type": "import",
"name": "NASPredictor",
"file_path": "ultralytics\\ultralytics\\models\\nas\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_NASPredictor_e55f4165"
},
{
"content": "from .val import NASValidator",
"chunk_type": "import",
"name": "NASValidator",
"file_path": "ultralytics\\ultralytics\\models\\nas\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_NASValidator_38a1f670"
},
{
"content": "__all__ = \"NASPredictor\", \"NASValidator\", \"NAS\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\nas\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___080d624a"
},
{
"content": "from ultralytics.engine.model import Model",
"chunk_type": "import",
"name": "Model",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Model_70dd2088"
},
{
"content": "from ultralytics.nn.tasks import RTDETRDetectionModel",
"chunk_type": "import",
"name": "RTDETRDetectionModel",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETRDetectionModel_df62a2a5"
},
{
"content": "from .predict import RTDETRPredictor",
"chunk_type": "import",
"name": "RTDETRPredictor",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETRPredictor_f1d87ecc"
},
{
"content": "from .train import RTDETRTrainer",
"chunk_type": "import",
"name": "RTDETRTrainer",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETRTrainer_a24dc92e"
},
{
"content": "from .val import RTDETRValidator",
"chunk_type": "import",
"name": "RTDETRValidator",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETRValidator_c77a4be2"
},
{
"content": "class RTDETR(Model):\n \"\"\"\n Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.\n\n This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware\n query selection, and adaptable inference speed.\n\n Attributes:\n model (str): Path to the pre-trained model.\n\n Methods:\n task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.\n\n Examples:\n Initialize RT-DETR with a pre-trained model\n >>> from ultralytics import RTDETR\n >>> model = RTDETR(\"rtdetr-l.pt\")\n >>> results = model(\"image.jpg\")\n \"\"\"\n\n def __init__(self, model: str = \"rtdetr-l.pt\") -> None:\n \"\"\"\n Initialize the RT-DETR model with the given pre-trained model file.\n\n Args:\n model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.\n \"\"\"\n super().__init__(model=model, task=\"detect\")\n\n @property\n def task_map(self) -> dict:\n \"\"\"\n Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.\n\n Returns:\n (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.\n \"\"\"\n return {\n \"detect\": {\n \"predictor\": RTDETRPredictor,\n \"validator\": RTDETRValidator,\n \"trainer\": RTDETRTrainer,\n \"model\": RTDETRDetectionModel,\n }\n }",
"chunk_type": "class",
"name": "RTDETR",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\model.py",
"start_line": 20,
"end_line": 64,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.\n\nThis model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware\nquery selection, and adaptable inference speed.\n\nAttributes:\n model (str): Path to the pre-trained model.\n\nMethods:\n task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.\n\nExamples:\n Initialize RT-DETR with a pre-trained model\n >>> from ultralytics import RTDETR\n >>> model = RTDETR(\"rtdetr-l.pt\")\n >>> results = model(\"image.jpg\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"ultralytics.engine.model.Model",
"ultralytics.nn.tasks.RTDETRDetectionModel",
"predict.RTDETRPredictor",
"train.RTDETRTrainer",
"val.RTDETRValidator",
"Model"
],
"chunk_id": "class_RTDETR_3dbd933c"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_bdcfae9f"
},
{
"content": "from ultralytics.data.augment import LetterBox",
"chunk_type": "import",
"name": "LetterBox",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LetterBox_58b6ab82"
},
{
"content": "from ultralytics.engine.predictor import BasePredictor",
"chunk_type": "import",
"name": "BasePredictor",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BasePredictor_c18cc8ab"
},
{
"content": "from ultralytics.engine.results import Results",
"chunk_type": "import",
"name": "Results",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Results_c349127c"
},
{
"content": "from ultralytics.utils import ops",
"chunk_type": "import",
"name": "ops",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ops_6f996d74"
},
{
"content": "class RTDETRPredictor(BasePredictor):\n \"\"\"\n RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.\n\n This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.\n It supports key features like efficient hybrid encoding and IoU-aware query selection.\n\n Attributes:\n imgsz (int): Image size for inference (must be square and scale-filled).\n args (dict): Argument overrides for the predictor.\n model (torch.nn.Module): The loaded RT-DETR model.\n batch (list): Current batch of processed inputs.\n\n Methods:\n postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.\n pre_transform: Pre-transform input images before feeding them into the model for inference.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.rtdetr import RTDETRPredictor\n >>> args = dict(model=\"rtdetr-l.pt\", source=ASSETS)\n >>> predictor = RTDETRPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.\n\n The method filters detections based on confidence and class if specified in `self.args`. It converts\n model predictions to Results objects containing properly scaled bounding boxes.\n\n Args:\n preds (list | tuple): List of [predictions, extra] from the model, where predictions contain\n bounding boxes and scores.\n img (torch.Tensor): Processed input images with shape (N, 3, H, W).\n orig_imgs (list | torch.Tensor): Original, unprocessed images.\n\n Returns:\n results (List[Results]): A list of Results objects containing the post-processed bounding boxes,\n confidence scores, and class labels.\n \"\"\"\n if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference\n preds = [preds, None]\n\n nd = preds[0].shape[-1]\n bboxes, scores = preds[0].split((4, nd - 4), dim=-1)\n\n if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list\n orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)\n\n results = []\n for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)\n bbox = ops.xywh2xyxy(bbox)\n max_score, cls = score.max(-1, keepdim=True) # (300, 1)\n idx = max_score.squeeze(-1) > self.args.conf # (300, )\n if self.args.classes is not None:\n idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx\n pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter\n oh, ow = orig_img.shape[:2]\n pred[..., [0, 2]] *= ow # scale x coordinates to original width\n pred[..., [1, 3]] *= oh # scale y coordinates to original height\n results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))\n return results\n\n def pre_transform(self, im):\n \"\"\"\n Pre-transform input images before feeding them into the model for inference.\n\n The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square\n (640) and scale_filled.\n\n Args:\n im (List[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,\n [(H, W, 3) x N] for list.\n\n Returns:\n (list): List of pre-transformed images ready for model inference.\n \"\"\"\n letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True)\n return [letterbox(image=x) for x in im]",
"chunk_type": "class",
"name": "RTDETRPredictor",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\predict.py",
"start_line": 11,
"end_line": 91,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.\n\nThis class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.\nIt supports key features like efficient hybrid encoding and IoU-aware query selection.\n\nAttributes:\n imgsz (int): Image size for inference (must be square and scale-filled).\n args (dict): Argument overrides for the predictor.\n model (torch.nn.Module): The loaded RT-DETR model.\n batch (list): Current batch of processed inputs.\n\nMethods:\n postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.\n pre_transform: Pre-transform input images before feeding them into the model for inference.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.rtdetr import RTDETRPredictor\n >>> args = dict(model=\"rtdetr-l.pt\", source=ASSETS)\n >>> predictor = RTDETRPredictor(overrides=args)\n >>> predictor.predict_cli()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"torch",
"ultralytics.data.augment.LetterBox",
"ultralytics.engine.predictor.BasePredictor",
"ultralytics.engine.results.Results",
"ultralytics.utils.ops",
"BasePredictor"
],
"chunk_id": "class_RTDETRPredictor_781515ed"
},
{
"content": "from copy import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_8c82df20"
},
{
"content": "from typing import Optional",
"chunk_type": "import",
"name": "Optional",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Optional_818d31a9"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionTrainer",
"chunk_type": "import",
"name": "DetectionTrainer",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionTrainer_51568663"
},
{
"content": "from ultralytics.nn.tasks import RTDETRDetectionModel",
"chunk_type": "import",
"name": "RTDETRDetectionModel",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETRDetectionModel_8a37c91d"
},
{
"content": "from ultralytics.utils import RANK, colorstr",
"chunk_type": "import",
"name": "RANK, colorstr",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RANK, colorstr_0b78d91a"
},
{
"content": "from .val import RTDETRDataset, RTDETRValidator",
"chunk_type": "import",
"name": "RTDETRDataset, RTDETRValidator",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETRDataset, RTDETRValidator_a803f6b9"
},
{
"content": "class RTDETRTrainer(DetectionTrainer):\n \"\"\"\n Trainer class for the RT-DETR model developed by Baidu for real-time object detection.\n\n This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.\n The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference\n speed.\n\n Attributes:\n loss_names (tuple): Names of the loss components used for training.\n data (dict): Dataset configuration containing class count and other parameters.\n args (dict): Training arguments and hyperparameters.\n save_dir (Path): Directory to save training results.\n test_loader (DataLoader): DataLoader for validation/testing data.\n\n Methods:\n get_model: Initialize and return an RT-DETR model for object detection tasks.\n build_dataset: Build and return an RT-DETR dataset for training or validation.\n get_validator: Return a DetectionValidator suitable for RT-DETR model validation.\n\n Notes:\n - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.\n - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.\n\n Examples:\n >>> from ultralytics.models.rtdetr.train import RTDETRTrainer\n >>> args = dict(model=\"rtdetr-l.yaml\", data=\"coco8.yaml\", imgsz=640, epochs=3)\n >>> trainer = RTDETRTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def get_model(self, cfg: Optional[dict] = None, weights: Optional[str] = None, verbose: bool = True):\n \"\"\"\n Initialize and return an RT-DETR model for object detection tasks.\n\n Args:\n cfg (dict, optional): Model configuration.\n weights (str, optional): Path to pre-trained model weights.\n verbose (bool): Verbose logging if True.\n\n Returns:\n (RTDETRDetectionModel): Initialized model.\n \"\"\"\n model = RTDETRDetectionModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n return model\n\n def build_dataset(self, img_path: str, mode: str = \"val\", batch: Optional[int] = None):\n \"\"\"\n Build and return an RT-DETR dataset for training or validation.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): Dataset mode, either 'train' or 'val'.\n batch (int, optional): Batch size for rectangle training.\n\n Returns:\n (RTDETRDataset): Dataset object for the specific mode.\n \"\"\"\n return RTDETRDataset(\n img_path=img_path,\n imgsz=self.args.imgsz,\n batch_size=batch,\n augment=mode == \"train\",\n hyp=self.args,\n rect=False,\n cache=self.args.cache or None,\n single_cls=self.args.single_cls or False,\n prefix=colorstr(f\"{mode}: \"),\n classes=self.args.classes,\n data=self.data,\n fraction=self.args.fraction if mode == \"train\" else 1.0,\n )\n\n def get_validator(self):\n \"\"\"Return a DetectionValidator suitable for RT-DETR model validation.\"\"\"\n self.loss_names = \"giou_loss\", \"cls_loss\", \"l1_loss\"\n return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))",
"chunk_type": "class",
"name": "RTDETRTrainer",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\train.py",
"start_line": 13,
"end_line": 91,
"start_col": 0,
"end_col": 94,
"parent_name": null,
"docstring": "Trainer class for the RT-DETR model developed by Baidu for real-time object detection.\n\nThis class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.\nThe model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference\nspeed.\n\nAttributes:\n loss_names (tuple): Names of the loss components used for training.\n data (dict): Dataset configuration containing class count and other parameters.\n args (dict): Training arguments and hyperparameters.\n save_dir (Path): Directory to save training results.\n test_loader (DataLoader): DataLoader for validation/testing data.\n\nMethods:\n get_model: Initialize and return an RT-DETR model for object detection tasks.\n build_dataset: Build and return an RT-DETR dataset for training or validation.\n get_validator: Return a DetectionValidator suitable for RT-DETR model validation.\n\nNotes:\n - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.\n - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.\n\nExamples:\n >>> from ultralytics.models.rtdetr.train import RTDETRTrainer\n >>> args = dict(model=\"rtdetr-l.yaml\", data=\"coco8.yaml\", imgsz=640, epochs=3)\n >>> trainer = RTDETRTrainer(overrides=args)\n >>> trainer.train()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"typing.Optional",
"ultralytics.models.yolo.detect.DetectionTrainer",
"ultralytics.nn.tasks.RTDETRDetectionModel",
"ultralytics.utils.RANK",
"ultralytics.utils.colorstr",
"val.RTDETRDataset",
"val.RTDETRValidator",
"DetectionTrainer"
],
"chunk_id": "class_RTDETRTrainer_68d1f2ef"
},
{
"content": "from typing import Any, Dict, List, Tuple, Union",
"chunk_type": "import",
"name": "Any, Dict, List, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Tuple, Union_ddc2db15"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_d7d425c3"
},
{
"content": "from ultralytics.data import YOLODataset",
"chunk_type": "import",
"name": "YOLODataset",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLODataset_b73f5263"
},
{
"content": "from ultralytics.data.augment import Compose, Format, v8_transforms",
"chunk_type": "import",
"name": "Compose, Format, v8_transforms",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 67,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Compose, Format, v8_transforms_82ece729"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionValidator",
"chunk_type": "import",
"name": "DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionValidator_ac673e65"
},
{
"content": "from ultralytics.utils import colorstr, ops",
"chunk_type": "import",
"name": "colorstr, ops",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_colorstr, ops_2d167c92"
},
{
"content": "__all__ = (\"RTDETRValidator\",) # tuple or list",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___376b9aa0"
},
{
"content": "class RTDETRDataset(YOLODataset):\n \"\"\"\n Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.\n\n This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for\n real-time detection and tracking tasks.\n\n Attributes:\n augment (bool): Whether to apply data augmentation.\n rect (bool): Whether to use rectangular training.\n use_segments (bool): Whether to use segmentation masks.\n use_keypoints (bool): Whether to use keypoint annotations.\n imgsz (int): Target image size for training.\n\n Methods:\n load_image: Load one image from dataset index.\n build_transforms: Build transformation pipeline for the dataset.\n\n Examples:\n Initialize an RT-DETR dataset\n >>> dataset = RTDETRDataset(img_path=\"path/to/images\", imgsz=640)\n >>> image, hw = dataset.load_image(0)\n \"\"\"\n\n def __init__(self, *args, data=None, **kwargs):\n \"\"\"\n Initialize the RTDETRDataset class by inheriting from the YOLODataset class.\n\n This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)\n model, building upon the base YOLODataset functionality.\n\n Args:\n *args (Any): Variable length argument list passed to the parent YOLODataset class.\n data (dict | None): Dictionary containing dataset information. If None, default values will be used.\n **kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.\n \"\"\"\n super().__init__(*args, data=data, **kwargs)\n\n def load_image(self, i, rect_mode=False):\n \"\"\"\n Load one image from dataset index 'i'.\n\n Args:\n i (int): Index of the image to load.\n rect_mode (bool, optional): Whether to use rectangular mode for batch inference.\n\n Returns:\n im (torch.Tensor): The loaded image.\n resized_hw (tuple): Height and width of the resized image with shape (2,).\n\n Examples:\n Load an image from the dataset\n >>> dataset = RTDETRDataset(img_path=\"path/to/images\")\n >>> image, hw = dataset.load_image(0)\n \"\"\"\n return super().load_image(i=i, rect_mode=rect_mode)\n\n def build_transforms(self, hyp=None):\n \"\"\"\n Build transformation pipeline for the dataset.\n\n Args:\n hyp (dict, optional): Hyperparameters for transformations.\n\n Returns:\n (Compose): Composition of transformation functions.\n \"\"\"\n if self.augment:\n hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0\n hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0\n hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0\n transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)\n else:\n # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])\n transforms = Compose([])\n transforms.append(\n Format(\n bbox_format=\"xywh\",\n normalize=True,\n return_mask=self.use_segments,\n return_keypoint=self.use_keypoints,\n batch_idx=True,\n mask_ratio=hyp.mask_ratio,\n mask_overlap=hyp.overlap_mask,\n )\n )\n return transforms",
"chunk_type": "class",
"name": "RTDETRDataset",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 15,
"end_line": 101,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.\n\nThis specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for\nreal-time detection and tracking tasks.\n\nAttributes:\n augment (bool): Whether to apply data augmentation.\n rect (bool): Whether to use rectangular training.\n use_segments (bool): Whether to use segmentation masks.\n use_keypoints (bool): Whether to use keypoint annotations.\n imgsz (int): Target image size for training.\n\nMethods:\n load_image: Load one image from dataset index.\n build_transforms: Build transformation pipeline for the dataset.\n\nExamples:\n Initialize an RT-DETR dataset\n >>> dataset = RTDETRDataset(img_path=\"path/to/images\", imgsz=640)\n >>> image, hw = dataset.load_image(0)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"torch",
"ultralytics.data.YOLODataset",
"ultralytics.data.augment.Compose",
"ultralytics.data.augment.Format",
"ultralytics.data.augment.v8_transforms",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.utils.colorstr",
"ultralytics.utils.ops",
"YOLODataset"
],
"chunk_id": "class_RTDETRDataset_6b7a2796"
},
{
"content": "class RTDETRValidator(DetectionValidator):\n \"\"\"\n RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for\n the RT-DETR (Real-Time DETR) object detection model.\n\n The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for\n post-processing, and updates evaluation metrics accordingly.\n\n Attributes:\n args (Namespace): Configuration arguments for validation.\n data (dict): Dataset configuration dictionary.\n\n Methods:\n build_dataset: Build an RTDETR Dataset for validation.\n postprocess: Apply Non-maximum suppression to prediction outputs.\n\n Examples:\n Initialize and run RT-DETR validation\n >>> from ultralytics.models.rtdetr import RTDETRValidator\n >>> args = dict(model=\"rtdetr-l.pt\", data=\"coco8.yaml\")\n >>> validator = RTDETRValidator(args=args)\n >>> validator()\n\n Notes:\n For further details on the attributes and methods, refer to the parent DetectionValidator class.\n \"\"\"\n\n def build_dataset(self, img_path, mode=\"val\", batch=None):\n \"\"\"\n Build an RTDETR Dataset.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for\n each mode.\n batch (int, optional): Size of batches, this is for `rect`.\n\n Returns:\n (RTDETRDataset): Dataset configured for RT-DETR validation.\n \"\"\"\n return RTDETRDataset(\n img_path=img_path,\n imgsz=self.args.imgsz,\n batch_size=batch,\n augment=False, # no augmentation\n hyp=self.args,\n rect=False, # no rect\n cache=self.args.cache or None,\n prefix=colorstr(f\"{mode}: \"),\n data=self.data,\n )\n\n def postprocess(\n self, preds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]\n ) -> List[Dict[str, torch.Tensor]]:\n \"\"\"\n Apply Non-maximum suppression to prediction outputs.\n\n Args:\n preds (torch.Tensor | List | Tuple): Raw predictions from the model. If tensor, should have shape\n (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.\n\n Returns:\n (List[Dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:\n - 'bboxes': Tensor of shape (N, 4) with bounding box coordinates\n - 'conf': Tensor of shape (N,) with confidence scores\n - 'cls': Tensor of shape (N,) with class indices\n \"\"\"\n if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference\n preds = [preds, None]\n\n bs, _, nd = preds[0].shape\n bboxes, scores = preds[0].split((4, nd - 4), dim=-1)\n bboxes *= self.args.imgsz\n outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs\n for i, bbox in enumerate(bboxes): # (300, 4)\n bbox = ops.xywh2xyxy(bbox)\n score, cls = scores[i].max(-1) # (300, )\n pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter\n # Sort by confidence to correctly get internal metrics\n pred = pred[score.argsort(descending=True)]\n outputs[i] = pred[score > self.args.conf]\n\n return [{\"bboxes\": x[:, :4], \"conf\": x[:, 4], \"cls\": x[:, 5]} for x in outputs]\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare a batch for validation by applying necessary transformations.\n\n Args:\n si (int): Batch index.\n batch (Dict[str, Any]): Batch data containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Prepared batch with transformed annotations containing cls, bboxes,\n ori_shape, imgsz, and ratio_pad.\n \"\"\"\n idx = batch[\"batch_idx\"] == si\n cls = batch[\"cls\"][idx].squeeze(-1)\n bbox = batch[\"bboxes\"][idx]\n ori_shape = batch[\"ori_shape\"][si]\n imgsz = batch[\"img\"].shape[2:]\n ratio_pad = batch[\"ratio_pad\"][si]\n if len(cls):\n bbox = ops.xywh2xyxy(bbox) # target boxes\n bbox[..., [0, 2]] *= ori_shape[1] # native-space pred\n bbox[..., [1, 3]] *= ori_shape[0] # native-space pred\n return {\"cls\": cls, \"bboxes\": bbox, \"ori_shape\": ori_shape, \"imgsz\": imgsz, \"ratio_pad\": ratio_pad}\n\n def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n \"\"\"\n Prepare predictions by scaling bounding boxes to original image dimensions.\n\n Args:\n pred (Dict[str, torch.Tensor]): Raw predictions containing 'cls', 'bboxes', and 'conf'.\n pbatch (Dict[str, torch.Tensor]): Prepared batch information containing 'ori_shape' and other metadata.\n\n Returns:\n (Dict[str, torch.Tensor]): Predictions scaled to original image dimensions.\n \"\"\"\n cls = pred[\"cls\"]\n if self.args.single_cls:\n cls *= 0\n bboxes = pred[\"bboxes\"].clone()\n bboxes[..., [0, 2]] *= pbatch[\"ori_shape\"][1] / self.args.imgsz # native-space pred\n bboxes[..., [1, 3]] *= pbatch[\"ori_shape\"][0] / self.args.imgsz # native-space pred\n return {\"bboxes\": bboxes, \"conf\": pred[\"conf\"], \"cls\": cls}",
"chunk_type": "class",
"name": "RTDETRValidator",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\val.py",
"start_line": 104,
"end_line": 230,
"start_col": 0,
"end_col": 67,
"parent_name": null,
"docstring": "RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for\nthe RT-DETR (Real-Time DETR) object detection model.\n\nThe class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for\npost-processing, and updates evaluation metrics accordingly.\n\nAttributes:\n args (Namespace): Configuration arguments for validation.\n data (dict): Dataset configuration dictionary.\n\nMethods:\n build_dataset: Build an RTDETR Dataset for validation.\n postprocess: Apply Non-maximum suppression to prediction outputs.\n\nExamples:\n Initialize and run RT-DETR validation\n >>> from ultralytics.models.rtdetr import RTDETRValidator\n >>> args = dict(model=\"rtdetr-l.pt\", data=\"coco8.yaml\")\n >>> validator = RTDETRValidator(args=args)\n >>> validator()\n\nNotes:\n For further details on the attributes and methods, refer to the parent DetectionValidator class.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"torch",
"ultralytics.data.YOLODataset",
"ultralytics.data.augment.Compose",
"ultralytics.data.augment.Format",
"ultralytics.data.augment.v8_transforms",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.utils.colorstr",
"ultralytics.utils.ops",
"DetectionValidator"
],
"chunk_id": "class_RTDETRValidator_cb87324e"
},
{
"content": "from .model import RTDETR",
"chunk_type": "import",
"name": "RTDETR",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETR_a9b0b4d3"
},
{
"content": "from .predict import RTDETRPredictor",
"chunk_type": "import",
"name": "RTDETRPredictor",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETRPredictor_47e1edf7"
},
{
"content": "from .val import RTDETRValidator",
"chunk_type": "import",
"name": "RTDETRValidator",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RTDETRValidator_8974b27e"
},
{
"content": "__all__ = \"RTDETRPredictor\", \"RTDETRValidator\", \"RTDETR\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\rtdetr\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___cb0272bf"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_de6c83bb"
},
{
"content": "from itertools import product",
"chunk_type": "import",
"name": "product",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_product_17266aff"
},
{
"content": "from typing import Any, Generator, List, Tuple",
"chunk_type": "import",
"name": "Any, Generator, List, Tuple",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Generator, List, Tuple_49540f0a"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_54d8819c"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_b4c550d5"
},
{
"content": "def is_box_near_crop_edge(\n boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0\n) -> torch.Tensor:\n \"\"\"\n Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.\n\n Args:\n boxes (torch.Tensor): Bounding boxes in XYXY format.\n crop_box (List[int]): Crop box coordinates in [x0, y0, x1, y1] format.\n orig_box (List[int]): Original image box coordinates in [x0, y0, x1, y1] format.\n atol (float, optional): Absolute tolerance for edge proximity detection.\n\n Returns:\n (torch.Tensor): Boolean tensor indicating which boxes are near crop edges.\n\n Examples:\n >>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])\n >>> crop_box = [0, 0, 200, 200]\n >>> orig_box = [0, 0, 300, 300]\n >>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)\n \"\"\"\n crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)\n orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)\n boxes = uncrop_boxes_xyxy(boxes, crop_box).float()\n near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)\n near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)\n near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)\n return torch.any(near_crop_edge, dim=1)",
"chunk_type": "function",
"name": "is_box_near_crop_edge",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 11,
"end_line": 38,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": "Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.\n\nArgs:\n boxes (torch.Tensor): Bounding boxes in XYXY format.\n crop_box (List[int]): Crop box coordinates in [x0, y0, x1, y1] format.\n orig_box (List[int]): Original image box coordinates in [x0, y0, x1, y1] format.\n atol (float, optional): Absolute tolerance for edge proximity detection.\n\nReturns:\n (torch.Tensor): Boolean tensor indicating which boxes are near crop edges.\n\nExamples:\n >>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])\n >>> crop_box = [0, 0, 200, 200]\n >>> orig_box = [0, 0, 300, 300]\n >>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)",
"parameters": [
"boxes: torch.Tensor",
"crop_box: List[int]",
"orig_box: List[int]",
"atol: float"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_is_box_near_crop_edge_d6d98eba"
},
{
"content": "def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:\n \"\"\"\n Yield batches of data from input arguments with specified batch size for efficient processing.\n\n This function takes a batch size and any number of iterables, then yields batches of elements from those\n iterables. All input iterables must have the same length.\n\n Args:\n batch_size (int): Size of each batch to yield.\n *args (Any): Variable length input iterables to batch. All iterables must have the same length.\n\n Yields:\n (List[Any]): A list of batched elements from each input iterable.\n\n Examples:\n >>> data = [1, 2, 3, 4, 5]\n >>> labels = [\"a\", \"b\", \"c\", \"d\", \"e\"]\n >>> for batch in batch_iterator(2, data, labels):\n ... print(batch)\n [[1, 2], ['a', 'b']]\n [[3, 4], ['c', 'd']]\n [[5], ['e']]\n \"\"\"\n assert args and all(len(a) == len(args[0]) for a in args), \"Batched iteration must have same-size inputs.\"\n n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)\n for b in range(n_batches):\n yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]",
"chunk_type": "function",
"name": "batch_iterator",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 41,
"end_line": 67,
"start_col": 0,
"end_col": 74,
"parent_name": null,
"docstring": "Yield batches of data from input arguments with specified batch size for efficient processing.\n\nThis function takes a batch size and any number of iterables, then yields batches of elements from those\niterables. All input iterables must have the same length.\n\nArgs:\n batch_size (int): Size of each batch to yield.\n *args (Any): Variable length input iterables to batch. All iterables must have the same length.\n\nYields:\n (List[Any]): A list of batched elements from each input iterable.\n\nExamples:\n >>> data = [1, 2, 3, 4, 5]\n >>> labels = [\"a\", \"b\", \"c\", \"d\", \"e\"]\n >>> for batch in batch_iterator(2, data, labels):\n ... print(batch)\n [[1, 2], ['a', 'b']]\n [[3, 4], ['c', 'd']]\n [[5], ['e']]",
"parameters": [
"batch_size: int"
],
"return_type": "Generator[List[Any], None, None]",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_batch_iterator_e43338cb"
},
{
"content": "def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:\n \"\"\"\n Compute the stability score for a batch of masks.\n\n The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at\n high and low values.\n\n Args:\n masks (torch.Tensor): Batch of predicted mask logits.\n mask_threshold (float): Threshold value for creating binary masks.\n threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.\n\n Returns:\n (torch.Tensor): Stability scores for each mask in the batch.\n\n Notes:\n - One mask is always contained inside the other.\n - Memory is saved by preventing unnecessary cast to torch.int64.\n\n Examples:\n >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks\n >>> mask_threshold = 0.5\n >>> threshold_offset = 0.1\n >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)\n \"\"\"\n intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)\n unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)\n return intersections / unions",
"chunk_type": "function",
"name": "calculate_stability_score",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 70,
"end_line": 97,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "Compute the stability score for a batch of masks.\n\nThe stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at\nhigh and low values.\n\nArgs:\n masks (torch.Tensor): Batch of predicted mask logits.\n mask_threshold (float): Threshold value for creating binary masks.\n threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.\n\nReturns:\n (torch.Tensor): Stability scores for each mask in the batch.\n\nNotes:\n - One mask is always contained inside the other.\n - Memory is saved by preventing unnecessary cast to torch.int64.\n\nExamples:\n >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks\n >>> mask_threshold = 0.5\n >>> threshold_offset = 0.1\n >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)",
"parameters": [
"masks: torch.Tensor",
"mask_threshold: float",
"threshold_offset: float"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_calculate_stability_score_eb844f5d"
},
{
"content": "def build_point_grid(n_per_side: int) -> np.ndarray:\n \"\"\"Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks.\"\"\"\n offset = 1 / (2 * n_per_side)\n points_one_side = np.linspace(offset, 1 - offset, n_per_side)\n points_x = np.tile(points_one_side[None, :], (n_per_side, 1))\n points_y = np.tile(points_one_side[:, None], (1, n_per_side))\n return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)",
"chunk_type": "function",
"name": "build_point_grid",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 100,
"end_line": 106,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": "Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks.",
"parameters": [
"n_per_side: int"
],
"return_type": "np.ndarray",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_build_point_grid_0b97c27c"
},
{
"content": "def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:\n \"\"\"Generate point grids for multiple crop layers with varying scales and densities.\"\"\"\n return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]",
"chunk_type": "function",
"name": "build_all_layer_point_grids",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 109,
"end_line": 111,
"start_col": 0,
"end_col": 98,
"parent_name": null,
"docstring": "Generate point grids for multiple crop layers with varying scales and densities.",
"parameters": [
"n_per_side: int",
"n_layers: int",
"scale_per_layer: int"
],
"return_type": "List[np.ndarray]",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_build_all_layer_point_grids_e3ffef29"
},
{
"content": "def generate_crop_boxes(\n im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float\n) -> Tuple[List[List[int]], List[int]]:\n \"\"\"\n Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.\n\n Args:\n im_size (Tuple[int, ...]): Height and width of the input image.\n n_layers (int): Number of layers to generate crop boxes for.\n overlap_ratio (float): Ratio of overlap between adjacent crop boxes.\n\n Returns:\n crop_boxes (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.\n layer_idxs (List[int]): List of layer indices corresponding to each crop box.\n\n Examples:\n >>> im_size = (800, 1200) # Height, width\n >>> n_layers = 3\n >>> overlap_ratio = 0.25\n >>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)\n \"\"\"\n crop_boxes, layer_idxs = [], []\n im_h, im_w = im_size\n short_side = min(im_h, im_w)\n\n # Original image\n crop_boxes.append([0, 0, im_w, im_h])\n layer_idxs.append(0)\n\n def crop_len(orig_len, n_crops, overlap):\n \"\"\"Calculate the length of each crop given the original length, number of crops, and overlap.\"\"\"\n return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))\n\n for i_layer in range(n_layers):\n n_crops_per_side = 2 ** (i_layer + 1)\n overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))\n\n crop_w = crop_len(im_w, n_crops_per_side, overlap)\n crop_h = crop_len(im_h, n_crops_per_side, overlap)\n\n crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]\n crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]\n\n # Crops in XYWH format\n for x0, y0 in product(crop_box_x0, crop_box_y0):\n box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]\n crop_boxes.append(box)\n layer_idxs.append(i_layer + 1)\n\n return crop_boxes, layer_idxs",
"chunk_type": "function",
"name": "generate_crop_boxes",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 114,
"end_line": 163,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.\n\nArgs:\n im_size (Tuple[int, ...]): Height and width of the input image.\n n_layers (int): Number of layers to generate crop boxes for.\n overlap_ratio (float): Ratio of overlap between adjacent crop boxes.\n\nReturns:\n crop_boxes (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.\n layer_idxs (List[int]): List of layer indices corresponding to each crop box.\n\nExamples:\n >>> im_size = (800, 1200) # Height, width\n >>> n_layers = 3\n >>> overlap_ratio = 0.25\n >>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)",
"parameters": [
"im_size: Tuple[int, ...]",
"n_layers: int",
"overlap_ratio: float"
],
"return_type": "Tuple[List[List[int]], List[int]]",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_generate_crop_boxes_9af763b4"
},
{
"content": "def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:\n \"\"\"Uncrop bounding boxes by adding the crop box offset to their coordinates.\"\"\"\n x0, y0, _, _ = crop_box\n offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)\n # Check if boxes has a channel dimension\n if len(boxes.shape) == 3:\n offset = offset.unsqueeze(1)\n return boxes + offset",
"chunk_type": "function",
"name": "uncrop_boxes_xyxy",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 166,
"end_line": 173,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "Uncrop bounding boxes by adding the crop box offset to their coordinates.",
"parameters": [
"boxes: torch.Tensor",
"crop_box: List[int]"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_uncrop_boxes_xyxy_a4e9fac7"
},
{
"content": "def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:\n \"\"\"Uncrop points by adding the crop box offset to their coordinates.\"\"\"\n x0, y0, _, _ = crop_box\n offset = torch.tensor([[x0, y0]], device=points.device)\n # Check if points has a channel dimension\n if len(points.shape) == 3:\n offset = offset.unsqueeze(1)\n return points + offset",
"chunk_type": "function",
"name": "uncrop_points",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 176,
"end_line": 183,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": "Uncrop points by adding the crop box offset to their coordinates.",
"parameters": [
"points: torch.Tensor",
"crop_box: List[int]"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_uncrop_points_eac1907e"
},
{
"content": "def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:\n \"\"\"Uncrop masks by padding them to the original image size, handling coordinate transformations.\"\"\"\n x0, y0, x1, y1 = crop_box\n if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:\n return masks\n # Coordinate transform masks\n pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)\n pad = (x0, pad_x - x0, y0, pad_y - y0)\n return torch.nn.functional.pad(masks, pad, value=0)",
"chunk_type": "function",
"name": "uncrop_masks",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 186,
"end_line": 194,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": "Uncrop masks by padding them to the original image size, handling coordinate transformations.",
"parameters": [
"masks: torch.Tensor",
"crop_box: List[int]",
"orig_h: int",
"orig_w: int"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_uncrop_masks_5b18c7a2"
},
{
"content": "def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:\n \"\"\"\n Remove small disconnected regions or holes in a mask based on area threshold and mode.\n\n Args:\n mask (np.ndarray): Binary mask to process.\n area_thresh (float): Area threshold below which regions will be removed.\n mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected\n regions.\n\n Returns:\n processed_mask (np.ndarray): Processed binary mask with small regions removed.\n modified (bool): Whether any regions were modified.\n\n Examples:\n >>> mask = np.zeros((100, 100), dtype=np.bool_)\n >>> mask[40:60, 40:60] = True # Create a square\n >>> mask[45:55, 45:55] = False # Create a hole\n >>> processed_mask, modified = remove_small_regions(mask, 50, \"holes\")\n \"\"\"\n import cv2 # type: ignore\n\n assert mode in {\"holes\", \"islands\"}, f\"Provided mode {mode} is invalid\"\n correct_holes = mode == \"holes\"\n working_mask = (correct_holes ^ mask).astype(np.uint8)\n n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)\n sizes = stats[:, -1][1:] # Row 0 is background label\n small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]\n if not small_regions:\n return mask, False\n fill_labels = [0] + small_regions\n if not correct_holes:\n # If every region is below threshold, keep largest\n fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]\n mask = np.isin(regions, fill_labels)\n return mask, True",
"chunk_type": "function",
"name": "remove_small_regions",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 197,
"end_line": 232,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": "Remove small disconnected regions or holes in a mask based on area threshold and mode.\n\nArgs:\n mask (np.ndarray): Binary mask to process.\n area_thresh (float): Area threshold below which regions will be removed.\n mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected\n regions.\n\nReturns:\n processed_mask (np.ndarray): Processed binary mask with small regions removed.\n modified (bool): Whether any regions were modified.\n\nExamples:\n >>> mask = np.zeros((100, 100), dtype=np.bool_)\n >>> mask[40:60, 40:60] = True # Create a square\n >>> mask[45:55, 45:55] = False # Create a hole\n >>> processed_mask, modified = remove_small_regions(mask, 50, \"holes\")",
"parameters": [
"mask: np.ndarray",
"area_thresh: float",
"mode: str"
],
"return_type": "Tuple[np.ndarray, bool]",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_remove_small_regions_8f6fa57b"
},
{
"content": "def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Calculate bounding boxes in XYXY format around binary masks.\n\n Args:\n masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).\n\n Returns:\n (torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).\n\n Notes:\n - Handles empty masks by returning zero boxes.\n - Preserves input tensor dimensions in the output.\n \"\"\"\n # torch.max below raises an error on empty inputs, just skip in this case\n if torch.numel(masks) == 0:\n return torch.zeros(*masks.shape[:-2], 4, device=masks.device)\n\n # Normalize shape to CxHxW\n shape = masks.shape\n h, w = shape[-2:]\n masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)\n # Get top and bottom edges\n in_height, _ = torch.max(masks, dim=-1)\n in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]\n bottom_edges, _ = torch.max(in_height_coords, dim=-1)\n in_height_coords = in_height_coords + h * (~in_height)\n top_edges, _ = torch.min(in_height_coords, dim=-1)\n\n # Get left and right edges\n in_width, _ = torch.max(masks, dim=-2)\n in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]\n right_edges, _ = torch.max(in_width_coords, dim=-1)\n in_width_coords = in_width_coords + w * (~in_width)\n left_edges, _ = torch.min(in_width_coords, dim=-1)\n\n # If the mask is empty the right edge will be to the left of the left edge.\n # Replace these boxes with [0, 0, 0, 0]\n empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)\n out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)\n out = out * (~empty_filter).unsqueeze(-1)\n\n # Return to original shape\n return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]",
"chunk_type": "function",
"name": "batched_mask_to_box",
"file_path": "ultralytics\\ultralytics\\models\\sam\\amg.py",
"start_line": 235,
"end_line": 278,
"start_col": 0,
"end_col": 68,
"parent_name": null,
"docstring": "Calculate bounding boxes in XYXY format around binary masks.\n\nArgs:\n masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).\n\nReturns:\n (torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).\n\nNotes:\n - Handles empty masks by returning zero boxes.\n - Preserves input tensor dimensions in the output.",
"parameters": [
"masks: torch.Tensor"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"math",
"itertools.product",
"typing.Any",
"typing.Generator",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"cv2"
],
"chunk_id": "function_batched_mask_to_box_e2200aae"
},
{
"content": "from functools import partial",
"chunk_type": "import",
"name": "partial",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_partial_8b42cab9"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_91b55564"
},
{
"content": "from ultralytics.utils.downloads import attempt_download_asset",
"chunk_type": "import",
"name": "attempt_download_asset",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 62,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_attempt_download_asset_1a278421"
},
{
"content": "from .modules.decoders import MaskDecoder",
"chunk_type": "import",
"name": "MaskDecoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_MaskDecoder_37076adc"
},
{
"content": "from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder",
"chunk_type": "import",
"name": "FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 105,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder_09f926dc"
},
{
"content": "from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer",
"chunk_type": "import",
"name": "MemoryAttention, MemoryAttentionLayer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 75,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_MemoryAttention, MemoryAttentionLayer_748598b9"
},
{
"content": "from .modules.sam import SAM2Model, SAMModel",
"chunk_type": "import",
"name": "SAM2Model, SAMModel",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SAM2Model, SAMModel_41e150c1"
},
{
"content": "from .modules.tiny_encoder import TinyViT",
"chunk_type": "import",
"name": "TinyViT",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TinyViT_d67c140d"
},
{
"content": "from .modules.transformer import TwoWayTransformer",
"chunk_type": "import",
"name": "TwoWayTransformer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 20,
"end_line": 20,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TwoWayTransformer_4299f81a"
},
{
"content": "def build_sam_vit_h(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters.\"\"\"\n return _build_sam(\n encoder_embed_dim=1280,\n encoder_depth=32,\n encoder_num_heads=16,\n encoder_global_attn_indexes=[7, 15, 23, 31],\n checkpoint=checkpoint,\n )",
"chunk_type": "function",
"name": "build_sam_vit_h",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 23,
"end_line": 31,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters.",
"parameters": [
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_sam_vit_h_b44b4a26"
},
{
"content": "def build_sam_vit_l(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters.\"\"\"\n return _build_sam(\n encoder_embed_dim=1024,\n encoder_depth=24,\n encoder_num_heads=16,\n encoder_global_attn_indexes=[5, 11, 17, 23],\n checkpoint=checkpoint,\n )",
"chunk_type": "function",
"name": "build_sam_vit_l",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 34,
"end_line": 42,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters.",
"parameters": [
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_sam_vit_l_468e3fc6"
},
{
"content": "def build_sam_vit_b(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters.\"\"\"\n return _build_sam(\n encoder_embed_dim=768,\n encoder_depth=12,\n encoder_num_heads=12,\n encoder_global_attn_indexes=[2, 5, 8, 11],\n checkpoint=checkpoint,\n )",
"chunk_type": "function",
"name": "build_sam_vit_b",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 45,
"end_line": 53,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters.",
"parameters": [
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_sam_vit_b_1f501484"
},
{
"content": "def build_mobile_sam(checkpoint=None):\n \"\"\"Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation.\"\"\"\n return _build_sam(\n encoder_embed_dim=[64, 128, 160, 320],\n encoder_depth=[2, 2, 6, 2],\n encoder_num_heads=[2, 4, 5, 10],\n encoder_global_attn_indexes=None,\n mobile_sam=True,\n checkpoint=checkpoint,\n )",
"chunk_type": "function",
"name": "build_mobile_sam",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 56,
"end_line": 65,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation.",
"parameters": [
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_mobile_sam_f102aafa"
},
{
"content": "def build_sam2_t(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters.\"\"\"\n return _build_sam2(\n encoder_embed_dim=96,\n encoder_stages=[1, 2, 7, 2],\n encoder_num_heads=1,\n encoder_global_att_blocks=[5, 7, 9],\n encoder_window_spec=[8, 4, 14, 7],\n encoder_backbone_channel_list=[768, 384, 192, 96],\n checkpoint=checkpoint,\n )",
"chunk_type": "function",
"name": "build_sam2_t",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 68,
"end_line": 78,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters.",
"parameters": [
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_sam2_t_9c0c8531"
},
{
"content": "def build_sam2_s(checkpoint=None):\n \"\"\"Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters.\"\"\"\n return _build_sam2(\n encoder_embed_dim=96,\n encoder_stages=[1, 2, 11, 2],\n encoder_num_heads=1,\n encoder_global_att_blocks=[7, 10, 13],\n encoder_window_spec=[8, 4, 14, 7],\n encoder_backbone_channel_list=[768, 384, 192, 96],\n checkpoint=checkpoint,\n )",
"chunk_type": "function",
"name": "build_sam2_s",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 81,
"end_line": 91,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters.",
"parameters": [
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_sam2_s_1919150f"
},
{
"content": "def build_sam2_b(checkpoint=None):\n \"\"\"Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters.\"\"\"\n return _build_sam2(\n encoder_embed_dim=112,\n encoder_stages=[2, 3, 16, 3],\n encoder_num_heads=2,\n encoder_global_att_blocks=[12, 16, 20],\n encoder_window_spec=[8, 4, 14, 7],\n encoder_window_spatial_size=[14, 14],\n encoder_backbone_channel_list=[896, 448, 224, 112],\n checkpoint=checkpoint,\n )",
"chunk_type": "function",
"name": "build_sam2_b",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 94,
"end_line": 105,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters.",
"parameters": [
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_sam2_b_ad8831d6"
},
{
"content": "def build_sam2_l(checkpoint=None):\n \"\"\"Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters.\"\"\"\n return _build_sam2(\n encoder_embed_dim=144,\n encoder_stages=[2, 6, 36, 4],\n encoder_num_heads=2,\n encoder_global_att_blocks=[23, 33, 43],\n encoder_window_spec=[8, 4, 16, 8],\n encoder_backbone_channel_list=[1152, 576, 288, 144],\n checkpoint=checkpoint,\n )",
"chunk_type": "function",
"name": "build_sam2_l",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 108,
"end_line": 118,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters.",
"parameters": [
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_sam2_l_b5bf2557"
},
{
"content": "def _build_sam(\n encoder_embed_dim,\n encoder_depth,\n encoder_num_heads,\n encoder_global_attn_indexes,\n checkpoint=None,\n mobile_sam=False,\n):\n \"\"\"\n Build a Segment Anything Model (SAM) with specified encoder parameters.\n\n Args:\n encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.\n encoder_depth (int | List[int]): Depth of the encoder.\n encoder_num_heads (int | List[int]): Number of attention heads in the encoder.\n encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.\n checkpoint (str | None, optional): Path to the model checkpoint file.\n mobile_sam (bool, optional): Whether to build a Mobile-SAM model.\n\n Returns:\n (SAMModel): A Segment Anything Model instance with the specified architecture.\n\n Examples:\n >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])\n >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)\n \"\"\"\n prompt_embed_dim = 256\n image_size = 1024\n vit_patch_size = 16\n image_embedding_size = image_size // vit_patch_size\n image_encoder = (\n TinyViT(\n img_size=1024,\n in_chans=3,\n num_classes=1000,\n embed_dims=encoder_embed_dim,\n depths=encoder_depth,\n num_heads=encoder_num_heads,\n window_sizes=[7, 7, 14, 7],\n mlp_ratio=4.0,\n drop_rate=0.0,\n drop_path_rate=0.0,\n use_checkpoint=False,\n mbconv_expand_ratio=4.0,\n local_conv_size=3,\n layer_lr_decay=0.8,\n )\n if mobile_sam\n else ImageEncoderViT(\n depth=encoder_depth,\n embed_dim=encoder_embed_dim,\n img_size=image_size,\n mlp_ratio=4,\n norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),\n num_heads=encoder_num_heads,\n patch_size=vit_patch_size,\n qkv_bias=True,\n use_rel_pos=True,\n global_attn_indexes=encoder_global_attn_indexes,\n window_size=14,\n out_chans=prompt_embed_dim,\n )\n )\n sam = SAMModel(\n image_encoder=image_encoder,\n prompt_encoder=PromptEncoder(\n embed_dim=prompt_embed_dim,\n image_embedding_size=(image_embedding_size, image_embedding_size),\n input_image_size=(image_size, image_size),\n mask_in_chans=16,\n ),\n mask_decoder=MaskDecoder(\n num_multimask_outputs=3,\n transformer=TwoWayTransformer(\n depth=2,\n embedding_dim=prompt_embed_dim,\n mlp_dim=2048,\n num_heads=8,\n ),\n transformer_dim=prompt_embed_dim,\n iou_head_depth=3,\n iou_head_hidden_dim=256,\n ),\n pixel_mean=[123.675, 116.28, 103.53],\n pixel_std=[58.395, 57.12, 57.375],\n )\n if checkpoint is not None:\n checkpoint = attempt_download_asset(checkpoint)\n with open(checkpoint, \"rb\") as f:\n state_dict = torch.load(f)\n sam.load_state_dict(state_dict)\n sam.eval()\n return sam",
"chunk_type": "function",
"name": "_build_sam",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 121,
"end_line": 213,
"start_col": 0,
"end_col": 14,
"parent_name": null,
"docstring": "Build a Segment Anything Model (SAM) with specified encoder parameters.\n\nArgs:\n encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.\n encoder_depth (int | List[int]): Depth of the encoder.\n encoder_num_heads (int | List[int]): Number of attention heads in the encoder.\n encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.\n checkpoint (str | None, optional): Path to the model checkpoint file.\n mobile_sam (bool, optional): Whether to build a Mobile-SAM model.\n\nReturns:\n (SAMModel): A Segment Anything Model instance with the specified architecture.\n\nExamples:\n >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])\n >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)",
"parameters": [
"encoder_embed_dim",
"encoder_depth",
"encoder_num_heads",
"encoder_global_attn_indexes",
"checkpoint",
"mobile_sam"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function__build_sam_ff65f01d"
},
{
"content": "def _build_sam2(\n encoder_embed_dim=1280,\n encoder_stages=[2, 6, 36, 4],\n encoder_num_heads=2,\n encoder_global_att_blocks=[7, 15, 23, 31],\n encoder_backbone_channel_list=[1152, 576, 288, 144],\n encoder_window_spatial_size=[7, 7],\n encoder_window_spec=[8, 4, 16, 8],\n checkpoint=None,\n):\n \"\"\"\n Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.\n\n Args:\n encoder_embed_dim (int, optional): Embedding dimension for the encoder.\n encoder_stages (List[int], optional): Number of blocks in each stage of the encoder.\n encoder_num_heads (int, optional): Number of attention heads in the encoder.\n encoder_global_att_blocks (List[int], optional): Indices of global attention blocks in the encoder.\n encoder_backbone_channel_list (List[int], optional): Channel dimensions for each level of the encoder backbone.\n encoder_window_spatial_size (List[int], optional): Spatial size of the window for position embeddings.\n encoder_window_spec (List[int], optional): Window specifications for each stage of the encoder.\n checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.\n\n Returns:\n (SAM2Model): A configured and initialized SAM2 model.\n\n Examples:\n >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])\n >>> sam2_model.eval()\n \"\"\"\n image_encoder = ImageEncoder(\n trunk=Hiera(\n embed_dim=encoder_embed_dim,\n num_heads=encoder_num_heads,\n stages=encoder_stages,\n global_att_blocks=encoder_global_att_blocks,\n window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,\n window_spec=encoder_window_spec,\n ),\n neck=FpnNeck(\n d_model=256,\n backbone_channel_list=encoder_backbone_channel_list,\n fpn_top_down_levels=[2, 3],\n fpn_interp_model=\"nearest\",\n ),\n scalp=1,\n )\n memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())\n memory_encoder = MemoryEncoder(out_dim=64)\n\n is_sam2_1 = checkpoint is not None and \"sam2.1\" in checkpoint\n sam2 = SAM2Model(\n image_encoder=image_encoder,\n memory_attention=memory_attention,\n memory_encoder=memory_encoder,\n num_maskmem=7,\n image_size=1024,\n sigmoid_scale_for_mem_enc=20.0,\n sigmoid_bias_for_mem_enc=-10.0,\n use_mask_input_as_output_without_sam=True,\n directly_add_no_mem_embed=True,\n use_high_res_features_in_sam=True,\n multimask_output_in_sam=True,\n iou_prediction_use_sigmoid=True,\n use_obj_ptrs_in_encoder=True,\n add_tpos_enc_to_obj_ptrs=True,\n only_obj_ptrs_in_the_past_for_eval=True,\n pred_obj_scores=True,\n pred_obj_scores_mlp=True,\n fixed_no_obj_ptr=True,\n multimask_output_for_tracking=True,\n use_multimask_token_for_obj_ptr=True,\n multimask_min_pt_num=0,\n multimask_max_pt_num=1,\n use_mlp_for_obj_ptr_proj=True,\n compile_image_encoder=False,\n no_obj_embed_spatial=is_sam2_1,\n proj_tpos_enc_in_obj_ptrs=is_sam2_1,\n use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,\n sam_mask_decoder_extra_args=dict(\n dynamic_multimask_via_stability=True,\n dynamic_multimask_stability_delta=0.05,\n dynamic_multimask_stability_thresh=0.98,\n ),\n )\n\n if checkpoint is not None:\n checkpoint = attempt_download_asset(checkpoint)\n with open(checkpoint, \"rb\") as f:\n state_dict = torch.load(f)[\"model\"]\n sam2.load_state_dict(state_dict)\n sam2.eval()\n return sam2",
"chunk_type": "function",
"name": "_build_sam2",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 216,
"end_line": 308,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.\n\nArgs:\n encoder_embed_dim (int, optional): Embedding dimension for the encoder.\n encoder_stages (List[int], optional): Number of blocks in each stage of the encoder.\n encoder_num_heads (int, optional): Number of attention heads in the encoder.\n encoder_global_att_blocks (List[int], optional): Indices of global attention blocks in the encoder.\n encoder_backbone_channel_list (List[int], optional): Channel dimensions for each level of the encoder backbone.\n encoder_window_spatial_size (List[int], optional): Spatial size of the window for position embeddings.\n encoder_window_spec (List[int], optional): Window specifications for each stage of the encoder.\n checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.\n\nReturns:\n (SAM2Model): A configured and initialized SAM2 model.\n\nExamples:\n >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])\n >>> sam2_model.eval()",
"parameters": [
"encoder_embed_dim",
"encoder_stages",
"encoder_num_heads",
"encoder_global_att_blocks",
"encoder_backbone_channel_list",
"encoder_window_spatial_size",
"encoder_window_spec",
"checkpoint"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function__build_sam2_f613f9c0"
},
{
"content": "sam_model_map = {\n \"sam_h.pt\": build_sam_vit_h,\n \"sam_l.pt\": build_sam_vit_l,\n \"sam_b.pt\": build_sam_vit_b,\n \"mobile_sam.pt\": build_mobile_sam,\n \"sam2_t.pt\": build_sam2_t,\n \"sam2_s.pt\": build_sam2_s,\n \"sam2_b.pt\": build_sam2_b,\n \"sam2_l.pt\": build_sam2_l,\n \"sam2.1_t.pt\": build_sam2_t,\n \"sam2.1_s.pt\": build_sam2_s,\n \"sam2.1_b.pt\": build_sam2_b,\n \"sam2.1_l.pt\": build_sam2_l,\n}",
"chunk_type": "variable",
"name": "sam_model_map",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 311,
"end_line": 324,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_sam_model_map_98550379"
},
{
"content": "def build_sam(ckpt=\"sam_b.pt\"):\n \"\"\"\n Build and return a Segment Anything Model (SAM) based on the provided checkpoint.\n\n Args:\n ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.\n\n Returns:\n (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.\n\n Raises:\n FileNotFoundError: If the provided checkpoint is not a supported SAM model.\n\n Examples:\n >>> sam_model = build_sam(\"sam_b.pt\")\n >>> sam_model = build_sam(\"path/to/custom_checkpoint.pt\")\n\n Notes:\n Supported pre-defined models include:\n - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'\n - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'\n \"\"\"\n model_builder = None\n ckpt = str(ckpt) # to allow Path ckpt types\n for k in sam_model_map.keys():\n if ckpt.endswith(k):\n model_builder = sam_model_map.get(k)\n\n if not model_builder:\n raise FileNotFoundError(f\"{ckpt} is not a supported SAM model. Available models are: \\n {sam_model_map.keys()}\")\n\n return model_builder(ckpt)",
"chunk_type": "function",
"name": "build_sam",
"file_path": "ultralytics\\ultralytics\\models\\sam\\build.py",
"start_line": 327,
"end_line": 358,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": "Build and return a Segment Anything Model (SAM) based on the provided checkpoint.\n\nArgs:\n ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.\n\nReturns:\n (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.\n\nRaises:\n FileNotFoundError: If the provided checkpoint is not a supported SAM model.\n\nExamples:\n >>> sam_model = build_sam(\"sam_b.pt\")\n >>> sam_model = build_sam(\"path/to/custom_checkpoint.pt\")\n\nNotes:\n Supported pre-defined models include:\n - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'\n - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'",
"parameters": [
"ckpt"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"functools.partial",
"torch",
"ultralytics.utils.downloads.attempt_download_asset",
"modules.decoders.MaskDecoder",
"modules.encoders.FpnNeck",
"modules.encoders.Hiera",
"modules.encoders.ImageEncoder",
"modules.encoders.ImageEncoderViT",
"modules.encoders.MemoryEncoder",
"modules.encoders.PromptEncoder",
"modules.memory_attention.MemoryAttention",
"modules.memory_attention.MemoryAttentionLayer",
"modules.sam.SAM2Model",
"modules.sam.SAMModel",
"modules.tiny_encoder.TinyViT",
"modules.transformer.TwoWayTransformer"
],
"chunk_id": "function_build_sam_898d7f08"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\sam\\model.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_dce6a1cc"
},
{
"content": "from typing import Dict, Type",
"chunk_type": "import",
"name": "Dict, Type",
"file_path": "ultralytics\\ultralytics\\models\\sam\\model.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Dict, Type_280e76d2"
},
{
"content": "from ultralytics.engine.model import Model",
"chunk_type": "import",
"name": "Model",
"file_path": "ultralytics\\ultralytics\\models\\sam\\model.py",
"start_line": 20,
"end_line": 20,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Model_df6ac9de"
},
{
"content": "from ultralytics.utils.torch_utils import model_info",
"chunk_type": "import",
"name": "model_info",
"file_path": "ultralytics\\ultralytics\\models\\sam\\model.py",
"start_line": 21,
"end_line": 21,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_model_info_9bba8c5a"
},
{
"content": "from .predict import Predictor, SAM2Predictor",
"chunk_type": "import",
"name": "Predictor, SAM2Predictor",
"file_path": "ultralytics\\ultralytics\\models\\sam\\model.py",
"start_line": 23,
"end_line": 23,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Predictor, SAM2Predictor_dcfb9e46"
},
{
"content": "class SAM(Model):\n \"\"\"\n SAM (Segment Anything Model) interface class for real-time image segmentation tasks.\n\n This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for\n promptable segmentation with versatility in image analysis. It supports various prompts such as bounding\n boxes, points, or labels, and features zero-shot performance capabilities.\n\n Attributes:\n model (torch.nn.Module): The loaded SAM model.\n is_sam2 (bool): Indicates whether the model is SAM2 variant.\n task (str): The task type, set to \"segment\" for SAM models.\n\n Methods:\n predict: Perform segmentation prediction on the given image or video source.\n info: Log information about the SAM model.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> results = sam.predict(\"image.jpg\", points=[[500, 375]])\n >>> for r in results:\n >>> print(f\"Detected {len(r.masks)} masks\")\n \"\"\"\n\n def __init__(self, model: str = \"sam_b.pt\") -> None:\n \"\"\"\n Initialize the SAM (Segment Anything Model) instance.\n\n Args:\n model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.\n\n Raises:\n NotImplementedError: If the model file extension is not .pt or .pth.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> print(sam.is_sam2)\n \"\"\"\n if model and Path(model).suffix not in {\".pt\", \".pth\"}:\n raise NotImplementedError(\"SAM prediction requires pre-trained *.pt or *.pth model.\")\n self.is_sam2 = \"sam2\" in Path(model).stem\n super().__init__(model=model, task=\"segment\")\n\n def _load(self, weights: str, task=None):\n \"\"\"\n Load the specified weights into the SAM model.\n\n Args:\n weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.\n task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> sam._load(\"path/to/custom_weights.pt\")\n \"\"\"\n from .build import build_sam # slow import\n\n self.model = build_sam(weights)\n\n def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):\n \"\"\"\n Perform segmentation prediction on the given image or video source.\n\n Args:\n source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or\n a np.ndarray object.\n stream (bool): If True, enables real-time streaming.\n bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.\n points (List[List[float]] | None): List of points for prompted segmentation.\n labels (List[int] | None): List of labels for prompted segmentation.\n **kwargs (Any): Additional keyword arguments for prediction.\n\n Returns:\n (list): The model predictions.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> results = sam.predict(\"image.jpg\", points=[[500, 375]])\n >>> for r in results:\n ... print(f\"Detected {len(r.masks)} masks\")\n \"\"\"\n overrides = dict(conf=0.25, task=\"segment\", mode=\"predict\", imgsz=1024)\n kwargs = {**overrides, **kwargs}\n prompts = dict(bboxes=bboxes, points=points, labels=labels)\n return super().predict(source, stream, prompts=prompts, **kwargs)\n\n def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):\n \"\"\"\n Perform segmentation prediction on the given image or video source.\n\n This method is an alias for the 'predict' method, providing a convenient way to call the SAM model\n for segmentation tasks.\n\n Args:\n source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image\n object, or a np.ndarray object.\n stream (bool): If True, enables real-time streaming.\n bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.\n points (List[List[float]] | None): List of points for prompted segmentation.\n labels (List[int] | None): List of labels for prompted segmentation.\n **kwargs (Any): Additional keyword arguments to be passed to the predict method.\n\n Returns:\n (list): The model predictions, typically containing segmentation masks and other relevant information.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> results = sam(\"image.jpg\", points=[[500, 375]])\n >>> print(f\"Detected {len(results[0].masks)} masks\")\n \"\"\"\n return self.predict(source, stream, bboxes, points, labels, **kwargs)\n\n def info(self, detailed: bool = False, verbose: bool = True):\n \"\"\"\n Log information about the SAM model.\n\n Args:\n detailed (bool): If True, displays detailed information about the model layers and operations.\n verbose (bool): If True, prints the information to the console.\n\n Returns:\n (tuple): A tuple containing the model's information (string representations of the model).\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> info = sam.info()\n >>> print(info[0]) # Print summary information\n \"\"\"\n return model_info(self.model, detailed=detailed, verbose=verbose)\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Type[Predictor]]]:\n \"\"\"\n Provide a mapping from the 'segment' task to its corresponding 'Predictor'.\n\n Returns:\n (Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding\n Predictor class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.\n\n Examples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> task_map = sam.task_map\n >>> print(task_map)\n {'segment': {'predictor': }}\n \"\"\"\n return {\"segment\": {\"predictor\": SAM2Predictor if self.is_sam2 else Predictor}}",
"chunk_type": "class",
"name": "SAM",
"file_path": "ultralytics\\ultralytics\\models\\sam\\model.py",
"start_line": 26,
"end_line": 171,
"start_col": 0,
"end_col": 87,
"parent_name": null,
"docstring": "SAM (Segment Anything Model) interface class for real-time image segmentation tasks.\n\nThis class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for\npromptable segmentation with versatility in image analysis. It supports various prompts such as bounding\nboxes, points, or labels, and features zero-shot performance capabilities.\n\nAttributes:\n model (torch.nn.Module): The loaded SAM model.\n is_sam2 (bool): Indicates whether the model is SAM2 variant.\n task (str): The task type, set to \"segment\" for SAM models.\n\nMethods:\n predict: Perform segmentation prediction on the given image or video source.\n info: Log information about the SAM model.\n\nExamples:\n >>> sam = SAM(\"sam_b.pt\")\n >>> results = sam.predict(\"image.jpg\", points=[[500, 375]])\n >>> for r in results:\n >>> print(f\"Detected {len(r.masks)} masks\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Dict",
"typing.Type",
"ultralytics.engine.model.Model",
"ultralytics.utils.torch_utils.model_info",
"predict.Predictor",
"predict.SAM2Predictor",
"build.build_sam",
"Model"
],
"chunk_id": "class_SAM_9775bc2e"
},
{
"content": "from collections import OrderedDict",
"chunk_type": "import",
"name": "OrderedDict",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OrderedDict_ac6d9e8b"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_0942c671"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_ffeb0bc7"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_4464cd3c"
},
{
"content": "from ultralytics.data.augment import LetterBox",
"chunk_type": "import",
"name": "LetterBox",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LetterBox_76b72a38"
},
{
"content": "from ultralytics.engine.predictor import BasePredictor",
"chunk_type": "import",
"name": "BasePredictor",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BasePredictor_5265f07c"
},
{
"content": "from ultralytics.engine.results import Results",
"chunk_type": "import",
"name": "Results",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Results_90d1e37f"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, ops",
"chunk_type": "import",
"name": "DEFAULT_CFG, ops",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 20,
"end_line": 20,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, ops_37bd5d37"
},
{
"content": "from ultralytics.utils.torch_utils import select_device, smart_inference_mode",
"chunk_type": "import",
"name": "select_device, smart_inference_mode",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 21,
"end_line": 21,
"start_col": 0,
"end_col": 77,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_select_device, smart_inference_mode_c00f2f8b"
},
{
"content": "from .amg import (\n batch_iterator,\n batched_mask_to_box,\n build_all_layer_point_grids,\n calculate_stability_score,\n generate_crop_boxes,\n is_box_near_crop_edge,\n remove_small_regions,\n uncrop_boxes_xyxy,\n uncrop_masks,\n)",
"chunk_type": "import",
"name": "batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 23,
"end_line": 33,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks_682436ba"
},
{
"content": "class Predictor(BasePredictor):\n \"\"\"\n Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.\n\n This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image\n segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for\n fine-grained control over segmentation results.\n\n Attributes:\n args (SimpleNamespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded SAM model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n im (torch.Tensor): The preprocessed input image.\n features (torch.Tensor): Extracted image features.\n prompts (Dict[str, Any]): Dictionary to store various types of prompts (e.g., bboxes, points, masks).\n segment_all (bool): Flag to indicate if full image segmentation should be performed.\n mean (torch.Tensor): Mean values for image normalization.\n std (torch.Tensor): Standard deviation values for image normalization.\n\n Methods:\n preprocess: Prepare input images for model inference.\n pre_transform: Perform initial transformations on the input image.\n inference: Perform segmentation inference based on input prompts.\n prompt_inference: Internal function for prompt-based segmentation inference.\n generate: Generate segmentation masks for an entire image.\n setup_model: Initialize the SAM model for inference.\n get_model: Build and return a SAM model.\n postprocess: Post-process model outputs to generate final results.\n setup_source: Set up the data source for inference.\n set_image: Set and preprocess a single image for inference.\n get_im_features: Extract image features using the SAM image encoder.\n set_prompts: Set prompts for subsequent inference.\n reset_image: Reset the current image and its features.\n remove_small_regions: Remove small disconnected regions and holes from masks.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.setup_model(model_path=\"sam_model.pt\")\n >>> predictor.set_image(\"image.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> results = predictor(bboxes=bboxes)\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the Predictor with configuration, overrides, and callbacks.\n\n Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or\n callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True\n for optimal results.\n\n Args:\n cfg (dict): Configuration dictionary containing default settings.\n overrides (dict | None): Dictionary of values to override default configuration.\n _callbacks (dict | None): Dictionary of callback functions to customize behavior.\n\n Examples:\n >>> predictor_example = Predictor(cfg=DEFAULT_CFG)\n >>> predictor_example_with_imgsz = Predictor(overrides={\"imgsz\": 640})\n >>> predictor_example_with_callback = Predictor(_callbacks={\"on_predict_start\": custom_callback})\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides.update(dict(task=\"segment\", mode=\"predict\", batch=1))\n super().__init__(cfg, overrides, _callbacks)\n self.args.retina_masks = True\n self.im = None\n self.features = None\n self.prompts = {}\n self.segment_all = False\n\n def preprocess(self, im):\n \"\"\"\n Preprocess the input image for model inference.\n\n This method prepares the input image by applying transformations and normalization. It supports both\n torch.Tensor and list of np.ndarray as input formats.\n\n Args:\n im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.\n\n Returns:\n (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.\n\n Examples:\n >>> predictor = Predictor()\n >>> image = torch.rand(1, 3, 640, 640)\n >>> preprocessed_image = predictor.preprocess(image)\n \"\"\"\n if self.im is not None:\n return self.im\n not_tensor = not isinstance(im, torch.Tensor)\n if not_tensor:\n im = np.stack(self.pre_transform(im))\n im = im[..., ::-1].transpose((0, 3, 1, 2))\n im = np.ascontiguousarray(im)\n im = torch.from_numpy(im)\n\n im = im.to(self.device)\n im = im.half() if self.model.fp16 else im.float()\n if not_tensor:\n im = (im - self.mean) / self.std\n return im\n\n def pre_transform(self, im):\n \"\"\"\n Perform initial transformations on the input image for preprocessing.\n\n This method applies transformations such as resizing to prepare the image for further preprocessing.\n Currently, batched inference is not supported; hence the list length should be 1.\n\n Args:\n im (List[np.ndarray]): List containing a single image in HWC numpy array format.\n\n Returns:\n (List[np.ndarray]): List containing the transformed image.\n\n Raises:\n AssertionError: If the input list contains more than one image.\n\n Examples:\n >>> predictor = Predictor()\n >>> image = np.random.rand(480, 640, 3) # Single HWC image\n >>> transformed = predictor.pre_transform([image])\n >>> print(len(transformed))\n 1\n \"\"\"\n assert len(im) == 1, \"SAM model does not currently support batched inference\"\n letterbox = LetterBox(self.args.imgsz, auto=False, center=False)\n return [letterbox(image=x) for x in im]\n\n def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):\n \"\"\"\n Perform image segmentation inference based on the given input cues, using the currently loaded image.\n\n This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt\n encoder, and mask decoder for real-time and promptable segmentation tasks.\n\n Args:\n im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).\n bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.\n points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.\n labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.\n masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.\n multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.\n *args (Any): Additional positional arguments.\n **kwargs (Any): Additional keyword arguments.\n\n Returns:\n pred_masks (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.\n pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.\n pred_logits (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.setup_model(model_path=\"sam_model.pt\")\n >>> predictor.set_image(\"image.jpg\")\n >>> results = predictor(bboxes=[[0, 0, 100, 100]])\n \"\"\"\n # Override prompts if any stored in self.prompts\n bboxes = self.prompts.pop(\"bboxes\", bboxes)\n points = self.prompts.pop(\"points\", points)\n masks = self.prompts.pop(\"masks\", masks)\n labels = self.prompts.pop(\"labels\", labels)\n\n if all(i is None for i in [bboxes, points, masks]):\n return self.generate(im, *args, **kwargs)\n\n return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)\n\n def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):\n \"\"\"\n Perform image segmentation inference based on input cues using SAM's specialized architecture.\n\n This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.\n It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.\n\n Args:\n im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).\n bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).\n points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.\n labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.\n masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.\n multimask_output (bool): Flag to return multiple masks for ambiguous prompts.\n\n Returns:\n pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.\n pred_scores (np.ndarray): Quality scores predicted by the model for each mask, with length C.\n\n Examples:\n >>> predictor = Predictor()\n >>> im = torch.rand(1, 3, 1024, 1024)\n >>> bboxes = [[100, 100, 200, 200]]\n >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)\n \"\"\"\n features = self.get_im_features(im) if self.features is None else self.features\n\n bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)\n points = (points, labels) if points is not None else None\n # Embed prompts\n sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)\n\n # Predict masks\n pred_masks, pred_scores = self.model.mask_decoder(\n image_embeddings=features,\n image_pe=self.model.prompt_encoder.get_dense_pe(),\n sparse_prompt_embeddings=sparse_embeddings,\n dense_prompt_embeddings=dense_embeddings,\n multimask_output=multimask_output,\n )\n\n # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )\n # `d` could be 1 or 3 depends on `multimask_output`.\n return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)\n\n def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):\n \"\"\"\n Prepare and transform the input prompts for processing based on the destination shape.\n\n Args:\n dst_shape (tuple): The target shape (height, width) for the prompts.\n bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).\n points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.\n labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.\n masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.\n\n Returns:\n bboxes (torch.Tensor | None): Transformed bounding boxes.\n points (torch.Tensor | None): Transformed points.\n labels (torch.Tensor | None): Transformed labels.\n masks (torch.Tensor | None): Transformed masks.\n\n Raises:\n AssertionError: If the number of points don't match the number of labels, in case labels were passed.\n \"\"\"\n src_shape = self.batch[1][0].shape[:2]\n r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])\n # Transform input prompts\n if points is not None:\n points = torch.as_tensor(points, dtype=torch.float32, device=self.device)\n points = points[None] if points.ndim == 1 else points\n # Assuming labels are all positive if users don't pass labels.\n if labels is None:\n labels = np.ones(points.shape[:-1])\n labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)\n assert points.shape[-2] == labels.shape[-1], (\n f\"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}.\"\n )\n points *= r\n if points.ndim == 2:\n # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)\n points, labels = points[:, None, :], labels[:, None]\n if bboxes is not None:\n bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)\n bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes\n bboxes *= r\n if masks is not None:\n masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)\n return bboxes, points, labels, masks\n\n def generate(\n self,\n im,\n crop_n_layers=0,\n crop_overlap_ratio=512 / 1500,\n crop_downscale_factor=1,\n point_grids=None,\n points_stride=32,\n points_batch_size=64,\n conf_thres=0.88,\n stability_score_thresh=0.95,\n stability_score_offset=0.95,\n crop_nms_thresh=0.7,\n ):\n \"\"\"\n Perform image segmentation using the Segment Anything Model (SAM).\n\n This method segments an entire image into constituent parts by leveraging SAM's advanced architecture\n and real-time performance capabilities. It can optionally work on image crops for finer segmentation.\n\n Args:\n im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).\n crop_n_layers (int): Number of layers for additional mask predictions on image crops.\n crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.\n crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.\n point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].\n points_stride (int): Number of points to sample along each side of the image.\n points_batch_size (int): Batch size for the number of points processed simultaneously.\n conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.\n stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.\n stability_score_offset (float): Offset value for calculating stability score.\n crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.\n\n Returns:\n pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).\n pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).\n pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).\n\n Examples:\n >>> predictor = Predictor()\n >>> im = torch.rand(1, 3, 1024, 1024) # Example input image\n >>> masks, scores, boxes = predictor.generate(im)\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n self.segment_all = True\n ih, iw = im.shape[2:]\n crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)\n if point_grids is None:\n point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)\n pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []\n for crop_region, layer_idx in zip(crop_regions, layer_idxs):\n x1, y1, x2, y2 = crop_region\n w, h = x2 - x1, y2 - y1\n area = torch.tensor(w * h, device=im.device)\n points_scale = np.array([[w, h]]) # w, h\n # Crop image and interpolate to input size\n crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode=\"bilinear\", align_corners=False)\n # (num_points, 2)\n points_for_image = point_grids[layer_idx] * points_scale\n crop_masks, crop_scores, crop_bboxes = [], [], []\n for (points,) in batch_iterator(points_batch_size, points_for_image):\n pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)\n # Interpolate predicted masks to input size\n pred_mask = F.interpolate(pred_mask[None], (h, w), mode=\"bilinear\", align_corners=False)[0]\n idx = pred_score > conf_thres\n pred_mask, pred_score = pred_mask[idx], pred_score[idx]\n\n stability_score = calculate_stability_score(\n pred_mask, self.model.mask_threshold, stability_score_offset\n )\n idx = stability_score > stability_score_thresh\n pred_mask, pred_score = pred_mask[idx], pred_score[idx]\n # Bool type is much more memory-efficient.\n pred_mask = pred_mask > self.model.mask_threshold\n # (N, 4)\n pred_bbox = batched_mask_to_box(pred_mask).float()\n keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])\n if not torch.all(keep_mask):\n pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]\n\n crop_masks.append(pred_mask)\n crop_bboxes.append(pred_bbox)\n crop_scores.append(pred_score)\n\n # Do nms within this crop\n crop_masks = torch.cat(crop_masks)\n crop_bboxes = torch.cat(crop_bboxes)\n crop_scores = torch.cat(crop_scores)\n keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS\n crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)\n crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)\n crop_scores = crop_scores[keep]\n\n pred_masks.append(crop_masks)\n pred_bboxes.append(crop_bboxes)\n pred_scores.append(crop_scores)\n region_areas.append(area.expand(len(crop_masks)))\n\n pred_masks = torch.cat(pred_masks)\n pred_bboxes = torch.cat(pred_bboxes)\n pred_scores = torch.cat(pred_scores)\n region_areas = torch.cat(region_areas)\n\n # Remove duplicate masks between crops\n if len(crop_regions) > 1:\n scores = 1 / region_areas\n keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)\n pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]\n\n return pred_masks, pred_scores, pred_bboxes\n\n def setup_model(self, model=None, verbose=True):\n \"\"\"\n Initialize the Segment Anything Model (SAM) for inference.\n\n This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary\n parameters for image normalization and other Ultralytics compatibility settings.\n\n Args:\n model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.\n verbose (bool): If True, prints selected device information.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.setup_model(model=sam_model, verbose=True)\n \"\"\"\n device = select_device(self.args.device, verbose=verbose)\n if model is None:\n model = self.get_model()\n model.eval()\n self.model = model.to(device)\n self.device = device\n self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)\n self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)\n\n # Ultralytics compatibility settings\n self.model.pt = False\n self.model.triton = False\n self.model.stride = 32\n self.model.fp16 = False\n self.done_warmup = True\n\n def get_model(self):\n \"\"\"Retrieve or build the Segment Anything Model (SAM) for image segmentation tasks.\"\"\"\n from .build import build_sam # slow import\n\n return build_sam(self.args.model)\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Post-process SAM's inference outputs to generate object detection masks and bounding boxes.\n\n This method scales masks and boxes to the original image size and applies a threshold to the mask\n predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.\n\n Args:\n preds (tuple): The output from SAM model inference, containing:\n - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).\n - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).\n - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.\n img (torch.Tensor): The processed input image tensor with shape (C, H, W).\n orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.\n\n Returns:\n (List[Results]): List of Results objects containing detection masks, bounding boxes, and other\n metadata for each processed image.\n\n Examples:\n >>> predictor = Predictor()\n >>> preds = predictor.inference(img)\n >>> results = predictor.postprocess(preds, img, orig_imgs)\n \"\"\"\n # (N, 1, H, W), (N, 1)\n pred_masks, pred_scores = preds[:2]\n pred_bboxes = preds[2] if self.segment_all else None\n names = dict(enumerate(str(i) for i in range(len(pred_masks))))\n\n if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list\n orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)\n\n results = []\n for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):\n if len(masks) == 0:\n masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)\n else:\n masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]\n masks = masks > self.model.mask_threshold # to bool\n if pred_bboxes is not None:\n pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)\n else:\n pred_bboxes = batched_mask_to_box(masks)\n # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.\n cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)\n pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)\n results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))\n # Reset segment-all mode.\n self.segment_all = False\n return results\n\n def setup_source(self, source):\n \"\"\"\n Set up the data source for inference.\n\n This method configures the data source from which images will be fetched for inference. It supports\n various input types such as image files, directories, video files, and other compatible data sources.\n\n Args:\n source (str | Path | None): The path or identifier for the image data source. Can be a file path,\n directory path, URL, or other supported source types.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.setup_source(\"path/to/images\")\n >>> predictor.setup_source(\"video.mp4\")\n >>> predictor.setup_source(None) # Uses default source if available\n\n Notes:\n - If source is None, the method may use a default source if configured.\n - The method adapts to different source types and prepares them for subsequent inference steps.\n - Supported source types may include local files, directories, URLs, and video streams.\n \"\"\"\n if source is not None:\n super().setup_source(source)\n\n def set_image(self, image):\n \"\"\"\n Preprocess and set a single image for inference.\n\n This method prepares the model for inference on a single image by setting up the model if not already\n initialized, configuring the data source, and preprocessing the image for feature extraction. It\n ensures that only one image is set at a time and extracts image features for subsequent use.\n\n Args:\n image (str | np.ndarray): Path to the image file as a string, or a numpy array representing\n an image read by cv2.\n\n Examples:\n >>> predictor = Predictor()\n >>> predictor.set_image(\"path/to/image.jpg\")\n >>> predictor.set_image(cv2.imread(\"path/to/image.jpg\"))\n\n Raises:\n AssertionError: If more than one image is attempted to be set.\n\n Notes:\n - This method should be called before performing inference on a new image.\n - The extracted features are stored in the `self.features` attribute for later use.\n \"\"\"\n if self.model is None:\n self.setup_model(model=None)\n self.setup_source(image)\n assert len(self.dataset) == 1, \"`set_image` only supports setting one image!\"\n for batch in self.dataset:\n im = self.preprocess(batch[1])\n self.features = self.get_im_features(im)\n break\n\n def get_im_features(self, im):\n \"\"\"Extract image features using the SAM model's image encoder for subsequent mask prediction.\"\"\"\n assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (\n f\"SAM models only support square image size, but got {self.imgsz}.\"\n )\n self.model.set_imgsz(self.imgsz)\n return self.model.image_encoder(im)\n\n def set_prompts(self, prompts):\n \"\"\"Set prompts for subsequent inference operations.\"\"\"\n self.prompts = prompts\n\n def reset_image(self):\n \"\"\"Reset the current image and its features, clearing them for subsequent inference.\"\"\"\n self.im = None\n self.features = None\n\n @staticmethod\n def remove_small_regions(masks, min_area=0, nms_thresh=0.7):\n \"\"\"\n Remove small disconnected regions and holes from segmentation masks.\n\n This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).\n It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum\n Suppression (NMS) to eliminate any newly created duplicate boxes.\n\n Args:\n masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of\n masks, H is height, and W is width.\n min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than\n this will be removed.\n nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.\n\n Returns:\n new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).\n keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.\n\n Examples:\n >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks\n >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)\n >>> print(f\"Original masks: {masks.shape}, Processed masks: {new_masks.shape}\")\n >>> print(f\"Indices of kept masks: {keep}\")\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n if len(masks) == 0:\n return masks\n\n # Filter small disconnected regions and holes\n new_masks = []\n scores = []\n for mask in masks:\n mask = mask.cpu().numpy().astype(np.uint8)\n mask, changed = remove_small_regions(mask, min_area, mode=\"holes\")\n unchanged = not changed\n mask, changed = remove_small_regions(mask, min_area, mode=\"islands\")\n unchanged = unchanged and not changed\n\n new_masks.append(torch.as_tensor(mask).unsqueeze(0))\n # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing\n scores.append(float(unchanged))\n\n # Recalculate boxes and remove any new duplicates\n new_masks = torch.cat(new_masks, dim=0)\n boxes = batched_mask_to_box(new_masks)\n keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)\n\n return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep",
"chunk_type": "class",
"name": "Predictor",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 36,
"end_line": 621,
"start_col": 0,
"end_col": 79,
"parent_name": null,
"docstring": "Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.\n\nThis class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image\nsegmentation tasks. It supports various input prompts like points, bounding boxes, and masks for\nfine-grained control over segmentation results.\n\nAttributes:\n args (SimpleNamespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded SAM model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n im (torch.Tensor): The preprocessed input image.\n features (torch.Tensor): Extracted image features.\n prompts (Dict[str, Any]): Dictionary to store various types of prompts (e.g., bboxes, points, masks).\n segment_all (bool): Flag to indicate if full image segmentation should be performed.\n mean (torch.Tensor): Mean values for image normalization.\n std (torch.Tensor): Standard deviation values for image normalization.\n\nMethods:\n preprocess: Prepare input images for model inference.\n pre_transform: Perform initial transformations on the input image.\n inference: Perform segmentation inference based on input prompts.\n prompt_inference: Internal function for prompt-based segmentation inference.\n generate: Generate segmentation masks for an entire image.\n setup_model: Initialize the SAM model for inference.\n get_model: Build and return a SAM model.\n postprocess: Post-process model outputs to generate final results.\n setup_source: Set up the data source for inference.\n set_image: Set and preprocess a single image for inference.\n get_im_features: Extract image features using the SAM image encoder.\n set_prompts: Set prompts for subsequent inference.\n reset_image: Reset the current image and its features.\n remove_small_regions: Remove small disconnected regions and holes from masks.\n\nExamples:\n >>> predictor = Predictor()\n >>> predictor.setup_model(model_path=\"sam_model.pt\")\n >>> predictor.set_image(\"image.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> results = predictor(bboxes=bboxes)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.OrderedDict",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.data.augment.LetterBox",
"ultralytics.engine.predictor.BasePredictor",
"ultralytics.engine.results.Results",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.ops",
"ultralytics.utils.torch_utils.select_device",
"ultralytics.utils.torch_utils.smart_inference_mode",
"amg.batch_iterator",
"amg.batched_mask_to_box",
"amg.build_all_layer_point_grids",
"amg.calculate_stability_score",
"amg.generate_crop_boxes",
"amg.is_box_near_crop_edge",
"amg.remove_small_regions",
"amg.uncrop_boxes_xyxy",
"amg.uncrop_masks",
"torchvision",
"build.build_sam",
"torchvision",
"build.build_sam",
"BasePredictor"
],
"chunk_id": "class_Predictor_3cfa213f"
},
{
"content": "class SAM2Predictor(Predictor):\n \"\"\"\n SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.\n\n This class extends the base Predictor class to implement SAM2-specific functionality for image\n segmentation tasks. It provides methods for model initialization, feature extraction, and\n prompt-based inference.\n\n Attributes:\n _bb_feat_sizes (List[tuple]): Feature sizes for different backbone levels.\n model (torch.nn.Module): The loaded SAM2 model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n features (dict): Cached image features for efficient inference.\n segment_all (bool): Flag to indicate if all segments should be predicted.\n prompts (Dict[str, Any]): Dictionary to store various types of prompts for inference.\n\n Methods:\n get_model: Retrieve and initialize the SAM2 model.\n prompt_inference: Perform image segmentation inference based on various prompts.\n set_image: Preprocess and set a single image for inference.\n get_im_features: Extract and process image features using SAM2's image encoder.\n\n Examples:\n >>> predictor = SAM2Predictor(cfg)\n >>> predictor.set_image(\"path/to/image.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> result = predictor(bboxes=bboxes)[0]\n >>> print(f\"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}\")\n \"\"\"\n\n _bb_feat_sizes = [\n (256, 256),\n (128, 128),\n (64, 64),\n ]\n\n def get_model(self):\n \"\"\"Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks.\"\"\"\n from .build import build_sam # slow import\n\n return build_sam(self.args.model)\n\n def prompt_inference(\n self,\n im,\n bboxes=None,\n points=None,\n labels=None,\n masks=None,\n multimask_output=False,\n img_idx=-1,\n ):\n \"\"\"\n Perform image segmentation inference based on various prompts using SAM2 architecture.\n\n This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images\n based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and\n multi-object prediction scenarios.\n\n Args:\n im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).\n bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).\n points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.\n labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.\n masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).\n multimask_output (bool): Flag to return multiple masks for ambiguous prompts.\n img_idx (int): Index of the image in the batch to process.\n\n Returns:\n pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.\n pred_scores (np.ndarray): Quality scores for each mask, with length C.\n\n Examples:\n >>> predictor = SAM2Predictor(cfg)\n >>> image = torch.rand(1, 3, 640, 640)\n >>> bboxes = [[100, 100, 200, 200]]\n >>> result = predictor(image, bboxes=bboxes)[0]\n >>> print(f\"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}\")\n\n Notes:\n - The method supports batched inference for multiple objects when points or bboxes are provided.\n - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.\n - When both bboxes and points are provided, they are merged into a single 'points' input for the model.\n \"\"\"\n features = self.get_im_features(im) if self.features is None else self.features\n\n points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)\n points = (points, labels) if points is not None else None\n\n sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(\n points=points,\n boxes=None,\n masks=masks,\n )\n # Predict masks\n batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction\n high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features[\"high_res_feats\"]]\n pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(\n image_embeddings=features[\"image_embed\"][img_idx].unsqueeze(0),\n image_pe=self.model.sam_prompt_encoder.get_dense_pe(),\n sparse_prompt_embeddings=sparse_embeddings,\n dense_prompt_embeddings=dense_embeddings,\n multimask_output=multimask_output,\n repeat_image=batched_mode,\n high_res_features=high_res_features,\n )\n # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )\n # `d` could be 1 or 3 depends on `multimask_output`.\n return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)\n\n def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):\n \"\"\"\n Prepare and transform the input prompts for processing based on the destination shape.\n\n Args:\n dst_shape (tuple): The target shape (height, width) for the prompts.\n bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).\n points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.\n labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.\n masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.\n\n Returns:\n points (torch.Tensor | None): Transformed points.\n labels (torch.Tensor | None): Transformed labels.\n masks (torch.Tensor | None): Transformed masks.\n\n Raises:\n AssertionError: If the number of points don't match the number of labels, in case labels were passed.\n \"\"\"\n bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)\n if bboxes is not None:\n bboxes = bboxes.view(-1, 2, 2)\n bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)\n # NOTE: merge \"boxes\" and \"points\" into a single \"points\" input\n # (where boxes are added at the beginning) to model.sam_prompt_encoder\n if points is not None:\n points = torch.cat([bboxes, points], dim=1)\n labels = torch.cat([bbox_labels, labels], dim=1)\n else:\n points, labels = bboxes, bbox_labels\n return points, labels, masks\n\n def set_image(self, image):\n \"\"\"\n Preprocess and set a single image for inference using the SAM2 model.\n\n This method initializes the model if not already done, configures the data source to the specified image,\n and preprocesses the image for feature extraction. It supports setting only one image at a time.\n\n Args:\n image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.\n\n Examples:\n >>> predictor = SAM2Predictor()\n >>> predictor.set_image(\"path/to/image.jpg\")\n >>> predictor.set_image(np.array([...])) # Using a numpy array\n\n Raises:\n AssertionError: If more than one image is attempted to be set.\n\n Notes:\n - This method must be called before performing any inference on a new image.\n - The method caches the extracted features for efficient subsequent inferences on the same image.\n - Only one image can be set at a time. To process multiple images, call this method for each new image.\n \"\"\"\n if self.model is None:\n self.setup_model(model=None)\n self.setup_source(image)\n assert len(self.dataset) == 1, \"`set_image` only supports setting one image!\"\n for batch in self.dataset:\n im = self.preprocess(batch[1])\n self.features = self.get_im_features(im)\n break\n\n def get_im_features(self, im):\n \"\"\"Extract image features from the SAM image encoder for subsequent processing.\"\"\"\n assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (\n f\"SAM 2 models only support square image size, but got {self.imgsz}.\"\n )\n self.model.set_imgsz(self.imgsz)\n self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]\n\n backbone_out = self.model.forward_image(im)\n _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)\n if self.model.directly_add_no_mem_embed:\n vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed\n feats = [\n feat.permute(1, 2, 0).view(1, -1, *feat_size)\n for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])\n ][::-1]\n return {\"image_embed\": feats[-1], \"high_res_feats\": feats[:-1]}",
"chunk_type": "class",
"name": "SAM2Predictor",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 624,
"end_line": 814,
"start_col": 0,
"end_col": 71,
"parent_name": null,
"docstring": "SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.\n\nThis class extends the base Predictor class to implement SAM2-specific functionality for image\nsegmentation tasks. It provides methods for model initialization, feature extraction, and\nprompt-based inference.\n\nAttributes:\n _bb_feat_sizes (List[tuple]): Feature sizes for different backbone levels.\n model (torch.nn.Module): The loaded SAM2 model.\n device (torch.device): The device (CPU or GPU) on which the model is loaded.\n features (dict): Cached image features for efficient inference.\n segment_all (bool): Flag to indicate if all segments should be predicted.\n prompts (Dict[str, Any]): Dictionary to store various types of prompts for inference.\n\nMethods:\n get_model: Retrieve and initialize the SAM2 model.\n prompt_inference: Perform image segmentation inference based on various prompts.\n set_image: Preprocess and set a single image for inference.\n get_im_features: Extract and process image features using SAM2's image encoder.\n\nExamples:\n >>> predictor = SAM2Predictor(cfg)\n >>> predictor.set_image(\"path/to/image.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> result = predictor(bboxes=bboxes)[0]\n >>> print(f\"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.OrderedDict",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.data.augment.LetterBox",
"ultralytics.engine.predictor.BasePredictor",
"ultralytics.engine.results.Results",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.ops",
"ultralytics.utils.torch_utils.select_device",
"ultralytics.utils.torch_utils.smart_inference_mode",
"amg.batch_iterator",
"amg.batched_mask_to_box",
"amg.build_all_layer_point_grids",
"amg.calculate_stability_score",
"amg.generate_crop_boxes",
"amg.is_box_near_crop_edge",
"amg.remove_small_regions",
"amg.uncrop_boxes_xyxy",
"amg.uncrop_masks",
"torchvision",
"build.build_sam",
"torchvision",
"build.build_sam",
"Predictor"
],
"chunk_id": "class_SAM2Predictor_0545c371"
},
{
"content": "class SAM2VideoPredictor(SAM2Predictor):\n \"\"\"\n SAM2VideoPredictor to handle user interactions with videos and manage inference states.\n\n This class extends the functionality of SAM2Predictor to support video processing and maintains\n the state of inference operations. It includes configurations for managing non-overlapping masks,\n clearing memory for non-conditional inputs, and setting up callbacks for prediction events.\n\n Attributes:\n inference_state (dict): A dictionary to store the current state of inference operations.\n non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.\n clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.\n clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.\n callbacks (dict): A dictionary of callbacks for various prediction lifecycle events.\n\n Methods:\n get_model: Retrieve and configure the model with binarization enabled.\n inference: Perform image segmentation inference based on the given input cues.\n postprocess: Post-process the predictions to apply non-overlapping constraints if required.\n add_new_prompts: Add new points or masks to a specific frame for a given object ID.\n propagate_in_video_preflight: Prepare inference_state and consolidate temporary outputs before tracking.\n init_state: Initialize an inference state for the predictor.\n get_im_features: Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.\n\n Examples:\n >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)\n >>> predictor.set_image(\"path/to/video_frame.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> results = predictor(bboxes=bboxes)\n\n Note:\n The `fill_hole_area` attribute is defined but not used in the current implementation.\n \"\"\"\n\n # fill_hole_area = 8 # not used\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the predictor with configuration and optional overrides.\n\n This constructor initializes the SAM2VideoPredictor with a given configuration, applies any\n specified overrides, and sets up the inference state along with certain flags\n that control the behavior of the predictor.\n\n Args:\n cfg (dict): Configuration dictionary containing default settings.\n overrides (dict | None): Dictionary of values to override default configuration.\n _callbacks (dict | None): Dictionary of callback functions to customize behavior.\n\n Examples:\n >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)\n >>> predictor_example_with_imgsz = SAM2VideoPredictor(overrides={\"imgsz\": 640})\n >>> predictor_example_with_callback = SAM2VideoPredictor(_callbacks={\"on_predict_start\": custom_callback})\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.inference_state = {}\n self.non_overlap_masks = True\n self.clear_non_cond_mem_around_input = False\n self.clear_non_cond_mem_for_multi_obj = False\n self.callbacks[\"on_predict_start\"].append(self.init_state)\n\n def get_model(self):\n \"\"\"\n Retrieve and configure the model with binarization enabled.\n\n Note:\n This method overrides the base class implementation to set the binarize flag to True.\n \"\"\"\n model = super().get_model()\n model.set_binarize(True)\n return model\n\n def inference(self, im, bboxes=None, points=None, labels=None, masks=None):\n \"\"\"\n Perform image segmentation inference based on the given input cues, using the currently loaded image. This\n method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and\n mask decoder for real-time and promptable segmentation tasks.\n\n Args:\n im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).\n bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.\n points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.\n labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.\n masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.\n\n Returns:\n pred_masks (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.\n pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.\n \"\"\"\n # Override prompts if any stored in self.prompts\n bboxes = self.prompts.pop(\"bboxes\", bboxes)\n points = self.prompts.pop(\"points\", points)\n masks = self.prompts.pop(\"masks\", masks)\n\n frame = self.dataset.frame\n self.inference_state[\"im\"] = im\n output_dict = self.inference_state[\"output_dict\"]\n if len(output_dict[\"cond_frame_outputs\"]) == 0: # initialize prompts\n points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)\n if points is not None:\n for i in range(len(points)):\n self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)\n elif masks is not None:\n for i in range(len(masks)):\n self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)\n self.propagate_in_video_preflight()\n\n consolidated_frame_inds = self.inference_state[\"consolidated_frame_inds\"]\n batch_size = len(self.inference_state[\"obj_idx_to_id\"])\n if len(output_dict[\"cond_frame_outputs\"]) == 0:\n raise RuntimeError(\"No points are provided; please add points first\")\n\n if frame in consolidated_frame_inds[\"cond_frame_outputs\"]:\n storage_key = \"cond_frame_outputs\"\n current_out = output_dict[storage_key][frame]\n if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):\n # clear non-conditioning memory of the surrounding frames\n self._clear_non_cond_mem_around_input(frame)\n elif frame in consolidated_frame_inds[\"non_cond_frame_outputs\"]:\n storage_key = \"non_cond_frame_outputs\"\n current_out = output_dict[storage_key][frame]\n else:\n storage_key = \"non_cond_frame_outputs\"\n current_out = self._run_single_frame_inference(\n output_dict=output_dict,\n frame_idx=frame,\n batch_size=batch_size,\n is_init_cond_frame=False,\n point_inputs=None,\n mask_inputs=None,\n reverse=False,\n run_mem_encoder=True,\n )\n output_dict[storage_key][frame] = current_out\n # Create slices of per-object outputs for subsequent interaction with each\n # individual object after tracking.\n self._add_output_per_object(frame, current_out, storage_key)\n self.inference_state[\"frames_already_tracked\"].append(frame)\n pred_masks = current_out[\"pred_masks\"].flatten(0, 1)\n pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks\n\n return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Post-process the predictions to apply non-overlapping constraints if required.\n\n This method extends the post-processing functionality by applying non-overlapping constraints\n to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that\n the masks do not overlap, which can be useful for certain applications.\n\n Args:\n preds (tuple): The predictions from the model.\n img (torch.Tensor): The processed image tensor.\n orig_imgs (List[np.ndarray]): The original images before processing.\n\n Returns:\n (list): The post-processed predictions.\n\n Note:\n If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.\n \"\"\"\n results = super().postprocess(preds, img, orig_imgs)\n if self.non_overlap_masks:\n for result in results:\n if result.masks is None or len(result.masks) == 0:\n continue\n result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]\n return results\n\n @smart_inference_mode()\n def add_new_prompts(\n self,\n obj_id,\n points=None,\n labels=None,\n masks=None,\n frame_idx=0,\n ):\n \"\"\"\n Add new points or masks to a specific frame for a given object ID.\n\n This method updates the inference state with new prompts (points or masks) for a specified\n object and frame index. It ensures that the prompts are either points or masks, but not both,\n and updates the internal state accordingly. It also handles the generation of new segmentations\n based on the provided prompts and the existing state.\n\n Args:\n obj_id (int): The ID of the object to which the prompts are associated.\n points (torch.Tensor, optional): The coordinates of the points of interest.\n labels (torch.Tensor, optional): The labels corresponding to the points.\n masks (torch.Tensor, optional): Binary masks for the object.\n frame_idx (int, optional): The index of the frame to which the prompts are applied.\n\n Returns:\n pred_masks (torch.Tensor): The flattened predicted masks.\n pred_scores (torch.Tensor): A tensor of ones indicating the number of objects.\n\n Raises:\n AssertionError: If both `masks` and `points` are provided, or neither is provided.\n\n Note:\n - Only one type of prompt (either points or masks) can be added per call.\n - If the frame is being tracked for the first time, it is treated as an initial conditioning frame.\n - The method handles the consolidation of outputs and resizing of masks to the original video resolution.\n \"\"\"\n assert (masks is None) ^ (points is None), \"'masks' and 'points' prompts are not compatible with each other.\"\n obj_idx = self._obj_id_to_idx(obj_id)\n\n point_inputs = None\n pop_key = \"point_inputs_per_obj\"\n if points is not None:\n point_inputs = {\"point_coords\": points, \"point_labels\": labels}\n self.inference_state[\"point_inputs_per_obj\"][obj_idx][frame_idx] = point_inputs\n pop_key = \"mask_inputs_per_obj\"\n self.inference_state[\"mask_inputs_per_obj\"][obj_idx][frame_idx] = masks\n self.inference_state[pop_key][obj_idx].pop(frame_idx, None)\n # If this frame hasn't been tracked before, we treat it as an initial conditioning\n # frame, meaning that the inputs points are to generate segments on this frame without\n # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),\n # the input points will be used to correct the already tracked masks.\n is_init_cond_frame = frame_idx not in self.inference_state[\"frames_already_tracked\"]\n obj_output_dict = self.inference_state[\"output_dict_per_obj\"][obj_idx]\n obj_temp_output_dict = self.inference_state[\"temp_output_dict_per_obj\"][obj_idx]\n # Add a frame to conditioning output if it's an initial conditioning frame or\n # if the model sees all frames receiving clicks/mask as conditioning frames.\n is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond\n storage_key = \"cond_frame_outputs\" if is_cond else \"non_cond_frame_outputs\"\n\n # Get any previously predicted mask logits on this object and feed it along with\n # the new clicks into the SAM mask decoder.\n prev_sam_mask_logits = None\n # lookup temporary output dict first, which contains the most recent output\n # (if not found, then lookup conditioning and non-conditioning frame output)\n if point_inputs is not None:\n prev_out = (\n obj_temp_output_dict[storage_key].get(frame_idx)\n or obj_output_dict[\"cond_frame_outputs\"].get(frame_idx)\n or obj_output_dict[\"non_cond_frame_outputs\"].get(frame_idx)\n )\n\n if prev_out is not None and prev_out.get(\"pred_masks\") is not None:\n prev_sam_mask_logits = prev_out[\"pred_masks\"].to(device=self.device, non_blocking=True)\n # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.\n prev_sam_mask_logits.clamp_(-32.0, 32.0)\n current_out = self._run_single_frame_inference(\n output_dict=obj_output_dict, # run on the slice of a single object\n frame_idx=frame_idx,\n batch_size=1, # run on the slice of a single object\n is_init_cond_frame=is_init_cond_frame,\n point_inputs=point_inputs,\n mask_inputs=masks,\n reverse=False,\n # Skip the memory encoder when adding clicks or mask. We execute the memory encoder\n # at the beginning of `propagate_in_video` (after user finalize their clicks). This\n # allows us to enforce non-overlapping constraints on all objects before encoding\n # them into memory.\n run_mem_encoder=False,\n prev_sam_mask_logits=prev_sam_mask_logits,\n )\n # Add the output to the output dict (to be used as future memory)\n obj_temp_output_dict[storage_key][frame_idx] = current_out\n\n # Resize the output mask to the original video resolution\n consolidated_out = self._consolidate_temp_output_across_obj(\n frame_idx,\n is_cond=is_cond,\n run_mem_encoder=False,\n )\n pred_masks = consolidated_out[\"pred_masks\"].flatten(0, 1)\n return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)\n\n @smart_inference_mode()\n def propagate_in_video_preflight(self):\n \"\"\"\n Prepare inference_state and consolidate temporary outputs before tracking.\n\n This method marks the start of tracking, disallowing the addition of new objects until the session is reset.\n It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.\n Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent\n with the provided inputs.\n \"\"\"\n # Tracking has started and we don't allow adding new objects until session is reset.\n self.inference_state[\"tracking_has_started\"] = True\n batch_size = len(self.inference_state[\"obj_idx_to_id\"])\n\n # Consolidate per-object temporary outputs in \"temp_output_dict_per_obj\" and\n # add them into \"output_dict\".\n temp_output_dict_per_obj = self.inference_state[\"temp_output_dict_per_obj\"]\n output_dict = self.inference_state[\"output_dict\"]\n # \"consolidated_frame_inds\" contains indices of those frames where consolidated\n # temporary outputs have been added (either in this call or any previous calls\n # to `propagate_in_video_preflight`).\n consolidated_frame_inds = self.inference_state[\"consolidated_frame_inds\"]\n for is_cond in {False, True}:\n # Separately consolidate conditioning and non-conditioning temp outputs\n storage_key = \"cond_frame_outputs\" if is_cond else \"non_cond_frame_outputs\"\n # Find all the frames that contain temporary outputs for any objects\n # (these should be the frames that have just received clicks for mask inputs\n # via `add_new_points` or `add_new_mask`)\n temp_frame_inds = set()\n for obj_temp_output_dict in temp_output_dict_per_obj.values():\n temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())\n consolidated_frame_inds[storage_key].update(temp_frame_inds)\n # consolidate the temporary output across all objects on this frame\n for frame_idx in temp_frame_inds:\n consolidated_out = self._consolidate_temp_output_across_obj(\n frame_idx, is_cond=is_cond, run_mem_encoder=True\n )\n # merge them into \"output_dict\" and also create per-object slices\n output_dict[storage_key][frame_idx] = consolidated_out\n self._add_output_per_object(frame_idx, consolidated_out, storage_key)\n if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):\n # clear non-conditioning memory of the surrounding frames\n self._clear_non_cond_mem_around_input(frame_idx)\n\n # clear temporary outputs in `temp_output_dict_per_obj`\n for obj_temp_output_dict in temp_output_dict_per_obj.values():\n obj_temp_output_dict[storage_key].clear()\n\n # edge case: if an output is added to \"cond_frame_outputs\", we remove any prior\n # output on the same frame in \"non_cond_frame_outputs\"\n for frame_idx in output_dict[\"cond_frame_outputs\"]:\n output_dict[\"non_cond_frame_outputs\"].pop(frame_idx, None)\n for obj_output_dict in self.inference_state[\"output_dict_per_obj\"].values():\n for frame_idx in obj_output_dict[\"cond_frame_outputs\"]:\n obj_output_dict[\"non_cond_frame_outputs\"].pop(frame_idx, None)\n for frame_idx in consolidated_frame_inds[\"cond_frame_outputs\"]:\n assert frame_idx in output_dict[\"cond_frame_outputs\"]\n consolidated_frame_inds[\"non_cond_frame_outputs\"].discard(frame_idx)\n\n # Make sure that the frame indices in \"consolidated_frame_inds\" are exactly those frames\n # with either points or mask inputs (which should be true under a correct workflow).\n all_consolidated_frame_inds = (\n consolidated_frame_inds[\"cond_frame_outputs\"] | consolidated_frame_inds[\"non_cond_frame_outputs\"]\n )\n input_frames_inds = set()\n for point_inputs_per_frame in self.inference_state[\"point_inputs_per_obj\"].values():\n input_frames_inds.update(point_inputs_per_frame.keys())\n for mask_inputs_per_frame in self.inference_state[\"mask_inputs_per_obj\"].values():\n input_frames_inds.update(mask_inputs_per_frame.keys())\n assert all_consolidated_frame_inds == input_frames_inds\n\n @staticmethod\n def init_state(predictor):\n \"\"\"\n Initialize an inference state for the predictor.\n\n This function sets up the initial state required for performing inference on video data.\n It includes initializing various dictionaries and ordered dictionaries that will store\n inputs, outputs, and other metadata relevant to the tracking process.\n\n Args:\n predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.\n \"\"\"\n if len(predictor.inference_state) > 0: # means initialized\n return\n assert predictor.dataset is not None\n assert predictor.dataset.mode == \"video\"\n\n inference_state = {\n \"num_frames\": predictor.dataset.frames,\n \"point_inputs_per_obj\": {}, # inputs points on each frame\n \"mask_inputs_per_obj\": {}, # inputs mask on each frame\n \"constants\": {}, # values that don't change across frames (so we only need to hold one copy of them)\n # mapping between client-side object id and model-side object index\n \"obj_id_to_idx\": OrderedDict(),\n \"obj_idx_to_id\": OrderedDict(),\n \"obj_ids\": [],\n # A storage to hold the model's tracking results and states on each frame\n \"output_dict\": {\n \"cond_frame_outputs\": {}, # dict containing {frame_idx: }\n \"non_cond_frame_outputs\": {}, # dict containing {frame_idx: }\n },\n # Slice (view) of each object tracking results, sharing the same memory with \"output_dict\"\n \"output_dict_per_obj\": {},\n # A temporary storage to hold new outputs when user interact with a frame\n # to add clicks or mask (it's merged into \"output_dict\" before propagation starts)\n \"temp_output_dict_per_obj\": {},\n # Frames that already holds consolidated outputs from click or mask inputs\n # (we directly use their consolidated outputs during tracking)\n \"consolidated_frame_inds\": {\n \"cond_frame_outputs\": set(), # set containing frame indices\n \"non_cond_frame_outputs\": set(), # set containing frame indices\n },\n # metadata for each tracking frame (e.g. which direction it's tracked)\n \"tracking_has_started\": False,\n \"frames_already_tracked\": [],\n }\n predictor.inference_state = inference_state\n\n def get_im_features(self, im, batch=1):\n \"\"\"\n Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.\n\n Args:\n im (torch.Tensor): The input image tensor.\n batch (int, optional): The batch size for expanding features if there are multiple prompts.\n\n Returns:\n vis_feats (torch.Tensor): The visual features extracted from the image.\n vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.\n feat_sizes (List[tuple]): A list containing the sizes of the extracted features.\n\n Note:\n - If `batch` is greater than 1, the features are expanded to fit the batch size.\n - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.\n \"\"\"\n backbone_out = self.model.forward_image(im)\n if batch > 1: # expand features if there's more than one prompt\n for i, feat in enumerate(backbone_out[\"backbone_fpn\"]):\n backbone_out[\"backbone_fpn\"][i] = feat.expand(batch, -1, -1, -1)\n for i, pos in enumerate(backbone_out[\"vision_pos_enc\"]):\n pos = pos.expand(batch, -1, -1, -1)\n backbone_out[\"vision_pos_enc\"][i] = pos\n _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)\n return vis_feats, vis_pos_embed, feat_sizes\n\n def _obj_id_to_idx(self, obj_id):\n \"\"\"\n Map client-side object id to model-side object index.\n\n Args:\n obj_id (int): The unique identifier of the object provided by the client side.\n\n Returns:\n (int): The index of the object on the model side.\n\n Raises:\n RuntimeError: If an attempt is made to add a new object after tracking has started.\n\n Note:\n - The method updates or retrieves mappings between object IDs and indices stored in\n `inference_state`.\n - It ensures that new objects can only be added before tracking commences.\n - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).\n - Additional data structures are initialized for the new object to store inputs and outputs.\n \"\"\"\n obj_idx = self.inference_state[\"obj_id_to_idx\"].get(obj_id, None)\n if obj_idx is not None:\n return obj_idx\n\n # This is a new object id not sent to the server before. We only allow adding\n # new objects *before* the tracking starts.\n allow_new_object = not self.inference_state[\"tracking_has_started\"]\n if allow_new_object:\n # get the next object slot\n obj_idx = len(self.inference_state[\"obj_id_to_idx\"])\n self.inference_state[\"obj_id_to_idx\"][obj_id] = obj_idx\n self.inference_state[\"obj_idx_to_id\"][obj_idx] = obj_id\n self.inference_state[\"obj_ids\"] = list(self.inference_state[\"obj_id_to_idx\"])\n # set up input and output structures for this object\n self.inference_state[\"point_inputs_per_obj\"][obj_idx] = {}\n self.inference_state[\"mask_inputs_per_obj\"][obj_idx] = {}\n self.inference_state[\"output_dict_per_obj\"][obj_idx] = {\n \"cond_frame_outputs\": {}, # dict containing {frame_idx: }\n \"non_cond_frame_outputs\": {}, # dict containing {frame_idx: }\n }\n self.inference_state[\"temp_output_dict_per_obj\"][obj_idx] = {\n \"cond_frame_outputs\": {}, # dict containing {frame_idx: }\n \"non_cond_frame_outputs\": {}, # dict containing {frame_idx: }\n }\n return obj_idx\n else:\n raise RuntimeError(\n f\"Cannot add new object id {obj_id} after tracking starts. \"\n f\"All existing object ids: {self.inference_state['obj_ids']}. \"\n f\"Please call 'reset_state' to restart from scratch.\"\n )\n\n def _run_single_frame_inference(\n self,\n output_dict,\n frame_idx,\n batch_size,\n is_init_cond_frame,\n point_inputs,\n mask_inputs,\n reverse,\n run_mem_encoder,\n prev_sam_mask_logits=None,\n ):\n \"\"\"\n Run tracking on a single frame based on current inputs and previous memory.\n\n Args:\n output_dict (dict): The dictionary containing the output states of the tracking process.\n frame_idx (int): The index of the current frame.\n batch_size (int): The batch size for processing the frame.\n is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.\n point_inputs (dict | None): Input points and their labels.\n mask_inputs (torch.Tensor | None): Input binary masks.\n reverse (bool): Indicates if the tracking should be performed in reverse order.\n run_mem_encoder (bool): Indicates if the memory encoder should be executed.\n prev_sam_mask_logits (torch.Tensor | None): Previous mask logits for the current object.\n\n Returns:\n (dict): A dictionary containing the output of the tracking step, including updated features and predictions.\n\n Raises:\n AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.\n\n Note:\n - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.\n - The method retrieves image features using the `get_im_features` method.\n - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.\n - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.\n \"\"\"\n # Retrieve correct image features\n current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(\n self.inference_state[\"im\"], batch_size\n )\n\n # point and mask should not appear as input simultaneously on the same frame\n assert point_inputs is None or mask_inputs is None\n current_out = self.model.track_step(\n frame_idx=frame_idx,\n is_init_cond_frame=is_init_cond_frame,\n current_vision_feats=current_vision_feats,\n current_vision_pos_embeds=current_vision_pos_embeds,\n feat_sizes=feat_sizes,\n point_inputs=point_inputs,\n mask_inputs=mask_inputs,\n output_dict=output_dict,\n num_frames=self.inference_state[\"num_frames\"],\n track_in_reverse=reverse,\n run_mem_encoder=run_mem_encoder,\n prev_sam_mask_logits=prev_sam_mask_logits,\n )\n\n maskmem_features = current_out[\"maskmem_features\"]\n if maskmem_features is not None:\n current_out[\"maskmem_features\"] = maskmem_features.to(\n dtype=torch.float16, device=self.device, non_blocking=True\n )\n # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions\n # potentially fill holes in the predicted masks\n # if self.fill_hole_area > 0:\n # pred_masks = current_out[\"pred_masks\"].to(self.device, non_blocking=True)\n # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)\n\n # \"maskmem_pos_enc\" is the same across frames, so we only need to store one copy of it\n current_out[\"maskmem_pos_enc\"] = self._get_maskmem_pos_enc(current_out[\"maskmem_pos_enc\"])\n return current_out\n\n def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):\n \"\"\"\n Cache and manage the positional encoding for mask memory across frames and objects.\n\n This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for\n mask memory, which is constant across frames and objects, thus reducing the amount of\n redundant information stored during an inference session. It checks if the positional\n encoding has already been cached; if not, it caches a slice of the provided encoding.\n If the batch size is greater than one, it expands the cached positional encoding to match\n the current batch size.\n\n Args:\n out_maskmem_pos_enc (List[torch.Tensor] | None): The positional encoding for mask memory.\n Should be a list of tensors or None.\n\n Returns:\n (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.\n\n Note:\n - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.\n - Only a single object's slice is cached since the encoding is the same across objects.\n - The method checks if the positional encoding has already been cached in the session's constants.\n - If the batch size is greater than one, the cached encoding is expanded to fit the batch size.\n \"\"\"\n model_constants = self.inference_state[\"constants\"]\n # \"out_maskmem_pos_enc\" should be either a list of tensors or None\n if out_maskmem_pos_enc is not None:\n if \"maskmem_pos_enc\" not in model_constants:\n assert isinstance(out_maskmem_pos_enc, list)\n # only take the slice for one object, since it's same across objects\n maskmem_pos_enc = [x[:1].clone() for x in out_maskmem_pos_enc]\n model_constants[\"maskmem_pos_enc\"] = maskmem_pos_enc\n else:\n maskmem_pos_enc = model_constants[\"maskmem_pos_enc\"]\n # expand the cached maskmem_pos_enc to the actual batch size\n batch_size = out_maskmem_pos_enc[0].size(0)\n if batch_size > 1:\n out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]\n return out_maskmem_pos_enc\n\n def _consolidate_temp_output_across_obj(\n self,\n frame_idx,\n is_cond=False,\n run_mem_encoder=False,\n ):\n \"\"\"\n Consolidate per-object temporary outputs into a single output for all objects.\n\n This method combines the temporary outputs for each object on a given frame into a unified\n output. It fills in any missing objects either from the main output dictionary or leaves\n placeholders if they do not exist in the main output. Optionally, it can re-run the memory\n encoder after applying non-overlapping constraints to the object scores.\n\n Args:\n frame_idx (int): The index of the frame for which to consolidate outputs.\n is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.\n run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after\n consolidating the outputs.\n\n Returns:\n (dict): A consolidated output dictionary containing the combined results for all objects.\n\n Note:\n - The method initializes the consolidated output with placeholder values for missing objects.\n - It searches for outputs in both the temporary and main output dictionaries.\n - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.\n - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.\n \"\"\"\n batch_size = len(self.inference_state[\"obj_idx_to_id\"])\n storage_key = \"cond_frame_outputs\" if is_cond else \"non_cond_frame_outputs\"\n\n # Initialize `consolidated_out`. Its \"maskmem_features\" and \"maskmem_pos_enc\"\n # will be added when rerunning the memory encoder after applying non-overlapping\n # constraints to object scores. Its \"pred_masks\" are prefilled with a large\n # negative value (NO_OBJ_SCORE) to represent missing objects.\n consolidated_out = {\n \"maskmem_features\": None,\n \"maskmem_pos_enc\": None,\n \"pred_masks\": torch.full(\n size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),\n fill_value=-1024.0,\n dtype=torch.float32,\n device=self.device,\n ),\n \"obj_ptr\": torch.full(\n size=(batch_size, self.model.hidden_dim),\n fill_value=-1024.0,\n dtype=torch.float32,\n device=self.device,\n ),\n \"object_score_logits\": torch.full(\n size=(batch_size, 1),\n # default to 10.0 for object_score_logits, i.e. assuming the object is\n # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`\n fill_value=10.0,\n dtype=torch.float32,\n device=self.device,\n ),\n }\n for obj_idx in range(batch_size):\n obj_temp_output_dict = self.inference_state[\"temp_output_dict_per_obj\"][obj_idx]\n obj_output_dict = self.inference_state[\"output_dict_per_obj\"][obj_idx]\n out = (\n obj_temp_output_dict[storage_key].get(frame_idx)\n # If the object doesn't appear in \"temp_output_dict_per_obj\" on this frame,\n # we fall back and look up its previous output in \"output_dict_per_obj\".\n # We look up both \"cond_frame_outputs\" and \"non_cond_frame_outputs\" in\n # \"output_dict_per_obj\" to find a previous output for this object.\n or obj_output_dict[\"cond_frame_outputs\"].get(frame_idx)\n or obj_output_dict[\"non_cond_frame_outputs\"].get(frame_idx)\n )\n # If the object doesn't appear in \"output_dict_per_obj\" either, we skip it\n # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE\n # placeholder above) and set its object pointer to be a dummy pointer.\n if out is None:\n # Fill in dummy object pointers for those objects without any inputs or\n # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,\n # i.e. when we need to build the memory for tracking).\n if run_mem_encoder:\n # fill object pointer with a dummy pointer (based on an empty mask)\n consolidated_out[\"obj_ptr\"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)\n continue\n # Add the temporary object output mask to consolidated output mask\n consolidated_out[\"pred_masks\"][obj_idx : obj_idx + 1] = out[\"pred_masks\"]\n consolidated_out[\"obj_ptr\"][obj_idx : obj_idx + 1] = out[\"obj_ptr\"]\n\n # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder\n if run_mem_encoder:\n high_res_masks = F.interpolate(\n consolidated_out[\"pred_masks\"],\n size=self.imgsz,\n mode=\"bilinear\",\n align_corners=False,\n )\n if self.model.non_overlap_masks_for_mem_enc:\n high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)\n consolidated_out[\"maskmem_features\"], consolidated_out[\"maskmem_pos_enc\"] = self._run_memory_encoder(\n batch_size=batch_size,\n high_res_masks=high_res_masks,\n is_mask_from_pts=True, # these frames are what the user interacted with\n object_score_logits=consolidated_out[\"object_score_logits\"],\n )\n\n return consolidated_out\n\n def _get_empty_mask_ptr(self, frame_idx):\n \"\"\"\n Get a dummy object pointer based on an empty mask on the current frame.\n\n Args:\n frame_idx (int): The index of the current frame for which to generate the dummy object pointer.\n\n Returns:\n (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.\n \"\"\"\n # Retrieve correct image features\n current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state[\"im\"])\n\n # Feed the empty mask and image feature above to get a dummy object pointer\n current_out = self.model.track_step(\n frame_idx=frame_idx,\n is_init_cond_frame=True,\n current_vision_feats=current_vision_feats,\n current_vision_pos_embeds=current_vision_pos_embeds,\n feat_sizes=feat_sizes,\n point_inputs=None,\n # A dummy (empty) mask with a single object\n mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),\n output_dict={},\n num_frames=self.inference_state[\"num_frames\"],\n track_in_reverse=False,\n run_mem_encoder=False,\n prev_sam_mask_logits=None,\n )\n return current_out[\"obj_ptr\"]\n\n def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):\n \"\"\"\n Run the memory encoder on masks.\n\n This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their\n memory also needs to be computed again with the memory encoder.\n\n Args:\n batch_size (int): The batch size for processing the frame.\n high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.\n object_score_logits (torch.Tensor): Logits representing the object scores.\n is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.\n\n Returns:\n maskmem_features (torch.Tensor): The encoded mask features.\n maskmem_pos_enc (torch.Tensor): The positional encoding.\n \"\"\"\n # Retrieve correct image features\n current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state[\"im\"], batch_size)\n maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(\n current_vision_feats=current_vision_feats,\n feat_sizes=feat_sizes,\n pred_masks_high_res=high_res_masks,\n is_mask_from_pts=is_mask_from_pts,\n object_score_logits=object_score_logits,\n )\n\n # \"maskmem_pos_enc\" is the same across frames, so we only need to store one copy of it\n maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)\n return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc\n\n def _add_output_per_object(self, frame_idx, current_out, storage_key):\n \"\"\"\n Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.\n\n The resulting slices share the same tensor storage.\n\n Args:\n frame_idx (int): The index of the current frame.\n current_out (dict): The current output dictionary containing multi-object outputs.\n storage_key (str): The key used to store the output in the per-object output dictionary.\n \"\"\"\n maskmem_features = current_out[\"maskmem_features\"]\n assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)\n\n maskmem_pos_enc = current_out[\"maskmem_pos_enc\"]\n assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)\n\n for obj_idx, obj_output_dict in self.inference_state[\"output_dict_per_obj\"].items():\n obj_slice = slice(obj_idx, obj_idx + 1)\n obj_out = {\n \"maskmem_features\": None,\n \"maskmem_pos_enc\": None,\n \"pred_masks\": current_out[\"pred_masks\"][obj_slice],\n \"obj_ptr\": current_out[\"obj_ptr\"][obj_slice],\n }\n if maskmem_features is not None:\n obj_out[\"maskmem_features\"] = maskmem_features[obj_slice]\n if maskmem_pos_enc is not None:\n obj_out[\"maskmem_pos_enc\"] = [x[obj_slice] for x in maskmem_pos_enc]\n obj_output_dict[storage_key][frame_idx] = obj_out\n\n def _clear_non_cond_mem_around_input(self, frame_idx):\n \"\"\"\n Remove the non-conditioning memory around the input frame.\n\n When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated\n object appearance information and could confuse the model. This method clears those non-conditioning memories\n surrounding the interacted frame to avoid giving the model both old and new information about the object.\n\n Args:\n frame_idx (int): The index of the current frame where user interaction occurred.\n \"\"\"\n r = self.model.memory_temporal_stride_for_eval\n frame_idx_begin = frame_idx - r * self.model.num_maskmem\n frame_idx_end = frame_idx + r * self.model.num_maskmem\n for t in range(frame_idx_begin, frame_idx_end + 1):\n self.inference_state[\"output_dict\"][\"non_cond_frame_outputs\"].pop(t, None)\n for obj_output_dict in self.inference_state[\"output_dict_per_obj\"].values():\n obj_output_dict[\"non_cond_frame_outputs\"].pop(t, None)",
"chunk_type": "class",
"name": "SAM2VideoPredictor",
"file_path": "ultralytics\\ultralytics\\models\\sam\\predict.py",
"start_line": 817,
"end_line": 1618,
"start_col": 0,
"end_col": 70,
"parent_name": null,
"docstring": "SAM2VideoPredictor to handle user interactions with videos and manage inference states.\n\nThis class extends the functionality of SAM2Predictor to support video processing and maintains\nthe state of inference operations. It includes configurations for managing non-overlapping masks,\nclearing memory for non-conditional inputs, and setting up callbacks for prediction events.\n\nAttributes:\n inference_state (dict): A dictionary to store the current state of inference operations.\n non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.\n clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.\n clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.\n callbacks (dict): A dictionary of callbacks for various prediction lifecycle events.\n\nMethods:\n get_model: Retrieve and configure the model with binarization enabled.\n inference: Perform image segmentation inference based on the given input cues.\n postprocess: Post-process the predictions to apply non-overlapping constraints if required.\n add_new_prompts: Add new points or masks to a specific frame for a given object ID.\n propagate_in_video_preflight: Prepare inference_state and consolidate temporary outputs before tracking.\n init_state: Initialize an inference state for the predictor.\n get_im_features: Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.\n\nExamples:\n >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)\n >>> predictor.set_image(\"path/to/video_frame.jpg\")\n >>> bboxes = [[100, 100, 200, 200]]\n >>> results = predictor(bboxes=bboxes)\n\nNote:\n The `fill_hole_area` attribute is defined but not used in the current implementation.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"collections.OrderedDict",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.data.augment.LetterBox",
"ultralytics.engine.predictor.BasePredictor",
"ultralytics.engine.results.Results",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.ops",
"ultralytics.utils.torch_utils.select_device",
"ultralytics.utils.torch_utils.smart_inference_mode",
"amg.batch_iterator",
"amg.batched_mask_to_box",
"amg.build_all_layer_point_grids",
"amg.calculate_stability_score",
"amg.generate_crop_boxes",
"amg.is_box_near_crop_edge",
"amg.remove_small_regions",
"amg.uncrop_boxes_xyxy",
"amg.uncrop_masks",
"torchvision",
"build.build_sam",
"torchvision",
"build.build_sam",
"SAM2Predictor"
],
"chunk_id": "class_SAM2VideoPredictor_4656fcbd"
},
{
"content": "from .model import SAM",
"chunk_type": "import",
"name": "SAM",
"file_path": "ultralytics\\ultralytics\\models\\sam\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SAM_fac33890"
},
{
"content": "from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor",
"chunk_type": "import",
"name": "Predictor, SAM2Predictor, SAM2VideoPredictor",
"file_path": "ultralytics\\ultralytics\\models\\sam\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Predictor, SAM2Predictor, SAM2VideoPredictor_61c8aadd"
},
{
"content": "__all__ = \"SAM\", \"Predictor\", \"SAM2Predictor\", \"SAM2VideoPredictor\" # tuple or list of exportable items",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\sam\\__init__.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 67,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___9af0d37e"
},
{
"content": "from typing import Any, Dict, List, Optional, Tuple",
"chunk_type": "import",
"name": "Any, Dict, List, Optional, Tuple",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional, Tuple_2395c2be"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_49f8c258"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_596b42cc"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_dc48fad9"
},
{
"content": "from ultralytics.utils.loss import FocalLoss, VarifocalLoss",
"chunk_type": "import",
"name": "FocalLoss, VarifocalLoss",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_FocalLoss, VarifocalLoss_a255af35"
},
{
"content": "from ultralytics.utils.metrics import bbox_iou",
"chunk_type": "import",
"name": "bbox_iou",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_bbox_iou_3e0995ff"
},
{
"content": "from .ops import HungarianMatcher",
"chunk_type": "import",
"name": "HungarianMatcher",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_HungarianMatcher_a687dd40"
},
{
"content": "class DETRLoss(nn.Module):\n \"\"\"\n DETR (DEtection TRansformer) Loss class for calculating various loss components.\n\n This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the\n DETR object detection model.\n\n Attributes:\n nc (int): Number of classes.\n loss_gain (Dict[str, float]): Coefficients for different loss components.\n aux_loss (bool): Whether to compute auxiliary losses.\n use_fl (bool): Whether to use FocalLoss.\n use_vfl (bool): Whether to use VarifocalLoss.\n use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.\n uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.\n matcher (HungarianMatcher): Object to compute matching cost and indices.\n fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.\n vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.\n device (torch.device): Device on which tensors are stored.\n \"\"\"\n\n def __init__(\n self,\n nc: int = 80,\n loss_gain: Optional[Dict[str, float]] = None,\n aux_loss: bool = True,\n use_fl: bool = True,\n use_vfl: bool = False,\n use_uni_match: bool = False,\n uni_match_ind: int = 0,\n gamma: float = 1.5,\n alpha: float = 0.25,\n ):\n \"\"\"\n Initialize DETR loss function with customizable components and gains.\n\n Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary\n losses and various loss types.\n\n Args:\n nc (int): Number of classes.\n loss_gain (Dict[str, float], optional): Coefficients for different loss components.\n aux_loss (bool): Whether to use auxiliary losses from each decoder layer.\n use_fl (bool): Whether to use FocalLoss.\n use_vfl (bool): Whether to use VarifocalLoss.\n use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.\n uni_match_ind (int): Index of fixed layer for uni_match.\n gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.\n alpha (float): The balancing factor used to address class imbalance.\n \"\"\"\n super().__init__()\n\n if loss_gain is None:\n loss_gain = {\"class\": 1, \"bbox\": 5, \"giou\": 2, \"no_object\": 0.1, \"mask\": 1, \"dice\": 1}\n self.nc = nc\n self.matcher = HungarianMatcher(cost_gain={\"class\": 2, \"bbox\": 5, \"giou\": 2})\n self.loss_gain = loss_gain\n self.aux_loss = aux_loss\n self.fl = FocalLoss(gamma, alpha) if use_fl else None\n self.vfl = VarifocalLoss(gamma, alpha) if use_vfl else None\n\n self.use_uni_match = use_uni_match\n self.uni_match_ind = uni_match_ind\n self.device = None\n\n def _get_loss_class(\n self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = \"\"\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Compute classification loss based on predictions, target values, and ground truth scores.\n\n Args:\n pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).\n targets (torch.Tensor): Target class indices with shape (B, N).\n gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).\n num_gts (int): Number of ground truth objects.\n postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary containing classification loss value.\n\n Notes:\n The function supports different classification loss types:\n - Varifocal Loss (if self.vfl is True and num_gts > 0)\n - Focal Loss (if self.fl is True)\n - BCE Loss (default fallback)\n \"\"\"\n # Logits: [b, query, num_classes], gt_class: list[[n, 1]]\n name_class = f\"loss_class{postfix}\"\n bs, nq = pred_scores.shape[:2]\n # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)\n one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)\n one_hot.scatter_(2, targets.unsqueeze(-1), 1)\n one_hot = one_hot[..., :-1]\n gt_scores = gt_scores.view(bs, nq, 1) * one_hot\n\n if self.fl:\n if num_gts and self.vfl:\n loss_cls = self.vfl(pred_scores, gt_scores, one_hot)\n else:\n loss_cls = self.fl(pred_scores, one_hot.float())\n loss_cls /= max(num_gts, 1) / nq\n else:\n loss_cls = nn.BCEWithLogitsLoss(reduction=\"none\")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss\n\n return {name_class: loss_cls.squeeze() * self.loss_gain[\"class\"]}\n\n def _get_loss_bbox(\n self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = \"\"\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).\n gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).\n postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary containing:\n - loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.\n - loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.\n\n Notes:\n If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.\n \"\"\"\n # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]\n name_bbox = f\"loss_bbox{postfix}\"\n name_giou = f\"loss_giou{postfix}\"\n\n loss = {}\n if len(gt_bboxes) == 0:\n loss[name_bbox] = torch.tensor(0.0, device=self.device)\n loss[name_giou] = torch.tensor(0.0, device=self.device)\n return loss\n\n loss[name_bbox] = self.loss_gain[\"bbox\"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction=\"sum\") / len(gt_bboxes)\n loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)\n loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)\n loss[name_giou] = self.loss_gain[\"giou\"] * loss[name_giou]\n return {k: v.squeeze() for k, v in loss.items()}\n\n # This function is for future RT-DETR Segment models\n # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):\n # # masks: [b, query, h, w], gt_mask: list[[n, H, W]]\n # name_mask = f'loss_mask{postfix}'\n # name_dice = f'loss_dice{postfix}'\n #\n # loss = {}\n # if sum(len(a) for a in gt_mask) == 0:\n # loss[name_mask] = torch.tensor(0., device=self.device)\n # loss[name_dice] = torch.tensor(0., device=self.device)\n # return loss\n #\n # num_gts = len(gt_mask)\n # src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)\n # src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]\n # # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.\n # loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,\n # torch.tensor([num_gts], dtype=torch.float32))\n # loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)\n # return loss\n\n # This function is for future RT-DETR Segment models\n # @staticmethod\n # def _dice_loss(inputs, targets, num_gts):\n # inputs = F.sigmoid(inputs).flatten(1)\n # targets = targets.flatten(1)\n # numerator = 2 * (inputs * targets).sum(1)\n # denominator = inputs.sum(-1) + targets.sum(-1)\n # loss = 1 - (numerator + 1) / (denominator + 1)\n # return loss.sum() / num_gts\n\n def _get_loss_aux(\n self,\n pred_bboxes: torch.Tensor,\n pred_scores: torch.Tensor,\n gt_bboxes: torch.Tensor,\n gt_cls: torch.Tensor,\n gt_groups: List[int],\n match_indices: Optional[List[Tuple]] = None,\n postfix: str = \"\",\n masks: Optional[torch.Tensor] = None,\n gt_mask: Optional[torch.Tensor] = None,\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Get auxiliary losses for intermediate decoder layers.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.\n pred_scores (torch.Tensor): Predicted scores from auxiliary layers.\n gt_bboxes (torch.Tensor): Ground truth bounding boxes.\n gt_cls (torch.Tensor): Ground truth classes.\n gt_groups (List[int]): Number of ground truths per image.\n match_indices (List[Tuple], optional): Pre-computed matching indices.\n postfix (str, optional): String to append to loss names.\n masks (torch.Tensor, optional): Predicted masks if using segmentation.\n gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary of auxiliary losses.\n \"\"\"\n # NOTE: loss class, bbox, giou, mask, dice\n loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)\n if match_indices is None and self.use_uni_match:\n match_indices = self.matcher(\n pred_bboxes[self.uni_match_ind],\n pred_scores[self.uni_match_ind],\n gt_bboxes,\n gt_cls,\n gt_groups,\n masks=masks[self.uni_match_ind] if masks is not None else None,\n gt_mask=gt_mask,\n )\n for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):\n aux_masks = masks[i] if masks is not None else None\n loss_ = self._get_loss(\n aux_bboxes,\n aux_scores,\n gt_bboxes,\n gt_cls,\n gt_groups,\n masks=aux_masks,\n gt_mask=gt_mask,\n postfix=postfix,\n match_indices=match_indices,\n )\n loss[0] += loss_[f\"loss_class{postfix}\"]\n loss[1] += loss_[f\"loss_bbox{postfix}\"]\n loss[2] += loss_[f\"loss_giou{postfix}\"]\n # if masks is not None and gt_mask is not None:\n # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)\n # loss[3] += loss_[f'loss_mask{postfix}']\n # loss[4] += loss_[f'loss_dice{postfix}']\n\n loss = {\n f\"loss_class_aux{postfix}\": loss[0],\n f\"loss_bbox_aux{postfix}\": loss[1],\n f\"loss_giou_aux{postfix}\": loss[2],\n }\n # if masks is not None and gt_mask is not None:\n # loss[f'loss_mask_aux{postfix}'] = loss[3]\n # loss[f'loss_dice_aux{postfix}'] = loss[4]\n return loss\n\n @staticmethod\n def _get_index(match_indices: List[Tuple]) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:\n \"\"\"\n Extract batch indices, source indices, and destination indices from match indices.\n\n Args:\n match_indices (List[Tuple]): List of tuples containing matched indices.\n\n Returns:\n batch_idx (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).\n dst_idx (torch.Tensor): Destination indices.\n \"\"\"\n batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])\n src_idx = torch.cat([src for (src, _) in match_indices])\n dst_idx = torch.cat([dst for (_, dst) in match_indices])\n return (batch_idx, src_idx), dst_idx\n\n def _get_assigned_bboxes(\n self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: List[Tuple]\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Assign predicted bounding boxes to ground truth bounding boxes based on match indices.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes.\n gt_bboxes (torch.Tensor): Ground truth bounding boxes.\n match_indices (List[Tuple]): List of tuples containing matched indices.\n\n Returns:\n pred_assigned (torch.Tensor): Assigned predicted bounding boxes.\n gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.\n \"\"\"\n pred_assigned = torch.cat(\n [\n t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)\n for t, (i, _) in zip(pred_bboxes, match_indices)\n ]\n )\n gt_assigned = torch.cat(\n [\n t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)\n for t, (_, j) in zip(gt_bboxes, match_indices)\n ]\n )\n return pred_assigned, gt_assigned\n\n def _get_loss(\n self,\n pred_bboxes: torch.Tensor,\n pred_scores: torch.Tensor,\n gt_bboxes: torch.Tensor,\n gt_cls: torch.Tensor,\n gt_groups: List[int],\n masks: Optional[torch.Tensor] = None,\n gt_mask: Optional[torch.Tensor] = None,\n postfix: str = \"\",\n match_indices: Optional[List[Tuple]] = None,\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Calculate losses for a single prediction layer.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes.\n pred_scores (torch.Tensor): Predicted class scores.\n gt_bboxes (torch.Tensor): Ground truth bounding boxes.\n gt_cls (torch.Tensor): Ground truth classes.\n gt_groups (List[int]): Number of ground truths per image.\n masks (torch.Tensor, optional): Predicted masks if using segmentation.\n gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.\n postfix (str, optional): String to append to loss names.\n match_indices (List[Tuple], optional): Pre-computed matching indices.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary of losses.\n \"\"\"\n if match_indices is None:\n match_indices = self.matcher(\n pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask\n )\n\n idx, gt_idx = self._get_index(match_indices)\n pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]\n\n bs, nq = pred_scores.shape[:2]\n targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)\n targets[idx] = gt_cls[gt_idx]\n\n gt_scores = torch.zeros([bs, nq], device=pred_scores.device)\n if len(gt_bboxes):\n gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)\n\n return {\n **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),\n **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),\n # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})\n }\n\n def forward(\n self,\n pred_bboxes: torch.Tensor,\n pred_scores: torch.Tensor,\n batch: Dict[str, Any],\n postfix: str = \"\",\n **kwargs: Any,\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Calculate loss for predicted bounding boxes and scores.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).\n pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).\n batch (Dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.\n postfix (str, optional): Postfix for loss names.\n **kwargs (Any): Additional arguments, may include 'match_indices'.\n\n Returns:\n (Dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).\n\n Notes:\n Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if\n self.aux_loss is True.\n \"\"\"\n self.device = pred_bboxes.device\n match_indices = kwargs.get(\"match_indices\", None)\n gt_cls, gt_bboxes, gt_groups = batch[\"cls\"], batch[\"bboxes\"], batch[\"gt_groups\"]\n\n total_loss = self._get_loss(\n pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices\n )\n\n if self.aux_loss:\n total_loss.update(\n self._get_loss_aux(\n pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix\n )\n )\n\n return total_loss",
"chunk_type": "class",
"name": "DETRLoss",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 15,
"end_line": 397,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": "DETR (DEtection TRansformer) Loss class for calculating various loss components.\n\nThis class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the\nDETR object detection model.\n\nAttributes:\n nc (int): Number of classes.\n loss_gain (Dict[str, float]): Coefficients for different loss components.\n aux_loss (bool): Whether to compute auxiliary losses.\n use_fl (bool): Whether to use FocalLoss.\n use_vfl (bool): Whether to use VarifocalLoss.\n use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.\n uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.\n matcher (HungarianMatcher): Object to compute matching cost and indices.\n fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.\n vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.\n device (torch.device): Device on which tensors are stored.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.loss.FocalLoss",
"ultralytics.utils.loss.VarifocalLoss",
"ultralytics.utils.metrics.bbox_iou",
"ops.HungarianMatcher",
"nn.Module"
],
"chunk_id": "class_DETRLoss_5541987f"
},
{
"content": "class RTDETRDetectionLoss(DETRLoss):\n \"\"\"\n Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.\n\n This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as\n an additional denoising training loss when provided with denoising metadata.\n \"\"\"\n\n def forward(\n self,\n preds: Tuple[torch.Tensor, torch.Tensor],\n batch: Dict[str, Any],\n dn_bboxes: Optional[torch.Tensor] = None,\n dn_scores: Optional[torch.Tensor] = None,\n dn_meta: Optional[Dict[str, Any]] = None,\n ) -> Dict[str, torch.Tensor]:\n \"\"\"\n Forward pass to compute detection loss with optional denoising loss.\n\n Args:\n preds (Tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.\n batch (Dict[str, Any]): Batch data containing ground truth information.\n dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.\n dn_scores (torch.Tensor, optional): Denoising scores.\n dn_meta (Dict[str, Any], optional): Metadata for denoising.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.\n \"\"\"\n pred_bboxes, pred_scores = preds\n total_loss = super().forward(pred_bboxes, pred_scores, batch)\n\n # Check for denoising metadata to compute denoising training loss\n if dn_meta is not None:\n dn_pos_idx, dn_num_group = dn_meta[\"dn_pos_idx\"], dn_meta[\"dn_num_group\"]\n assert len(batch[\"gt_groups\"]) == len(dn_pos_idx)\n\n # Get the match indices for denoising\n match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch[\"gt_groups\"])\n\n # Compute the denoising training loss\n dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix=\"_dn\", match_indices=match_indices)\n total_loss.update(dn_loss)\n else:\n # If no denoising metadata is provided, set denoising loss to zero\n total_loss.update({f\"{k}_dn\": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})\n\n return total_loss\n\n @staticmethod\n def get_dn_match_indices(\n dn_pos_idx: List[torch.Tensor], dn_num_group: int, gt_groups: List[int]\n ) -> List[Tuple[torch.Tensor, torch.Tensor]]:\n \"\"\"\n Get match indices for denoising.\n\n Args:\n dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.\n dn_num_group (int): Number of denoising groups.\n gt_groups (List[int]): List of integers representing number of ground truths per image.\n\n Returns:\n (List[Tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.\n \"\"\"\n dn_match_indices = []\n idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)\n for i, num_gt in enumerate(gt_groups):\n if num_gt > 0:\n gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]\n gt_idx = gt_idx.repeat(dn_num_group)\n assert len(dn_pos_idx[i]) == len(gt_idx), (\n f\"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.\"\n )\n dn_match_indices.append((dn_pos_idx[i], gt_idx))\n else:\n dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))\n return dn_match_indices",
"chunk_type": "class",
"name": "RTDETRDetectionLoss",
"file_path": "ultralytics\\ultralytics\\models\\utils\\loss.py",
"start_line": 400,
"end_line": 476,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": "Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.\n\nThis class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as\nan additional denoising training loss when provided with denoising metadata.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.loss.FocalLoss",
"ultralytics.utils.loss.VarifocalLoss",
"ultralytics.utils.metrics.bbox_iou",
"ops.HungarianMatcher",
"DETRLoss"
],
"chunk_id": "class_RTDETRDetectionLoss_bddeaaab"
},
{
"content": "from typing import Any, Dict, List, Optional, Tuple",
"chunk_type": "import",
"name": "Any, Dict, List, Optional, Tuple",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional, Tuple_2a5a3362"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_337a1d94"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_8ceb190b"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_6482ad80"
},
{
"content": "from scipy.optimize import linear_sum_assignment",
"chunk_type": "import",
"name": "linear_sum_assignment",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_linear_sum_assignment_e319a284"
},
{
"content": "from ultralytics.utils.metrics import bbox_iou",
"chunk_type": "import",
"name": "bbox_iou",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_bbox_iou_a1d5a92c"
},
{
"content": "from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh",
"chunk_type": "import",
"name": "xywh2xyxy, xyxy2xywh",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_xywh2xyxy, xyxy2xywh_74ca1074"
},
{
"content": "class HungarianMatcher(nn.Module):\n \"\"\"\n A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.\n\n HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost\n function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is\n used in end-to-end object detection models like DETR.\n\n Attributes:\n cost_gain (Dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'\n components.\n use_fl (bool): Whether to use Focal Loss for classification cost calculation.\n with_mask (bool): Whether the model makes mask predictions.\n num_sample_points (int): Number of sample points used in mask cost calculation.\n alpha (float): Alpha factor in Focal Loss calculation.\n gamma (float): Gamma factor in Focal Loss calculation.\n\n Methods:\n forward: Compute optimal assignment between predictions and ground truths for a batch.\n _cost_mask: Compute mask cost and dice cost if masks are predicted.\n\n Examples:\n Initialize a HungarianMatcher with custom cost gains\n >>> matcher = HungarianMatcher(cost_gain={\"class\": 2, \"bbox\": 5, \"giou\": 2})\n\n Perform matching between predictions and ground truth\n >>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100\n >>> pred_scores = torch.rand(2, 100, 80) # 80 classes\n >>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes\n >>> gt_classes = torch.randint(0, 80, (10,))\n >>> gt_groups = [5, 5] # 5 GT boxes per image\n >>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)\n \"\"\"\n\n def __init__(\n self,\n cost_gain: Optional[Dict[str, float]] = None,\n use_fl: bool = True,\n with_mask: bool = False,\n num_sample_points: int = 12544,\n alpha: float = 0.25,\n gamma: float = 2.0,\n ):\n \"\"\"\n Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.\n\n Args:\n cost_gain (Dict[str, float], optional): Dictionary of cost coefficients for different matching cost\n components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.\n use_fl (bool): Whether to use Focal Loss for classification cost calculation.\n with_mask (bool): Whether the model makes mask predictions.\n num_sample_points (int): Number of sample points used in mask cost calculation.\n alpha (float): Alpha factor in Focal Loss calculation.\n gamma (float): Gamma factor in Focal Loss calculation.\n \"\"\"\n super().__init__()\n if cost_gain is None:\n cost_gain = {\"class\": 1, \"bbox\": 5, \"giou\": 2, \"mask\": 1, \"dice\": 1}\n self.cost_gain = cost_gain\n self.use_fl = use_fl\n self.with_mask = with_mask\n self.num_sample_points = num_sample_points\n self.alpha = alpha\n self.gamma = gamma\n\n def forward(\n self,\n pred_bboxes: torch.Tensor,\n pred_scores: torch.Tensor,\n gt_bboxes: torch.Tensor,\n gt_cls: torch.Tensor,\n gt_groups: List[int],\n masks: Optional[torch.Tensor] = None,\n gt_mask: Optional[List[torch.Tensor]] = None,\n ) -> List[Tuple[torch.Tensor, torch.Tensor]]:\n \"\"\"\n Compute optimal assignment between predictions and ground truth using Hungarian algorithm.\n\n This method calculates matching costs based on classification scores, bounding box coordinates, and optionally\n mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.\n\n Args:\n pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).\n pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,\n num_classes).\n gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).\n gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).\n gt_groups (List[int]): Number of ground truth boxes for each image in the batch.\n masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).\n gt_mask (List[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).\n\n Returns:\n (List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple\n (index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order)\n and index_j is the tensor of indices of the corresponding selected ground truth targets (in order).\n For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).\n \"\"\"\n bs, nq, nc = pred_scores.shape\n\n if sum(gt_groups) == 0:\n return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]\n\n # Flatten to compute cost matrices in batch format\n pred_scores = pred_scores.detach().view(-1, nc)\n pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)\n pred_bboxes = pred_bboxes.detach().view(-1, 4)\n\n # Compute classification cost\n pred_scores = pred_scores[:, gt_cls]\n if self.use_fl:\n neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())\n pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())\n cost_class = pos_cost_class - neg_cost_class\n else:\n cost_class = -pred_scores\n\n # Compute L1 cost between boxes\n cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)\n\n # Compute GIoU cost between boxes, (bs*num_queries, num_gt)\n cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)\n\n # Combine costs into final cost matrix\n C = (\n self.cost_gain[\"class\"] * cost_class\n + self.cost_gain[\"bbox\"] * cost_bbox\n + self.cost_gain[\"giou\"] * cost_giou\n )\n\n # Add mask costs if available\n if self.with_mask:\n C += self._cost_mask(bs, gt_groups, masks, gt_mask)\n\n # Set invalid values (NaNs and infinities) to 0\n C[C.isnan() | C.isinf()] = 0.0\n\n C = C.view(bs, nq, -1).cpu()\n indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]\n gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)\n return [\n (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])\n for k, (i, j) in enumerate(indices)\n ]",
"chunk_type": "class",
"name": "HungarianMatcher",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 14,
"end_line": 156,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.\n\nHungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost\nfunction that considers classification scores, bounding box coordinates, and optionally mask predictions. This is\nused in end-to-end object detection models like DETR.\n\nAttributes:\n cost_gain (Dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'\n components.\n use_fl (bool): Whether to use Focal Loss for classification cost calculation.\n with_mask (bool): Whether the model makes mask predictions.\n num_sample_points (int): Number of sample points used in mask cost calculation.\n alpha (float): Alpha factor in Focal Loss calculation.\n gamma (float): Gamma factor in Focal Loss calculation.\n\nMethods:\n forward: Compute optimal assignment between predictions and ground truths for a batch.\n _cost_mask: Compute mask cost and dice cost if masks are predicted.\n\nExamples:\n Initialize a HungarianMatcher with custom cost gains\n >>> matcher = HungarianMatcher(cost_gain={\"class\": 2, \"bbox\": 5, \"giou\": 2})\n\n Perform matching between predictions and ground truth\n >>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100\n >>> pred_scores = torch.rand(2, 100, 80) # 80 classes\n >>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes\n >>> gt_classes = torch.randint(0, 80, (10,))\n >>> gt_groups = [5, 5] # 5 GT boxes per image\n >>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"scipy.optimize.linear_sum_assignment",
"ultralytics.utils.metrics.bbox_iou",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh",
"nn.Module"
],
"chunk_id": "class_HungarianMatcher_ec3aa111"
},
{
"content": "def get_cdn_group(\n batch: Dict[str, Any],\n num_classes: int,\n num_queries: int,\n class_embed: torch.Tensor,\n num_dn: int = 100,\n cls_noise_ratio: float = 0.5,\n box_noise_scale: float = 1.0,\n training: bool = False,\n) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[Dict[str, Any]]]:\n \"\"\"\n Generate contrastive denoising training group with positive and negative samples from ground truths.\n\n This function creates denoising queries for contrastive denoising training by adding noise to ground truth\n bounding boxes and class labels. It generates both positive and negative samples to improve model robustness.\n\n Args:\n batch (Dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),\n 'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (List[int]) indicating number of\n ground truths per image.\n num_classes (int): Total number of object classes.\n num_queries (int): Number of object queries.\n class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.\n num_dn (int): Number of denoising queries to generate.\n cls_noise_ratio (float): Noise ratio for class labels.\n box_noise_scale (float): Noise scale for bounding box coordinates.\n training (bool): Whether model is in training mode.\n\n Returns:\n padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).\n padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).\n attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).\n dn_meta (Dict[str, Any] | None): Meta information dictionary containing denoising parameters.\n\n Examples:\n Generate denoising group for training\n >>> batch = {\n ... \"cls\": torch.tensor([0, 1, 2]),\n ... \"bboxes\": torch.rand(3, 4),\n ... \"batch_idx\": torch.tensor([0, 0, 1]),\n ... \"gt_groups\": [2, 1],\n ... }\n >>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim\n >>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)\n \"\"\"\n if (not training) or num_dn <= 0 or batch is None:\n return None, None, None, None\n gt_groups = batch[\"gt_groups\"]\n total_num = sum(gt_groups)\n max_nums = max(gt_groups)\n if max_nums == 0:\n return None, None, None, None\n\n num_group = num_dn // max_nums\n num_group = 1 if num_group == 0 else num_group\n # Pad gt to max_num of a batch\n bs = len(gt_groups)\n gt_cls = batch[\"cls\"] # (bs*num, )\n gt_bbox = batch[\"bboxes\"] # bs*num, 4\n b_idx = batch[\"batch_idx\"]\n\n # Each group has positive and negative queries\n dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )\n dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4\n dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )\n\n # Positive and negative mask\n # (bs*num*num_group, ), the second total_num*num_group part as negative samples\n neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num\n\n if cls_noise_ratio > 0:\n # Apply class label noise to half of the samples\n mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)\n idx = torch.nonzero(mask).squeeze(-1)\n # Randomly assign new class labels\n new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)\n dn_cls[idx] = new_label\n\n if box_noise_scale > 0:\n known_bbox = xywh2xyxy(dn_bbox)\n\n diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4\n\n rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0\n rand_part = torch.rand_like(dn_bbox)\n rand_part[neg_idx] += 1.0\n rand_part *= rand_sign\n known_bbox += rand_part * diff\n known_bbox.clip_(min=0.0, max=1.0)\n dn_bbox = xyxy2xywh(known_bbox)\n dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid\n\n num_dn = int(max_nums * 2 * num_group) # total denoising queries\n dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256\n padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)\n padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)\n\n map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])\n pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)\n\n map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])\n padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed\n padding_bbox[(dn_b_idx, map_indices)] = dn_bbox\n\n tgt_size = num_dn + num_queries\n attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)\n # Match query cannot see the reconstruct\n attn_mask[num_dn:, :num_dn] = True\n # Reconstruct cannot see each other\n for i in range(num_group):\n if i == 0:\n attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True\n if i == num_group - 1:\n attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True\n else:\n attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True\n attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True\n dn_meta = {\n \"dn_pos_idx\": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],\n \"dn_num_group\": num_group,\n \"dn_num_split\": [num_dn, num_queries],\n }\n\n return (\n padding_cls.to(class_embed.device),\n padding_bbox.to(class_embed.device),\n attn_mask.to(class_embed.device),\n dn_meta,\n )",
"chunk_type": "function",
"name": "get_cdn_group",
"file_path": "ultralytics\\ultralytics\\models\\utils\\ops.py",
"start_line": 189,
"end_line": 317,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Generate contrastive denoising training group with positive and negative samples from ground truths.\n\nThis function creates denoising queries for contrastive denoising training by adding noise to ground truth\nbounding boxes and class labels. It generates both positive and negative samples to improve model robustness.\n\nArgs:\n batch (Dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),\n 'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (List[int]) indicating number of\n ground truths per image.\n num_classes (int): Total number of object classes.\n num_queries (int): Number of object queries.\n class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.\n num_dn (int): Number of denoising queries to generate.\n cls_noise_ratio (float): Noise ratio for class labels.\n box_noise_scale (float): Noise scale for bounding box coordinates.\n training (bool): Whether model is in training mode.\n\nReturns:\n padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).\n padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).\n attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).\n dn_meta (Dict[str, Any] | None): Meta information dictionary containing denoising parameters.\n\nExamples:\n Generate denoising group for training\n >>> batch = {\n ... \"cls\": torch.tensor([0, 1, 2]),\n ... \"bboxes\": torch.rand(3, 4),\n ... \"batch_idx\": torch.tensor([0, 0, 1]),\n ... \"gt_groups\": [2, 1],\n ... }\n >>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim\n >>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)",
"parameters": [
"batch: Dict[str, Any]",
"num_classes: int",
"num_queries: int",
"class_embed: torch.Tensor",
"num_dn: int",
"cls_noise_ratio: float",
"box_noise_scale: float",
"training: bool"
],
"return_type": "Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[Dict[str, Any]]]",
"decorators": [],
"complexity_score": 12,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"scipy.optimize.linear_sum_assignment",
"ultralytics.utils.metrics.bbox_iou",
"ultralytics.utils.ops.xywh2xyxy",
"ultralytics.utils.ops.xyxy2xywh"
],
"chunk_id": "function_get_cdn_group_2042d0e1"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_dd98d879"
},
{
"content": "from typing import Any, Dict, List, Optional, Union",
"chunk_type": "import",
"name": "Any, Dict, List, Optional, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional, Union_b1bbec1d"
},
{
"content": "from ultralytics.data.build import load_inference_source",
"chunk_type": "import",
"name": "load_inference_source",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_load_inference_source_d794da92"
},
{
"content": "from ultralytics.engine.model import Model",
"chunk_type": "import",
"name": "Model",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Model_4c74fc3d"
},
{
"content": "from ultralytics.models import yolo",
"chunk_type": "import",
"name": "yolo",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_yolo_c5d758a9"
},
{
"content": "from ultralytics.nn.tasks import (\n ClassificationModel,\n DetectionModel,\n OBBModel,\n PoseModel,\n SegmentationModel,\n WorldModel,\n YOLOEModel,\n YOLOESegModel,\n)",
"chunk_type": "import",
"name": "ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel, YOLOEModel, YOLOESegModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 9,
"end_line": 18,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel, YOLOEModel, YOLOESegModel_27430974"
},
{
"content": "from ultralytics.utils import ROOT, YAML",
"chunk_type": "import",
"name": "ROOT, YAML",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ROOT, YAML_b0400d76"
},
{
"content": "class YOLO(Model):\n \"\"\"\n YOLO (You Only Look Once) object detection model.\n\n This class provides a unified interface for YOLO models, automatically switching to specialized model types\n (YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object\n detection, segmentation, classification, pose estimation, and oriented bounding box detection.\n\n Attributes:\n model: The loaded YOLO model instance.\n task: The task type (detect, segment, classify, pose, obb).\n overrides: Configuration overrides for the model.\n\n Methods:\n __init__: Initialize a YOLO model with automatic type detection.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n\n Examples:\n Load a pretrained YOLOv11n detection model\n >>> model = YOLO(\"yolo11n.pt\")\n\n Load a pretrained YOLO11n segmentation model\n >>> model = YOLO(\"yolo11n-seg.pt\")\n\n Initialize from a YAML configuration\n >>> model = YOLO(\"yolo11n.yaml\")\n \"\"\"\n\n def __init__(self, model: Union[str, Path] = \"yolo11n.pt\", task: Optional[str] = None, verbose: bool = False):\n \"\"\"\n Initialize a YOLO model.\n\n This constructor initializes a YOLO model, automatically switching to specialized model types\n (YOLOWorld or YOLOE) based on the model filename.\n\n Args:\n model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.\n task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.\n Defaults to auto-detection based on model.\n verbose (bool): Display model info on load.\n\n Examples:\n >>> from ultralytics import YOLO\n >>> model = YOLO(\"yolo11n.pt\") # load a pretrained YOLOv11n detection model\n >>> model = YOLO(\"yolo11n-seg.pt\") # load a pretrained YOLO11n segmentation model\n \"\"\"\n path = Path(model if isinstance(model, (str, Path)) else \"\")\n if \"-world\" in path.stem and path.suffix in {\".pt\", \".yaml\", \".yml\"}: # if YOLOWorld PyTorch model\n new_instance = YOLOWorld(path, verbose=verbose)\n self.__class__ = type(new_instance)\n self.__dict__ = new_instance.__dict__\n elif \"yoloe\" in path.stem and path.suffix in {\".pt\", \".yaml\", \".yml\"}: # if YOLOE PyTorch model\n new_instance = YOLOE(path, task=task, verbose=verbose)\n self.__class__ = type(new_instance)\n self.__dict__ = new_instance.__dict__\n else:\n # Continue with default YOLO initialization\n super().__init__(model=model, task=task, verbose=verbose)\n if hasattr(self.model, \"model\") and \"RTDETR\" in self.model.model[-1]._get_name(): # if RTDETR head\n from ultralytics import RTDETR\n\n new_instance = RTDETR(self)\n self.__class__ = type(new_instance)\n self.__dict__ = new_instance.__dict__\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Map head to model, trainer, validator, and predictor classes.\"\"\"\n return {\n \"classify\": {\n \"model\": ClassificationModel,\n \"trainer\": yolo.classify.ClassificationTrainer,\n \"validator\": yolo.classify.ClassificationValidator,\n \"predictor\": yolo.classify.ClassificationPredictor,\n },\n \"detect\": {\n \"model\": DetectionModel,\n \"trainer\": yolo.detect.DetectionTrainer,\n \"validator\": yolo.detect.DetectionValidator,\n \"predictor\": yolo.detect.DetectionPredictor,\n },\n \"segment\": {\n \"model\": SegmentationModel,\n \"trainer\": yolo.segment.SegmentationTrainer,\n \"validator\": yolo.segment.SegmentationValidator,\n \"predictor\": yolo.segment.SegmentationPredictor,\n },\n \"pose\": {\n \"model\": PoseModel,\n \"trainer\": yolo.pose.PoseTrainer,\n \"validator\": yolo.pose.PoseValidator,\n \"predictor\": yolo.pose.PosePredictor,\n },\n \"obb\": {\n \"model\": OBBModel,\n \"trainer\": yolo.obb.OBBTrainer,\n \"validator\": yolo.obb.OBBValidator,\n \"predictor\": yolo.obb.OBBPredictor,\n },\n }",
"chunk_type": "class",
"name": "YOLO",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 22,
"end_line": 121,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "YOLO (You Only Look Once) object detection model.\n\nThis class provides a unified interface for YOLO models, automatically switching to specialized model types\n(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object\ndetection, segmentation, classification, pose estimation, and oriented bounding box detection.\n\nAttributes:\n model: The loaded YOLO model instance.\n task: The task type (detect, segment, classify, pose, obb).\n overrides: Configuration overrides for the model.\n\nMethods:\n __init__: Initialize a YOLO model with automatic type detection.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n\nExamples:\n Load a pretrained YOLOv11n detection model\n >>> model = YOLO(\"yolo11n.pt\")\n\n Load a pretrained YOLO11n segmentation model\n >>> model = YOLO(\"yolo11n-seg.pt\")\n\n Initialize from a YAML configuration\n >>> model = YOLO(\"yolo11n.yaml\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"ultralytics.data.build.load_inference_source",
"ultralytics.engine.model.Model",
"ultralytics.models.yolo",
"ultralytics.nn.tasks.ClassificationModel",
"ultralytics.nn.tasks.DetectionModel",
"ultralytics.nn.tasks.OBBModel",
"ultralytics.nn.tasks.PoseModel",
"ultralytics.nn.tasks.SegmentationModel",
"ultralytics.nn.tasks.WorldModel",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.nn.tasks.YOLOESegModel",
"ultralytics.utils.ROOT",
"ultralytics.utils.YAML",
"ultralytics.RTDETR",
"Model"
],
"chunk_id": "class_YOLO_33c1b65a"
},
{
"content": "class YOLOWorld(Model):\n \"\"\"\n YOLO-World object detection model.\n\n YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions\n without requiring training on specific classes. It extends the YOLO architecture to support real-time\n open-vocabulary detection.\n\n Attributes:\n model: The loaded YOLO-World model instance.\n task: Always set to 'detect' for object detection.\n overrides: Configuration overrides for the model.\n\n Methods:\n __init__: Initialize YOLOv8-World model with a pre-trained model file.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n set_classes: Set the model's class names for detection.\n\n Examples:\n Load a YOLOv8-World model\n >>> model = YOLOWorld(\"yolov8s-world.pt\")\n\n Set custom classes for detection\n >>> model.set_classes([\"person\", \"car\", \"bicycle\"])\n \"\"\"\n\n def __init__(self, model: Union[str, Path] = \"yolov8s-world.pt\", verbose: bool = False) -> None:\n \"\"\"\n Initialize YOLOv8-World model with a pre-trained model file.\n\n Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default\n COCO class names.\n\n Args:\n model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.\n verbose (bool): If True, prints additional information during initialization.\n \"\"\"\n super().__init__(model=model, task=\"detect\", verbose=verbose)\n\n # Assign default COCO class names when there are no custom names\n if not hasattr(self.model, \"names\"):\n self.model.names = YAML.load(ROOT / \"cfg/datasets/coco8.yaml\").get(\"names\")\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Map head to model, validator, and predictor classes.\"\"\"\n return {\n \"detect\": {\n \"model\": WorldModel,\n \"validator\": yolo.detect.DetectionValidator,\n \"predictor\": yolo.detect.DetectionPredictor,\n \"trainer\": yolo.world.WorldTrainer,\n }\n }\n\n def set_classes(self, classes: List[str]) -> None:\n \"\"\"\n Set the model's class names for detection.\n\n Args:\n classes (List[str]): A list of categories i.e. [\"person\"].\n \"\"\"\n self.model.set_classes(classes)\n # Remove background if it's given\n background = \" \"\n if background in classes:\n classes.remove(background)\n self.model.names = classes\n\n # Reset method class names\n if self.predictor:\n self.predictor.model.names = classes",
"chunk_type": "class",
"name": "YOLOWorld",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 124,
"end_line": 195,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": "YOLO-World object detection model.\n\nYOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions\nwithout requiring training on specific classes. It extends the YOLO architecture to support real-time\nopen-vocabulary detection.\n\nAttributes:\n model: The loaded YOLO-World model instance.\n task: Always set to 'detect' for object detection.\n overrides: Configuration overrides for the model.\n\nMethods:\n __init__: Initialize YOLOv8-World model with a pre-trained model file.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n set_classes: Set the model's class names for detection.\n\nExamples:\n Load a YOLOv8-World model\n >>> model = YOLOWorld(\"yolov8s-world.pt\")\n\n Set custom classes for detection\n >>> model.set_classes([\"person\", \"car\", \"bicycle\"])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"ultralytics.data.build.load_inference_source",
"ultralytics.engine.model.Model",
"ultralytics.models.yolo",
"ultralytics.nn.tasks.ClassificationModel",
"ultralytics.nn.tasks.DetectionModel",
"ultralytics.nn.tasks.OBBModel",
"ultralytics.nn.tasks.PoseModel",
"ultralytics.nn.tasks.SegmentationModel",
"ultralytics.nn.tasks.WorldModel",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.nn.tasks.YOLOESegModel",
"ultralytics.utils.ROOT",
"ultralytics.utils.YAML",
"ultralytics.RTDETR",
"Model"
],
"chunk_id": "class_YOLOWorld_17b0b132"
},
{
"content": "class YOLOE(Model):\n \"\"\"\n YOLOE object detection and segmentation model.\n\n YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with\n improved performance and additional features like visual and text positional embeddings.\n\n Attributes:\n model: The loaded YOLOE model instance.\n task: The task type (detect or segment).\n overrides: Configuration overrides for the model.\n\n Methods:\n __init__: Initialize YOLOE model with a pre-trained model file.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n get_text_pe: Get text positional embeddings for the given texts.\n get_visual_pe: Get visual positional embeddings for the given image and visual features.\n set_vocab: Set vocabulary and class names for the YOLOE model.\n get_vocab: Get vocabulary for the given class names.\n set_classes: Set the model's class names and embeddings for detection.\n val: Validate the model using text or visual prompts.\n predict: Run prediction on images, videos, directories, streams, etc.\n\n Examples:\n Load a YOLOE detection model\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n\n Set vocabulary and class names\n >>> model.set_vocab([\"person\", \"car\", \"dog\"], [\"person\", \"car\", \"dog\"])\n\n Predict with visual prompts\n >>> prompts = {\"bboxes\": [[10, 20, 100, 200]], \"cls\": [\"person\"]}\n >>> results = model.predict(\"image.jpg\", visual_prompts=prompts)\n \"\"\"\n\n def __init__(\n self, model: Union[str, Path] = \"yoloe-11s-seg.pt\", task: Optional[str] = None, verbose: bool = False\n ) -> None:\n \"\"\"\n Initialize YOLOE model with a pre-trained model file.\n\n Args:\n model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.\n task (str, optional): Task type for the model. Auto-detected if None.\n verbose (bool): If True, prints additional information during initialization.\n \"\"\"\n super().__init__(model=model, task=task, verbose=verbose)\n\n @property\n def task_map(self) -> Dict[str, Dict[str, Any]]:\n \"\"\"Map head to model, validator, and predictor classes.\"\"\"\n return {\n \"detect\": {\n \"model\": YOLOEModel,\n \"validator\": yolo.yoloe.YOLOEDetectValidator,\n \"predictor\": yolo.detect.DetectionPredictor,\n \"trainer\": yolo.yoloe.YOLOETrainer,\n },\n \"segment\": {\n \"model\": YOLOESegModel,\n \"validator\": yolo.yoloe.YOLOESegValidator,\n \"predictor\": yolo.segment.SegmentationPredictor,\n \"trainer\": yolo.yoloe.YOLOESegTrainer,\n },\n }\n\n def get_text_pe(self, texts):\n \"\"\"Get text positional embeddings for the given texts.\"\"\"\n assert isinstance(self.model, YOLOEModel)\n return self.model.get_text_pe(texts)\n\n def get_visual_pe(self, img, visual):\n \"\"\"\n Get visual positional embeddings for the given image and visual features.\n\n This method extracts positional embeddings from visual features based on the input image. It requires\n that the model is an instance of YOLOEModel.\n\n Args:\n img (torch.Tensor): Input image tensor.\n visual (torch.Tensor): Visual features extracted from the image.\n\n Returns:\n (torch.Tensor): Visual positional embeddings.\n\n Examples:\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n >>> img = torch.rand(1, 3, 640, 640)\n >>> visual_features = torch.rand(1, 1, 80, 80)\n >>> pe = model.get_visual_pe(img, visual_features)\n \"\"\"\n assert isinstance(self.model, YOLOEModel)\n return self.model.get_visual_pe(img, visual)\n\n def set_vocab(self, vocab: List[str], names: List[str]) -> None:\n \"\"\"\n Set vocabulary and class names for the YOLOE model.\n\n This method configures the vocabulary and class names used by the model for text processing and\n classification tasks. The model must be an instance of YOLOEModel.\n\n Args:\n vocab (List[str]): Vocabulary list containing tokens or words used by the model for text processing.\n names (List[str]): List of class names that the model can detect or classify.\n\n Raises:\n AssertionError: If the model is not an instance of YOLOEModel.\n\n Examples:\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n >>> model.set_vocab([\"person\", \"car\", \"dog\"], [\"person\", \"car\", \"dog\"])\n \"\"\"\n assert isinstance(self.model, YOLOEModel)\n self.model.set_vocab(vocab, names=names)\n\n def get_vocab(self, names):\n \"\"\"Get vocabulary for the given class names.\"\"\"\n assert isinstance(self.model, YOLOEModel)\n return self.model.get_vocab(names)\n\n def set_classes(self, classes: List[str], embeddings) -> None:\n \"\"\"\n Set the model's class names and embeddings for detection.\n\n Args:\n classes (List[str]): A list of categories i.e. [\"person\"].\n embeddings (torch.Tensor): Embeddings corresponding to the classes.\n \"\"\"\n assert isinstance(self.model, YOLOEModel)\n self.model.set_classes(classes, embeddings)\n # Verify no background class is present\n assert \" \" not in classes\n self.model.names = classes\n\n # Reset method class names\n if self.predictor:\n self.predictor.model.names = classes\n\n def val(\n self,\n validator=None,\n load_vp: bool = False,\n refer_data: Optional[str] = None,\n **kwargs,\n ):\n \"\"\"\n Validate the model using text or visual prompts.\n\n Args:\n validator (callable, optional): A callable validator function. If None, a default validator is loaded.\n load_vp (bool): Whether to load visual prompts. If False, text prompts are used.\n refer_data (str, optional): Path to the reference data for visual prompts.\n **kwargs (Any): Additional keyword arguments to override default settings.\n\n Returns:\n (dict): Validation statistics containing metrics computed during validation.\n \"\"\"\n custom = {\"rect\": not load_vp} # method defaults\n args = {**self.overrides, **custom, **kwargs, \"mode\": \"val\"} # highest priority args on the right\n\n validator = (validator or self._smart_load(\"validator\"))(args=args, _callbacks=self.callbacks)\n validator(model=self.model, load_vp=load_vp, refer_data=refer_data)\n self.metrics = validator.metrics\n return validator.metrics\n\n def predict(\n self,\n source=None,\n stream: bool = False,\n visual_prompts: Dict[str, List] = {},\n refer_image=None,\n predictor=None,\n **kwargs,\n ):\n \"\"\"\n Run prediction on images, videos, directories, streams, etc.\n\n Args:\n source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,\n directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.\n stream (bool): Whether to stream the prediction results. If True, results are yielded as a\n generator as they are computed.\n visual_prompts (Dict[str, List]): Dictionary containing visual prompts for the model. Must include\n 'bboxes' and 'cls' keys when non-empty.\n refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.\n predictor (callable, optional): Custom predictor function. If None, a predictor is automatically\n loaded based on the task.\n **kwargs (Any): Additional keyword arguments passed to the predictor.\n\n Returns:\n (List | generator): List of Results objects or generator of Results objects if stream=True.\n\n Examples:\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n >>> results = model.predict(\"path/to/image.jpg\")\n >>> # With visual prompts\n >>> prompts = {\"bboxes\": [[10, 20, 100, 200]], \"cls\": [\"person\"]}\n >>> results = model.predict(\"path/to/image.jpg\", visual_prompts=prompts)\n \"\"\"\n if len(visual_prompts):\n assert \"bboxes\" in visual_prompts and \"cls\" in visual_prompts, (\n f\"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}\"\n )\n assert len(visual_prompts[\"bboxes\"]) == len(visual_prompts[\"cls\"]), (\n f\"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and \"\n f\"{len(visual_prompts['cls'])} respectively\"\n )\n if not isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):\n self.predictor = (predictor or yolo.yoloe.YOLOEVPDetectPredictor)(\n overrides={\n \"task\": self.model.task,\n \"mode\": \"predict\",\n \"save\": False,\n \"verbose\": refer_image is None,\n \"batch\": 1,\n },\n _callbacks=self.callbacks,\n )\n\n num_cls = (\n max(len(set(c)) for c in visual_prompts[\"cls\"])\n if isinstance(source, list) and refer_image is None # means multiple images\n else len(set(visual_prompts[\"cls\"]))\n )\n self.model.model[-1].nc = num_cls\n self.model.names = [f\"object{i}\" for i in range(num_cls)]\n self.predictor.set_prompts(visual_prompts.copy())\n self.predictor.setup_model(model=self.model)\n\n if refer_image is None and source is not None:\n dataset = load_inference_source(source)\n if dataset.mode in {\"video\", \"stream\"}:\n # NOTE: set the first frame as refer image for videos/streams inference\n refer_image = next(iter(dataset))[1][0]\n if refer_image is not None:\n vpe = self.predictor.get_vpe(refer_image)\n self.model.set_classes(self.model.names, vpe)\n self.task = \"segment\" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else \"detect\"\n self.predictor = None # reset predictor\n elif isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):\n self.predictor = None # reset predictor if no visual prompts\n\n return super().predict(source, stream, **kwargs)",
"chunk_type": "class",
"name": "YOLOE",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\model.py",
"start_line": 198,
"end_line": 440,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": "YOLOE object detection and segmentation model.\n\nYOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with\nimproved performance and additional features like visual and text positional embeddings.\n\nAttributes:\n model: The loaded YOLOE model instance.\n task: The task type (detect or segment).\n overrides: Configuration overrides for the model.\n\nMethods:\n __init__: Initialize YOLOE model with a pre-trained model file.\n task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.\n get_text_pe: Get text positional embeddings for the given texts.\n get_visual_pe: Get visual positional embeddings for the given image and visual features.\n set_vocab: Set vocabulary and class names for the YOLOE model.\n get_vocab: Get vocabulary for the given class names.\n set_classes: Set the model's class names and embeddings for detection.\n val: Validate the model using text or visual prompts.\n predict: Run prediction on images, videos, directories, streams, etc.\n\nExamples:\n Load a YOLOE detection model\n >>> model = YOLOE(\"yoloe-11s-seg.pt\")\n\n Set vocabulary and class names\n >>> model.set_vocab([\"person\", \"car\", \"dog\"], [\"person\", \"car\", \"dog\"])\n\n Predict with visual prompts\n >>> prompts = {\"bboxes\": [[10, 20, 100, 200]], \"cls\": [\"person\"]}\n >>> results = model.predict(\"image.jpg\", visual_prompts=prompts)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"ultralytics.data.build.load_inference_source",
"ultralytics.engine.model.Model",
"ultralytics.models.yolo",
"ultralytics.nn.tasks.ClassificationModel",
"ultralytics.nn.tasks.DetectionModel",
"ultralytics.nn.tasks.OBBModel",
"ultralytics.nn.tasks.PoseModel",
"ultralytics.nn.tasks.SegmentationModel",
"ultralytics.nn.tasks.WorldModel",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.nn.tasks.YOLOESegModel",
"ultralytics.utils.ROOT",
"ultralytics.utils.YAML",
"ultralytics.RTDETR",
"Model"
],
"chunk_id": "class_YOLOE_237af088"
},
{
"content": "from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe",
"chunk_type": "import",
"name": "classify, detect, obb, pose, segment, world, yoloe",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 86,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_classify, detect, obb, pose, segment, world, yoloe_bf01e42a"
},
{
"content": "from .model import YOLO, YOLOE, YOLOWorld",
"chunk_type": "import",
"name": "YOLO, YOLOE, YOLOWorld",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLO, YOLOE, YOLOWorld_e96d119e"
},
{
"content": "__all__ = \"classify\", \"segment\", \"detect\", \"pose\", \"obb\", \"world\", \"yoloe\", \"YOLO\", \"YOLOWorld\", \"YOLOE\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 104,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___f3807658"
},
{
"content": "import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_999a2db0"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_6bc7677f"
},
{
"content": "from functools import partial",
"chunk_type": "import",
"name": "partial",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_partial_43112863"
},
{
"content": "from typing import Any, Optional, Tuple, Type, Union",
"chunk_type": "import",
"name": "Any, Optional, Tuple, Type, Union",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Optional, Tuple, Type, Union_5730cfac"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_06181710"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_621563b7"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_48ff619f"
},
{
"content": "from torch import Tensor, nn",
"chunk_type": "import",
"name": "Tensor, nn",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Tensor, nn_82500d82"
},
{
"content": "from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock",
"chunk_type": "import",
"name": "MLP, LayerNorm2d, MLPBlock",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_MLP, LayerNorm2d, MLPBlock_4e9f6fe4"
},
{
"content": "from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer",
"chunk_type": "import",
"name": "Attention, TwoWayAttentionBlock, TwoWayTransformer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 75,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Attention, TwoWayAttentionBlock, TwoWayTransformer_c782685c"
},
{
"content": "from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition",
"chunk_type": "import",
"name": "add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 116,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition_0195c9a6"
},
{
"content": "class DropPath(nn.Module):\n \"\"\"\n Implements stochastic depth regularization for neural networks during training.\n\n Attributes:\n drop_prob (float): Probability of dropping a path during training.\n scale_by_keep (bool): Whether to scale the output by the keep probability.\n\n Methods:\n forward: Applies stochastic depth to input tensor during training, with optional scaling.\n\n Examples:\n >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True)\n >>> x = torch.randn(32, 64, 224, 224)\n >>> output = drop_path(x)\n \"\"\"\n\n def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):\n \"\"\"Initialize DropPath module for stochastic depth regularization during training.\"\"\"\n super().__init__()\n self.drop_prob = drop_prob\n self.scale_by_keep = scale_by_keep\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"Apply stochastic depth to input tensor during training, with optional scaling.\"\"\"\n if self.drop_prob == 0.0 or not self.training:\n return x\n keep_prob = 1 - self.drop_prob\n shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n if keep_prob > 0.0 and self.scale_by_keep:\n random_tensor.div_(keep_prob)\n return x * random_tensor",
"chunk_type": "class",
"name": "DropPath",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 19,
"end_line": 51,
"start_col": 0,
"end_col": 32,
"parent_name": null,
"docstring": "Implements stochastic depth regularization for neural networks during training.\n\nAttributes:\n drop_prob (float): Probability of dropping a path during training.\n scale_by_keep (bool): Whether to scale the output by the keep probability.\n\nMethods:\n forward: Applies stochastic depth to input tensor during training, with optional scaling.\n\nExamples:\n >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True)\n >>> x = torch.randn(32, 64, 224, 224)\n >>> output = drop_path(x)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_DropPath_b0d4151c"
},
{
"content": "class MaskDownSampler(nn.Module):\n \"\"\"\n A mask downsampling and embedding module for efficient processing of input masks.\n\n This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks\n while expanding their channel dimensions using convolutional layers, layer normalization, and activation\n functions.\n\n Attributes:\n encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and\n activation functions for downsampling and embedding masks.\n\n Methods:\n forward: Downsamples and encodes input mask to embed_dim channels.\n\n Examples:\n >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16)\n >>> input_mask = torch.randn(1, 1, 256, 256)\n >>> output = mask_downsampler(input_mask)\n >>> print(output.shape)\n torch.Size([1, 256, 16, 16])\n \"\"\"\n\n def __init__(\n self,\n embed_dim: int = 256,\n kernel_size: int = 4,\n stride: int = 4,\n padding: int = 0,\n total_stride: int = 16,\n activation: Type[nn.Module] = nn.GELU,\n ):\n \"\"\"Initialize a mask downsampler module for progressive downsampling and channel expansion.\"\"\"\n super().__init__()\n num_layers = int(math.log2(total_stride) // math.log2(stride))\n assert stride**num_layers == total_stride\n self.encoder = nn.Sequential()\n mask_in_chans, mask_out_chans = 1, 1\n for _ in range(num_layers):\n mask_out_chans = mask_in_chans * (stride**2)\n self.encoder.append(\n nn.Conv2d(\n mask_in_chans,\n mask_out_chans,\n kernel_size=kernel_size,\n stride=stride,\n padding=padding,\n )\n )\n self.encoder.append(LayerNorm2d(mask_out_chans))\n self.encoder.append(activation())\n mask_in_chans = mask_out_chans\n\n self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"Downsample and encode input mask to embed_dim channels using convolutional layers and LayerNorm2d.\"\"\"\n return self.encoder(x)",
"chunk_type": "class",
"name": "MaskDownSampler",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 54,
"end_line": 111,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": "A mask downsampling and embedding module for efficient processing of input masks.\n\nThis class implements a mask downsampler that progressively reduces the spatial dimensions of input masks\nwhile expanding their channel dimensions using convolutional layers, layer normalization, and activation\nfunctions.\n\nAttributes:\n encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and\n activation functions for downsampling and embedding masks.\n\nMethods:\n forward: Downsamples and encodes input mask to embed_dim channels.\n\nExamples:\n >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16)\n >>> input_mask = torch.randn(1, 1, 256, 256)\n >>> output = mask_downsampler(input_mask)\n >>> print(output.shape)\n torch.Size([1, 256, 16, 16])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_MaskDownSampler_c51bf8d6"
},
{
"content": "class CXBlock(nn.Module):\n \"\"\"\n ConvNeXt Block for efficient feature extraction in convolutional neural networks.\n\n This block implements a modified version of the ConvNeXt architecture, offering improved performance and\n flexibility in feature extraction.\n\n Attributes:\n dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.\n norm (LayerNorm2d): Layer normalization applied to channels.\n pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.\n act (nn.GELU): GELU activation function.\n pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.\n gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.\n drop_path (nn.Module): DropPath layer for stochastic depth regularization.\n\n Methods:\n forward: Processes the input tensor through the ConvNeXt block.\n\n Examples:\n >>> import torch\n >>> x = torch.randn(1, 64, 56, 56)\n >>> block = CXBlock(dim=64, kernel_size=7, padding=3)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 64, 56, 56])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n kernel_size: int = 7,\n padding: int = 3,\n drop_path: float = 0.0,\n layer_scale_init_value: float = 1e-6,\n use_dwconv: bool = True,\n ):\n \"\"\"\n Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.\n\n This block implements a modified version of the ConvNeXt architecture, offering improved performance and\n flexibility in feature extraction.\n\n Args:\n dim (int): Number of input channels.\n kernel_size (int): Size of the convolutional kernel.\n padding (int): Padding size for the convolution.\n drop_path (float): Stochastic depth rate.\n layer_scale_init_value (float): Initial value for Layer Scale.\n use_dwconv (bool): Whether to use depthwise convolution.\n\n Examples:\n >>> block = CXBlock(dim=64, kernel_size=7, padding=3)\n >>> x = torch.randn(1, 64, 32, 32)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 64, 32, 32])\n \"\"\"\n super().__init__()\n self.dwconv = nn.Conv2d(\n dim,\n dim,\n kernel_size=kernel_size,\n padding=padding,\n groups=dim if use_dwconv else 1,\n ) # depthwise conv\n self.norm = LayerNorm2d(dim, eps=1e-6)\n self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers\n self.act = nn.GELU()\n self.pwconv2 = nn.Linear(4 * dim, dim)\n self.gamma = (\n nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)\n if layer_scale_init_value > 0\n else None\n )\n self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"Apply ConvNeXt block operations to input tensor, including convolutions and residual connection.\"\"\"\n input = x\n x = self.dwconv(x)\n x = self.norm(x)\n x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)\n x = self.pwconv1(x)\n x = self.act(x)\n x = self.pwconv2(x)\n if self.gamma is not None:\n x = self.gamma * x\n x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)\n\n x = input + self.drop_path(x)\n return x",
"chunk_type": "class",
"name": "CXBlock",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 114,
"end_line": 205,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "ConvNeXt Block for efficient feature extraction in convolutional neural networks.\n\nThis block implements a modified version of the ConvNeXt architecture, offering improved performance and\nflexibility in feature extraction.\n\nAttributes:\n dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.\n norm (LayerNorm2d): Layer normalization applied to channels.\n pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.\n act (nn.GELU): GELU activation function.\n pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.\n gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.\n drop_path (nn.Module): DropPath layer for stochastic depth regularization.\n\nMethods:\n forward: Processes the input tensor through the ConvNeXt block.\n\nExamples:\n >>> import torch\n >>> x = torch.randn(1, 64, 56, 56)\n >>> block = CXBlock(dim=64, kernel_size=7, padding=3)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 64, 56, 56])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_CXBlock_2e708231"
},
{
"content": "class Fuser(nn.Module):\n \"\"\"\n A module for fusing features through multiple layers of a neural network.\n\n This class applies a series of identical layers to an input tensor, optionally projecting the input first.\n\n Attributes:\n proj (nn.Module): An optional input projection layer. Identity if no projection is needed.\n layers (nn.ModuleList): A list of identical layers to be applied sequentially.\n\n Methods:\n forward: Applies the fuser to an input tensor.\n\n Examples:\n >>> layer = CXBlock(dim=256)\n >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = fuser(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])\n \"\"\"\n\n def __init__(self, layer: nn.Module, num_layers: int, dim: Optional[int] = None, input_projection: bool = False):\n \"\"\"\n Initialize the Fuser module for feature fusion through multiple layers.\n\n This module creates a sequence of identical layers and optionally applies an input projection.\n\n Args:\n layer (nn.Module): The layer to be replicated in the fuser.\n num_layers (int): The number of times to replicate the layer.\n dim (int | None): The dimension for input projection, if used.\n input_projection (bool): Whether to use input projection.\n\n Examples:\n >>> layer = nn.Linear(64, 64)\n >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)\n >>> input_tensor = torch.randn(1, 64)\n >>> output = fuser(input_tensor)\n \"\"\"\n super().__init__()\n self.proj = nn.Identity()\n self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])\n\n if input_projection:\n assert dim is not None\n self.proj = nn.Conv2d(dim, dim, kernel_size=1)\n\n def forward(self, x: Tensor) -> Tensor:\n \"\"\"Apply a series of layers to the input tensor, optionally projecting it first.\"\"\"\n x = self.proj(x)\n for layer in self.layers:\n x = layer(x)\n return x",
"chunk_type": "class",
"name": "Fuser",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 208,
"end_line": 261,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "A module for fusing features through multiple layers of a neural network.\n\nThis class applies a series of identical layers to an input tensor, optionally projecting the input first.\n\nAttributes:\n proj (nn.Module): An optional input projection layer. Identity if no projection is needed.\n layers (nn.ModuleList): A list of identical layers to be applied sequentially.\n\nMethods:\n forward: Applies the fuser to an input tensor.\n\nExamples:\n >>> layer = CXBlock(dim=256)\n >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = fuser(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_Fuser_76f231f1"
},
{
"content": "class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):\n \"\"\"\n A two-way attention block for performing self-attention and cross-attention in both directions.\n\n This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on\n sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and\n cross-attention from dense to sparse inputs.\n\n Attributes:\n self_attn (Attention): Self-attention layer for queries.\n norm1 (nn.LayerNorm): Layer normalization after the first attention block.\n cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.\n norm2 (nn.LayerNorm): Layer normalization after the second attention block.\n mlp (MLP): MLP block for transforming query embeddings.\n norm3 (nn.LayerNorm): Layer normalization after the MLP block.\n norm4 (nn.LayerNorm): Layer normalization after the third attention block.\n cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.\n skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.\n\n Methods:\n forward: Processes input through the attention blocks and MLP.\n\n Examples:\n >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8)\n >>> sparse_input = torch.randn(1, 100, 256)\n >>> dense_input = torch.randn(1, 256, 16, 16)\n >>> sparse_output, dense_output = block(sparse_input, dense_input)\n \"\"\"\n\n def __init__(\n self,\n embedding_dim: int,\n num_heads: int,\n mlp_dim: int = 2048,\n activation: Type[nn.Module] = nn.ReLU,\n attention_downsample_rate: int = 2,\n skip_first_layer_pe: bool = False,\n ) -> None:\n \"\"\"\n Initialize a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.\n\n This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse\n inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention\n from dense to sparse inputs.\n\n Args:\n embedding_dim (int): The channel dimension of the embeddings.\n num_heads (int): The number of heads in the attention layers.\n mlp_dim (int): The hidden dimension of the MLP block.\n activation (Type[nn.Module]): The activation function of the MLP block.\n attention_downsample_rate (int): The downsample rate for attention computations.\n skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.\n\n Examples:\n >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> sparse_inputs = torch.randn(1, 100, 256)\n >>> dense_inputs = torch.randn(1, 256, 32, 32)\n >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)\n \"\"\"\n super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)\n self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)",
"chunk_type": "class",
"name": "SAM2TwoWayAttentionBlock",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 264,
"end_line": 324,
"start_col": 0,
"end_col": 91,
"parent_name": null,
"docstring": "A two-way attention block for performing self-attention and cross-attention in both directions.\n\nThis block extends the TwoWayAttentionBlock and consists of four main components: self-attention on\nsparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and\ncross-attention from dense to sparse inputs.\n\nAttributes:\n self_attn (Attention): Self-attention layer for queries.\n norm1 (nn.LayerNorm): Layer normalization after the first attention block.\n cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.\n norm2 (nn.LayerNorm): Layer normalization after the second attention block.\n mlp (MLP): MLP block for transforming query embeddings.\n norm3 (nn.LayerNorm): Layer normalization after the MLP block.\n norm4 (nn.LayerNorm): Layer normalization after the third attention block.\n cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.\n skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.\n\nMethods:\n forward: Processes input through the attention blocks and MLP.\n\nExamples:\n >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8)\n >>> sparse_input = torch.randn(1, 100, 256)\n >>> dense_input = torch.randn(1, 256, 16, 16)\n >>> sparse_output, dense_output = block(sparse_input, dense_input)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"TwoWayAttentionBlock"
],
"chunk_id": "class_SAM2TwoWayAttentionBlock_a7497d58"
},
{
"content": "class SAM2TwoWayTransformer(TwoWayTransformer):\n \"\"\"\n A Two-Way Transformer module for simultaneous attention to image and query points.\n\n This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an\n input image using queries with supplied positional embeddings. It is particularly useful for tasks like\n object detection, image segmentation, and point cloud processing.\n\n Attributes:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention.\n mlp_dim (int): Internal channel dimension for the MLP block.\n layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer.\n final_attn_token_to_image (Attention): Final attention layer from queries to image.\n norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.\n\n Methods:\n forward: Processes input image embeddings and query embeddings through the transformer.\n\n Examples:\n >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> image_embedding = torch.randn(1, 256, 64, 64)\n >>> query_embedding = torch.randn(1, 100, 256)\n >>> output = transformer(image_embedding, query_embedding)\n >>> print(output[0].shape, output[1].shape)\n torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64])\n \"\"\"\n\n def __init__(\n self,\n depth: int,\n embedding_dim: int,\n num_heads: int,\n mlp_dim: int,\n activation: Type[nn.Module] = nn.ReLU,\n attention_downsample_rate: int = 2,\n ) -> None:\n \"\"\"\n Initialize a SAM2TwoWayTransformer instance.\n\n This transformer decoder attends to an input image using queries with supplied positional embeddings.\n It is designed for tasks like object detection, image segmentation, and point cloud processing.\n\n Args:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for the input embeddings.\n num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.\n mlp_dim (int): Channel dimension internal to the MLP block.\n activation (Type[nn.Module]): Activation function to use in the MLP block.\n attention_downsample_rate (int): Downsampling rate for attention computations.\n\n Examples:\n >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> transformer\n SAM2TwoWayTransformer(\n (layers): ModuleList(\n (0-4): 5 x SAM2TwoWayAttentionBlock(...)\n )\n (final_attn_token_to_image): Attention(...)\n (norm_final_attn): LayerNorm(...)\n )\n \"\"\"\n super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)\n self.layers = nn.ModuleList()\n for i in range(depth):\n self.layers.append(\n SAM2TwoWayAttentionBlock(\n embedding_dim=embedding_dim,\n num_heads=num_heads,\n mlp_dim=mlp_dim,\n activation=activation,\n attention_downsample_rate=attention_downsample_rate,\n skip_first_layer_pe=(i == 0),\n )\n )",
"chunk_type": "class",
"name": "SAM2TwoWayTransformer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 327,
"end_line": 402,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": "A Two-Way Transformer module for simultaneous attention to image and query points.\n\nThis class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an\ninput image using queries with supplied positional embeddings. It is particularly useful for tasks like\nobject detection, image segmentation, and point cloud processing.\n\nAttributes:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention.\n mlp_dim (int): Internal channel dimension for the MLP block.\n layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer.\n final_attn_token_to_image (Attention): Final attention layer from queries to image.\n norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.\n\nMethods:\n forward: Processes input image embeddings and query embeddings through the transformer.\n\nExamples:\n >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> image_embedding = torch.randn(1, 256, 64, 64)\n >>> query_embedding = torch.randn(1, 100, 256)\n >>> output = transformer(image_embedding, query_embedding)\n >>> print(output[0].shape, output[1].shape)\n torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"TwoWayTransformer"
],
"chunk_id": "class_SAM2TwoWayTransformer_14edd00a"
},
{
"content": "class RoPEAttention(Attention):\n \"\"\"\n Implements rotary position encoding for attention mechanisms in transformer architectures.\n\n This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance\n the positional awareness of the attention mechanism.\n\n Attributes:\n compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.\n freqs_cis (torch.Tensor): Precomputed frequency tensor for rotary encoding.\n rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.\n\n Methods:\n forward: Applies rotary position encoding and computes attention between query, key, and value tensors.\n\n Examples:\n >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))\n >>> q = torch.randn(1, 1024, 256)\n >>> k = torch.randn(1, 1024, 256)\n >>> v = torch.randn(1, 1024, 256)\n >>> output = rope_attn(q, k, v)\n >>> print(output.shape)\n torch.Size([1, 1024, 256])\n \"\"\"\n\n def __init__(\n self,\n *args,\n rope_theta: float = 10000.0,\n rope_k_repeat: bool = False,\n feat_sizes: Tuple[int, int] = (32, 32), # [w, h] for stride 16 feats at 512 resolution\n **kwargs,\n ):\n \"\"\"Initialize RoPEAttention with rotary position encoding for enhanced positional awareness.\"\"\"\n super().__init__(*args, **kwargs)\n\n self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)\n freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])\n self.freqs_cis = freqs_cis\n self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories\n\n def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_k_exclude_rope: int = 0) -> torch.Tensor:\n \"\"\"Apply rotary position encoding and compute attention between query, key, and value tensors.\"\"\"\n q = self.q_proj(q)\n k = self.k_proj(k)\n v = self.v_proj(v)\n\n # Separate into heads\n q = self._separate_heads(q, self.num_heads)\n k = self._separate_heads(k, self.num_heads)\n v = self._separate_heads(v, self.num_heads)\n\n # Apply rotary position encoding\n w = h = math.sqrt(q.shape[-2])\n self.freqs_cis = self.freqs_cis.to(q.device)\n if self.freqs_cis.shape[0] != q.shape[-2]:\n self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)\n if q.shape[-2] != k.shape[-2]:\n assert self.rope_k_repeat\n\n num_k_rope = k.size(-2) - num_k_exclude_rope\n q, k[:, :, :num_k_rope] = apply_rotary_enc(\n q,\n k[:, :, :num_k_rope],\n freqs_cis=self.freqs_cis,\n repeat_freqs_k=self.rope_k_repeat,\n )\n\n # Attention\n _, _, _, c_per_head = q.shape\n attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens\n attn = attn / math.sqrt(c_per_head)\n attn = torch.softmax(attn, dim=-1)\n\n # Get output\n out = attn @ v\n\n out = self._recombine_heads(out)\n out = self.out_proj(out)\n\n return out",
"chunk_type": "class",
"name": "RoPEAttention",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 405,
"end_line": 485,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Implements rotary position encoding for attention mechanisms in transformer architectures.\n\nThis class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance\nthe positional awareness of the attention mechanism.\n\nAttributes:\n compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.\n freqs_cis (torch.Tensor): Precomputed frequency tensor for rotary encoding.\n rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.\n\nMethods:\n forward: Applies rotary position encoding and computes attention between query, key, and value tensors.\n\nExamples:\n >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))\n >>> q = torch.randn(1, 1024, 256)\n >>> k = torch.randn(1, 1024, 256)\n >>> v = torch.randn(1, 1024, 256)\n >>> output = rope_attn(q, k, v)\n >>> print(output.shape)\n torch.Size([1, 1024, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"Attention"
],
"chunk_id": "class_RoPEAttention_c89f6ba9"
},
{
"content": "def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:\n \"\"\"Apply pooling and optional normalization to a tensor, handling spatial dimension permutations.\"\"\"\n if pool is None:\n return x\n # (B, H, W, C) -> (B, C, H, W)\n x = x.permute(0, 3, 1, 2)\n x = pool(x)\n # (B, C, H', W') -> (B, H', W', C)\n x = x.permute(0, 2, 3, 1)\n if norm:\n x = norm(x)\n\n return x",
"chunk_type": "function",
"name": "do_pool",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 488,
"end_line": 500,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Apply pooling and optional normalization to a tensor, handling spatial dimension permutations.",
"parameters": [
"x: torch.Tensor",
"pool: nn.Module",
"norm: nn.Module"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition"
],
"chunk_id": "function_do_pool_669ad8d8"
},
{
"content": "class MultiScaleAttention(nn.Module):\n \"\"\"\n Implements multiscale self-attention with optional query pooling for efficient feature extraction.\n\n This class provides a flexible implementation of multiscale attention, allowing for optional\n downsampling of query features through pooling. It's designed to enhance the model's ability to\n capture multiscale information in visual tasks.\n\n Attributes:\n dim (int): Input dimension of the feature map.\n dim_out (int): Output dimension of the attention module.\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for dot-product attention.\n q_pool (nn.Module | None): Optional pooling module for query features.\n qkv (nn.Linear): Linear projection for query, key, and value.\n proj (nn.Linear): Output projection.\n\n Methods:\n forward: Applies multiscale attention to the input tensor.\n\n Examples:\n >>> import torch\n >>> from torch import nn\n >>> x = torch.randn(1, 64, 64, 256)\n >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8)\n >>> output = msa(x)\n >>> print(output.shape)\n torch.Size([1, 64, 64, 256])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n dim_out: int,\n num_heads: int,\n q_pool: nn.Module = None,\n ):\n \"\"\"Initialize multiscale attention with optional query pooling for efficient feature extraction.\"\"\"\n super().__init__()\n\n self.dim = dim\n self.dim_out = dim_out\n\n self.num_heads = num_heads\n head_dim = dim_out // num_heads\n self.scale = head_dim**-0.5\n\n self.q_pool = q_pool\n self.qkv = nn.Linear(dim, dim_out * 3)\n self.proj = nn.Linear(dim_out, dim_out)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply multiscale attention with optional query pooling to extract multiscale features.\"\"\"\n B, H, W, _ = x.shape\n # qkv with shape (B, H * W, 3, nHead, C)\n qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)\n # q, k, v with shape (B, H * W, nheads, C)\n q, k, v = torch.unbind(qkv, 2)\n\n # Q pooling (for downsample at stage changes)\n if self.q_pool:\n q = do_pool(q.reshape(B, H, W, -1), self.q_pool)\n H, W = q.shape[1:3] # downsampled shape\n q = q.reshape(B, H * W, self.num_heads, -1)\n\n # Torch's SDPA expects [B, nheads, H*W, C] so we transpose\n x = F.scaled_dot_product_attention(\n q.transpose(1, 2),\n k.transpose(1, 2),\n v.transpose(1, 2),\n )\n # Transpose back\n x = x.transpose(1, 2)\n x = x.reshape(B, H, W, -1)\n\n x = self.proj(x)\n\n return x",
"chunk_type": "class",
"name": "MultiScaleAttention",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 503,
"end_line": 580,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Implements multiscale self-attention with optional query pooling for efficient feature extraction.\n\nThis class provides a flexible implementation of multiscale attention, allowing for optional\ndownsampling of query features through pooling. It's designed to enhance the model's ability to\ncapture multiscale information in visual tasks.\n\nAttributes:\n dim (int): Input dimension of the feature map.\n dim_out (int): Output dimension of the attention module.\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for dot-product attention.\n q_pool (nn.Module | None): Optional pooling module for query features.\n qkv (nn.Linear): Linear projection for query, key, and value.\n proj (nn.Linear): Output projection.\n\nMethods:\n forward: Applies multiscale attention to the input tensor.\n\nExamples:\n >>> import torch\n >>> from torch import nn\n >>> x = torch.randn(1, 64, 64, 256)\n >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8)\n >>> output = msa(x)\n >>> print(output.shape)\n torch.Size([1, 64, 64, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_MultiScaleAttention_4d077d4a"
},
{
"content": "class MultiScaleBlock(nn.Module):\n \"\"\"\n A multiscale attention block with window partitioning and query pooling for efficient vision transformers.\n\n This class implements a multiscale attention mechanism with optional window partitioning and downsampling,\n designed for use in vision transformer architectures.\n\n Attributes:\n dim (int): Input dimension of the block.\n dim_out (int): Output dimension of the block.\n norm1 (nn.Module): First normalization layer.\n window_size (int): Size of the window for partitioning.\n pool (nn.Module | None): Pooling layer for query downsampling.\n q_stride (Tuple[int, int] | None): Stride for query pooling.\n attn (MultiScaleAttention): Multi-scale attention module.\n drop_path (nn.Module): Drop path layer for regularization.\n norm2 (nn.Module): Second normalization layer.\n mlp (MLP): Multi-layer perceptron module.\n proj (nn.Linear | None): Projection layer for dimension mismatch.\n\n Methods:\n forward: Processes input tensor through the multiscale block.\n\n Examples:\n >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 28, 28, 512])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n dim_out: int,\n num_heads: int,\n mlp_ratio: float = 4.0,\n drop_path: float = 0.0,\n norm_layer: Union[nn.Module, str] = \"LayerNorm\",\n q_stride: Tuple[int, int] = None,\n act_layer: Type[nn.Module] = nn.GELU,\n window_size: int = 0,\n ):\n \"\"\"Initialize a multiscale attention block with window partitioning and optional query pooling.\"\"\"\n super().__init__()\n\n if isinstance(norm_layer, str):\n norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)\n\n self.dim = dim\n self.dim_out = dim_out\n self.norm1 = norm_layer(dim)\n\n self.window_size = window_size\n\n self.pool, self.q_stride = None, q_stride\n if self.q_stride:\n self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)\n\n self.attn = MultiScaleAttention(\n dim,\n dim_out,\n num_heads=num_heads,\n q_pool=self.pool,\n )\n self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n self.norm2 = norm_layer(dim_out)\n self.mlp = MLP(\n dim_out,\n int(dim_out * mlp_ratio),\n dim_out,\n num_layers=2,\n act=act_layer,\n )\n\n if dim != dim_out:\n self.proj = nn.Linear(dim, dim_out)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through multiscale attention and MLP, with optional windowing and downsampling.\"\"\"\n shortcut = x # B, H, W, C\n x = self.norm1(x)\n\n # Skip connection\n if self.dim != self.dim_out:\n shortcut = do_pool(self.proj(x), self.pool)\n\n # Window partition\n window_size = self.window_size\n if window_size > 0:\n H, W = x.shape[1], x.shape[2]\n x, pad_hw = window_partition(x, window_size)\n\n # Window Attention + Q Pooling (if stage change)\n x = self.attn(x)\n if self.q_stride:\n # Shapes have changed due to Q pooling\n window_size = self.window_size // self.q_stride[0]\n H, W = shortcut.shape[1:3]\n\n pad_h = (window_size - H % window_size) % window_size\n pad_w = (window_size - W % window_size) % window_size\n pad_hw = (H + pad_h, W + pad_w)\n\n # Reverse window partition\n if self.window_size > 0:\n x = window_unpartition(x, window_size, pad_hw, (H, W))\n\n x = shortcut + self.drop_path(x)\n # MLP\n x = x + self.drop_path(self.mlp(self.norm2(x)))\n return x",
"chunk_type": "class",
"name": "MultiScaleBlock",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 583,
"end_line": 695,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "A multiscale attention block with window partitioning and query pooling for efficient vision transformers.\n\nThis class implements a multiscale attention mechanism with optional window partitioning and downsampling,\ndesigned for use in vision transformer architectures.\n\nAttributes:\n dim (int): Input dimension of the block.\n dim_out (int): Output dimension of the block.\n norm1 (nn.Module): First normalization layer.\n window_size (int): Size of the window for partitioning.\n pool (nn.Module | None): Pooling layer for query downsampling.\n q_stride (Tuple[int, int] | None): Stride for query pooling.\n attn (MultiScaleAttention): Multi-scale attention module.\n drop_path (nn.Module): Drop path layer for regularization.\n norm2 (nn.Module): Second normalization layer.\n mlp (MLP): Multi-layer perceptron module.\n proj (nn.Linear | None): Projection layer for dimension mismatch.\n\nMethods:\n forward: Processes input tensor through the multiscale block.\n\nExamples:\n >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 28, 28, 512])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_MultiScaleBlock_cea50200"
},
{
"content": "class PositionEmbeddingSine(nn.Module):\n \"\"\"\n A module for generating sinusoidal positional embeddings for 2D inputs like images.\n\n This class implements sinusoidal position encoding for 2D spatial positions, which can be used in\n transformer-based models for computer vision tasks.\n\n Attributes:\n num_pos_feats (int): Number of positional features (half of the embedding dimension).\n temperature (int): Temperature parameter for the sinusoidal functions.\n normalize (bool): Whether to normalize the positional embeddings.\n scale (float): Scaling factor for the embeddings when normalize is True.\n cache (dict): Cache for storing precomputed embeddings.\n\n Methods:\n _encode_xy: Encodes 2D positions using sine and cosine functions.\n encode_boxes: Encodes box coordinates and dimensions into positional embeddings.\n encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings.\n forward: Generates sinusoidal position embeddings for 2D inputs.\n\n Examples:\n >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> embeddings = pos_emb(x)\n >>> print(embeddings.shape)\n torch.Size([1, 256, 224, 224])\n \"\"\"\n\n def __init__(\n self,\n num_pos_feats: int,\n temperature: int = 10000,\n normalize: bool = True,\n scale: Optional[float] = None,\n ):\n \"\"\"Initialize sinusoidal position embeddings for 2D image inputs.\"\"\"\n super().__init__()\n assert num_pos_feats % 2 == 0, \"Expecting even model width\"\n self.num_pos_feats = num_pos_feats // 2\n self.temperature = temperature\n self.normalize = normalize\n if scale is not None and not normalize:\n raise ValueError(\"normalize should be True if scale is passed\")\n if scale is None:\n scale = 2 * math.pi\n self.scale = scale\n\n self.cache = {}\n\n def _encode_xy(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Encode 2D positions using sine/cosine functions for transformer positional embeddings.\"\"\"\n assert len(x) == len(y) and x.ndim == y.ndim == 1\n x_embed = x * self.scale\n y_embed = y * self.scale\n\n dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)\n dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n\n pos_x = x_embed[:, None] / dim_t\n pos_y = y_embed[:, None] / dim_t\n pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)\n pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)\n return pos_x, pos_y\n\n @torch.no_grad()\n def encode_boxes(self, x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor) -> torch.Tensor:\n \"\"\"Encode box coordinates and dimensions into positional embeddings for detection.\"\"\"\n pos_x, pos_y = self._encode_xy(x, y)\n return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)\n\n encode = encode_boxes # Backwards compatibility\n\n @torch.no_grad()\n def encode_points(self, x: torch.Tensor, y: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:\n \"\"\"Encode 2D points with sinusoidal embeddings and append labels.\"\"\"\n (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape\n assert bx == by and nx == ny and bx == bl and nx == nl\n pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())\n pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)\n return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)\n\n @torch.no_grad()\n def forward(self, x: torch.Tensor) -> Tensor:\n \"\"\"Generate sinusoidal position embeddings for 2D inputs like images.\"\"\"\n cache_key = (x.shape[-2], x.shape[-1])\n if cache_key in self.cache:\n return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)\n y_embed = (\n torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)\n .view(1, -1, 1)\n .repeat(x.shape[0], 1, x.shape[-1])\n )\n x_embed = (\n torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)\n .view(1, 1, -1)\n .repeat(x.shape[0], x.shape[-2], 1)\n )\n\n if self.normalize:\n eps = 1e-6\n y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale\n x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale\n\n dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)\n dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n\n pos_x = x_embed[:, :, :, None] / dim_t\n pos_y = y_embed[:, :, :, None] / dim_t\n pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)\n pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)\n pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)\n self.cache[cache_key] = pos[0]\n return pos",
"chunk_type": "class",
"name": "PositionEmbeddingSine",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 698,
"end_line": 810,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "A module for generating sinusoidal positional embeddings for 2D inputs like images.\n\nThis class implements sinusoidal position encoding for 2D spatial positions, which can be used in\ntransformer-based models for computer vision tasks.\n\nAttributes:\n num_pos_feats (int): Number of positional features (half of the embedding dimension).\n temperature (int): Temperature parameter for the sinusoidal functions.\n normalize (bool): Whether to normalize the positional embeddings.\n scale (float): Scaling factor for the embeddings when normalize is True.\n cache (dict): Cache for storing precomputed embeddings.\n\nMethods:\n _encode_xy: Encodes 2D positions using sine and cosine functions.\n encode_boxes: Encodes box coordinates and dimensions into positional embeddings.\n encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings.\n forward: Generates sinusoidal position embeddings for 2D inputs.\n\nExamples:\n >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> embeddings = pos_emb(x)\n >>> print(embeddings.shape)\n torch.Size([1, 256, 224, 224])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_PositionEmbeddingSine_cb0e1827"
},
{
"content": "class PositionEmbeddingRandom(nn.Module):\n \"\"\"\n Positional encoding using random spatial frequencies.\n\n This class generates positional embeddings for input coordinates using random spatial frequencies. It is\n particularly useful for transformer-based models that require position information.\n\n Attributes:\n positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding.\n\n Methods:\n _pe_encoding: Positionally encodes points that are normalized to [0,1].\n forward: Generates positional encoding for a grid of the specified size.\n forward_with_coords: Positionally encodes points that are not normalized to [0,1].\n\n Examples:\n >>> pe = PositionEmbeddingRandom(num_pos_feats=64)\n >>> size = (32, 32)\n >>> encoding = pe(size)\n >>> print(encoding.shape)\n torch.Size([128, 32, 32])\n \"\"\"\n\n def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:\n \"\"\"Initialize random spatial frequency position embedding for transformers.\"\"\"\n super().__init__()\n if scale is None or scale <= 0.0:\n scale = 1.0\n self.register_buffer(\"positional_encoding_gaussian_matrix\", scale * torch.randn((2, num_pos_feats)))\n\n # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'\n torch.use_deterministic_algorithms(False)\n torch.backends.cudnn.deterministic = False\n\n def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:\n \"\"\"Encode normalized [0,1] coordinates using random spatial frequencies.\"\"\"\n # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape\n coords = 2 * coords - 1\n coords = coords @ self.positional_encoding_gaussian_matrix\n coords = 2 * np.pi * coords\n # Outputs d_1 x ... x d_n x C shape\n return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)\n\n def forward(self, size: Tuple[int, int]) -> torch.Tensor:\n \"\"\"Generate positional encoding for a grid using random spatial frequencies.\"\"\"\n h, w = size\n device: Any = self.positional_encoding_gaussian_matrix.device\n grid = torch.ones((h, w), device=device, dtype=torch.float32)\n y_embed = grid.cumsum(dim=0) - 0.5\n x_embed = grid.cumsum(dim=1) - 0.5\n y_embed = y_embed / h\n x_embed = x_embed / w\n\n pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))\n return pe.permute(2, 0, 1) # C x H x W\n\n def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:\n \"\"\"Positionally encode input coordinates, normalizing them to [0,1] based on the given image size.\"\"\"\n coords = coords_input.clone()\n coords[:, :, 0] = coords[:, :, 0] / image_size[1]\n coords[:, :, 1] = coords[:, :, 1] / image_size[0]\n return self._pe_encoding(coords.to(torch.float)) # B x N x C",
"chunk_type": "class",
"name": "PositionEmbeddingRandom",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 813,
"end_line": 874,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": "Positional encoding using random spatial frequencies.\n\nThis class generates positional embeddings for input coordinates using random spatial frequencies. It is\nparticularly useful for transformer-based models that require position information.\n\nAttributes:\n positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding.\n\nMethods:\n _pe_encoding: Positionally encodes points that are normalized to [0,1].\n forward: Generates positional encoding for a grid of the specified size.\n forward_with_coords: Positionally encodes points that are not normalized to [0,1].\n\nExamples:\n >>> pe = PositionEmbeddingRandom(num_pos_feats=64)\n >>> size = (32, 32)\n >>> encoding = pe(size)\n >>> print(encoding.shape)\n torch.Size([128, 32, 32])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_PositionEmbeddingRandom_b8cb3304"
},
{
"content": "class Block(nn.Module):\n \"\"\"\n Transformer block with support for window attention and residual propagation.\n\n This class implements a transformer block that can use either global or windowed self-attention,\n followed by a feed-forward network. It supports relative positional embeddings and is designed\n for use in vision transformer architectures.\n\n Attributes:\n norm1 (nn.Module): First normalization layer.\n attn (REAttention): Self-attention layer with optional relative positional encoding.\n norm2 (nn.Module): Second normalization layer.\n mlp (MLPBlock): Multi-layer perceptron block.\n window_size (int): Size of attention window. If 0, global attention is used.\n\n Methods:\n forward: Processes input through the transformer block.\n\n Examples:\n >>> import torch\n >>> block = Block(dim=256, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 56, 56, 256])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n num_heads: int,\n mlp_ratio: float = 4.0,\n qkv_bias: bool = True,\n norm_layer: Type[nn.Module] = nn.LayerNorm,\n act_layer: Type[nn.Module] = nn.GELU,\n use_rel_pos: bool = False,\n rel_pos_zero_init: bool = True,\n window_size: int = 0,\n input_size: Optional[Tuple[int, int]] = None,\n ) -> None:\n \"\"\"\n Initialize a transformer block with optional window attention and relative positional embeddings.\n\n This constructor sets up a transformer block that can use either global or windowed self-attention,\n followed by a feed-forward network. It supports relative positional embeddings and is designed\n for use in vision transformer architectures.\n\n Args:\n dim (int): Number of input channels.\n num_heads (int): Number of attention heads in the self-attention layer.\n mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension.\n qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.\n norm_layer (Type[nn.Module]): Type of normalization layer to use.\n act_layer (Type[nn.Module]): Type of activation function to use in the MLP block.\n use_rel_pos (bool): If True, uses relative positional embeddings in attention.\n rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.\n window_size (int): Size of attention window. If 0, uses global attention.\n input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.\n\n Examples:\n >>> block = Block(dim=256, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 56, 56, 256])\n \"\"\"\n super().__init__()\n self.norm1 = norm_layer(dim)\n self.attn = REAttention(\n dim,\n num_heads=num_heads,\n qkv_bias=qkv_bias,\n use_rel_pos=use_rel_pos,\n rel_pos_zero_init=rel_pos_zero_init,\n input_size=input_size if window_size == 0 else (window_size, window_size),\n )\n\n self.norm2 = norm_layer(dim)\n self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)\n\n self.window_size = window_size\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through transformer block with optional windowed self-attention and residual connection.\"\"\"\n shortcut = x\n x = self.norm1(x)\n # Window partition\n if self.window_size > 0:\n H, W = x.shape[1], x.shape[2]\n x, pad_hw = window_partition(x, self.window_size)\n\n x = self.attn(x)\n # Reverse window partition\n if self.window_size > 0:\n x = window_unpartition(x, self.window_size, pad_hw, (H, W))\n\n x = shortcut + x\n return x + self.mlp(self.norm2(x))",
"chunk_type": "class",
"name": "Block",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 877,
"end_line": 974,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": "Transformer block with support for window attention and residual propagation.\n\nThis class implements a transformer block that can use either global or windowed self-attention,\nfollowed by a feed-forward network. It supports relative positional embeddings and is designed\nfor use in vision transformer architectures.\n\nAttributes:\n norm1 (nn.Module): First normalization layer.\n attn (REAttention): Self-attention layer with optional relative positional encoding.\n norm2 (nn.Module): Second normalization layer.\n mlp (MLPBlock): Multi-layer perceptron block.\n window_size (int): Size of attention window. If 0, global attention is used.\n\nMethods:\n forward: Processes input through the transformer block.\n\nExamples:\n >>> import torch\n >>> block = Block(dim=256, num_heads=8, window_size=7)\n >>> x = torch.randn(1, 56, 56, 256)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 56, 56, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_Block_3dc742de"
},
{
"content": "class REAttention(nn.Module):\n \"\"\"\n Relative Position Attention module for efficient self-attention in transformer architectures.\n\n This class implements a multi-head attention mechanism with relative positional embeddings, designed\n for use in vision transformer models. It supports optional query pooling and window partitioning\n for efficient processing of large inputs.\n\n Attributes:\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for attention computation.\n qkv (nn.Linear): Linear projection for query, key, and value.\n proj (nn.Linear): Output projection layer.\n use_rel_pos (bool): Whether to use relative positional embeddings.\n rel_pos_h (nn.Parameter): Relative positional embeddings for height dimension.\n rel_pos_w (nn.Parameter): Relative positional embeddings for width dimension.\n\n Methods:\n forward: Applies multi-head attention with optional relative positional encoding to input tensor.\n\n Examples:\n >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))\n >>> x = torch.randn(1, 32, 32, 256)\n >>> output = attention(x)\n >>> print(output.shape)\n torch.Size([1, 32, 32, 256])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n num_heads: int = 8,\n qkv_bias: bool = True,\n use_rel_pos: bool = False,\n rel_pos_zero_init: bool = True,\n input_size: Optional[Tuple[int, int]] = None,\n ) -> None:\n \"\"\"\n Initialize a Relative Position Attention module for transformer-based architectures.\n\n This module implements multi-head attention with optional relative positional encodings, designed\n specifically for vision tasks in transformer models.\n\n Args:\n dim (int): Number of input channels.\n num_heads (int): Number of attention heads.\n qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.\n use_rel_pos (bool): If True, uses relative positional encodings.\n rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.\n input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.\n Required if use_rel_pos is True.\n\n Examples:\n >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))\n >>> x = torch.randn(1, 32, 32, 256)\n >>> output = attention(x)\n >>> print(output.shape)\n torch.Size([1, 32, 32, 256])\n \"\"\"\n super().__init__()\n self.num_heads = num_heads\n head_dim = dim // num_heads\n self.scale = head_dim**-0.5\n\n self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n self.proj = nn.Linear(dim, dim)\n\n self.use_rel_pos = use_rel_pos\n if self.use_rel_pos:\n assert input_size is not None, \"Input size must be provided if using relative positional encoding.\"\n # Initialize relative positional embeddings\n self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))\n self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply multi-head attention with optional relative positional encoding to input tensor.\"\"\"\n B, H, W, _ = x.shape\n # qkv with shape (3, B, nHead, H * W, C)\n qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n # q, k, v with shape (B * nHead, H * W, C)\n q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)\n\n attn = (q * self.scale) @ k.transpose(-2, -1)\n\n if self.use_rel_pos:\n attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))\n\n attn = attn.softmax(dim=-1)\n x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)\n return self.proj(x)",
"chunk_type": "class",
"name": "REAttention",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 977,
"end_line": 1066,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": "Relative Position Attention module for efficient self-attention in transformer architectures.\n\nThis class implements a multi-head attention mechanism with relative positional embeddings, designed\nfor use in vision transformer models. It supports optional query pooling and window partitioning\nfor efficient processing of large inputs.\n\nAttributes:\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for attention computation.\n qkv (nn.Linear): Linear projection for query, key, and value.\n proj (nn.Linear): Output projection layer.\n use_rel_pos (bool): Whether to use relative positional embeddings.\n rel_pos_h (nn.Parameter): Relative positional embeddings for height dimension.\n rel_pos_w (nn.Parameter): Relative positional embeddings for width dimension.\n\nMethods:\n forward: Applies multi-head attention with optional relative positional encoding to input tensor.\n\nExamples:\n >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))\n >>> x = torch.randn(1, 32, 32, 256)\n >>> output = attention(x)\n >>> print(output.shape)\n torch.Size([1, 32, 32, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_REAttention_f78613db"
},
{
"content": "class PatchEmbed(nn.Module):\n \"\"\"\n Image to Patch Embedding module for vision transformer architectures.\n\n This module converts an input image into a sequence of patch embeddings using a convolutional layer.\n It is commonly used as the first layer in vision transformer architectures to transform image data\n into a suitable format for subsequent transformer blocks.\n\n Attributes:\n proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.\n\n Methods:\n forward: Applies patch embedding to the input tensor.\n\n Examples:\n >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 768, 14, 14])\n \"\"\"\n\n def __init__(\n self,\n kernel_size: Tuple[int, int] = (16, 16),\n stride: Tuple[int, int] = (16, 16),\n padding: Tuple[int, int] = (0, 0),\n in_chans: int = 3,\n embed_dim: int = 768,\n ) -> None:\n \"\"\"\n Initialize the PatchEmbed module for converting image patches to embeddings.\n\n This module is typically used as the first layer in vision transformer architectures to transform\n image data into a suitable format for subsequent transformer blocks.\n\n Args:\n kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction.\n stride (Tuple[int, int]): Stride of the convolutional operation.\n padding (Tuple[int, int]): Padding applied to the input before convolution.\n in_chans (int): Number of input image channels.\n embed_dim (int): Dimensionality of the output patch embeddings.\n\n Examples:\n >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 768, 14, 14])\n \"\"\"\n super().__init__()\n\n self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Compute patch embedding by applying convolution and transposing resulting tensor.\"\"\"\n return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C",
"chunk_type": "class",
"name": "PatchEmbed",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\blocks.py",
"start_line": 1069,
"end_line": 1125,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Image to Patch Embedding module for vision transformer architectures.\n\nThis module converts an input image into a sequence of patch embeddings using a convolutional layer.\nIt is commonly used as the first layer in vision transformer architectures to transform image data\ninto a suitable format for subsequent transformer blocks.\n\nAttributes:\n proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.\n\nMethods:\n forward: Applies patch embedding to the input tensor.\n\nExamples:\n >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 768, 14, 14])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"functools.partial",
"typing.Any",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"typing.Union",
"numpy",
"torch",
"torch.nn.functional",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.nn.modules.MLPBlock",
"transformer.Attention",
"transformer.TwoWayAttentionBlock",
"transformer.TwoWayTransformer",
"utils.add_decomposed_rel_pos",
"utils.apply_rotary_enc",
"utils.compute_axial_cis",
"utils.window_partition",
"utils.window_unpartition",
"nn.Module"
],
"chunk_id": "class_PatchEmbed_102cc953"
},
{
"content": "from typing import List, Optional, Tuple, Type",
"chunk_type": "import",
"name": "List, Optional, Tuple, Type",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional, Tuple, Type_d508b8c2"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_5dca3f97"
},
{
"content": "from torch import nn",
"chunk_type": "import",
"name": "nn",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_nn_7328c221"
},
{
"content": "from ultralytics.nn.modules import MLP, LayerNorm2d",
"chunk_type": "import",
"name": "MLP, LayerNorm2d",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_MLP, LayerNorm2d_c51d52d2"
},
{
"content": "class MaskDecoder(nn.Module):\n \"\"\"\n Decoder module for generating masks and their associated quality scores using a transformer architecture.\n\n This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and\n generate mask predictions along with their quality scores.\n\n Attributes:\n transformer_dim (int): Channel dimension for the transformer module.\n transformer (nn.Module): Transformer module used for mask prediction.\n num_multimask_outputs (int): Number of masks to predict for disambiguating masks.\n iou_token (nn.Embedding): Embedding for the IoU token.\n num_mask_tokens (int): Number of mask tokens.\n mask_tokens (nn.Embedding): Embedding for the mask tokens.\n output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.\n output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.\n iou_prediction_head (nn.Module): MLP for predicting mask quality.\n\n Methods:\n forward: Predict masks given image and prompt embeddings.\n predict_masks: Internal method for mask prediction.\n\n Examples:\n >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)\n >>> masks, iou_pred = decoder(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True\n ... )\n >>> print(f\"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}\")\n \"\"\"\n\n def __init__(\n self,\n transformer_dim: int,\n transformer: nn.Module,\n num_multimask_outputs: int = 3,\n activation: Type[nn.Module] = nn.GELU,\n iou_head_depth: int = 3,\n iou_head_hidden_dim: int = 256,\n ) -> None:\n \"\"\"\n Initialize the MaskDecoder module for generating masks and their associated quality scores.\n\n Args:\n transformer_dim (int): Channel dimension for the transformer module.\n transformer (nn.Module): Transformer module used for mask prediction.\n num_multimask_outputs (int): Number of masks to predict for disambiguating masks.\n activation (Type[nn.Module]): Type of activation to use when upscaling masks.\n iou_head_depth (int): Depth of the MLP used to predict mask quality.\n iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.\n\n Examples:\n >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)\n >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)\n >>> print(decoder)\n \"\"\"\n super().__init__()\n self.transformer_dim = transformer_dim\n self.transformer = transformer\n\n self.num_multimask_outputs = num_multimask_outputs\n\n self.iou_token = nn.Embedding(1, transformer_dim)\n self.num_mask_tokens = num_multimask_outputs + 1\n self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)\n\n self.output_upscaling = nn.Sequential(\n nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),\n LayerNorm2d(transformer_dim // 4),\n activation(),\n nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),\n activation(),\n )\n self.output_hypernetworks_mlps = nn.ModuleList(\n [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]\n )\n\n self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)\n\n def forward(\n self,\n image_embeddings: torch.Tensor,\n image_pe: torch.Tensor,\n sparse_prompt_embeddings: torch.Tensor,\n dense_prompt_embeddings: torch.Tensor,\n multimask_output: bool,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Predict masks given image and prompt embeddings.\n\n Args:\n image_embeddings (torch.Tensor): Embeddings from the image encoder.\n image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.\n sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.\n dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.\n multimask_output (bool): Whether to return multiple masks or a single mask.\n\n Returns:\n masks (torch.Tensor): Batched predicted masks.\n iou_pred (torch.Tensor): Batched predictions of mask quality.\n\n Examples:\n >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)\n >>> image_emb = torch.rand(1, 256, 64, 64)\n >>> image_pe = torch.rand(1, 256, 64, 64)\n >>> sparse_emb = torch.rand(1, 2, 256)\n >>> dense_emb = torch.rand(1, 256, 64, 64)\n >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)\n >>> print(f\"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}\")\n \"\"\"\n masks, iou_pred = self.predict_masks(\n image_embeddings=image_embeddings,\n image_pe=image_pe,\n sparse_prompt_embeddings=sparse_prompt_embeddings,\n dense_prompt_embeddings=dense_prompt_embeddings,\n )\n\n # Select the correct mask or masks for output\n mask_slice = slice(1, None) if multimask_output else slice(0, 1)\n masks = masks[:, mask_slice, :, :]\n iou_pred = iou_pred[:, mask_slice]\n\n return masks, iou_pred\n\n def predict_masks(\n self,\n image_embeddings: torch.Tensor,\n image_pe: torch.Tensor,\n sparse_prompt_embeddings: torch.Tensor,\n dense_prompt_embeddings: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"Predict masks and quality scores using image and prompt embeddings via transformer architecture.\"\"\"\n # Concatenate output tokens\n output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)\n output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)\n tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)\n\n # Expand per-image data in batch direction to be per-mask\n src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)\n src = src + dense_prompt_embeddings\n pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)\n b, c, h, w = src.shape\n\n # Run the transformer\n hs, src = self.transformer(src, pos_src, tokens)\n iou_token_out = hs[:, 0, :]\n mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]\n\n # Upscale mask embeddings and predict masks using the mask tokens\n src = src.transpose(1, 2).view(b, c, h, w)\n upscaled_embedding = self.output_upscaling(src)\n hyper_in_list: List[torch.Tensor] = [\n self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)\n ]\n hyper_in = torch.stack(hyper_in_list, dim=1)\n b, c, h, w = upscaled_embedding.shape\n masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)\n\n # Generate mask quality predictions\n iou_pred = self.iou_prediction_head(iou_token_out)\n\n return masks, iou_pred",
"chunk_type": "class",
"name": "MaskDecoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py",
"start_line": 11,
"end_line": 171,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": "Decoder module for generating masks and their associated quality scores using a transformer architecture.\n\nThis class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and\ngenerate mask predictions along with their quality scores.\n\nAttributes:\n transformer_dim (int): Channel dimension for the transformer module.\n transformer (nn.Module): Transformer module used for mask prediction.\n num_multimask_outputs (int): Number of masks to predict for disambiguating masks.\n iou_token (nn.Embedding): Embedding for the IoU token.\n num_mask_tokens (int): Number of mask tokens.\n mask_tokens (nn.Embedding): Embedding for the mask tokens.\n output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.\n output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.\n iou_prediction_head (nn.Module): MLP for predicting mask quality.\n\nMethods:\n forward: Predict masks given image and prompt embeddings.\n predict_masks: Internal method for mask prediction.\n\nExamples:\n >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)\n >>> masks, iou_pred = decoder(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True\n ... )\n >>> print(f\"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}\")",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"torch",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"nn.Module"
],
"chunk_id": "class_MaskDecoder_43de16dc"
},
{
"content": "class SAM2MaskDecoder(nn.Module):\n \"\"\"\n Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.\n\n This class extends the functionality of the MaskDecoder, incorporating additional features such as\n high-resolution feature processing, dynamic multimask output, and object score prediction.\n\n Attributes:\n transformer_dim (int): Channel dimension of the transformer.\n transformer (nn.Module): Transformer used to predict masks.\n num_multimask_outputs (int): Number of masks to predict when disambiguating masks.\n iou_token (nn.Embedding): Embedding for IOU token.\n num_mask_tokens (int): Total number of mask tokens.\n mask_tokens (nn.Embedding): Embedding for mask tokens.\n pred_obj_scores (bool): Whether to predict object scores.\n obj_score_token (nn.Embedding): Embedding for object score token.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.\n output_upscaling (nn.Sequential): Upscaling layers for output.\n use_high_res_features (bool): Whether to use high-resolution features.\n conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).\n conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).\n output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.\n iou_prediction_head (MLP): MLP for IOU prediction.\n pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.\n dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.\n dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.\n dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.\n\n Methods:\n forward: Predict masks given image and prompt embeddings.\n predict_masks: Predict instance segmentation masks from image and prompt embeddings.\n _get_stability_scores: Compute mask stability scores based on IoU between thresholds.\n _dynamic_multimask_via_stability: Dynamically select the most stable mask output.\n\n Examples:\n >>> image_embeddings = torch.rand(1, 256, 64, 64)\n >>> image_pe = torch.rand(1, 256, 64, 64)\n >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)\n >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)\n >>> decoder = SAM2MaskDecoder(256, transformer)\n >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False\n ... )\n \"\"\"\n\n def __init__(\n self,\n transformer_dim: int,\n transformer: nn.Module,\n num_multimask_outputs: int = 3,\n activation: Type[nn.Module] = nn.GELU,\n iou_head_depth: int = 3,\n iou_head_hidden_dim: int = 256,\n use_high_res_features: bool = False,\n iou_prediction_use_sigmoid=False,\n dynamic_multimask_via_stability=False,\n dynamic_multimask_stability_delta=0.05,\n dynamic_multimask_stability_thresh=0.98,\n pred_obj_scores: bool = False,\n pred_obj_scores_mlp: bool = False,\n use_multimask_token_for_obj_ptr: bool = False,\n ) -> None:\n \"\"\"\n Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.\n\n This decoder extends the functionality of MaskDecoder, incorporating additional features such as\n high-resolution feature processing, dynamic multimask output, and object score prediction.\n\n Args:\n transformer_dim (int): Channel dimension of the transformer.\n transformer (nn.Module): Transformer used to predict masks.\n num_multimask_outputs (int): Number of masks to predict when disambiguating masks.\n activation (Type[nn.Module]): Type of activation to use when upscaling masks.\n iou_head_depth (int): Depth of the MLP used to predict mask quality.\n iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.\n use_high_res_features (bool): Whether to use high-resolution features.\n iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.\n dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.\n dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.\n dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.\n pred_obj_scores (bool): Whether to predict object scores.\n pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.\n\n Examples:\n >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)\n >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)\n >>> print(decoder)\n \"\"\"\n super().__init__()\n self.transformer_dim = transformer_dim\n self.transformer = transformer\n\n self.num_multimask_outputs = num_multimask_outputs\n\n self.iou_token = nn.Embedding(1, transformer_dim)\n self.num_mask_tokens = num_multimask_outputs + 1\n self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)\n\n self.pred_obj_scores = pred_obj_scores\n if self.pred_obj_scores:\n self.obj_score_token = nn.Embedding(1, transformer_dim)\n self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr\n\n self.output_upscaling = nn.Sequential(\n nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),\n LayerNorm2d(transformer_dim // 4),\n activation(),\n nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),\n activation(),\n )\n self.use_high_res_features = use_high_res_features\n if use_high_res_features:\n self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)\n self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)\n\n self.output_hypernetworks_mlps = nn.ModuleList(\n [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]\n )\n\n self.iou_prediction_head = MLP(\n transformer_dim,\n iou_head_hidden_dim,\n self.num_mask_tokens,\n iou_head_depth,\n sigmoid=iou_prediction_use_sigmoid,\n )\n if self.pred_obj_scores:\n self.pred_obj_score_head = nn.Linear(transformer_dim, 1)\n if pred_obj_scores_mlp:\n self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)\n\n # When outputting a single mask, optionally we can dynamically fall back to the best\n # multimask output token if the single mask output token gives low stability scores.\n self.dynamic_multimask_via_stability = dynamic_multimask_via_stability\n self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta\n self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh\n\n def forward(\n self,\n image_embeddings: torch.Tensor,\n image_pe: torch.Tensor,\n sparse_prompt_embeddings: torch.Tensor,\n dense_prompt_embeddings: torch.Tensor,\n multimask_output: bool,\n repeat_image: bool,\n high_res_features: Optional[List[torch.Tensor]] = None,\n ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"\n Predict masks given image and prompt embeddings.\n\n Args:\n image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).\n image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).\n sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).\n dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).\n multimask_output (bool): Whether to return multiple masks or a single mask.\n repeat_image (bool): Flag to repeat the image embeddings.\n high_res_features (List[torch.Tensor] | None, optional): Optional high-resolution features.\n\n Returns:\n masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).\n iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).\n sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).\n object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).\n\n Examples:\n >>> image_embeddings = torch.rand(1, 256, 64, 64)\n >>> image_pe = torch.rand(1, 256, 64, 64)\n >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)\n >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)\n >>> decoder = SAM2MaskDecoder(256, transformer)\n >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False\n ... )\n \"\"\"\n masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(\n image_embeddings=image_embeddings,\n image_pe=image_pe,\n sparse_prompt_embeddings=sparse_prompt_embeddings,\n dense_prompt_embeddings=dense_prompt_embeddings,\n repeat_image=repeat_image,\n high_res_features=high_res_features,\n )\n\n # Select the correct mask or masks for output\n if multimask_output:\n masks = masks[:, 1:, :, :]\n iou_pred = iou_pred[:, 1:]\n elif self.dynamic_multimask_via_stability and not self.training:\n masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)\n else:\n masks = masks[:, 0:1, :, :]\n iou_pred = iou_pred[:, 0:1]\n\n if multimask_output and self.use_multimask_token_for_obj_ptr:\n sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape\n else:\n # Take the mask output token. Here we *always* use the token for single mask output.\n # At test time, even if we track after 1-click (and using multimask_output=True),\n # we still take the single mask token here. The rationale is that we always track\n # after multiple clicks during training, so the past tokens seen during training\n # are always the single mask token (and we'll let it be the object-memory token).\n sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape\n\n return masks, iou_pred, sam_tokens_out, object_score_logits\n\n def predict_masks(\n self,\n image_embeddings: torch.Tensor,\n image_pe: torch.Tensor,\n sparse_prompt_embeddings: torch.Tensor,\n dense_prompt_embeddings: torch.Tensor,\n repeat_image: bool,\n high_res_features: Optional[List[torch.Tensor]] = None,\n ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"Predict instance segmentation masks from image and prompt embeddings using a transformer.\"\"\"\n # Concatenate output tokens\n s = 0\n if self.pred_obj_scores:\n output_tokens = torch.cat(\n [\n self.obj_score_token.weight,\n self.iou_token.weight,\n self.mask_tokens.weight,\n ],\n dim=0,\n )\n s = 1\n else:\n output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)\n output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)\n tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)\n\n # Expand per-image data in batch direction to be per-mask\n if repeat_image:\n src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)\n else:\n assert image_embeddings.shape[0] == tokens.shape[0]\n src = image_embeddings\n src = src + dense_prompt_embeddings\n assert image_pe.size(0) == 1, \"image_pe should have size 1 in batch dim (from `get_dense_pe()`)\"\n pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)\n b, c, h, w = src.shape\n\n # Run the transformer\n hs, src = self.transformer(src, pos_src, tokens)\n iou_token_out = hs[:, s, :]\n mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]\n\n # Upscale mask embeddings and predict masks using the mask tokens\n src = src.transpose(1, 2).view(b, c, h, w)\n if not self.use_high_res_features:\n upscaled_embedding = self.output_upscaling(src)\n else:\n dc1, ln1, act1, dc2, act2 = self.output_upscaling\n feat_s0, feat_s1 = high_res_features\n upscaled_embedding = act1(ln1(dc1(src) + feat_s1))\n upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)\n\n hyper_in_list: List[torch.Tensor] = [\n self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)\n ]\n hyper_in = torch.stack(hyper_in_list, dim=1)\n b, c, h, w = upscaled_embedding.shape\n masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)\n\n # Generate mask quality predictions\n iou_pred = self.iou_prediction_head(iou_token_out)\n if self.pred_obj_scores:\n assert s == 1\n object_score_logits = self.pred_obj_score_head(hs[:, 0, :])\n else:\n # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1\n object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)\n\n return masks, iou_pred, mask_tokens_out, object_score_logits\n\n def _get_stability_scores(self, mask_logits):\n \"\"\"Compute mask stability scores based on IoU between upper and lower thresholds.\"\"\"\n mask_logits = mask_logits.flatten(-2)\n stability_delta = self.dynamic_multimask_stability_delta\n area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()\n area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()\n return torch.where(area_u > 0, area_i / area_u, 1.0)\n\n def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):\n \"\"\"\n Dynamically select the most stable mask output based on stability scores and IoU predictions.\n\n This method is used when outputting a single mask. If the stability score from the current single-mask\n output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs\n (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask\n for both clicking and tracking scenarios.\n\n Args:\n all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is\n batch size, N is number of masks (typically 4), and H, W are mask dimensions.\n all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).\n\n Returns:\n mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).\n iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).\n\n Examples:\n >>> decoder = SAM2MaskDecoder(...)\n >>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each\n >>> all_iou_scores = torch.rand(2, 4)\n >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)\n >>> print(mask_logits.shape, iou_scores.shape)\n torch.Size([2, 1, 256, 256]) torch.Size([2, 1])\n \"\"\"\n # The best mask from multimask output tokens (1~3)\n multimask_logits = all_mask_logits[:, 1:, :, :]\n multimask_iou_scores = all_iou_scores[:, 1:]\n best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)\n batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)\n best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]\n best_multimask_logits = best_multimask_logits.unsqueeze(1)\n best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]\n best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)\n\n # The mask from singlemask output token 0 and its stability score\n singlemask_logits = all_mask_logits[:, 0:1, :, :]\n singlemask_iou_scores = all_iou_scores[:, 0:1]\n stability_scores = self._get_stability_scores(singlemask_logits)\n is_stable = stability_scores >= self.dynamic_multimask_stability_thresh\n\n # Dynamically fall back to best multimask output upon low stability scores.\n mask_logits_out = torch.where(\n is_stable[..., None, None].expand_as(singlemask_logits),\n singlemask_logits,\n best_multimask_logits,\n )\n iou_scores_out = torch.where(\n is_stable.expand_as(singlemask_iou_scores),\n singlemask_iou_scores,\n best_multimask_iou_scores,\n )\n return mask_logits_out, iou_scores_out",
"chunk_type": "class",
"name": "SAM2MaskDecoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\decoders.py",
"start_line": 174,
"end_line": 513,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.\n\nThis class extends the functionality of the MaskDecoder, incorporating additional features such as\nhigh-resolution feature processing, dynamic multimask output, and object score prediction.\n\nAttributes:\n transformer_dim (int): Channel dimension of the transformer.\n transformer (nn.Module): Transformer used to predict masks.\n num_multimask_outputs (int): Number of masks to predict when disambiguating masks.\n iou_token (nn.Embedding): Embedding for IOU token.\n num_mask_tokens (int): Total number of mask tokens.\n mask_tokens (nn.Embedding): Embedding for mask tokens.\n pred_obj_scores (bool): Whether to predict object scores.\n obj_score_token (nn.Embedding): Embedding for object score token.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.\n output_upscaling (nn.Sequential): Upscaling layers for output.\n use_high_res_features (bool): Whether to use high-resolution features.\n conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).\n conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).\n output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.\n iou_prediction_head (MLP): MLP for IOU prediction.\n pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.\n dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.\n dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.\n dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.\n\nMethods:\n forward: Predict masks given image and prompt embeddings.\n predict_masks: Predict instance segmentation masks from image and prompt embeddings.\n _get_stability_scores: Compute mask stability scores based on IoU between thresholds.\n _dynamic_multimask_via_stability: Dynamically select the most stable mask output.\n\nExamples:\n >>> image_embeddings = torch.rand(1, 256, 64, 64)\n >>> image_pe = torch.rand(1, 256, 64, 64)\n >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)\n >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)\n >>> decoder = SAM2MaskDecoder(256, transformer)\n >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(\n ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False\n ... )",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"torch",
"torch.nn",
"ultralytics.nn.modules.MLP",
"ultralytics.nn.modules.LayerNorm2d",
"nn.Module"
],
"chunk_id": "class_SAM2MaskDecoder_5d57bda5"
},
{
"content": "from typing import List, Optional, Tuple, Type",
"chunk_type": "import",
"name": "List, Optional, Tuple, Type",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional, Tuple, Type_54fc0454"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_c5b44249"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_08fde6ac"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_da6be143"
},
{
"content": "from ultralytics.nn.modules import LayerNorm2d",
"chunk_type": "import",
"name": "LayerNorm2d",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LayerNorm2d_68f0168c"
},
{
"content": "from .blocks import (\n Block,\n CXBlock,\n Fuser,\n MaskDownSampler,\n MultiScaleBlock,\n PatchEmbed,\n PositionEmbeddingRandom,\n PositionEmbeddingSine,\n)",
"chunk_type": "import",
"name": "Block, CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PatchEmbed, PositionEmbeddingRandom, PositionEmbeddingSine",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 11,
"end_line": 20,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Block, CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PatchEmbed, PositionEmbeddingRandom, PositionEmbeddingSine_7e4f3b8e"
},
{
"content": "class ImageEncoderViT(nn.Module):\n \"\"\"\n An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.\n\n This class processes images by splitting them into patches, applying transformer blocks, and generating a final\n encoded representation through a neck module.\n\n Attributes:\n img_size (int): Dimension of input images, assumed to be square.\n patch_embed (PatchEmbed): Module for patch embedding.\n pos_embed (nn.Parameter | None): Absolute positional embedding for patches.\n blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.\n neck (nn.Sequential): Neck module to further process the output.\n\n Methods:\n forward: Process input through patch embedding, positional embedding, blocks, and neck.\n\n Examples:\n >>> import torch\n >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)\n >>> input_image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(input_image)\n >>> print(output.shape)\n \"\"\"\n\n def __init__(\n self,\n img_size: int = 1024,\n patch_size: int = 16,\n in_chans: int = 3,\n embed_dim: int = 768,\n depth: int = 12,\n num_heads: int = 12,\n mlp_ratio: float = 4.0,\n out_chans: int = 256,\n qkv_bias: bool = True,\n norm_layer: Type[nn.Module] = nn.LayerNorm,\n act_layer: Type[nn.Module] = nn.GELU,\n use_abs_pos: bool = True,\n use_rel_pos: bool = False,\n rel_pos_zero_init: bool = True,\n window_size: int = 0,\n global_attn_indexes: Tuple[int, ...] = (),\n ) -> None:\n \"\"\"\n Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.\n\n Args:\n img_size (int): Input image size, assumed to be square.\n patch_size (int): Size of image patches.\n in_chans (int): Number of input image channels.\n embed_dim (int): Dimension of patch embeddings.\n depth (int): Number of transformer blocks.\n num_heads (int): Number of attention heads in each block.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n out_chans (int): Number of output channels from the neck module.\n qkv_bias (bool): If True, adds learnable bias to query, key, value projections.\n norm_layer (Type[nn.Module]): Type of normalization layer to use.\n act_layer (Type[nn.Module]): Type of activation layer to use.\n use_abs_pos (bool): If True, uses absolute positional embeddings.\n use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.\n rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.\n window_size (int): Size of attention window for windowed attention blocks.\n global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.\n\n Examples:\n >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)\n >>> input_image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(input_image)\n >>> print(output.shape)\n \"\"\"\n super().__init__()\n self.img_size = img_size\n\n self.patch_embed = PatchEmbed(\n kernel_size=(patch_size, patch_size),\n stride=(patch_size, patch_size),\n in_chans=in_chans,\n embed_dim=embed_dim,\n )\n\n self.pos_embed: Optional[nn.Parameter] = None\n if use_abs_pos:\n # Initialize absolute positional embedding with pretrain image size\n self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))\n\n self.blocks = nn.ModuleList()\n for i in range(depth):\n block = Block(\n dim=embed_dim,\n num_heads=num_heads,\n mlp_ratio=mlp_ratio,\n qkv_bias=qkv_bias,\n norm_layer=norm_layer,\n act_layer=act_layer,\n use_rel_pos=use_rel_pos,\n rel_pos_zero_init=rel_pos_zero_init,\n window_size=window_size if i not in global_attn_indexes else 0,\n input_size=(img_size // patch_size, img_size // patch_size),\n )\n self.blocks.append(block)\n\n self.neck = nn.Sequential(\n nn.Conv2d(\n embed_dim,\n out_chans,\n kernel_size=1,\n bias=False,\n ),\n LayerNorm2d(out_chans),\n nn.Conv2d(\n out_chans,\n out_chans,\n kernel_size=3,\n padding=1,\n bias=False,\n ),\n LayerNorm2d(out_chans),\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through patch embedding, positional embedding, transformer blocks, and neck module.\"\"\"\n x = self.patch_embed(x)\n if self.pos_embed is not None:\n pos_embed = (\n F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)\n if self.img_size != 1024\n else self.pos_embed\n )\n x = x + pos_embed\n for blk in self.blocks:\n x = blk(x)\n return self.neck(x.permute(0, 3, 1, 2))",
"chunk_type": "class",
"name": "ImageEncoderViT",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 23,
"end_line": 155,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.\n\nThis class processes images by splitting them into patches, applying transformer blocks, and generating a final\nencoded representation through a neck module.\n\nAttributes:\n img_size (int): Dimension of input images, assumed to be square.\n patch_embed (PatchEmbed): Module for patch embedding.\n pos_embed (nn.Parameter | None): Absolute positional embedding for patches.\n blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.\n neck (nn.Sequential): Neck module to further process the output.\n\nMethods:\n forward: Process input through patch embedding, positional embedding, blocks, and neck.\n\nExamples:\n >>> import torch\n >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)\n >>> input_image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(input_image)\n >>> print(output.shape)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"blocks.Block",
"blocks.CXBlock",
"blocks.Fuser",
"blocks.MaskDownSampler",
"blocks.MultiScaleBlock",
"blocks.PatchEmbed",
"blocks.PositionEmbeddingRandom",
"blocks.PositionEmbeddingSine",
"nn.Module"
],
"chunk_id": "class_ImageEncoderViT_137a3367"
},
{
"content": "class PromptEncoder(nn.Module):\n \"\"\"\n Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.\n\n Attributes:\n embed_dim (int): Dimension of the embeddings.\n input_image_size (Tuple[int, int]): Size of the input image as (H, W).\n image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).\n pe_layer (PositionEmbeddingRandom): Module for random position embedding.\n num_point_embeddings (int): Number of point embeddings for different types of points.\n point_embeddings (nn.ModuleList): List of point embeddings.\n not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.\n mask_input_size (Tuple[int, int]): Size of the input mask.\n mask_downscaling (nn.Sequential): Neural network for downscaling the mask.\n no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.\n\n Methods:\n get_dense_pe: Return the positional encoding used to encode point prompts.\n forward: Embed different types of prompts, returning both sparse and dense embeddings.\n\n Examples:\n >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))\n >>> boxes = torch.rand(1, 2, 2)\n >>> masks = torch.rand(1, 1, 256, 256)\n >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)\n >>> print(sparse_embeddings.shape, dense_embeddings.shape)\n torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])\n \"\"\"\n\n def __init__(\n self,\n embed_dim: int,\n image_embedding_size: Tuple[int, int],\n input_image_size: Tuple[int, int],\n mask_in_chans: int,\n activation: Type[nn.Module] = nn.GELU,\n ) -> None:\n \"\"\"\n Initialize the PromptEncoder module for encoding various types of prompts.\n\n Args:\n embed_dim (int): The dimension of the embeddings.\n image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).\n input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).\n mask_in_chans (int): The number of hidden channels used for encoding input masks.\n activation (Type[nn.Module]): The activation function to use when encoding input masks.\n\n Examples:\n >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))\n >>> boxes = torch.rand(1, 2, 2)\n >>> masks = torch.rand(1, 1, 256, 256)\n >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)\n >>> print(sparse_embeddings.shape, dense_embeddings.shape)\n torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])\n \"\"\"\n super().__init__()\n self.embed_dim = embed_dim\n self.input_image_size = input_image_size\n self.image_embedding_size = image_embedding_size\n self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)\n\n self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners\n point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]\n self.point_embeddings = nn.ModuleList(point_embeddings)\n self.not_a_point_embed = nn.Embedding(1, embed_dim)\n\n self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])\n self.mask_downscaling = nn.Sequential(\n nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),\n LayerNorm2d(mask_in_chans // 4),\n activation(),\n nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),\n LayerNorm2d(mask_in_chans),\n activation(),\n nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),\n )\n self.no_mask_embed = nn.Embedding(1, embed_dim)\n\n def get_dense_pe(self) -> torch.Tensor:\n \"\"\"\n Return the dense positional encoding used for encoding point prompts.\n\n Generate a positional encoding for a dense set of points matching the shape of the image\n encoding. The encoding is used to provide spatial information to the model when processing point prompts.\n\n Returns:\n (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the\n height and width of the image embedding size, respectively.\n\n Examples:\n >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> dense_pe = prompt_encoder.get_dense_pe()\n >>> print(dense_pe.shape)\n torch.Size([1, 256, 64, 64])\n \"\"\"\n return self.pe_layer(self.image_embedding_size).unsqueeze(0)\n\n def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:\n \"\"\"Embed point prompts by applying positional encoding and label-specific embeddings.\"\"\"\n points = points + 0.5 # Shift to center of pixel\n if pad:\n padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)\n padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)\n points = torch.cat([points, padding_point], dim=1)\n labels = torch.cat([labels, padding_label], dim=1)\n point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)\n point_embedding[labels == -1] = 0.0\n point_embedding[labels == -1] += self.not_a_point_embed.weight\n point_embedding[labels == 0] += self.point_embeddings[0].weight\n point_embedding[labels == 1] += self.point_embeddings[1].weight\n point_embedding[labels == 2] += self.point_embeddings[2].weight\n point_embedding[labels == 3] += self.point_embeddings[3].weight\n return point_embedding\n\n def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:\n \"\"\"Embed box prompts by applying positional encoding and adding corner embeddings.\"\"\"\n boxes = boxes + 0.5 # Shift to center of pixel\n coords = boxes.reshape(-1, 2, 2)\n corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)\n corner_embedding[:, 0, :] += self.point_embeddings[2].weight\n corner_embedding[:, 1, :] += self.point_embeddings[3].weight\n return corner_embedding\n\n def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:\n \"\"\"Embed mask inputs by downscaling and processing through convolutional layers.\"\"\"\n return self.mask_downscaling(masks)\n\n @staticmethod\n def _get_batch_size(\n points: Optional[Tuple[torch.Tensor, torch.Tensor]],\n boxes: Optional[torch.Tensor],\n masks: Optional[torch.Tensor],\n ) -> int:\n \"\"\"Get the batch size of the output given the batch size of the input prompts.\"\"\"\n if points is not None:\n return points[0].shape[0]\n elif boxes is not None:\n return boxes.shape[0]\n elif masks is not None:\n return masks.shape[0]\n else:\n return 1\n\n def _get_device(self) -> torch.device:\n \"\"\"Return the device of the first point embedding's weight tensor.\"\"\"\n return self.point_embeddings[0].weight.device\n\n def forward(\n self,\n points: Optional[Tuple[torch.Tensor, torch.Tensor]],\n boxes: Optional[torch.Tensor],\n masks: Optional[torch.Tensor],\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Embed different types of prompts, returning both sparse and dense embeddings.\n\n Args:\n points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first\n tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with\n shape (B, N).\n boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.\n masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).\n\n Returns:\n sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).\n dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).\n\n Examples:\n >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))\n >>> boxes = torch.rand(1, 2, 2, 2)\n >>> masks = torch.rand(1, 1, 256, 256)\n >>> sparse_emb, dense_emb = encoder(points, boxes, masks)\n >>> print(sparse_emb.shape, dense_emb.shape)\n torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])\n \"\"\"\n bs = self._get_batch_size(points, boxes, masks)\n sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())\n if points is not None:\n coords, labels = points\n point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))\n sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)\n if boxes is not None:\n box_embeddings = self._embed_boxes(boxes)\n sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)\n\n if masks is not None:\n dense_embeddings = self._embed_masks(masks)\n else:\n dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(\n bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]\n )\n\n return sparse_embeddings, dense_embeddings",
"chunk_type": "class",
"name": "PromptEncoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 158,
"end_line": 353,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": "Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.\n\nAttributes:\n embed_dim (int): Dimension of the embeddings.\n input_image_size (Tuple[int, int]): Size of the input image as (H, W).\n image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).\n pe_layer (PositionEmbeddingRandom): Module for random position embedding.\n num_point_embeddings (int): Number of point embeddings for different types of points.\n point_embeddings (nn.ModuleList): List of point embeddings.\n not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.\n mask_input_size (Tuple[int, int]): Size of the input mask.\n mask_downscaling (nn.Sequential): Neural network for downscaling the mask.\n no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.\n\nMethods:\n get_dense_pe: Return the positional encoding used to encode point prompts.\n forward: Embed different types of prompts, returning both sparse and dense embeddings.\n\nExamples:\n >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)\n >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))\n >>> boxes = torch.rand(1, 2, 2)\n >>> masks = torch.rand(1, 1, 256, 256)\n >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)\n >>> print(sparse_embeddings.shape, dense_embeddings.shape)\n torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"blocks.Block",
"blocks.CXBlock",
"blocks.Fuser",
"blocks.MaskDownSampler",
"blocks.MultiScaleBlock",
"blocks.PatchEmbed",
"blocks.PositionEmbeddingRandom",
"blocks.PositionEmbeddingSine",
"nn.Module"
],
"chunk_id": "class_PromptEncoder_cef65f27"
},
{
"content": "class MemoryEncoder(nn.Module):\n \"\"\"\n Encode pixel features and masks into a memory representation for efficient image segmentation.\n\n This class processes pixel-level features and masks, fusing them to generate encoded memory representations\n suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).\n\n Attributes:\n mask_downsampler (MaskDownSampler): Module for downsampling input masks.\n pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.\n fuser (Fuser): Module for fusing pixel features and masks.\n position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.\n out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.\n\n Methods:\n forward: Process input pixel features and masks to generate encoded memory representations.\n\n Examples:\n >>> import torch\n >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)\n >>> pix_feat = torch.randn(1, 256, 64, 64)\n >>> masks = torch.randn(1, 1, 64, 64)\n >>> encoded_feat, pos = encoder(pix_feat, masks)\n >>> print(encoded_feat.shape, pos.shape)\n torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])\n \"\"\"\n\n def __init__(\n self,\n out_dim,\n in_dim=256, # in_dim of pix_feats\n ):\n \"\"\"\n Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.\n\n This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations\n suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).\n\n Args:\n out_dim (int): Output dimension of the encoded features.\n in_dim (int): Input dimension of the pixel features.\n\n Examples:\n >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)\n >>> pix_feat = torch.randn(1, 256, 64, 64)\n >>> masks = torch.randn(1, 1, 64, 64)\n >>> encoded_feat, pos = encoder(pix_feat, masks)\n >>> print(encoded_feat.shape, pos.shape)\n torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])\n \"\"\"\n super().__init__()\n\n self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)\n\n self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)\n self.fuser = Fuser(CXBlock(dim=256), num_layers=2)\n self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)\n self.out_proj = nn.Identity()\n if out_dim != in_dim:\n self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)\n\n def forward(\n self,\n pix_feat: torch.Tensor,\n masks: torch.Tensor,\n skip_mask_sigmoid: bool = False,\n ) -> dict:\n \"\"\"Process pixel features and masks to generate encoded memory representations for segmentation.\"\"\"\n if not skip_mask_sigmoid:\n masks = F.sigmoid(masks)\n masks = self.mask_downsampler(masks)\n\n # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA\n pix_feat = pix_feat.to(masks.device)\n\n x = self.pix_feat_proj(pix_feat)\n x = x + masks\n x = self.fuser(x)\n x = self.out_proj(x)\n\n pos = self.position_encoding(x).to(x.dtype)\n\n return {\"vision_features\": x, \"vision_pos_enc\": [pos]}",
"chunk_type": "class",
"name": "MemoryEncoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 356,
"end_line": 438,
"start_col": 0,
"end_col": 62,
"parent_name": null,
"docstring": "Encode pixel features and masks into a memory representation for efficient image segmentation.\n\nThis class processes pixel-level features and masks, fusing them to generate encoded memory representations\nsuitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).\n\nAttributes:\n mask_downsampler (MaskDownSampler): Module for downsampling input masks.\n pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.\n fuser (Fuser): Module for fusing pixel features and masks.\n position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.\n out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.\n\nMethods:\n forward: Process input pixel features and masks to generate encoded memory representations.\n\nExamples:\n >>> import torch\n >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)\n >>> pix_feat = torch.randn(1, 256, 64, 64)\n >>> masks = torch.randn(1, 1, 64, 64)\n >>> encoded_feat, pos = encoder(pix_feat, masks)\n >>> print(encoded_feat.shape, pos.shape)\n torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"blocks.Block",
"blocks.CXBlock",
"blocks.Fuser",
"blocks.MaskDownSampler",
"blocks.MultiScaleBlock",
"blocks.PatchEmbed",
"blocks.PositionEmbeddingRandom",
"blocks.PositionEmbeddingSine",
"nn.Module"
],
"chunk_id": "class_MemoryEncoder_789e12d3"
},
{
"content": "class ImageEncoder(nn.Module):\n \"\"\"\n Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.\n\n This class combines a trunk network for feature extraction with a neck network for feature refinement\n and positional encoding generation. It can optionally discard the lowest resolution features.\n\n Attributes:\n trunk (nn.Module): The trunk network for initial feature extraction.\n neck (nn.Module): The neck network for feature refinement and positional encoding generation.\n scalp (int): Number of lowest resolution feature levels to discard.\n\n Methods:\n forward: Process the input image through the trunk and neck networks.\n\n Examples:\n >>> trunk = SomeTrunkNetwork()\n >>> neck = SomeNeckNetwork()\n >>> encoder = ImageEncoder(trunk, neck, scalp=1)\n >>> image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(image)\n >>> print(output.keys())\n dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])\n \"\"\"\n\n def __init__(\n self,\n trunk: nn.Module,\n neck: nn.Module,\n scalp: int = 0,\n ):\n \"\"\"\n Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.\n\n This encoder combines a trunk network for feature extraction with a neck network for feature refinement\n and positional encoding generation. It can optionally discard the lowest resolution features.\n\n Args:\n trunk (nn.Module): The trunk network for initial feature extraction.\n neck (nn.Module): The neck network for feature refinement and positional encoding generation.\n scalp (int): Number of lowest resolution feature levels to discard.\n\n Examples:\n >>> trunk = SomeTrunkNetwork()\n >>> neck = SomeNeckNetwork()\n >>> encoder = ImageEncoder(trunk, neck, scalp=1)\n >>> image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(image)\n >>> print(output.keys())\n dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])\n \"\"\"\n super().__init__()\n self.trunk = trunk\n self.neck = neck\n self.scalp = scalp\n assert self.trunk.channel_list == self.neck.backbone_channel_list, (\n f\"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match.\"\n )\n\n def forward(self, sample: torch.Tensor):\n \"\"\"Encode input through trunk and neck networks, returning multiscale features and positional encodings.\"\"\"\n features, pos = self.neck(self.trunk(sample))\n if self.scalp > 0:\n # Discard the lowest resolution features\n features, pos = features[: -self.scalp], pos[: -self.scalp]\n\n src = features[-1]\n return {\n \"vision_features\": src,\n \"vision_pos_enc\": pos,\n \"backbone_fpn\": features,\n }",
"chunk_type": "class",
"name": "ImageEncoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 441,
"end_line": 512,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.\n\nThis class combines a trunk network for feature extraction with a neck network for feature refinement\nand positional encoding generation. It can optionally discard the lowest resolution features.\n\nAttributes:\n trunk (nn.Module): The trunk network for initial feature extraction.\n neck (nn.Module): The neck network for feature refinement and positional encoding generation.\n scalp (int): Number of lowest resolution feature levels to discard.\n\nMethods:\n forward: Process the input image through the trunk and neck networks.\n\nExamples:\n >>> trunk = SomeTrunkNetwork()\n >>> neck = SomeNeckNetwork()\n >>> encoder = ImageEncoder(trunk, neck, scalp=1)\n >>> image = torch.randn(1, 3, 224, 224)\n >>> output = encoder(image)\n >>> print(output.keys())\n dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"blocks.Block",
"blocks.CXBlock",
"blocks.Fuser",
"blocks.MaskDownSampler",
"blocks.MultiScaleBlock",
"blocks.PatchEmbed",
"blocks.PositionEmbeddingRandom",
"blocks.PositionEmbeddingSine",
"nn.Module"
],
"chunk_id": "class_ImageEncoder_25645eaa"
},
{
"content": "class FpnNeck(nn.Module):\n \"\"\"\n A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.\n\n This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,\n similar to ViT positional embedding interpolation.\n\n Attributes:\n position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.\n convs (nn.ModuleList): List of convolutional layers for each backbone level.\n backbone_channel_list (List[int]): List of channel dimensions from the backbone.\n fpn_interp_model (str): Interpolation mode for FPN feature resizing.\n fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.\n fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.\n\n Methods:\n forward: Perform forward pass through the FPN neck.\n\n Examples:\n >>> backbone_channels = [64, 128, 256, 512]\n >>> fpn_neck = FpnNeck(256, backbone_channels)\n >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]\n >>> outputs, positions = fpn_neck(inputs)\n >>> print(len(outputs), len(positions))\n 4 4\n \"\"\"\n\n def __init__(\n self,\n d_model: int,\n backbone_channel_list: List[int],\n kernel_size: int = 1,\n stride: int = 1,\n padding: int = 0,\n fpn_interp_model: str = \"bilinear\",\n fuse_type: str = \"sum\",\n fpn_top_down_levels: Optional[List[int]] = None,\n ):\n \"\"\"\n Initialize a modified Feature Pyramid Network (FPN) neck.\n\n This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,\n similar to ViT positional embedding interpolation.\n\n Args:\n d_model (int): Dimension of the model.\n backbone_channel_list (List[int]): List of channel dimensions from the backbone.\n kernel_size (int): Kernel size for the convolutional layers.\n stride (int): Stride for the convolutional layers.\n padding (int): Padding for the convolutional layers.\n fpn_interp_model (str): Interpolation mode for FPN feature resizing.\n fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.\n fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.\n\n Examples:\n >>> backbone_channels = [64, 128, 256, 512]\n >>> fpn_neck = FpnNeck(256, backbone_channels)\n >>> print(fpn_neck)\n \"\"\"\n super().__init__()\n self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)\n self.convs = nn.ModuleList()\n self.backbone_channel_list = backbone_channel_list\n for dim in backbone_channel_list:\n current = nn.Sequential()\n current.add_module(\n \"conv\",\n nn.Conv2d(\n in_channels=dim,\n out_channels=d_model,\n kernel_size=kernel_size,\n stride=stride,\n padding=padding,\n ),\n )\n\n self.convs.append(current)\n self.fpn_interp_model = fpn_interp_model\n assert fuse_type in {\"sum\", \"avg\"}\n self.fuse_type = fuse_type\n\n # Levels to have top-down features in its outputs\n # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3\n # have top-down propagation, while outputs of level 0 and level 1 have only\n # lateral features from the same backbone level\n if fpn_top_down_levels is None:\n # Default is to have top-down features on all levels\n fpn_top_down_levels = range(len(self.convs))\n self.fpn_top_down_levels = list(fpn_top_down_levels)\n\n def forward(self, xs: List[torch.Tensor]):\n \"\"\"\n Perform forward pass through the Feature Pyramid Network (FPN) neck.\n\n This method processes a list of input tensors from the backbone through the FPN, applying lateral connections\n and top-down feature fusion. It generates output feature maps and corresponding positional encodings.\n\n Args:\n xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).\n\n Returns:\n out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape\n (B, d_model, H, W).\n pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.\n\n Examples:\n >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])\n >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]\n >>> outputs, positions = fpn_neck(inputs)\n >>> print(len(outputs), len(positions))\n 4 4\n \"\"\"\n out = [None] * len(self.convs)\n pos = [None] * len(self.convs)\n assert len(xs) == len(self.convs)\n # FPN forward pass\n # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py\n prev_features = None\n # Forward in top-down order (from low to high resolution)\n n = len(self.convs) - 1\n for i in range(n, -1, -1):\n x = xs[i]\n lateral_features = self.convs[n - i](x)\n if i in self.fpn_top_down_levels and prev_features is not None:\n top_down_features = F.interpolate(\n prev_features.to(dtype=torch.float32),\n scale_factor=2.0,\n mode=self.fpn_interp_model,\n align_corners=(None if self.fpn_interp_model == \"nearest\" else False),\n antialias=False,\n )\n prev_features = lateral_features + top_down_features\n if self.fuse_type == \"avg\":\n prev_features /= 2\n else:\n prev_features = lateral_features\n x_out = prev_features\n out[i] = x_out\n pos[i] = self.position_encoding(x_out).to(x_out.dtype)\n\n return out, pos",
"chunk_type": "class",
"name": "FpnNeck",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 515,
"end_line": 655,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.\n\nThis FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,\nsimilar to ViT positional embedding interpolation.\n\nAttributes:\n position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.\n convs (nn.ModuleList): List of convolutional layers for each backbone level.\n backbone_channel_list (List[int]): List of channel dimensions from the backbone.\n fpn_interp_model (str): Interpolation mode for FPN feature resizing.\n fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.\n fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.\n\nMethods:\n forward: Perform forward pass through the FPN neck.\n\nExamples:\n >>> backbone_channels = [64, 128, 256, 512]\n >>> fpn_neck = FpnNeck(256, backbone_channels)\n >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]\n >>> outputs, positions = fpn_neck(inputs)\n >>> print(len(outputs), len(positions))\n 4 4",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"blocks.Block",
"blocks.CXBlock",
"blocks.Fuser",
"blocks.MaskDownSampler",
"blocks.MultiScaleBlock",
"blocks.PatchEmbed",
"blocks.PositionEmbeddingRandom",
"blocks.PositionEmbeddingSine",
"nn.Module"
],
"chunk_id": "class_FpnNeck_3bf0d7be"
},
{
"content": "class Hiera(nn.Module):\n \"\"\"\n Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.\n\n This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for\n efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,\n with optional pooling and global attention mechanisms.\n\n Attributes:\n window_spec (Tuple[int, ...]): Window sizes for each stage.\n q_stride (Tuple[int, int]): Downsampling stride between stages.\n stage_ends (List[int]): Indices of the last block in each stage.\n q_pool_blocks (List[int]): Indices of blocks where pooling is applied.\n return_interm_layers (bool): Whether to return intermediate layer outputs.\n patch_embed (PatchEmbed): Module for patch embedding.\n global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.\n window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.\n pos_embed (nn.Parameter): Positional embedding for the background.\n pos_embed_window (nn.Parameter): Positional embedding for the window.\n blocks (nn.ModuleList): List of MultiScaleBlock modules.\n channel_list (List[int]): List of output channel dimensions for each stage.\n\n Methods:\n _get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.\n forward: Perform the forward pass through the Hiera model.\n\n Examples:\n >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output_features = model(input_tensor)\n >>> for feat in output_features:\n ... print(feat.shape)\n \"\"\"\n\n def __init__(\n self,\n embed_dim: int = 96, # initial embed dim\n num_heads: int = 1, # initial number of heads\n drop_path_rate: float = 0.0, # stochastic depth\n q_pool: int = 3, # number of q_pool stages\n q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages\n stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage\n dim_mul: float = 2.0, # dim_mul factor at stage shift\n head_mul: float = 2.0, # head_mul factor at stage shift\n window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),\n # window size per stage, when not using global att.\n window_spec: Tuple[int, ...] = (\n 8,\n 4,\n 14,\n 7,\n ),\n # global attn in these blocks\n global_att_blocks: Tuple[int, ...] = (\n 12,\n 16,\n 20,\n ),\n return_interm_layers=True, # return feats from every stage\n ):\n \"\"\"\n Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.\n\n Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction\n in image processing tasks. It uses a series of transformer blocks organized into stages, with optional\n pooling and global attention mechanisms.\n\n Args:\n embed_dim (int): Initial embedding dimension for the model.\n num_heads (int): Initial number of attention heads.\n drop_path_rate (float): Stochastic depth rate.\n q_pool (int): Number of query pooling stages.\n q_stride (Tuple[int, int]): Downsampling stride between stages.\n stages (Tuple[int, ...]): Number of blocks per stage.\n dim_mul (float): Dimension multiplier factor at stage transitions.\n head_mul (float): Head multiplier factor at stage transitions.\n window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.\n window_spec (Tuple[int, ...]): Window sizes for each stage when not using global attention.\n global_att_blocks (Tuple[int, ...]): Indices of blocks that use global attention.\n return_interm_layers (bool): Whether to return intermediate layer outputs.\n\n Examples:\n >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output_features = model(input_tensor)\n >>> for feat in output_features:\n ... print(feat.shape)\n \"\"\"\n super().__init__()\n\n assert len(stages) == len(window_spec)\n self.window_spec = window_spec\n\n depth = sum(stages)\n self.q_stride = q_stride\n self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]\n assert 0 <= q_pool <= len(self.stage_ends[:-1])\n self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]\n self.return_interm_layers = return_interm_layers\n\n self.patch_embed = PatchEmbed(\n embed_dim=embed_dim,\n kernel_size=(7, 7),\n stride=(4, 4),\n padding=(3, 3),\n )\n # Which blocks have global attention?\n self.global_att_blocks = global_att_blocks\n\n # Windowed positional embedding (https://arxiv.org/abs/2311.05613)\n self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size\n self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))\n self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))\n\n dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule\n\n cur_stage = 1\n self.blocks = nn.ModuleList()\n\n for i in range(depth):\n dim_out = embed_dim\n # Lags by a block, so first block of next stage uses an initial window size\n # of previous stage and final window size of current stage\n window_size = self.window_spec[cur_stage - 1]\n\n if self.global_att_blocks is not None:\n window_size = 0 if i in self.global_att_blocks else window_size\n\n if i - 1 in self.stage_ends:\n dim_out = int(embed_dim * dim_mul)\n num_heads = int(num_heads * head_mul)\n cur_stage += 1\n\n block = MultiScaleBlock(\n dim=embed_dim,\n dim_out=dim_out,\n num_heads=num_heads,\n drop_path=dpr[i],\n q_stride=self.q_stride if i in self.q_pool_blocks else None,\n window_size=window_size,\n )\n\n embed_dim = dim_out\n self.blocks.append(block)\n\n self.channel_list = (\n [self.blocks[i].dim_out for i in self.stage_ends[::-1]]\n if return_interm_layers\n else [self.blocks[-1].dim_out]\n )\n\n def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:\n \"\"\"Generate positional embeddings by interpolating and combining window and background embeddings.\"\"\"\n h, w = hw\n window_embed = self.pos_embed_window\n pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode=\"bicubic\")\n pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])\n pos_embed = pos_embed.permute(0, 2, 3, 1)\n return pos_embed\n\n def forward(self, x: torch.Tensor) -> List[torch.Tensor]:\n \"\"\"\n Perform forward pass through Hiera model, extracting multiscale features from input images.\n\n Args:\n x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.\n\n Returns:\n (List[torch.Tensor]): List of feature maps at different scales, each with shape (B, C_i, H_i, W_i), where\n C_i is the channel dimension and H_i, W_i are the spatial dimensions at scale i. The list is ordered\n from highest resolution (fine features) to lowest resolution (coarse features) if return_interm_layers\n is True, otherwise contains only the final output.\n\n Examples:\n >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output_features = model(input_tensor)\n >>> for feat in output_features:\n ... print(feat.shape)\n \"\"\"\n x = self.patch_embed(x)\n # x: (B, H, W, C)\n\n # Add positional embedding\n x = x + self._get_pos_embed(x.shape[1:3])\n\n outputs = []\n for i, blk in enumerate(self.blocks):\n x = blk(x)\n if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):\n feats = x.permute(0, 3, 1, 2)\n outputs.append(feats)\n\n return outputs",
"chunk_type": "class",
"name": "Hiera",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\encoders.py",
"start_line": 658,
"end_line": 851,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.\n\nThis class implements a Hiera model, which is a hierarchical vision transformer architecture designed for\nefficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,\nwith optional pooling and global attention mechanisms.\n\nAttributes:\n window_spec (Tuple[int, ...]): Window sizes for each stage.\n q_stride (Tuple[int, int]): Downsampling stride between stages.\n stage_ends (List[int]): Indices of the last block in each stage.\n q_pool_blocks (List[int]): Indices of blocks where pooling is applied.\n return_interm_layers (bool): Whether to return intermediate layer outputs.\n patch_embed (PatchEmbed): Module for patch embedding.\n global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.\n window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.\n pos_embed (nn.Parameter): Positional embedding for the background.\n pos_embed_window (nn.Parameter): Positional embedding for the window.\n blocks (nn.ModuleList): List of MultiScaleBlock modules.\n channel_list (List[int]): List of output channel dimensions for each stage.\n\nMethods:\n _get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.\n forward: Perform the forward pass through the Hiera model.\n\nExamples:\n >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output_features = model(input_tensor)\n >>> for feat in output_features:\n ... print(feat.shape)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Type",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"blocks.Block",
"blocks.CXBlock",
"blocks.Fuser",
"blocks.MaskDownSampler",
"blocks.MultiScaleBlock",
"blocks.PatchEmbed",
"blocks.PositionEmbeddingRandom",
"blocks.PositionEmbeddingSine",
"nn.Module"
],
"chunk_id": "class_Hiera_8bb910ce"
},
{
"content": "import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_e7340124"
},
{
"content": "from typing import Optional",
"chunk_type": "import",
"name": "Optional",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Optional_9b86dd1e"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_9d64bdea"
},
{
"content": "from torch import nn",
"chunk_type": "import",
"name": "nn",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_nn_890e43f1"
},
{
"content": "from .blocks import RoPEAttention",
"chunk_type": "import",
"name": "RoPEAttention",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RoPEAttention_4b80ab1c"
},
{
"content": "class MemoryAttentionLayer(nn.Module):\n \"\"\"\n Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.\n\n This class combines self-attention, cross-attention, and feedforward components to process input tensors and\n generate memory-based attention outputs.\n\n Attributes:\n d_model (int): Dimensionality of the model.\n dim_feedforward (int): Dimensionality of the feedforward network.\n dropout_value (float): Dropout rate for regularization.\n self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).\n cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.\n linear1 (nn.Linear): First linear layer of the feedforward network.\n linear2 (nn.Linear): Second linear layer of the feedforward network.\n norm1 (nn.LayerNorm): Layer normalization for self-attention output.\n norm2 (nn.LayerNorm): Layer normalization for cross-attention output.\n norm3 (nn.LayerNorm): Layer normalization for feedforward network output.\n dropout1 (nn.Dropout): Dropout layer after self-attention.\n dropout2 (nn.Dropout): Dropout layer after cross-attention.\n dropout3 (nn.Dropout): Dropout layer after feedforward network.\n activation (nn.ReLU): Activation function for the feedforward network.\n pos_enc_at_attn (bool): Flag to add positional encoding at attention.\n pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.\n pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.\n\n Methods:\n forward: Performs the full memory attention operation on input tensors.\n _forward_sa: Performs self-attention on input tensor.\n _forward_ca: Performs cross-attention between target and memory tensors.\n\n Examples:\n >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)\n >>> tgt = torch.randn(1, 100, 256)\n >>> memory = torch.randn(1, 100, 64)\n >>> pos = torch.randn(1, 100, 256)\n >>> query_pos = torch.randn(1, 100, 256)\n >>> output = layer(tgt, memory, pos, query_pos)\n >>> print(output.shape)\n torch.Size([1, 100, 256])\n \"\"\"\n\n def __init__(\n self,\n d_model: int = 256,\n dim_feedforward: int = 2048,\n dropout: float = 0.1,\n pos_enc_at_attn: bool = False,\n pos_enc_at_cross_attn_keys: bool = True,\n pos_enc_at_cross_attn_queries: bool = False,\n ):\n \"\"\"\n Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.\n\n Args:\n d_model (int): Dimensionality of the model.\n dim_feedforward (int): Dimensionality of the feedforward network.\n dropout (float): Dropout rate for regularization.\n pos_enc_at_attn (bool): Whether to add positional encoding at attention.\n pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.\n pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.\n \"\"\"\n super().__init__()\n self.d_model = d_model\n self.dim_feedforward = dim_feedforward\n self.dropout_value = dropout\n self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)\n self.cross_attn_image = RoPEAttention(\n rope_k_repeat=True,\n embedding_dim=256,\n num_heads=1,\n downsample_rate=1,\n kv_in_dim=64,\n )\n\n # Implementation of Feedforward model\n self.linear1 = nn.Linear(d_model, dim_feedforward)\n self.dropout = nn.Dropout(dropout)\n self.linear2 = nn.Linear(dim_feedforward, d_model)\n\n self.norm1 = nn.LayerNorm(d_model)\n self.norm2 = nn.LayerNorm(d_model)\n self.norm3 = nn.LayerNorm(d_model)\n self.dropout1 = nn.Dropout(dropout)\n self.dropout2 = nn.Dropout(dropout)\n self.dropout3 = nn.Dropout(dropout)\n\n self.activation = nn.ReLU()\n\n # Where to add pos enc\n self.pos_enc_at_attn = pos_enc_at_attn\n self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries\n self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys\n\n def _forward_sa(self, tgt: torch.Tensor, query_pos: Optional[torch.Tensor]) -> torch.Tensor:\n \"\"\"Perform self-attention on input tensor using positional encoding and RoPE attention mechanism.\"\"\"\n tgt2 = self.norm1(tgt)\n q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2\n tgt2 = self.self_attn(q, k, v=tgt2)\n tgt = tgt + self.dropout1(tgt2)\n return tgt\n\n def _forward_ca(\n self,\n tgt: torch.Tensor,\n memory: torch.Tensor,\n query_pos: Optional[torch.Tensor],\n pos: Optional[torch.Tensor],\n num_k_exclude_rope: int = 0,\n ) -> torch.Tensor:\n \"\"\"Perform cross-attention between target and memory tensors using RoPEAttention mechanism.\"\"\"\n kwds = {}\n if num_k_exclude_rope > 0:\n assert isinstance(self.cross_attn_image, RoPEAttention)\n kwds = {\"num_k_exclude_rope\": num_k_exclude_rope}\n\n # Cross-Attention\n tgt2 = self.norm2(tgt)\n tgt2 = self.cross_attn_image(\n q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,\n k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,\n v=memory,\n **kwds,\n )\n tgt = tgt + self.dropout2(tgt2)\n return tgt\n\n def forward(\n self,\n tgt: torch.Tensor,\n memory: torch.Tensor,\n pos: Optional[torch.Tensor] = None,\n query_pos: Optional[torch.Tensor] = None,\n num_k_exclude_rope: int = 0,\n ) -> torch.Tensor:\n \"\"\"\n Process input tensors through self-attention, cross-attention, and feedforward network layers.\n\n Args:\n tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).\n memory (torch.Tensor): Memory tensor for cross-attention with shape (N, S, D).\n pos (Optional[torch.Tensor]): Positional encoding for memory tensor.\n query_pos (Optional[torch.Tensor]): Positional encoding for target tensor.\n num_k_exclude_rope (int): Number of keys to exclude from rotary position embedding.\n\n Returns:\n (torch.Tensor): Processed tensor after attention and feedforward layers with shape (N, L, D).\n \"\"\"\n tgt = self._forward_sa(tgt, query_pos)\n tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)\n # MLP\n tgt2 = self.norm3(tgt)\n tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))\n tgt = tgt + self.dropout3(tgt2)\n return tgt",
"chunk_type": "class",
"name": "MemoryAttentionLayer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py",
"start_line": 12,
"end_line": 166,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.\n\nThis class combines self-attention, cross-attention, and feedforward components to process input tensors and\ngenerate memory-based attention outputs.\n\nAttributes:\n d_model (int): Dimensionality of the model.\n dim_feedforward (int): Dimensionality of the feedforward network.\n dropout_value (float): Dropout rate for regularization.\n self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).\n cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.\n linear1 (nn.Linear): First linear layer of the feedforward network.\n linear2 (nn.Linear): Second linear layer of the feedforward network.\n norm1 (nn.LayerNorm): Layer normalization for self-attention output.\n norm2 (nn.LayerNorm): Layer normalization for cross-attention output.\n norm3 (nn.LayerNorm): Layer normalization for feedforward network output.\n dropout1 (nn.Dropout): Dropout layer after self-attention.\n dropout2 (nn.Dropout): Dropout layer after cross-attention.\n dropout3 (nn.Dropout): Dropout layer after feedforward network.\n activation (nn.ReLU): Activation function for the feedforward network.\n pos_enc_at_attn (bool): Flag to add positional encoding at attention.\n pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.\n pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.\n\nMethods:\n forward: Performs the full memory attention operation on input tensors.\n _forward_sa: Performs self-attention on input tensor.\n _forward_ca: Performs cross-attention between target and memory tensors.\n\nExamples:\n >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)\n >>> tgt = torch.randn(1, 100, 256)\n >>> memory = torch.randn(1, 100, 64)\n >>> pos = torch.randn(1, 100, 256)\n >>> query_pos = torch.randn(1, 100, 256)\n >>> output = layer(tgt, memory, pos, query_pos)\n >>> print(output.shape)\n torch.Size([1, 100, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"typing.Optional",
"torch",
"torch.nn",
"blocks.RoPEAttention",
"nn.Module"
],
"chunk_id": "class_MemoryAttentionLayer_9da47992"
},
{
"content": "class MemoryAttention(nn.Module):\n \"\"\"\n Memory attention module for processing sequential data with self and cross-attention mechanisms.\n\n This class implements a multi-layer attention mechanism that combines self-attention and cross-attention\n for processing sequential data, particularly useful in transformer-like architectures.\n\n Attributes:\n d_model (int): The dimension of the model's hidden state.\n layers (nn.ModuleList): A list of MemoryAttentionLayer modules.\n num_layers (int): The number of attention layers.\n norm (nn.LayerNorm): Layer normalization applied to the output.\n pos_enc_at_input (bool): Whether to apply positional encoding at the input.\n batch_first (bool): Whether the input tensors are in batch-first format.\n\n Methods:\n forward: Processes input tensors through the attention layers.\n\n Examples:\n >>> d_model = 256\n >>> layer = MemoryAttentionLayer(d_model)\n >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)\n >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)\n >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)\n >>> curr_pos = torch.randn(10, 32, d_model)\n >>> memory_pos = torch.randn(20, 32, d_model)\n >>> output = attention(curr, memory, curr_pos, memory_pos)\n >>> print(output.shape)\n torch.Size([10, 32, 256])\n \"\"\"\n\n def __init__(\n self,\n d_model: int,\n pos_enc_at_input: bool,\n layer: nn.Module,\n num_layers: int,\n batch_first: bool = True, # Do layers expect batch first input?\n ):\n \"\"\"\n Initialize MemoryAttention with specified layers and normalization for sequential data processing.\n\n This class implements a multi-layer attention mechanism that combines self-attention and cross-attention\n for processing sequential data, particularly useful in transformer-like architectures.\n\n Args:\n d_model (int): The dimension of the model's hidden state.\n pos_enc_at_input (bool): Whether to apply positional encoding at the input.\n layer (nn.Module): The attention layer to be used in the module.\n num_layers (int): The number of attention layers.\n batch_first (bool): Whether the input tensors are in batch-first format.\n\n Examples:\n >>> d_model = 256\n >>> layer = MemoryAttentionLayer(d_model)\n >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)\n >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)\n >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)\n >>> curr_pos = torch.randn(10, 32, d_model)\n >>> memory_pos = torch.randn(20, 32, d_model)\n >>> output = attention(curr, memory, curr_pos, memory_pos)\n >>> print(output.shape)\n torch.Size([10, 32, 256])\n \"\"\"\n super().__init__()\n self.d_model = d_model\n self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])\n self.num_layers = num_layers\n self.norm = nn.LayerNorm(d_model)\n self.pos_enc_at_input = pos_enc_at_input\n self.batch_first = batch_first\n\n def forward(\n self,\n curr: torch.Tensor, # self-attention inputs\n memory: torch.Tensor, # cross-attention inputs\n curr_pos: Optional[torch.Tensor] = None, # pos_enc for self-attention inputs\n memory_pos: Optional[torch.Tensor] = None, # pos_enc for cross-attention inputs\n num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*\n ) -> torch.Tensor:\n \"\"\"\n Process inputs through attention layers, applying self and cross-attention with positional encoding.\n\n Args:\n curr (torch.Tensor): Self-attention input tensor, representing the current state.\n memory (torch.Tensor): Cross-attention input tensor, representing memory information.\n curr_pos (Optional[torch.Tensor]): Positional encoding for self-attention inputs.\n memory_pos (Optional[torch.Tensor]): Positional encoding for cross-attention inputs.\n num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.\n\n Returns:\n (torch.Tensor): Processed output tensor after applying attention layers and normalization.\n\n Examples:\n >>> d_model = 256\n >>> layer = MemoryAttentionLayer(d_model)\n >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)\n >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)\n >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)\n >>> curr_pos = torch.randn(10, 32, d_model)\n >>> memory_pos = torch.randn(20, 32, d_model)\n >>> output = attention(curr, memory, curr_pos, memory_pos)\n >>> print(output.shape)\n torch.Size([10, 32, 256])\n \"\"\"\n if isinstance(curr, list):\n assert isinstance(curr_pos, list)\n assert len(curr) == len(curr_pos) == 1\n curr, curr_pos = curr[0], curr_pos[0]\n\n assert curr.shape[1] == memory.shape[1], \"Batch size must be the same for curr and memory\"\n\n output = curr\n if self.pos_enc_at_input and curr_pos is not None:\n output = output + 0.1 * curr_pos\n\n if self.batch_first:\n # Convert to batch first\n output = output.transpose(0, 1)\n curr_pos = curr_pos.transpose(0, 1)\n memory = memory.transpose(0, 1)\n memory_pos = memory_pos.transpose(0, 1)\n\n for layer in self.layers:\n kwds = {}\n if isinstance(layer.cross_attn_image, RoPEAttention):\n kwds = {\"num_k_exclude_rope\": num_obj_ptr_tokens}\n\n output = layer(\n tgt=output,\n memory=memory,\n pos=memory_pos,\n query_pos=curr_pos,\n **kwds,\n )\n normed_output = self.norm(output)\n\n if self.batch_first:\n # Convert back to seq first\n normed_output = normed_output.transpose(0, 1)\n curr_pos = curr_pos.transpose(0, 1)\n\n return normed_output",
"chunk_type": "class",
"name": "MemoryAttention",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\memory_attention.py",
"start_line": 169,
"end_line": 311,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Memory attention module for processing sequential data with self and cross-attention mechanisms.\n\nThis class implements a multi-layer attention mechanism that combines self-attention and cross-attention\nfor processing sequential data, particularly useful in transformer-like architectures.\n\nAttributes:\n d_model (int): The dimension of the model's hidden state.\n layers (nn.ModuleList): A list of MemoryAttentionLayer modules.\n num_layers (int): The number of attention layers.\n norm (nn.LayerNorm): Layer normalization applied to the output.\n pos_enc_at_input (bool): Whether to apply positional encoding at the input.\n batch_first (bool): Whether the input tensors are in batch-first format.\n\nMethods:\n forward: Processes input tensors through the attention layers.\n\nExamples:\n >>> d_model = 256\n >>> layer = MemoryAttentionLayer(d_model)\n >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)\n >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)\n >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)\n >>> curr_pos = torch.randn(10, 32, d_model)\n >>> memory_pos = torch.randn(20, 32, d_model)\n >>> output = attention(curr, memory, curr_pos, memory_pos)\n >>> print(output.shape)\n torch.Size([10, 32, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"typing.Optional",
"torch",
"torch.nn",
"blocks.RoPEAttention",
"nn.Module"
],
"chunk_id": "class_MemoryAttention_82ab8fdf"
},
{
"content": "from typing import List",
"chunk_type": "import",
"name": "List",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List_d360fbfb"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_b9f1e54d"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_5af2f3e6"
},
{
"content": "from torch import nn",
"chunk_type": "import",
"name": "nn",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_nn_fa6e5c75"
},
{
"content": "from torch.nn.init import trunc_normal_",
"chunk_type": "import",
"name": "trunc_normal_",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_trunc_normal__e8ba48f7"
},
{
"content": "from ultralytics.nn.modules import MLP",
"chunk_type": "import",
"name": "MLP",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_MLP_7c9a62a8"
},
{
"content": "from ultralytics.utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_fde5fd72"
},
{
"content": "from .blocks import SAM2TwoWayTransformer",
"chunk_type": "import",
"name": "SAM2TwoWayTransformer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SAM2TwoWayTransformer_ee3bfbb6"
},
{
"content": "from .decoders import MaskDecoder, SAM2MaskDecoder",
"chunk_type": "import",
"name": "MaskDecoder, SAM2MaskDecoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 20,
"end_line": 20,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_MaskDecoder, SAM2MaskDecoder_753b2615"
},
{
"content": "from .encoders import ImageEncoderViT, PromptEncoder",
"chunk_type": "import",
"name": "ImageEncoderViT, PromptEncoder",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 21,
"end_line": 21,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ImageEncoderViT, PromptEncoder_47cfddc7"
},
{
"content": "from .utils import get_1d_sine_pe, select_closest_cond_frames",
"chunk_type": "import",
"name": "get_1d_sine_pe, select_closest_cond_frames",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 22,
"end_line": 22,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_get_1d_sine_pe, select_closest_cond_frames_b5ce41f7"
},
{
"content": "NO_OBJ_SCORE = -1024.0",
"chunk_type": "variable",
"name": "NO_OBJ_SCORE",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 25,
"end_line": 25,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_NO_OBJ_SCORE_03a1cbc0"
},
{
"content": "class SAMModel(nn.Module):\n \"\"\"\n Segment Anything Model (SAM) for object segmentation tasks.\n\n This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images\n and input prompts.\n\n Attributes:\n mask_threshold (float): Threshold value for mask prediction.\n image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.\n prompt_encoder (PromptEncoder): Encoder for various types of input prompts.\n mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.\n pixel_mean (torch.Tensor): Mean values for normalizing pixels in the input image.\n pixel_std (torch.Tensor): Standard deviation values for normalizing pixels in the input image.\n\n Methods:\n set_imgsz: Set image size to make model compatible with different image sizes.\n\n Examples:\n >>> image_encoder = ImageEncoderViT(...)\n >>> prompt_encoder = PromptEncoder(...)\n >>> mask_decoder = MaskDecoder(...)\n >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)\n >>> # Further usage depends on SAMPredictor class\n\n Notes:\n All forward() operations are implemented in the SAMPredictor class.\n \"\"\"\n\n mask_threshold: float = 0.0\n\n def __init__(\n self,\n image_encoder: ImageEncoderViT,\n prompt_encoder: PromptEncoder,\n mask_decoder: MaskDecoder,\n pixel_mean: List[float] = (123.675, 116.28, 103.53),\n pixel_std: List[float] = (58.395, 57.12, 57.375),\n ) -> None:\n \"\"\"\n Initialize the SAMModel class to predict object masks from an image and input prompts.\n\n Args:\n image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.\n prompt_encoder (PromptEncoder): Encodes various types of input prompts.\n mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.\n pixel_mean (List[float]): Mean values for normalizing pixels in the input image.\n pixel_std (List[float]): Standard deviation values for normalizing pixels in the input image.\n\n Examples:\n >>> image_encoder = ImageEncoderViT(...)\n >>> prompt_encoder = PromptEncoder(...)\n >>> mask_decoder = MaskDecoder(...)\n >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)\n >>> # Further usage depends on SAMPredictor class\n\n Notes:\n All forward() operations moved to SAMPredictor.\n \"\"\"\n super().__init__()\n self.image_encoder = image_encoder\n self.prompt_encoder = prompt_encoder\n self.mask_decoder = mask_decoder\n self.register_buffer(\"pixel_mean\", torch.Tensor(pixel_mean).view(-1, 1, 1), False)\n self.register_buffer(\"pixel_std\", torch.Tensor(pixel_std).view(-1, 1, 1), False)\n\n def set_imgsz(self, imgsz):\n \"\"\"Set image size to make model compatible with different image sizes.\"\"\"\n if hasattr(self.image_encoder, \"set_imgsz\"):\n self.image_encoder.set_imgsz(imgsz)\n self.prompt_encoder.input_image_size = imgsz\n self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model\n self.image_encoder.img_size = imgsz[0]",
"chunk_type": "class",
"name": "SAMModel",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 28,
"end_line": 100,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Segment Anything Model (SAM) for object segmentation tasks.\n\nThis class combines image encoders, prompt encoders, and mask decoders to predict object masks from images\nand input prompts.\n\nAttributes:\n mask_threshold (float): Threshold value for mask prediction.\n image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.\n prompt_encoder (PromptEncoder): Encoder for various types of input prompts.\n mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.\n pixel_mean (torch.Tensor): Mean values for normalizing pixels in the input image.\n pixel_std (torch.Tensor): Standard deviation values for normalizing pixels in the input image.\n\nMethods:\n set_imgsz: Set image size to make model compatible with different image sizes.\n\nExamples:\n >>> image_encoder = ImageEncoderViT(...)\n >>> prompt_encoder = PromptEncoder(...)\n >>> mask_decoder = MaskDecoder(...)\n >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)\n >>> # Further usage depends on SAMPredictor class\n\nNotes:\n All forward() operations are implemented in the SAMPredictor class.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"torch",
"torch.nn.functional",
"torch.nn",
"torch.nn.init.trunc_normal_",
"ultralytics.nn.modules.MLP",
"ultralytics.utils.LOGGER",
"blocks.SAM2TwoWayTransformer",
"decoders.MaskDecoder",
"decoders.SAM2MaskDecoder",
"encoders.ImageEncoderViT",
"encoders.PromptEncoder",
"utils.get_1d_sine_pe",
"utils.select_closest_cond_frames",
"nn.Module"
],
"chunk_id": "class_SAMModel_74592926"
},
{
"content": "class SAM2Model(torch.nn.Module):\n \"\"\"\n SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.\n\n This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms\n for temporal consistency and efficient tracking of objects across frames.\n\n Attributes:\n mask_threshold (float): Threshold value for mask prediction.\n image_encoder (ImageEncoderViT): Visual encoder for extracting image features.\n memory_attention (nn.Module): Module for attending to memory features.\n memory_encoder (nn.Module): Encoder for generating memory representations.\n num_maskmem (int): Number of accessible memory frames.\n image_size (int): Size of input images.\n backbone_stride (int): Stride of the backbone network output.\n sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.\n sam_image_embedding_size (int): Size of SAM image embeddings.\n sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.\n sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.\n obj_ptr_proj (nn.Module): Projection layer for object pointers.\n obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.\n hidden_dim (int): Hidden dimension of the model.\n mem_dim (int): Memory dimension for encoding features.\n use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.\n use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.\n max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder cross-attention.\n add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers.\n proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional\n encoding in object pointers.\n use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in temporal positional encoding.\n only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during\n evaluation.\n pred_obj_scores (bool): Whether to predict if there is an object in the frame.\n pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.\n fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.\n soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.\n use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.\n no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.\n max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.\n directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the\n first frame.\n multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial\n conditioning frames.\n multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.\n multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.\n multimask_output_for_tracking (bool): Whether to use multimask output for tracking.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.\n iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].\n memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.\n non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in\n memory encoder during evaluation.\n sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.\n sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.\n binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames\n with clicks during evaluation.\n use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM\n prompt encoder and mask decoder on frames with mask input.\n\n Methods:\n forward_image: Process image batch through encoder to extract multi-level features.\n track_step: Perform a single tracking step, updating object masks and memory features.\n set_binarize: Set binarize for VideoPredictor.\n set_imgsz: Set image size to make model compatible with different image sizes.\n\n Examples:\n >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)\n >>> image_batch = torch.rand(1, 3, 512, 512)\n >>> features = model.forward_image(image_batch)\n >>> track_results = model.track_step(0, True, features, None, None, None, {})\n \"\"\"\n\n mask_threshold: float = 0.0\n\n def __init__(\n self,\n image_encoder,\n memory_attention,\n memory_encoder,\n num_maskmem=7,\n image_size=512,\n backbone_stride=16,\n sigmoid_scale_for_mem_enc=1.0,\n sigmoid_bias_for_mem_enc=0.0,\n binarize_mask_from_pts_for_mem_enc=False,\n use_mask_input_as_output_without_sam=False,\n max_cond_frames_in_attn=-1,\n directly_add_no_mem_embed=False,\n use_high_res_features_in_sam=False,\n multimask_output_in_sam=False,\n multimask_min_pt_num=1,\n multimask_max_pt_num=1,\n multimask_output_for_tracking=False,\n use_multimask_token_for_obj_ptr: bool = False,\n iou_prediction_use_sigmoid=False,\n memory_temporal_stride_for_eval=1,\n non_overlap_masks_for_mem_enc=False,\n use_obj_ptrs_in_encoder=False,\n max_obj_ptrs_in_encoder=16,\n add_tpos_enc_to_obj_ptrs=True,\n proj_tpos_enc_in_obj_ptrs=False,\n use_signed_tpos_enc_to_obj_ptrs=False,\n only_obj_ptrs_in_the_past_for_eval=False,\n pred_obj_scores: bool = False,\n pred_obj_scores_mlp: bool = False,\n fixed_no_obj_ptr: bool = False,\n soft_no_obj_ptr: bool = False,\n use_mlp_for_obj_ptr_proj: bool = False,\n no_obj_embed_spatial: bool = False,\n sam_mask_decoder_extra_args=None,\n compile_image_encoder: bool = False,\n ):\n \"\"\"\n Initialize the SAM2Model for video object segmentation with memory-based tracking.\n\n Args:\n image_encoder (nn.Module): Visual encoder for extracting image features.\n memory_attention (nn.Module): Module for attending to memory features.\n memory_encoder (nn.Module): Encoder for generating memory representations.\n num_maskmem (int): Number of accessible memory frames.\n image_size (int): Size of input images.\n backbone_stride (int): Stride of the image backbone output.\n sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.\n sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.\n binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames\n with clicks during evaluation.\n use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM\n prompt encoder and mask decoder on frames with mask input.\n max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.\n directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the\n first frame.\n use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.\n multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial\n conditioning frames.\n multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.\n multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.\n multimask_output_for_tracking (bool): Whether to use multimask output for tracking.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.\n iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].\n memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.\n non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in\n memory encoder during evaluation.\n use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.\n max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder\n cross-attention.\n add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in\n the encoder.\n proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional\n encoding in object pointers.\n use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding\n in the object pointers.\n only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past\n during evaluation.\n pred_obj_scores (bool): Whether to predict if there is an object in the frame.\n pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.\n fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.\n soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.\n use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.\n no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.\n sam_mask_decoder_extra_args (dict | None): Extra arguments for constructing the SAM mask decoder.\n compile_image_encoder (bool): Whether to compile the image encoder for faster inference.\n\n Examples:\n >>> image_encoder = ImageEncoderViT(...)\n >>> memory_attention = SAM2TwoWayTransformer(...)\n >>> memory_encoder = nn.Sequential(...)\n >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)\n >>> image_batch = torch.rand(1, 3, 512, 512)\n >>> features = model.forward_image(image_batch)\n >>> track_results = model.track_step(0, True, features, None, None, None, {})\n \"\"\"\n super().__init__()\n\n # Part 1: the image backbone\n self.image_encoder = image_encoder\n # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting\n self.use_high_res_features_in_sam = use_high_res_features_in_sam\n self.num_feature_levels = 3 if use_high_res_features_in_sam else 1\n self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder\n self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder\n if use_obj_ptrs_in_encoder:\n # A conv layer to downsample the mask prompt to stride 4 (the same stride as\n # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,\n # so that it can be fed into the SAM mask decoder to generate a pointer.\n self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)\n self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs\n if proj_tpos_enc_in_obj_ptrs:\n assert add_tpos_enc_to_obj_ptrs # these options need to be used together\n self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs\n self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs\n self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval\n\n # Part 2: memory attention to condition current frame's visual features\n # with memories (and obj ptrs) from past frames\n self.memory_attention = memory_attention\n self.hidden_dim = memory_attention.d_model\n\n # Part 3: memory encoder for the previous frame's outputs\n self.memory_encoder = memory_encoder\n self.mem_dim = self.hidden_dim\n if hasattr(self.memory_encoder, \"out_proj\") and hasattr(self.memory_encoder.out_proj, \"weight\"):\n # if there is compression of memories along channel dim\n self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]\n self.num_maskmem = num_maskmem # Number of memories accessible\n # Temporal encoding of the memories\n self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))\n trunc_normal_(self.maskmem_tpos_enc, std=0.02)\n # a single token to indicate no memory embedding from previous frames\n self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))\n self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))\n trunc_normal_(self.no_mem_embed, std=0.02)\n trunc_normal_(self.no_mem_pos_enc, std=0.02)\n self.directly_add_no_mem_embed = directly_add_no_mem_embed\n # Apply sigmoid to the output raw mask logits (to turn them from\n # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder\n self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc\n self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc\n self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc\n self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc\n self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval\n # On frames with mask input, whether to directly output the input mask without\n # using a SAM prompt encoder + mask decoder\n self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam\n self.multimask_output_in_sam = multimask_output_in_sam\n self.multimask_min_pt_num = multimask_min_pt_num\n self.multimask_max_pt_num = multimask_max_pt_num\n self.multimask_output_for_tracking = multimask_output_for_tracking\n self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr\n self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid\n\n # Part 4: SAM-style prompt encoder (for both mask and point inputs)\n # and SAM-style mask decoder for the final mask output\n self.image_size = image_size\n self.backbone_stride = backbone_stride\n self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args\n self.pred_obj_scores = pred_obj_scores\n self.pred_obj_scores_mlp = pred_obj_scores_mlp\n self.fixed_no_obj_ptr = fixed_no_obj_ptr\n self.soft_no_obj_ptr = soft_no_obj_ptr\n if self.fixed_no_obj_ptr:\n assert self.pred_obj_scores\n assert self.use_obj_ptrs_in_encoder\n if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:\n self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))\n trunc_normal_(self.no_obj_ptr, std=0.02)\n self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj\n self.no_obj_embed_spatial = None\n if no_obj_embed_spatial:\n self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))\n trunc_normal_(self.no_obj_embed_spatial, std=0.02)\n\n self._build_sam_heads()\n self.max_cond_frames_in_attn = max_cond_frames_in_attn\n\n # Model compilation\n if compile_image_encoder:\n # Compile the forward function (not the full module) to allow loading checkpoints.\n LOGGER.info(\"Image encoder compilation is enabled. First forward pass will be slow.\")\n self.image_encoder.forward = torch.compile(\n self.image_encoder.forward,\n mode=\"max-autotune\",\n fullgraph=True,\n dynamic=False,\n )\n\n @property\n def device(self):\n \"\"\"Return the device on which the model's parameters are stored.\"\"\"\n return next(self.parameters()).device\n\n def forward(self, *args, **kwargs):\n \"\"\"Process image and prompt inputs to generate object masks and scores in video sequences.\"\"\"\n raise NotImplementedError(\n \"Please use the corresponding methods in SAM2VideoPredictor for inference.\"\n \"See notebooks/video_predictor_example.ipynb for an example.\"\n )\n\n def _build_sam_heads(self):\n \"\"\"Build SAM-style prompt encoder and mask decoder for image segmentation tasks.\"\"\"\n self.sam_prompt_embed_dim = self.hidden_dim\n self.sam_image_embedding_size = self.image_size // self.backbone_stride\n\n # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)\n self.sam_prompt_encoder = PromptEncoder(\n embed_dim=self.sam_prompt_embed_dim,\n image_embedding_size=(\n self.sam_image_embedding_size,\n self.sam_image_embedding_size,\n ),\n input_image_size=(self.image_size, self.image_size),\n mask_in_chans=16,\n )\n self.sam_mask_decoder = SAM2MaskDecoder(\n num_multimask_outputs=3,\n transformer=SAM2TwoWayTransformer(\n depth=2,\n embedding_dim=self.sam_prompt_embed_dim,\n mlp_dim=2048,\n num_heads=8,\n ),\n transformer_dim=self.sam_prompt_embed_dim,\n iou_head_depth=3,\n iou_head_hidden_dim=256,\n use_high_res_features=self.use_high_res_features_in_sam,\n iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,\n pred_obj_scores=self.pred_obj_scores,\n pred_obj_scores_mlp=self.pred_obj_scores_mlp,\n use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,\n **(self.sam_mask_decoder_extra_args or {}),\n )\n if self.use_obj_ptrs_in_encoder:\n # a linear projection on SAM output tokens to turn them into object pointers\n self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)\n if self.use_mlp_for_obj_ptr_proj:\n self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)\n else:\n self.obj_ptr_proj = torch.nn.Identity()\n if self.proj_tpos_enc_in_obj_ptrs:\n # a linear projection on temporal positional encoding in object pointers to\n # avoid potential interference with spatial positional encoding\n self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)\n else:\n self.obj_ptr_tpos_proj = torch.nn.Identity()\n\n def _forward_sam_heads(\n self,\n backbone_features,\n point_inputs=None,\n mask_inputs=None,\n high_res_features=None,\n multimask_output=False,\n ):\n \"\"\"\n Forward pass through SAM prompt encoders and mask heads.\n\n This method processes image features and optional point/mask inputs to generate object masks and scores.\n\n Args:\n backbone_features (torch.Tensor): Image features with shape (B, C, H, W).\n point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.\n 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute\n pixel-unit coordinates in (x, y) format for P input points.\n 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,\n 0 means negative clicks, and -1 means padding.\n mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the\n same spatial size as the image.\n high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes\n (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps\n for SAM decoder.\n multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,\n output only 1 mask and its IoU estimate.\n\n Returns:\n low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.\n high_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.\n ious (torch.Tensor): Tensor of shape (B, M) with estimated IoU for each output mask.\n low_res_masks (torch.Tensor): Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.\n high_res_masks (torch.Tensor): Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.\n obj_ptr (torch.Tensor): Tensor of shape (B, C) with object pointer vector for the output mask.\n object_score_logits (torch.Tensor): Tensor of shape (B) with object score logits.\n\n Examples:\n >>> backbone_features = torch.rand(1, 256, 32, 32)\n >>> point_inputs = {\"point_coords\": torch.rand(1, 2, 2), \"point_labels\": torch.tensor([[1, 0]])}\n >>> mask_inputs = torch.rand(1, 1, 512, 512)\n >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)\n >>> (\n ... low_res_multimasks,\n ... high_res_multimasks,\n ... ious,\n ... low_res_masks,\n ... high_res_masks,\n ... obj_ptr,\n ... object_score_logits,\n ... ) = results\n \"\"\"\n B = backbone_features.size(0)\n device = backbone_features.device\n assert backbone_features.size(1) == self.sam_prompt_embed_dim\n assert backbone_features.size(2) == self.sam_image_embedding_size\n assert backbone_features.size(3) == self.sam_image_embedding_size\n\n # a) Handle point prompts\n if point_inputs is not None:\n sam_point_coords = point_inputs[\"point_coords\"]\n sam_point_labels = point_inputs[\"point_labels\"]\n assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B\n else:\n # If no points are provide, pad with an empty point (with label -1)\n sam_point_coords = torch.zeros(B, 1, 2, device=device)\n sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)\n\n # b) Handle mask prompts\n if mask_inputs is not None:\n # If mask_inputs is provided, downsize it into low-res mask input if needed\n # and feed it as a dense mask prompt into the SAM mask encoder\n assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)\n if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:\n sam_mask_prompt = F.interpolate(\n mask_inputs.float(),\n size=self.sam_prompt_encoder.mask_input_size,\n align_corners=False,\n mode=\"bilinear\",\n antialias=True, # use antialias for downsampling\n )\n else:\n sam_mask_prompt = mask_inputs\n else:\n # Otherwise, simply feed None (and SAM's prompt encoder will add\n # a learned `no_mask_embed` to indicate no mask input in this case).\n sam_mask_prompt = None\n\n sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(\n points=(sam_point_coords, sam_point_labels),\n boxes=None,\n masks=sam_mask_prompt,\n )\n low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(\n image_embeddings=backbone_features,\n image_pe=self.sam_prompt_encoder.get_dense_pe(),\n sparse_prompt_embeddings=sparse_embeddings,\n dense_prompt_embeddings=dense_embeddings,\n multimask_output=multimask_output,\n repeat_image=False, # the image is already batched\n high_res_features=high_res_features,\n )\n if self.pred_obj_scores:\n is_obj_appearing = object_score_logits > 0\n\n # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction\n low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)\n\n # convert masks from possibly bfloat16 (or float16) to float32\n # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)\n low_res_multimasks = low_res_multimasks.float()\n high_res_multimasks = F.interpolate(\n low_res_multimasks,\n size=(self.image_size, self.image_size),\n mode=\"bilinear\",\n align_corners=False,\n )\n\n sam_output_token = sam_output_tokens[:, 0]\n if multimask_output:\n # take the best mask prediction (with the highest IoU estimation)\n best_iou_inds = torch.argmax(ious, dim=-1)\n batch_inds = torch.arange(B, device=device)\n low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)\n high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)\n if sam_output_tokens.size(1) > 1:\n sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]\n else:\n low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks\n\n # Extract object pointer from the SAM output token (with occlusion handling)\n obj_ptr = self.obj_ptr_proj(sam_output_token)\n if self.pred_obj_scores:\n # Allow *soft* no obj ptr, unlike for masks\n if self.soft_no_obj_ptr:\n lambda_is_obj_appearing = object_score_logits.sigmoid()\n else:\n lambda_is_obj_appearing = is_obj_appearing.float()\n\n if self.fixed_no_obj_ptr:\n obj_ptr = lambda_is_obj_appearing * obj_ptr\n obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr\n\n return (\n low_res_multimasks,\n high_res_multimasks,\n ious,\n low_res_masks,\n high_res_masks,\n obj_ptr,\n object_score_logits,\n )\n\n def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):\n \"\"\"Process mask inputs directly as output, bypassing SAM encoder/decoder.\"\"\"\n # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).\n out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05\n mask_inputs_float = mask_inputs.float()\n high_res_masks = mask_inputs_float * out_scale + out_bias\n low_res_masks = F.interpolate(\n high_res_masks,\n size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),\n align_corners=False,\n mode=\"bilinear\",\n antialias=True, # use antialias for downsampling\n )\n # a dummy IoU prediction of all 1's under mask input\n ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()\n if not self.use_obj_ptrs_in_encoder:\n # all zeros as a dummy object pointer (of shape [B, C])\n obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)\n else:\n # produce an object pointer using the SAM decoder from the mask input\n _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(\n backbone_features=backbone_features,\n mask_inputs=self.mask_downsample(mask_inputs_float),\n high_res_features=high_res_features,\n )\n # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;\n # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying\n # on the object_scores from the SAM decoder.\n is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)\n is_obj_appearing = is_obj_appearing[..., None]\n lambda_is_obj_appearing = is_obj_appearing.float()\n object_score_logits = out_scale * lambda_is_obj_appearing + out_bias\n if self.pred_obj_scores:\n if self.fixed_no_obj_ptr:\n obj_ptr = lambda_is_obj_appearing * obj_ptr\n obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr\n\n return (\n low_res_masks,\n high_res_masks,\n ious,\n low_res_masks,\n high_res_masks,\n obj_ptr,\n object_score_logits,\n )\n\n def forward_image(self, img_batch: torch.Tensor):\n \"\"\"Process image batch through encoder to extract multi-level features for SAM model.\"\"\"\n backbone_out = self.image_encoder(img_batch)\n if self.use_high_res_features_in_sam:\n # precompute projected level 0 and level 1 features in SAM decoder\n # to avoid running it again on every SAM click\n backbone_out[\"backbone_fpn\"][0] = self.sam_mask_decoder.conv_s0(backbone_out[\"backbone_fpn\"][0])\n backbone_out[\"backbone_fpn\"][1] = self.sam_mask_decoder.conv_s1(backbone_out[\"backbone_fpn\"][1])\n return backbone_out\n\n def _prepare_backbone_features(self, backbone_out):\n \"\"\"Prepare and flatten visual features from the image backbone output for further processing.\"\"\"\n assert len(backbone_out[\"backbone_fpn\"]) == len(backbone_out[\"vision_pos_enc\"])\n assert len(backbone_out[\"backbone_fpn\"]) >= self.num_feature_levels\n\n feature_maps = backbone_out[\"backbone_fpn\"][-self.num_feature_levels :]\n vision_pos_embeds = backbone_out[\"vision_pos_enc\"][-self.num_feature_levels :]\n\n feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]\n # flatten NxCxHxW to HWxNxC\n vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]\n vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]\n\n return backbone_out, vision_feats, vision_pos_embeds, feat_sizes\n\n def _prepare_memory_conditioned_features(\n self,\n frame_idx,\n is_init_cond_frame,\n current_vision_feats,\n current_vision_pos_embeds,\n feat_sizes,\n output_dict,\n num_frames,\n track_in_reverse=False, # tracking in reverse time order (for demo usage)\n ):\n \"\"\"Prepare memory-conditioned features by fusing current frame's visual features with previous memories.\"\"\"\n B = current_vision_feats[-1].size(1) # batch size on this frame\n C = self.hidden_dim\n H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size\n device = current_vision_feats[-1].device\n # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.\n # In this case, we skip the fusion with any memory.\n if self.num_maskmem == 0: # Disable memory and skip fusion\n return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)\n num_obj_ptr_tokens = 0\n tpos_sign_mul = -1 if track_in_reverse else 1\n # Step 1: condition the visual features of the current frame on previous memories\n if not is_init_cond_frame:\n # Retrieve the memories encoded with the maskmem backbone\n to_cat_memory, to_cat_memory_pos_embed = [], []\n # Add conditioning frame's output first (all cond frames have t_pos=0 for\n # when getting temporal positional embedding below)\n assert len(output_dict[\"cond_frame_outputs\"]) > 0\n # Select a maximum number of temporally closest cond frames for cross attention\n cond_outputs = output_dict[\"cond_frame_outputs\"]\n selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(\n frame_idx, cond_outputs, self.max_cond_frames_in_attn\n )\n t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]\n # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory\n # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1\n # We also allow taking the memory frame non-consecutively (with r>1), in which case\n # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.\n r = 1 if self.training else self.memory_temporal_stride_for_eval\n for t_pos in range(1, self.num_maskmem):\n t_rel = self.num_maskmem - t_pos # how many frames before current frame\n if t_rel == 1:\n # for t_rel == 1, we take the last frame (regardless of r)\n prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel\n elif not track_in_reverse:\n # first find the nearest frame among every r-th frames before this frame\n # for r=1, this would be (frame_idx - 2)\n prev_frame_idx = ((frame_idx - 2) // r) * r\n # then seek further among every r-th frames\n prev_frame_idx = prev_frame_idx - (t_rel - 2) * r\n else:\n # first find the nearest frame among every r-th frames after this frame\n # for r=1, this would be (frame_idx + 2)\n prev_frame_idx = -(-(frame_idx + 2) // r) * r\n # then seek further among every r-th frames\n prev_frame_idx = prev_frame_idx + (t_rel - 2) * r\n out = output_dict[\"non_cond_frame_outputs\"].get(prev_frame_idx, None)\n if out is None:\n # If an unselected conditioning frame is among the last (self.num_maskmem - 1)\n # frames, we still attend to it as if it's a non-conditioning frame.\n out = unselected_cond_outputs.get(prev_frame_idx, None)\n t_pos_and_prevs.append((t_pos, out))\n\n for t_pos, prev in t_pos_and_prevs:\n if prev is None:\n continue # skip padding frames\n # \"maskmem_features\" might have been offloaded to CPU in demo use cases,\n # so we load it back to inference device (it's a no-op if it's already on device).\n feats = prev[\"maskmem_features\"].to(device=device, non_blocking=True)\n to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))\n # Spatial positional encoding (it might have been offloaded to CPU in eval)\n maskmem_enc = prev[\"maskmem_pos_enc\"][-1].to(device=device)\n maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)\n # Temporal positional encoding\n maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]\n to_cat_memory_pos_embed.append(maskmem_enc)\n\n # Construct the list of past object pointers\n if self.use_obj_ptrs_in_encoder:\n max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)\n # First add those object pointers from selected conditioning frames\n # (optionally, only include object pointers in the past during evaluation)\n if not self.training and self.only_obj_ptrs_in_the_past_for_eval:\n ptr_cond_outputs = {\n t: out\n for t, out in selected_cond_outputs.items()\n if (t >= frame_idx if track_in_reverse else t <= frame_idx)\n }\n else:\n ptr_cond_outputs = selected_cond_outputs\n pos_and_ptrs = [\n # Temporal pos encoding contains how far away each pointer is from current frame\n (\n (\n (frame_idx - t) * tpos_sign_mul\n if self.use_signed_tpos_enc_to_obj_ptrs\n else abs(frame_idx - t)\n ),\n out[\"obj_ptr\"],\n )\n for t, out in ptr_cond_outputs.items()\n ]\n # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame\n for t_diff in range(1, max_obj_ptrs_in_encoder):\n t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff\n if t < 0 or (num_frames is not None and t >= num_frames):\n break\n out = output_dict[\"non_cond_frame_outputs\"].get(t, unselected_cond_outputs.get(t, None))\n if out is not None:\n pos_and_ptrs.append((t_diff, out[\"obj_ptr\"]))\n # If we have at least one object pointer, add them to the across attention\n if pos_and_ptrs:\n pos_list, ptrs_list = zip(*pos_and_ptrs)\n # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape\n obj_ptrs = torch.stack(ptrs_list, dim=0)\n # a temporal positional embedding based on how far each object pointer is from\n # the current frame (sine embedding normalized by the max pointer num).\n if self.add_tpos_enc_to_obj_ptrs:\n t_diff_max = max_obj_ptrs_in_encoder - 1\n tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim\n obj_pos = torch.tensor(pos_list, device=device)\n obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)\n obj_pos = self.obj_ptr_tpos_proj(obj_pos)\n obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)\n else:\n obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)\n if self.mem_dim < C:\n # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C\n obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)\n obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)\n obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)\n to_cat_memory.append(obj_ptrs)\n to_cat_memory_pos_embed.append(obj_pos)\n num_obj_ptr_tokens = obj_ptrs.shape[0]\n else:\n num_obj_ptr_tokens = 0\n else:\n # for initial conditioning frames, encode them without using any previous memory\n if self.directly_add_no_mem_embed:\n # directly add no-mem embedding (instead of using the transformer encoder)\n pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed\n pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)\n return pix_feat_with_mem\n\n # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)\n to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]\n to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]\n\n # Step 2: Concatenate the memories and forward through the transformer encoder\n memory = torch.cat(to_cat_memory, dim=0)\n memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)\n\n pix_feat_with_mem = self.memory_attention(\n curr=current_vision_feats,\n curr_pos=current_vision_pos_embeds,\n memory=memory,\n memory_pos=memory_pos_embed,\n num_obj_ptr_tokens=num_obj_ptr_tokens,\n )\n # reshape the output (HW)BC => BCHW\n pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)\n return pix_feat_with_mem\n\n def _encode_new_memory(\n self,\n current_vision_feats,\n feat_sizes,\n pred_masks_high_res,\n object_score_logits,\n is_mask_from_pts,\n ):\n \"\"\"Encode frame features and masks into a new memory representation for video segmentation.\"\"\"\n B = current_vision_feats[-1].size(1) # batch size on this frame\n C = self.hidden_dim\n H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size\n # top-level feature, (HW)BC => BCHW\n pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)\n if self.non_overlap_masks_for_mem_enc and not self.training:\n # optionally, apply non-overlapping constraints to the masks (it's applied\n # in the batch dimension and should only be used during eval, where all\n # the objects come from the same video under batch size 1).\n pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)\n # scale the raw mask logits with a temperature before applying sigmoid\n binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts\n if binarize and not self.training:\n mask_for_mem = (pred_masks_high_res > 0).float()\n else:\n # apply sigmoid on the raw mask logits to turn them into range (0, 1)\n mask_for_mem = torch.sigmoid(pred_masks_high_res)\n # apply scale and bias terms to the sigmoid probabilities\n if self.sigmoid_scale_for_mem_enc != 1.0:\n mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc\n if self.sigmoid_bias_for_mem_enc != 0.0:\n mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc\n maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied\n maskmem_features = maskmem_out[\"vision_features\"]\n maskmem_pos_enc = maskmem_out[\"vision_pos_enc\"]\n # add a no-object embedding to the spatial memory to indicate that the frame\n # is predicted to be occluded (i.e. no object is appearing in the frame)\n if self.no_obj_embed_spatial is not None:\n is_obj_appearing = (object_score_logits > 0).float()\n maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[\n ..., None, None\n ].expand(*maskmem_features.shape)\n\n return maskmem_features, maskmem_pos_enc\n\n def _track_step(\n self,\n frame_idx,\n is_init_cond_frame,\n current_vision_feats,\n current_vision_pos_embeds,\n feat_sizes,\n point_inputs,\n mask_inputs,\n output_dict,\n num_frames,\n track_in_reverse,\n prev_sam_mask_logits,\n ):\n \"\"\"Perform a single tracking step, updating object masks and memory features based on current frame inputs.\"\"\"\n current_out = {\"point_inputs\": point_inputs, \"mask_inputs\": mask_inputs}\n # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW\n if len(current_vision_feats) > 1:\n high_res_features = [\n x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)\n for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])\n ]\n else:\n high_res_features = None\n if mask_inputs is not None and self.use_mask_input_as_output_without_sam:\n # When use_mask_input_as_output_without_sam=True, we directly output the mask input\n # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.\n pix_feat = current_vision_feats[-1].permute(1, 2, 0)\n pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])\n sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)\n else:\n # fused the visual feature with previous memory features in the memory bank\n pix_feat = self._prepare_memory_conditioned_features(\n frame_idx=frame_idx,\n is_init_cond_frame=is_init_cond_frame,\n current_vision_feats=current_vision_feats[-1:],\n current_vision_pos_embeds=current_vision_pos_embeds[-1:],\n feat_sizes=feat_sizes[-1:],\n output_dict=output_dict,\n num_frames=num_frames,\n track_in_reverse=track_in_reverse,\n )\n # apply SAM-style segmentation head\n # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,\n # e.g. in demo where such logits come from earlier interaction instead of correction sampling\n # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)\n if prev_sam_mask_logits is not None:\n assert point_inputs is not None and mask_inputs is None\n mask_inputs = prev_sam_mask_logits\n multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)\n sam_outputs = self._forward_sam_heads(\n backbone_features=pix_feat,\n point_inputs=point_inputs,\n mask_inputs=mask_inputs,\n high_res_features=high_res_features,\n multimask_output=multimask_output,\n )\n return current_out, sam_outputs, high_res_features, pix_feat\n\n def _encode_memory_in_output(\n self,\n current_vision_feats,\n feat_sizes,\n point_inputs,\n run_mem_encoder,\n high_res_masks,\n object_score_logits,\n current_out,\n ):\n \"\"\"Run memory encoder on predicted mask to encode it into a new memory feature for future frames.\"\"\"\n if run_mem_encoder and self.num_maskmem > 0:\n high_res_masks_for_mem_enc = high_res_masks\n maskmem_features, maskmem_pos_enc = self._encode_new_memory(\n current_vision_feats=current_vision_feats,\n feat_sizes=feat_sizes,\n pred_masks_high_res=high_res_masks_for_mem_enc,\n object_score_logits=object_score_logits,\n is_mask_from_pts=(point_inputs is not None),\n )\n current_out[\"maskmem_features\"] = maskmem_features\n current_out[\"maskmem_pos_enc\"] = maskmem_pos_enc\n else:\n current_out[\"maskmem_features\"] = None\n current_out[\"maskmem_pos_enc\"] = None\n\n def track_step(\n self,\n frame_idx,\n is_init_cond_frame,\n current_vision_feats,\n current_vision_pos_embeds,\n feat_sizes,\n point_inputs,\n mask_inputs,\n output_dict,\n num_frames,\n track_in_reverse=False, # tracking in reverse time order (for demo usage)\n # Whether to run the memory encoder on the predicted masks. Sometimes we might want\n # to skip the memory encoder with `run_mem_encoder=False`. For example,\n # in demo we might call `track_step` multiple times for each user click,\n # and only encode the memory when the user finalizes their clicks. And in ablation\n # settings like SAM training on static images, we don't need the memory encoder.\n run_mem_encoder=True,\n # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).\n prev_sam_mask_logits=None,\n ):\n \"\"\"Perform a single tracking step, updating object masks and memory features based on current frame inputs.\"\"\"\n current_out, sam_outputs, _, _ = self._track_step(\n frame_idx,\n is_init_cond_frame,\n current_vision_feats,\n current_vision_pos_embeds,\n feat_sizes,\n point_inputs,\n mask_inputs,\n output_dict,\n num_frames,\n track_in_reverse,\n prev_sam_mask_logits,\n )\n _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs\n\n current_out[\"pred_masks\"] = low_res_masks\n current_out[\"pred_masks_high_res\"] = high_res_masks\n current_out[\"obj_ptr\"] = obj_ptr\n if not self.training:\n # Only add this in inference (to avoid unused param in activation checkpointing;\n # it's mainly used in the demo to encode spatial memories w/ consolidated masks)\n current_out[\"object_score_logits\"] = object_score_logits\n\n # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)\n self._encode_memory_in_output(\n current_vision_feats,\n feat_sizes,\n point_inputs,\n run_mem_encoder,\n high_res_masks,\n object_score_logits,\n current_out,\n )\n\n return current_out\n\n def _use_multimask(self, is_init_cond_frame, point_inputs):\n \"\"\"Determine whether to use multiple mask outputs in the SAM head based on configuration and inputs.\"\"\"\n num_pts = 0 if point_inputs is None else point_inputs[\"point_labels\"].size(1)\n return (\n self.multimask_output_in_sam\n and (is_init_cond_frame or self.multimask_output_for_tracking)\n and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)\n )\n\n @staticmethod\n def _apply_non_overlapping_constraints(pred_masks):\n \"\"\"Apply non-overlapping constraints to masks, keeping the highest scoring object per location.\"\"\"\n batch_size = pred_masks.size(0)\n if batch_size == 1:\n return pred_masks\n\n device = pred_masks.device\n # \"max_obj_inds\": object index of the object with the highest score at each location\n max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)\n # \"batch_obj_inds\": object index of each object slice (along dim 0) in `pred_masks`\n batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]\n keep = max_obj_inds == batch_obj_inds\n # suppress overlapping regions' scores below -10.0 so that the foreground regions\n # don't overlap (here sigmoid(-10.0)=4.5398e-05)\n pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))\n return pred_masks\n\n def set_binarize(self, binarize=False):\n \"\"\"Set binarize for VideoPredictor.\"\"\"\n self.binarize_mask_from_pts_for_mem_enc = binarize\n\n def set_imgsz(self, imgsz):\n \"\"\"Set image size to make model compatible with different image sizes.\"\"\"\n self.image_size = imgsz[0]\n self.sam_prompt_encoder.input_image_size = imgsz\n self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16",
"chunk_type": "class",
"name": "SAM2Model",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\sam.py",
"start_line": 103,
"end_line": 1037,
"start_col": 0,
"end_col": 79,
"parent_name": null,
"docstring": "SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.\n\nThis class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms\nfor temporal consistency and efficient tracking of objects across frames.\n\nAttributes:\n mask_threshold (float): Threshold value for mask prediction.\n image_encoder (ImageEncoderViT): Visual encoder for extracting image features.\n memory_attention (nn.Module): Module for attending to memory features.\n memory_encoder (nn.Module): Encoder for generating memory representations.\n num_maskmem (int): Number of accessible memory frames.\n image_size (int): Size of input images.\n backbone_stride (int): Stride of the backbone network output.\n sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.\n sam_image_embedding_size (int): Size of SAM image embeddings.\n sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.\n sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.\n obj_ptr_proj (nn.Module): Projection layer for object pointers.\n obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.\n hidden_dim (int): Hidden dimension of the model.\n mem_dim (int): Memory dimension for encoding features.\n use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.\n use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.\n max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder cross-attention.\n add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers.\n proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional\n encoding in object pointers.\n use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in temporal positional encoding.\n only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during\n evaluation.\n pred_obj_scores (bool): Whether to predict if there is an object in the frame.\n pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.\n fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.\n soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.\n use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.\n no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.\n max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.\n directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the\n first frame.\n multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial\n conditioning frames.\n multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.\n multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.\n multimask_output_for_tracking (bool): Whether to use multimask output for tracking.\n use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.\n iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].\n memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.\n non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in\n memory encoder during evaluation.\n sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.\n sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.\n binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames\n with clicks during evaluation.\n use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM\n prompt encoder and mask decoder on frames with mask input.\n\nMethods:\n forward_image: Process image batch through encoder to extract multi-level features.\n track_step: Perform a single tracking step, updating object masks and memory features.\n set_binarize: Set binarize for VideoPredictor.\n set_imgsz: Set image size to make model compatible with different image sizes.\n\nExamples:\n >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)\n >>> image_batch = torch.rand(1, 3, 512, 512)\n >>> features = model.forward_image(image_batch)\n >>> track_results = model.track_step(0, True, features, None, None, None, {})",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"torch",
"torch.nn.functional",
"torch.nn",
"torch.nn.init.trunc_normal_",
"ultralytics.nn.modules.MLP",
"ultralytics.utils.LOGGER",
"blocks.SAM2TwoWayTransformer",
"decoders.MaskDecoder",
"decoders.SAM2MaskDecoder",
"encoders.ImageEncoderViT",
"encoders.PromptEncoder",
"utils.get_1d_sine_pe",
"utils.select_closest_cond_frames",
"torch.nn.Module"
],
"chunk_id": "class_SAM2Model_8b658922"
},
{
"content": "import itertools",
"chunk_type": "import",
"name": "itertools",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_itertools_c0c04d63"
},
{
"content": "from typing import List, Optional, Tuple, Union",
"chunk_type": "import",
"name": "List, Optional, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional, Tuple, Union_50b5127e"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_dd9d1b04"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_76f46e0d"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_9776ea9d"
},
{
"content": "from ultralytics.nn.modules import LayerNorm2d",
"chunk_type": "import",
"name": "LayerNorm2d",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LayerNorm2d_ec2dd36e"
},
{
"content": "from ultralytics.utils.instance import to_2tuple",
"chunk_type": "import",
"name": "to_2tuple",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 20,
"end_line": 20,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_to_2tuple_dc97ab81"
},
{
"content": "class Conv2d_BN(torch.nn.Sequential):\n \"\"\"\n A sequential container that performs 2D convolution followed by batch normalization.\n\n This module combines a 2D convolution layer with batch normalization, providing a common building block\n for convolutional neural networks. The batch normalization weights and biases are initialized to specific\n values for optimal training performance.\n\n Attributes:\n c (torch.nn.Conv2d): 2D convolution layer.\n bn (torch.nn.BatchNorm2d): Batch normalization layer.\n\n Examples:\n >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output = conv_bn(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 64, 224, 224])\n \"\"\"\n\n def __init__(\n self,\n a: int,\n b: int,\n ks: int = 1,\n stride: int = 1,\n pad: int = 0,\n dilation: int = 1,\n groups: int = 1,\n bn_weight_init: float = 1,\n ):\n \"\"\"\n Initialize a sequential container with 2D convolution followed by batch normalization.\n\n Args:\n a (int): Number of input channels.\n b (int): Number of output channels.\n ks (int, optional): Kernel size for the convolution.\n stride (int, optional): Stride for the convolution.\n pad (int, optional): Padding for the convolution.\n dilation (int, optional): Dilation factor for the convolution.\n groups (int, optional): Number of groups for the convolution.\n bn_weight_init (float, optional): Initial value for batch normalization weight.\n \"\"\"\n super().__init__()\n self.add_module(\"c\", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))\n bn = torch.nn.BatchNorm2d(b)\n torch.nn.init.constant_(bn.weight, bn_weight_init)\n torch.nn.init.constant_(bn.bias, 0)\n self.add_module(\"bn\", bn)",
"chunk_type": "class",
"name": "Conv2d_BN",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 23,
"end_line": 72,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "A sequential container that performs 2D convolution followed by batch normalization.\n\nThis module combines a 2D convolution layer with batch normalization, providing a common building block\nfor convolutional neural networks. The batch normalization weights and biases are initialized to specific\nvalues for optimal training performance.\n\nAttributes:\n c (torch.nn.Conv2d): 2D convolution layer.\n bn (torch.nn.BatchNorm2d): Batch normalization layer.\n\nExamples:\n >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)\n >>> input_tensor = torch.randn(1, 3, 224, 224)\n >>> output = conv_bn(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 64, 224, 224])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"torch.nn.Sequential"
],
"chunk_id": "class_Conv2d_BN_4fa71554"
},
{
"content": "class PatchEmbed(nn.Module):\n \"\"\"\n Embed images into patches and project them into a specified embedding dimension.\n\n This module converts input images into patch embeddings using a sequence of convolutional layers,\n effectively downsampling the spatial dimensions while increasing the channel dimension.\n\n Attributes:\n patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.\n num_patches (int): Total number of patches.\n in_chans (int): Number of input channels.\n embed_dim (int): Dimension of the embedding.\n seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.\n\n Examples:\n >>> import torch\n >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 96, 56, 56])\n \"\"\"\n\n def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):\n \"\"\"\n Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.\n\n Args:\n in_chans (int): Number of input channels.\n embed_dim (int): Dimension of the embedding.\n resolution (int): Input image resolution.\n activation (nn.Module): Activation function to use between convolutions.\n \"\"\"\n super().__init__()\n img_size: Tuple[int, int] = to_2tuple(resolution)\n self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)\n self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]\n self.in_chans = in_chans\n self.embed_dim = embed_dim\n n = embed_dim\n self.seq = nn.Sequential(\n Conv2d_BN(in_chans, n // 2, 3, 2, 1),\n activation(),\n Conv2d_BN(n // 2, n, 3, 2, 1),\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input tensor through patch embedding sequence, converting images to patch embeddings.\"\"\"\n return self.seq(x)",
"chunk_type": "class",
"name": "PatchEmbed",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 75,
"end_line": 123,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": "Embed images into patches and project them into a specified embedding dimension.\n\nThis module converts input images into patch embeddings using a sequence of convolutional layers,\neffectively downsampling the spatial dimensions while increasing the channel dimension.\n\nAttributes:\n patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.\n num_patches (int): Total number of patches.\n in_chans (int): Number of input channels.\n embed_dim (int): Dimension of the embedding.\n seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.\n\nExamples:\n >>> import torch\n >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> output = patch_embed(x)\n >>> print(output.shape)\n torch.Size([1, 96, 56, 56])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"nn.Module"
],
"chunk_id": "class_PatchEmbed_c1f155b0"
},
{
"content": "class MBConv(nn.Module):\n \"\"\"\n Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.\n\n This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,\n and projection phases, along with residual connections for improved gradient flow.\n\n Attributes:\n in_chans (int): Number of input channels.\n hidden_chans (int): Number of hidden channels after expansion.\n out_chans (int): Number of output channels.\n conv1 (Conv2d_BN): First convolutional layer for channel expansion.\n act1 (nn.Module): First activation function.\n conv2 (Conv2d_BN): Depthwise convolutional layer.\n act2 (nn.Module): Second activation function.\n conv3 (Conv2d_BN): Final convolutional layer for projection.\n act3 (nn.Module): Third activation function.\n drop_path (nn.Module): Drop path layer (Identity for inference).\n\n Examples:\n >>> in_chans, out_chans = 32, 64\n >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)\n >>> x = torch.randn(1, in_chans, 56, 56)\n >>> output = mbconv(x)\n >>> print(output.shape)\n torch.Size([1, 64, 56, 56])\n \"\"\"\n\n def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):\n \"\"\"\n Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.\n\n Args:\n in_chans (int): Number of input channels.\n out_chans (int): Number of output channels.\n expand_ratio (float): Channel expansion ratio for the hidden layer.\n activation (nn.Module): Activation function to use.\n drop_path (float): Drop path rate for stochastic depth.\n \"\"\"\n super().__init__()\n self.in_chans = in_chans\n self.hidden_chans = int(in_chans * expand_ratio)\n self.out_chans = out_chans\n\n self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)\n self.act1 = activation()\n\n self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)\n self.act2 = activation()\n\n self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)\n self.act3 = activation()\n\n # NOTE: `DropPath` is needed only for training.\n # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n self.drop_path = nn.Identity()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Implement the forward pass of MBConv, applying convolutions and skip connection.\"\"\"\n shortcut = x\n x = self.conv1(x)\n x = self.act1(x)\n x = self.conv2(x)\n x = self.act2(x)\n x = self.conv3(x)\n x = self.drop_path(x)\n x += shortcut\n return self.act3(x)",
"chunk_type": "class",
"name": "MBConv",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 126,
"end_line": 193,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": "Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.\n\nThis module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,\nand projection phases, along with residual connections for improved gradient flow.\n\nAttributes:\n in_chans (int): Number of input channels.\n hidden_chans (int): Number of hidden channels after expansion.\n out_chans (int): Number of output channels.\n conv1 (Conv2d_BN): First convolutional layer for channel expansion.\n act1 (nn.Module): First activation function.\n conv2 (Conv2d_BN): Depthwise convolutional layer.\n act2 (nn.Module): Second activation function.\n conv3 (Conv2d_BN): Final convolutional layer for projection.\n act3 (nn.Module): Third activation function.\n drop_path (nn.Module): Drop path layer (Identity for inference).\n\nExamples:\n >>> in_chans, out_chans = 32, 64\n >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)\n >>> x = torch.randn(1, in_chans, 56, 56)\n >>> output = mbconv(x)\n >>> print(output.shape)\n torch.Size([1, 64, 56, 56])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"nn.Module"
],
"chunk_id": "class_MBConv_189eda56"
},
{
"content": "class PatchMerging(nn.Module):\n \"\"\"\n Merge neighboring patches in the feature map and project to a new dimension.\n\n This class implements a patch merging operation that combines spatial information and adjusts the feature\n dimension using a series of convolutional layers with batch normalization. It effectively reduces spatial\n resolution while potentially increasing channel dimensions.\n\n Attributes:\n input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.\n dim (int): The input dimension of the feature map.\n out_dim (int): The output dimension after merging and projection.\n act (nn.Module): The activation function used between convolutions.\n conv1 (Conv2d_BN): The first convolutional layer for dimension projection.\n conv2 (Conv2d_BN): The second convolutional layer for spatial merging.\n conv3 (Conv2d_BN): The third convolutional layer for final projection.\n\n Examples:\n >>> input_resolution = (56, 56)\n >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)\n >>> x = torch.randn(4, 64, 56, 56)\n >>> output = patch_merging(x)\n >>> print(output.shape)\n torch.Size([4, 3136, 128])\n \"\"\"\n\n def __init__(self, input_resolution: Tuple[int, int], dim: int, out_dim: int, activation):\n \"\"\"\n Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.\n\n Args:\n input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.\n dim (int): The input dimension of the feature map.\n out_dim (int): The output dimension after merging and projection.\n activation (nn.Module): The activation function used between convolutions.\n \"\"\"\n super().__init__()\n\n self.input_resolution = input_resolution\n self.dim = dim\n self.out_dim = out_dim\n self.act = activation()\n self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)\n stride_c = 1 if out_dim in {320, 448, 576} else 2\n self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)\n self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply patch merging and dimension projection to the input feature map.\"\"\"\n if x.ndim == 3:\n H, W = self.input_resolution\n B = len(x)\n # (B, C, H, W)\n x = x.view(B, H, W, -1).permute(0, 3, 1, 2)\n\n x = self.conv1(x)\n x = self.act(x)\n\n x = self.conv2(x)\n x = self.act(x)\n x = self.conv3(x)\n return x.flatten(2).transpose(1, 2)",
"chunk_type": "class",
"name": "PatchMerging",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 196,
"end_line": 257,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": "Merge neighboring patches in the feature map and project to a new dimension.\n\nThis class implements a patch merging operation that combines spatial information and adjusts the feature\ndimension using a series of convolutional layers with batch normalization. It effectively reduces spatial\nresolution while potentially increasing channel dimensions.\n\nAttributes:\n input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.\n dim (int): The input dimension of the feature map.\n out_dim (int): The output dimension after merging and projection.\n act (nn.Module): The activation function used between convolutions.\n conv1 (Conv2d_BN): The first convolutional layer for dimension projection.\n conv2 (Conv2d_BN): The second convolutional layer for spatial merging.\n conv3 (Conv2d_BN): The third convolutional layer for final projection.\n\nExamples:\n >>> input_resolution = (56, 56)\n >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)\n >>> x = torch.randn(4, 64, 56, 56)\n >>> output = patch_merging(x)\n >>> print(output.shape)\n torch.Size([4, 3136, 128])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"nn.Module"
],
"chunk_id": "class_PatchMerging_a0111844"
},
{
"content": "class ConvLayer(nn.Module):\n \"\"\"\n Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).\n\n This layer optionally applies downsample operations to the output and supports gradient checkpointing\n for memory efficiency during training.\n\n Attributes:\n dim (int): Dimensionality of the input and output.\n input_resolution (Tuple[int, int]): Resolution of the input image.\n depth (int): Number of MBConv layers in the block.\n use_checkpoint (bool): Whether to use gradient checkpointing to save memory.\n blocks (nn.ModuleList): List of MBConv layers.\n downsample (Optional[nn.Module]): Function for downsampling the output.\n\n Examples:\n >>> input_tensor = torch.randn(1, 64, 56, 56)\n >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)\n >>> output = conv_layer(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 3136, 128])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n input_resolution: Tuple[int, int],\n depth: int,\n activation,\n drop_path: Union[float, List[float]] = 0.0,\n downsample: Optional[nn.Module] = None,\n use_checkpoint: bool = False,\n out_dim: Optional[int] = None,\n conv_expand_ratio: float = 4.0,\n ):\n \"\"\"\n Initialize the ConvLayer with the given dimensions and settings.\n\n This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and\n optionally applies downsampling to the output.\n\n Args:\n dim (int): The dimensionality of the input and output.\n input_resolution (Tuple[int, int]): The resolution of the input image.\n depth (int): The number of MBConv layers in the block.\n activation (nn.Module): Activation function applied after each convolution.\n drop_path (float | List[float], optional): Drop path rate. Single float or a list of floats for each MBConv.\n downsample (Optional[nn.Module], optional): Function for downsampling the output. None to skip downsampling.\n use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.\n out_dim (Optional[int], optional): The dimensionality of the output. None means it will be the same as `dim`.\n conv_expand_ratio (float, optional): Expansion ratio for the MBConv layers.\n \"\"\"\n super().__init__()\n self.dim = dim\n self.input_resolution = input_resolution\n self.depth = depth\n self.use_checkpoint = use_checkpoint\n\n # Build blocks\n self.blocks = nn.ModuleList(\n [\n MBConv(\n dim,\n dim,\n conv_expand_ratio,\n activation,\n drop_path[i] if isinstance(drop_path, list) else drop_path,\n )\n for i in range(depth)\n ]\n )\n\n # Patch merging layer\n self.downsample = (\n None\n if downsample is None\n else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through convolutional layers, applying MBConv blocks and optional downsampling.\"\"\"\n for blk in self.blocks:\n x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import\n return x if self.downsample is None else self.downsample(x)",
"chunk_type": "class",
"name": "ConvLayer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 260,
"end_line": 343,
"start_col": 0,
"end_col": 67,
"parent_name": null,
"docstring": "Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).\n\nThis layer optionally applies downsample operations to the output and supports gradient checkpointing\nfor memory efficiency during training.\n\nAttributes:\n dim (int): Dimensionality of the input and output.\n input_resolution (Tuple[int, int]): Resolution of the input image.\n depth (int): Number of MBConv layers in the block.\n use_checkpoint (bool): Whether to use gradient checkpointing to save memory.\n blocks (nn.ModuleList): List of MBConv layers.\n downsample (Optional[nn.Module]): Function for downsampling the output.\n\nExamples:\n >>> input_tensor = torch.randn(1, 64, 56, 56)\n >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)\n >>> output = conv_layer(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 3136, 128])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"nn.Module"
],
"chunk_id": "class_ConvLayer_b164e6de"
},
{
"content": "class MLP(nn.Module):\n \"\"\"\n Multi-layer Perceptron (MLP) module for transformer architectures.\n\n This module applies layer normalization, two fully-connected layers with an activation function in between,\n and dropout. It is commonly used in transformer-based architectures for processing token embeddings.\n\n Attributes:\n norm (nn.LayerNorm): Layer normalization applied to the input.\n fc1 (nn.Linear): First fully-connected layer.\n fc2 (nn.Linear): Second fully-connected layer.\n act (nn.Module): Activation function applied after the first fully-connected layer.\n drop (nn.Dropout): Dropout layer applied after the activation function.\n\n Examples:\n >>> import torch\n >>> from torch import nn\n >>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)\n >>> x = torch.randn(32, 100, 256)\n >>> output = mlp(x)\n >>> print(output.shape)\n torch.Size([32, 100, 256])\n \"\"\"\n\n def __init__(\n self,\n in_features: int,\n hidden_features: Optional[int] = None,\n out_features: Optional[int] = None,\n activation=nn.GELU,\n drop: float = 0.0,\n ):\n \"\"\"\n Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.\n\n Args:\n in_features (int): Number of input features.\n hidden_features (Optional[int], optional): Number of hidden features.\n out_features (Optional[int], optional): Number of output features.\n activation (nn.Module): Activation function applied after the first fully-connected layer.\n drop (float, optional): Dropout probability.\n \"\"\"\n super().__init__()\n out_features = out_features or in_features\n hidden_features = hidden_features or in_features\n self.norm = nn.LayerNorm(in_features)\n self.fc1 = nn.Linear(in_features, hidden_features)\n self.fc2 = nn.Linear(hidden_features, out_features)\n self.act = activation()\n self.drop = nn.Dropout(drop)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply MLP operations: layer norm, FC layers, activation, and dropout to the input tensor.\"\"\"\n x = self.norm(x)\n x = self.fc1(x)\n x = self.act(x)\n x = self.drop(x)\n x = self.fc2(x)\n return self.drop(x)",
"chunk_type": "class",
"name": "MLP",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 346,
"end_line": 404,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": "Multi-layer Perceptron (MLP) module for transformer architectures.\n\nThis module applies layer normalization, two fully-connected layers with an activation function in between,\nand dropout. It is commonly used in transformer-based architectures for processing token embeddings.\n\nAttributes:\n norm (nn.LayerNorm): Layer normalization applied to the input.\n fc1 (nn.Linear): First fully-connected layer.\n fc2 (nn.Linear): Second fully-connected layer.\n act (nn.Module): Activation function applied after the first fully-connected layer.\n drop (nn.Dropout): Dropout layer applied after the activation function.\n\nExamples:\n >>> import torch\n >>> from torch import nn\n >>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)\n >>> x = torch.randn(32, 100, 256)\n >>> output = mlp(x)\n >>> print(output.shape)\n torch.Size([32, 100, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"nn.Module"
],
"chunk_id": "class_MLP_abe157c9"
},
{
"content": "class Attention(torch.nn.Module):\n \"\"\"\n Multi-head attention module with spatial awareness and trainable attention biases.\n\n This module implements a multi-head attention mechanism with support for spatial awareness, applying\n attention biases based on spatial resolution. It includes trainable attention biases for each unique\n offset between spatial positions in the resolution grid.\n\n Attributes:\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for attention scores.\n key_dim (int): Dimensionality of the keys and queries.\n nh_kd (int): Product of num_heads and key_dim.\n d (int): Dimensionality of the value vectors.\n dh (int): Product of d and num_heads.\n attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.\n norm (nn.LayerNorm): Layer normalization applied to input.\n qkv (nn.Linear): Linear layer for computing query, key, and value projections.\n proj (nn.Linear): Linear layer for final projection.\n attention_biases (nn.Parameter): Learnable attention biases.\n attention_bias_idxs (torch.Tensor): Indices for attention biases.\n ab (torch.Tensor): Cached attention biases for inference, deleted during training.\n\n Examples:\n >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))\n >>> x = torch.randn(1, 196, 256)\n >>> output = attn(x)\n >>> print(output.shape)\n torch.Size([1, 196, 256])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n key_dim: int,\n num_heads: int = 8,\n attn_ratio: float = 4,\n resolution: Tuple[int, int] = (14, 14),\n ):\n \"\"\"\n Initialize the Attention module for multi-head attention with spatial awareness.\n\n This module implements a multi-head attention mechanism with support for spatial awareness, applying\n attention biases based on spatial resolution. It includes trainable attention biases for each unique\n offset between spatial positions in the resolution grid.\n\n Args:\n dim (int): The dimensionality of the input and output.\n key_dim (int): The dimensionality of the keys and queries.\n num_heads (int, optional): Number of attention heads.\n attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors.\n resolution (Tuple[int, int], optional): Spatial resolution of the input feature map.\n \"\"\"\n super().__init__()\n\n assert isinstance(resolution, tuple) and len(resolution) == 2, \"'resolution' argument not tuple of length 2\"\n self.num_heads = num_heads\n self.scale = key_dim**-0.5\n self.key_dim = key_dim\n self.nh_kd = nh_kd = key_dim * num_heads\n self.d = int(attn_ratio * key_dim)\n self.dh = int(attn_ratio * key_dim) * num_heads\n self.attn_ratio = attn_ratio\n h = self.dh + nh_kd * 2\n\n self.norm = nn.LayerNorm(dim)\n self.qkv = nn.Linear(dim, h)\n self.proj = nn.Linear(self.dh, dim)\n\n points = list(itertools.product(range(resolution[0]), range(resolution[1])))\n N = len(points)\n attention_offsets = {}\n idxs = []\n for p1 in points:\n for p2 in points:\n offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))\n if offset not in attention_offsets:\n attention_offsets[offset] = len(attention_offsets)\n idxs.append(attention_offsets[offset])\n self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))\n self.register_buffer(\"attention_bias_idxs\", torch.LongTensor(idxs).view(N, N), persistent=False)\n\n @torch.no_grad()\n def train(self, mode: bool = True):\n \"\"\"Set the module in training mode and handle the 'ab' attribute for cached attention biases.\"\"\"\n super().train(mode)\n if mode and hasattr(self, \"ab\"):\n del self.ab\n else:\n self.ab = self.attention_biases[:, self.attention_bias_idxs]\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply multi-head attention with spatial awareness and trainable attention biases.\"\"\"\n B, N, _ = x.shape # B, N, C\n\n # Normalization\n x = self.norm(x)\n\n qkv = self.qkv(x)\n # (B, N, num_heads, d)\n q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)\n # (B, num_heads, N, d)\n q = q.permute(0, 2, 1, 3)\n k = k.permute(0, 2, 1, 3)\n v = v.permute(0, 2, 1, 3)\n self.ab = self.ab.to(self.attention_biases.device)\n\n attn = (q @ k.transpose(-2, -1)) * self.scale + (\n self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab\n )\n attn = attn.softmax(dim=-1)\n x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)\n return self.proj(x)",
"chunk_type": "class",
"name": "Attention",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 407,
"end_line": 519,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": "Multi-head attention module with spatial awareness and trainable attention biases.\n\nThis module implements a multi-head attention mechanism with support for spatial awareness, applying\nattention biases based on spatial resolution. It includes trainable attention biases for each unique\noffset between spatial positions in the resolution grid.\n\nAttributes:\n num_heads (int): Number of attention heads.\n scale (float): Scaling factor for attention scores.\n key_dim (int): Dimensionality of the keys and queries.\n nh_kd (int): Product of num_heads and key_dim.\n d (int): Dimensionality of the value vectors.\n dh (int): Product of d and num_heads.\n attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.\n norm (nn.LayerNorm): Layer normalization applied to input.\n qkv (nn.Linear): Linear layer for computing query, key, and value projections.\n proj (nn.Linear): Linear layer for final projection.\n attention_biases (nn.Parameter): Learnable attention biases.\n attention_bias_idxs (torch.Tensor): Indices for attention biases.\n ab (torch.Tensor): Cached attention biases for inference, deleted during training.\n\nExamples:\n >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))\n >>> x = torch.randn(1, 196, 256)\n >>> output = attn(x)\n >>> print(output.shape)\n torch.Size([1, 196, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"torch.nn.Module"
],
"chunk_id": "class_Attention_04fc8140"
},
{
"content": "class TinyViTBlock(nn.Module):\n \"\"\"\n TinyViT Block that applies self-attention and a local convolution to the input.\n\n This block is a key component of the TinyViT architecture, combining self-attention mechanisms with\n local convolutions to process input features efficiently. It supports windowed attention for\n computational efficiency and includes residual connections.\n\n Attributes:\n dim (int): The dimensionality of the input and output.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.\n num_heads (int): Number of attention heads.\n window_size (int): Size of the attention window.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n drop_path (nn.Module): Stochastic depth layer, identity function during inference.\n attn (Attention): Self-attention module.\n mlp (MLP): Multi-layer perceptron module.\n local_conv (Conv2d_BN): Depth-wise local convolution layer.\n\n Examples:\n >>> input_tensor = torch.randn(1, 196, 192)\n >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)\n >>> output = block(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 196, 192])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n input_resolution: Tuple[int, int],\n num_heads: int,\n window_size: int = 7,\n mlp_ratio: float = 4.0,\n drop: float = 0.0,\n drop_path: float = 0.0,\n local_conv_size: int = 3,\n activation=nn.GELU,\n ):\n \"\"\"\n Initialize a TinyViT block with self-attention and local convolution.\n\n This block is a key component of the TinyViT architecture, combining self-attention mechanisms with\n local convolutions to process input features efficiently.\n\n Args:\n dim (int): Dimensionality of the input and output features.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).\n num_heads (int): Number of attention heads.\n window_size (int, optional): Size of the attention window. Must be greater than 0.\n mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.\n drop (float, optional): Dropout rate.\n drop_path (float, optional): Stochastic depth rate.\n local_conv_size (int, optional): Kernel size of the local convolution.\n activation (nn.Module): Activation function for MLP.\n \"\"\"\n super().__init__()\n self.dim = dim\n self.input_resolution = input_resolution\n self.num_heads = num_heads\n assert window_size > 0, \"window_size must be greater than 0\"\n self.window_size = window_size\n self.mlp_ratio = mlp_ratio\n\n # NOTE: `DropPath` is needed only for training.\n # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n self.drop_path = nn.Identity()\n\n assert dim % num_heads == 0, \"dim must be divisible by num_heads\"\n head_dim = dim // num_heads\n\n window_resolution = (window_size, window_size)\n self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)\n\n mlp_hidden_dim = int(dim * mlp_ratio)\n mlp_activation = activation\n self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=mlp_activation, drop=drop)\n\n pad = local_conv_size // 2\n self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply self-attention, local convolution, and MLP operations to the input tensor.\"\"\"\n h, w = self.input_resolution\n b, hw, c = x.shape # batch, height*width, channels\n assert hw == h * w, \"input feature has wrong size\"\n res_x = x\n if h == self.window_size and w == self.window_size:\n x = self.attn(x)\n else:\n x = x.view(b, h, w, c)\n pad_b = (self.window_size - h % self.window_size) % self.window_size\n pad_r = (self.window_size - w % self.window_size) % self.window_size\n padding = pad_b > 0 or pad_r > 0\n if padding:\n x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))\n\n pH, pW = h + pad_b, w + pad_r\n nH = pH // self.window_size\n nW = pW // self.window_size\n\n # Window partition\n x = (\n x.view(b, nH, self.window_size, nW, self.window_size, c)\n .transpose(2, 3)\n .reshape(b * nH * nW, self.window_size * self.window_size, c)\n )\n x = self.attn(x)\n\n # Window reverse\n x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)\n if padding:\n x = x[:, :h, :w].contiguous()\n\n x = x.view(b, hw, c)\n\n x = res_x + self.drop_path(x)\n x = x.transpose(1, 2).reshape(b, c, h, w)\n x = self.local_conv(x)\n x = x.view(b, c, hw).transpose(1, 2)\n\n return x + self.drop_path(self.mlp(x))\n\n def extra_repr(self) -> str:\n \"\"\"\n Return a string representation of the TinyViTBlock's parameters.\n\n This method provides a formatted string containing key information about the TinyViTBlock, including its\n dimension, input resolution, number of attention heads, window size, and MLP ratio.\n\n Returns:\n (str): A formatted string containing the block's parameters.\n\n Examples:\n >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)\n >>> print(block.extra_repr())\n dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0\n \"\"\"\n return (\n f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \"\n f\"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}\"\n )",
"chunk_type": "class",
"name": "TinyViTBlock",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 522,
"end_line": 663,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "TinyViT Block that applies self-attention and a local convolution to the input.\n\nThis block is a key component of the TinyViT architecture, combining self-attention mechanisms with\nlocal convolutions to process input features efficiently. It supports windowed attention for\ncomputational efficiency and includes residual connections.\n\nAttributes:\n dim (int): The dimensionality of the input and output.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.\n num_heads (int): Number of attention heads.\n window_size (int): Size of the attention window.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n drop_path (nn.Module): Stochastic depth layer, identity function during inference.\n attn (Attention): Self-attention module.\n mlp (MLP): Multi-layer perceptron module.\n local_conv (Conv2d_BN): Depth-wise local convolution layer.\n\nExamples:\n >>> input_tensor = torch.randn(1, 196, 192)\n >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)\n >>> output = block(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 196, 192])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"nn.Module"
],
"chunk_id": "class_TinyViTBlock_f2ae92c6"
},
{
"content": "class BasicLayer(nn.Module):\n \"\"\"\n A basic TinyViT layer for one stage in a TinyViT architecture.\n\n This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks\n and an optional downsampling operation. It processes features at a specific resolution and\n dimensionality within the overall architecture.\n\n Attributes:\n dim (int): The dimensionality of the input and output features.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.\n depth (int): Number of TinyViT blocks in this layer.\n use_checkpoint (bool): Whether to use gradient checkpointing to save memory.\n blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.\n downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.\n\n Examples:\n >>> input_tensor = torch.randn(1, 3136, 192)\n >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)\n >>> output = layer(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 784, 384])\n \"\"\"\n\n def __init__(\n self,\n dim: int,\n input_resolution: Tuple[int, int],\n depth: int,\n num_heads: int,\n window_size: int,\n mlp_ratio: float = 4.0,\n drop: float = 0.0,\n drop_path: Union[float, List[float]] = 0.0,\n downsample: Optional[nn.Module] = None,\n use_checkpoint: bool = False,\n local_conv_size: int = 3,\n activation=nn.GELU,\n out_dim: Optional[int] = None,\n ):\n \"\"\"\n Initialize a BasicLayer in the TinyViT architecture.\n\n This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to\n process feature maps at a specific resolution and dimensionality within the TinyViT model.\n\n Args:\n dim (int): Dimensionality of the input and output features.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).\n depth (int): Number of TinyViT blocks in this layer.\n num_heads (int): Number of attention heads in each TinyViT block.\n window_size (int): Size of the local window for attention computation.\n mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.\n drop (float, optional): Dropout rate.\n drop_path (float | List[float], optional): Stochastic depth rate. Can be a float or a list of floats for each block.\n downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip downsampling.\n use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.\n local_conv_size (int, optional): Kernel size for the local convolution in each TinyViT block.\n activation (nn.Module): Activation function used in the MLP.\n out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as `dim`.\n \"\"\"\n super().__init__()\n self.dim = dim\n self.input_resolution = input_resolution\n self.depth = depth\n self.use_checkpoint = use_checkpoint\n\n # Build blocks\n self.blocks = nn.ModuleList(\n [\n TinyViTBlock(\n dim=dim,\n input_resolution=input_resolution,\n num_heads=num_heads,\n window_size=window_size,\n mlp_ratio=mlp_ratio,\n drop=drop,\n drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n local_conv_size=local_conv_size,\n activation=activation,\n )\n for i in range(depth)\n ]\n )\n\n # Patch merging layer\n self.downsample = (\n None\n if downsample is None\n else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through TinyViT blocks and optional downsampling.\"\"\"\n for blk in self.blocks:\n x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import\n return x if self.downsample is None else self.downsample(x)\n\n def extra_repr(self) -> str:\n \"\"\"Return a string with the layer's parameters for printing.\"\"\"\n return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"",
"chunk_type": "class",
"name": "BasicLayer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 666,
"end_line": 766,
"start_col": 0,
"end_col": 94,
"parent_name": null,
"docstring": "A basic TinyViT layer for one stage in a TinyViT architecture.\n\nThis class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks\nand an optional downsampling operation. It processes features at a specific resolution and\ndimensionality within the overall architecture.\n\nAttributes:\n dim (int): The dimensionality of the input and output features.\n input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.\n depth (int): Number of TinyViT blocks in this layer.\n use_checkpoint (bool): Whether to use gradient checkpointing to save memory.\n blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.\n downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.\n\nExamples:\n >>> input_tensor = torch.randn(1, 3136, 192)\n >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)\n >>> output = layer(input_tensor)\n >>> print(output.shape)\n torch.Size([1, 784, 384])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"nn.Module"
],
"chunk_id": "class_BasicLayer_d9b97223"
},
{
"content": "class TinyViT(nn.Module):\n \"\"\"\n TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.\n\n This class implements the TinyViT model, which combines elements of vision transformers and convolutional\n neural networks for improved efficiency and performance on vision tasks. It features hierarchical processing\n with patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.\n\n Attributes:\n img_size (int): Input image size.\n num_classes (int): Number of classification classes.\n depths (Tuple[int, int, int, int]): Number of blocks in each stage.\n num_layers (int): Total number of layers in the network.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n patch_embed (PatchEmbed): Module for patch embedding.\n patches_resolution (Tuple[int, int]): Resolution of embedded patches.\n layers (nn.ModuleList): List of network layers.\n norm_head (nn.LayerNorm): Layer normalization for the classifier head.\n head (nn.Linear): Linear layer for final classification.\n neck (nn.Sequential): Neck module for feature refinement.\n\n Examples:\n >>> model = TinyViT(img_size=224, num_classes=1000)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> features = model.forward_features(x)\n >>> print(features.shape)\n torch.Size([1, 256, 56, 56])\n \"\"\"\n\n def __init__(\n self,\n img_size: int = 224,\n in_chans: int = 3,\n num_classes: int = 1000,\n embed_dims: Tuple[int, int, int, int] = (96, 192, 384, 768),\n depths: Tuple[int, int, int, int] = (2, 2, 6, 2),\n num_heads: Tuple[int, int, int, int] = (3, 6, 12, 24),\n window_sizes: Tuple[int, int, int, int] = (7, 7, 14, 7),\n mlp_ratio: float = 4.0,\n drop_rate: float = 0.0,\n drop_path_rate: float = 0.1,\n use_checkpoint: bool = False,\n mbconv_expand_ratio: float = 4.0,\n local_conv_size: int = 3,\n layer_lr_decay: float = 1.0,\n ):\n \"\"\"\n Initialize the TinyViT model.\n\n This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of\n attention and convolution blocks, and a classification head.\n\n Args:\n img_size (int, optional): Size of the input image.\n in_chans (int, optional): Number of input channels.\n num_classes (int, optional): Number of classes for classification.\n embed_dims (Tuple[int, int, int, int], optional): Embedding dimensions for each stage.\n depths (Tuple[int, int, int, int], optional): Number of blocks in each stage.\n num_heads (Tuple[int, int, int, int], optional): Number of attention heads in each stage.\n window_sizes (Tuple[int, int, int, int], optional): Window sizes for each stage.\n mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim.\n drop_rate (float, optional): Dropout rate.\n drop_path_rate (float, optional): Stochastic depth rate.\n use_checkpoint (bool, optional): Whether to use checkpointing to save memory.\n mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer.\n local_conv_size (int, optional): Kernel size for local convolutions.\n layer_lr_decay (float, optional): Layer-wise learning rate decay factor.\n \"\"\"\n super().__init__()\n self.img_size = img_size\n self.num_classes = num_classes\n self.depths = depths\n self.num_layers = len(depths)\n self.mlp_ratio = mlp_ratio\n\n activation = nn.GELU\n\n self.patch_embed = PatchEmbed(\n in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation\n )\n\n patches_resolution = self.patch_embed.patches_resolution\n self.patches_resolution = patches_resolution\n\n # Stochastic depth\n dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule\n\n # Build layers\n self.layers = nn.ModuleList()\n for i_layer in range(self.num_layers):\n kwargs = dict(\n dim=embed_dims[i_layer],\n input_resolution=(\n patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),\n patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),\n ),\n # input_resolution=(patches_resolution[0] // (2 ** i_layer),\n # patches_resolution[1] // (2 ** i_layer)),\n depth=depths[i_layer],\n drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],\n downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n use_checkpoint=use_checkpoint,\n out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],\n activation=activation,\n )\n if i_layer == 0:\n layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)\n else:\n layer = BasicLayer(\n num_heads=num_heads[i_layer],\n window_size=window_sizes[i_layer],\n mlp_ratio=self.mlp_ratio,\n drop=drop_rate,\n local_conv_size=local_conv_size,\n **kwargs,\n )\n self.layers.append(layer)\n\n # Classifier head\n self.norm_head = nn.LayerNorm(embed_dims[-1])\n self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()\n\n # Init weights\n self.apply(self._init_weights)\n self.set_layer_lr_decay(layer_lr_decay)\n self.neck = nn.Sequential(\n nn.Conv2d(\n embed_dims[-1],\n 256,\n kernel_size=1,\n bias=False,\n ),\n LayerNorm2d(256),\n nn.Conv2d(\n 256,\n 256,\n kernel_size=3,\n padding=1,\n bias=False,\n ),\n LayerNorm2d(256),\n )\n\n def set_layer_lr_decay(self, layer_lr_decay: float):\n \"\"\"Set layer-wise learning rate decay for the TinyViT model based on depth.\"\"\"\n decay_rate = layer_lr_decay\n\n # Layers -> blocks (depth)\n depth = sum(self.depths)\n lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]\n\n def _set_lr_scale(m, scale):\n \"\"\"Set the learning rate scale for each layer in the model based on the layer's depth.\"\"\"\n for p in m.parameters():\n p.lr_scale = scale\n\n self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))\n i = 0\n for layer in self.layers:\n for block in layer.blocks:\n block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))\n i += 1\n if layer.downsample is not None:\n layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))\n assert i == depth\n for m in {self.norm_head, self.head}:\n m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))\n\n for k, p in self.named_parameters():\n p.param_name = k\n\n def _check_lr_scale(m):\n \"\"\"Check if the learning rate scale attribute is present in module's parameters.\"\"\"\n for p in m.parameters():\n assert hasattr(p, \"lr_scale\"), p.param_name\n\n self.apply(_check_lr_scale)\n\n @staticmethod\n def _init_weights(m):\n \"\"\"Initialize weights for linear and normalization layers in the TinyViT model.\"\"\"\n if isinstance(m, nn.Linear):\n # NOTE: This initialization is needed only for training.\n # trunc_normal_(m.weight, std=.02)\n if m.bias is not None:\n nn.init.constant_(m.bias, 0)\n elif isinstance(m, nn.LayerNorm):\n nn.init.constant_(m.bias, 0)\n nn.init.constant_(m.weight, 1.0)\n\n @torch.jit.ignore\n def no_weight_decay_keywords(self):\n \"\"\"Return a set of keywords for parameters that should not use weight decay.\"\"\"\n return {\"attention_biases\"}\n\n def forward_features(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input through feature extraction layers, returning spatial features.\"\"\"\n x = self.patch_embed(x) # x input is (N, C, H, W)\n\n x = self.layers[0](x)\n start_i = 1\n\n for i in range(start_i, len(self.layers)):\n layer = self.layers[i]\n x = layer(x)\n batch, _, channel = x.shape\n x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel)\n x = x.permute(0, 3, 1, 2)\n return self.neck(x)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Perform the forward pass through the TinyViT model, extracting features from the input image.\"\"\"\n return self.forward_features(x)\n\n def set_imgsz(self, imgsz: List[int] = [1024, 1024]):\n \"\"\"Set image size to make model compatible with different image sizes.\"\"\"\n imgsz = [s // 4 for s in imgsz]\n self.patches_resolution = imgsz\n for i, layer in enumerate(self.layers):\n input_resolution = (\n imgsz[0] // (2 ** (i - 1 if i == 3 else i)),\n imgsz[1] // (2 ** (i - 1 if i == 3 else i)),\n )\n layer.input_resolution = input_resolution\n if layer.downsample is not None:\n layer.downsample.input_resolution = input_resolution\n if isinstance(layer, BasicLayer):\n for b in layer.blocks:\n b.input_resolution = input_resolution",
"chunk_type": "class",
"name": "TinyViT",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\tiny_encoder.py",
"start_line": 769,
"end_line": 997,
"start_col": 0,
"end_col": 57,
"parent_name": null,
"docstring": "TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.\n\nThis class implements the TinyViT model, which combines elements of vision transformers and convolutional\nneural networks for improved efficiency and performance on vision tasks. It features hierarchical processing\nwith patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.\n\nAttributes:\n img_size (int): Input image size.\n num_classes (int): Number of classification classes.\n depths (Tuple[int, int, int, int]): Number of blocks in each stage.\n num_layers (int): Total number of layers in the network.\n mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.\n patch_embed (PatchEmbed): Module for patch embedding.\n patches_resolution (Tuple[int, int]): Resolution of embedded patches.\n layers (nn.ModuleList): List of network layers.\n norm_head (nn.LayerNorm): Layer normalization for the classifier head.\n head (nn.Linear): Linear layer for final classification.\n neck (nn.Sequential): Neck module for feature refinement.\n\nExamples:\n >>> model = TinyViT(img_size=224, num_classes=1000)\n >>> x = torch.randn(1, 3, 224, 224)\n >>> features = model.forward_features(x)\n >>> print(features.shape)\n torch.Size([1, 256, 56, 56])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.nn.modules.LayerNorm2d",
"ultralytics.utils.instance.to_2tuple",
"nn.Module"
],
"chunk_id": "class_TinyViT_bdde4483"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_1a0c7fce"
},
{
"content": "from typing import Tuple, Type",
"chunk_type": "import",
"name": "Tuple, Type",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Tuple, Type_20611453"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_ed6be518"
},
{
"content": "from torch import Tensor, nn",
"chunk_type": "import",
"name": "Tensor, nn",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Tensor, nn_ba6ccb03"
},
{
"content": "from ultralytics.nn.modules import MLPBlock",
"chunk_type": "import",
"name": "MLPBlock",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_MLPBlock_8903c4ea"
},
{
"content": "class TwoWayTransformer(nn.Module):\n \"\"\"\n A Two-Way Transformer module for simultaneous attention to image and query points.\n\n This class implements a specialized transformer decoder that attends to an input image using queries with\n supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point\n cloud processing.\n\n Attributes:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention.\n mlp_dim (int): Internal channel dimension for the MLP block.\n layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.\n final_attn_token_to_image (Attention): Final attention layer from queries to image.\n norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.\n\n Methods:\n forward: Process image and point embeddings through the transformer.\n\n Examples:\n >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> image_embedding = torch.randn(1, 256, 32, 32)\n >>> image_pe = torch.randn(1, 256, 32, 32)\n >>> point_embedding = torch.randn(1, 100, 256)\n >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)\n >>> print(output_queries.shape, output_image.shape)\n \"\"\"\n\n def __init__(\n self,\n depth: int,\n embedding_dim: int,\n num_heads: int,\n mlp_dim: int,\n activation: Type[nn.Module] = nn.ReLU,\n attention_downsample_rate: int = 2,\n ) -> None:\n \"\"\"\n Initialize a Two-Way Transformer for simultaneous attention to image and query points.\n\n Args:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.\n mlp_dim (int): Internal channel dimension for the MLP block.\n activation (Type[nn.Module], optional): Activation function to use in the MLP block.\n attention_downsample_rate (int, optional): Downsampling rate for attention mechanism.\n \"\"\"\n super().__init__()\n self.depth = depth\n self.embedding_dim = embedding_dim\n self.num_heads = num_heads\n self.mlp_dim = mlp_dim\n self.layers = nn.ModuleList()\n\n for i in range(depth):\n self.layers.append(\n TwoWayAttentionBlock(\n embedding_dim=embedding_dim,\n num_heads=num_heads,\n mlp_dim=mlp_dim,\n activation=activation,\n attention_downsample_rate=attention_downsample_rate,\n skip_first_layer_pe=(i == 0),\n )\n )\n\n self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n self.norm_final_attn = nn.LayerNorm(embedding_dim)\n\n def forward(\n self,\n image_embedding: torch.Tensor,\n image_pe: torch.Tensor,\n point_embedding: torch.Tensor,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Process image and point embeddings through the Two-Way Transformer.\n\n Args:\n image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).\n image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.\n point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).\n\n Returns:\n queries (torch.Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).\n keys (torch.Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).\n \"\"\"\n # BxCxHxW -> BxHWxC == B x N_image_tokens x C\n image_embedding = image_embedding.flatten(2).permute(0, 2, 1)\n image_pe = image_pe.flatten(2).permute(0, 2, 1)\n\n # Prepare queries\n queries = point_embedding\n keys = image_embedding\n\n # Apply transformer blocks and final layernorm\n for layer in self.layers:\n queries, keys = layer(\n queries=queries,\n keys=keys,\n query_pe=point_embedding,\n key_pe=image_pe,\n )\n\n # Apply the final attention layer from the points to the image\n q = queries + point_embedding\n k = keys + image_pe\n attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)\n queries = queries + attn_out\n queries = self.norm_final_attn(queries)\n\n return queries, keys",
"chunk_type": "class",
"name": "TwoWayTransformer",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py",
"start_line": 12,
"end_line": 125,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "A Two-Way Transformer module for simultaneous attention to image and query points.\n\nThis class implements a specialized transformer decoder that attends to an input image using queries with\nsupplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point\ncloud processing.\n\nAttributes:\n depth (int): Number of layers in the transformer.\n embedding_dim (int): Channel dimension for input embeddings.\n num_heads (int): Number of heads for multihead attention.\n mlp_dim (int): Internal channel dimension for the MLP block.\n layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.\n final_attn_token_to_image (Attention): Final attention layer from queries to image.\n norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.\n\nMethods:\n forward: Process image and point embeddings through the transformer.\n\nExamples:\n >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)\n >>> image_embedding = torch.randn(1, 256, 32, 32)\n >>> image_pe = torch.randn(1, 256, 32, 32)\n >>> point_embedding = torch.randn(1, 100, 256)\n >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)\n >>> print(output_queries.shape, output_image.shape)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.Tuple",
"typing.Type",
"torch",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLPBlock",
"nn.Module"
],
"chunk_id": "class_TwoWayTransformer_15517ad5"
},
{
"content": "class TwoWayAttentionBlock(nn.Module):\n \"\"\"\n A two-way attention block for simultaneous attention to image and query points.\n\n This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,\n cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense\n inputs to sparse inputs.\n\n Attributes:\n self_attn (Attention): Self-attention layer for queries.\n norm1 (nn.LayerNorm): Layer normalization after self-attention.\n cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.\n norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.\n mlp (MLPBlock): MLP block for transforming query embeddings.\n norm3 (nn.LayerNorm): Layer normalization after MLP block.\n norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.\n cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.\n skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.\n\n Methods:\n forward: Apply self-attention and cross-attention to queries and keys.\n\n Examples:\n >>> embedding_dim, num_heads = 256, 8\n >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)\n >>> queries = torch.randn(1, 100, embedding_dim)\n >>> keys = torch.randn(1, 1000, embedding_dim)\n >>> query_pe = torch.randn(1, 100, embedding_dim)\n >>> key_pe = torch.randn(1, 1000, embedding_dim)\n >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)\n \"\"\"\n\n def __init__(\n self,\n embedding_dim: int,\n num_heads: int,\n mlp_dim: int = 2048,\n activation: Type[nn.Module] = nn.ReLU,\n attention_downsample_rate: int = 2,\n skip_first_layer_pe: bool = False,\n ) -> None:\n \"\"\"\n Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.\n\n This block implements a specialized transformer layer with four main components: self-attention on sparse\n inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention\n of dense inputs to sparse inputs.\n\n Args:\n embedding_dim (int): Channel dimension of the embeddings.\n num_heads (int): Number of attention heads in the attention layers.\n mlp_dim (int, optional): Hidden dimension of the MLP block.\n activation (Type[nn.Module], optional): Activation function for the MLP block.\n attention_downsample_rate (int, optional): Downsampling rate for the attention mechanism.\n skip_first_layer_pe (bool, optional): Whether to skip positional encoding in the first layer.\n \"\"\"\n super().__init__()\n self.self_attn = Attention(embedding_dim, num_heads)\n self.norm1 = nn.LayerNorm(embedding_dim)\n\n self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n self.norm2 = nn.LayerNorm(embedding_dim)\n\n self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)\n self.norm3 = nn.LayerNorm(embedding_dim)\n\n self.norm4 = nn.LayerNorm(embedding_dim)\n self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n\n self.skip_first_layer_pe = skip_first_layer_pe\n\n def forward(\n self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Apply two-way attention to process query and key embeddings in a transformer block.\n\n Args:\n queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).\n keys (torch.Tensor): Key embeddings with shape (B, N_keys, embedding_dim).\n query_pe (torch.Tensor): Positional encodings for queries with same shape as queries.\n key_pe (torch.Tensor): Positional encodings for keys with same shape as keys.\n\n Returns:\n queries (torch.Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).\n keys (torch.Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).\n \"\"\"\n # Self attention block\n if self.skip_first_layer_pe:\n queries = self.self_attn(q=queries, k=queries, v=queries)\n else:\n q = queries + query_pe\n attn_out = self.self_attn(q=q, k=q, v=queries)\n queries = queries + attn_out\n queries = self.norm1(queries)\n\n # Cross attention block, tokens attending to image embedding\n q = queries + query_pe\n k = keys + key_pe\n attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)\n queries = queries + attn_out\n queries = self.norm2(queries)\n\n # MLP block\n mlp_out = self.mlp(queries)\n queries = queries + mlp_out\n queries = self.norm3(queries)\n\n # Cross attention block, image embedding attending to tokens\n q = queries + query_pe\n k = keys + key_pe\n attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)\n keys = keys + attn_out\n keys = self.norm4(keys)\n\n return queries, keys",
"chunk_type": "class",
"name": "TwoWayAttentionBlock",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py",
"start_line": 128,
"end_line": 243,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "A two-way attention block for simultaneous attention to image and query points.\n\nThis class implements a specialized transformer block with four main layers: self-attention on sparse inputs,\ncross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense\ninputs to sparse inputs.\n\nAttributes:\n self_attn (Attention): Self-attention layer for queries.\n norm1 (nn.LayerNorm): Layer normalization after self-attention.\n cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.\n norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.\n mlp (MLPBlock): MLP block for transforming query embeddings.\n norm3 (nn.LayerNorm): Layer normalization after MLP block.\n norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.\n cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.\n skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.\n\nMethods:\n forward: Apply self-attention and cross-attention to queries and keys.\n\nExamples:\n >>> embedding_dim, num_heads = 256, 8\n >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)\n >>> queries = torch.randn(1, 100, embedding_dim)\n >>> keys = torch.randn(1, 1000, embedding_dim)\n >>> query_pe = torch.randn(1, 100, embedding_dim)\n >>> key_pe = torch.randn(1, 1000, embedding_dim)\n >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.Tuple",
"typing.Type",
"torch",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLPBlock",
"nn.Module"
],
"chunk_id": "class_TwoWayAttentionBlock_851bcd69"
},
{
"content": "class Attention(nn.Module):\n \"\"\"\n An attention layer with downscaling capability for embedding size after projection.\n\n This class implements a multi-head attention mechanism with the option to downsample the internal\n dimension of queries, keys, and values.\n\n Attributes:\n embedding_dim (int): Dimensionality of input embeddings.\n kv_in_dim (int): Dimensionality of key and value inputs.\n internal_dim (int): Internal dimension after downsampling.\n num_heads (int): Number of attention heads.\n q_proj (nn.Linear): Linear projection for queries.\n k_proj (nn.Linear): Linear projection for keys.\n v_proj (nn.Linear): Linear projection for values.\n out_proj (nn.Linear): Linear projection for output.\n\n Methods:\n _separate_heads: Separate input tensor into attention heads.\n _recombine_heads: Recombine separated attention heads.\n forward: Compute attention output for given query, key, and value tensors.\n\n Examples:\n >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)\n >>> q = torch.randn(1, 100, 256)\n >>> k = v = torch.randn(1, 50, 256)\n >>> output = attn(q, k, v)\n >>> print(output.shape)\n torch.Size([1, 100, 256])\n \"\"\"\n\n def __init__(\n self,\n embedding_dim: int,\n num_heads: int,\n downsample_rate: int = 1,\n kv_in_dim: int = None,\n ) -> None:\n \"\"\"\n Initialize the Attention module with specified dimensions and settings.\n\n Args:\n embedding_dim (int): Dimensionality of input embeddings.\n num_heads (int): Number of attention heads.\n downsample_rate (int, optional): Factor by which internal dimensions are downsampled.\n kv_in_dim (int | None, optional): Dimensionality of key and value inputs. If None, uses embedding_dim.\n\n Raises:\n AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).\n \"\"\"\n super().__init__()\n self.embedding_dim = embedding_dim\n self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim\n self.internal_dim = embedding_dim // downsample_rate\n self.num_heads = num_heads\n assert self.internal_dim % num_heads == 0, \"num_heads must divide embedding_dim.\"\n\n self.q_proj = nn.Linear(embedding_dim, self.internal_dim)\n self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)\n self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)\n self.out_proj = nn.Linear(self.internal_dim, embedding_dim)\n\n @staticmethod\n def _separate_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:\n \"\"\"Separate the input tensor into the specified number of attention heads.\"\"\"\n b, n, c = x.shape\n x = x.reshape(b, n, num_heads, c // num_heads)\n return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head\n\n @staticmethod\n def _recombine_heads(x: Tensor) -> Tensor:\n \"\"\"Recombine separated attention heads into a single tensor.\"\"\"\n b, n_heads, n_tokens, c_per_head = x.shape\n x = x.transpose(1, 2)\n return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C\n\n def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply multi-head attention to query, key, and value tensors with optional downsampling.\n\n Args:\n q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).\n k (torch.Tensor): Key tensor with shape (B, N_k, embedding_dim).\n v (torch.Tensor): Value tensor with shape (B, N_k, embedding_dim).\n\n Returns:\n (torch.Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).\n \"\"\"\n # Input projections\n q = self.q_proj(q)\n k = self.k_proj(k)\n v = self.v_proj(v)\n\n # Separate into heads\n q = self._separate_heads(q, self.num_heads)\n k = self._separate_heads(k, self.num_heads)\n v = self._separate_heads(v, self.num_heads)\n\n # Attention\n _, _, _, c_per_head = q.shape\n attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens\n attn = attn / math.sqrt(c_per_head)\n attn = torch.softmax(attn, dim=-1)\n\n # Get output\n out = attn @ v\n out = self._recombine_heads(out)\n return self.out_proj(out)",
"chunk_type": "class",
"name": "Attention",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\transformer.py",
"start_line": 246,
"end_line": 353,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "An attention layer with downscaling capability for embedding size after projection.\n\nThis class implements a multi-head attention mechanism with the option to downsample the internal\ndimension of queries, keys, and values.\n\nAttributes:\n embedding_dim (int): Dimensionality of input embeddings.\n kv_in_dim (int): Dimensionality of key and value inputs.\n internal_dim (int): Internal dimension after downsampling.\n num_heads (int): Number of attention heads.\n q_proj (nn.Linear): Linear projection for queries.\n k_proj (nn.Linear): Linear projection for keys.\n v_proj (nn.Linear): Linear projection for values.\n out_proj (nn.Linear): Linear projection for output.\n\nMethods:\n _separate_heads: Separate input tensor into attention heads.\n _recombine_heads: Recombine separated attention heads.\n forward: Compute attention output for given query, key, and value tensors.\n\nExamples:\n >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)\n >>> q = torch.randn(1, 100, 256)\n >>> k = v = torch.randn(1, 50, 256)\n >>> output = attn(q, k, v)\n >>> print(output.shape)\n torch.Size([1, 100, 256])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.Tuple",
"typing.Type",
"torch",
"torch.Tensor",
"torch.nn",
"ultralytics.nn.modules.MLPBlock",
"nn.Module"
],
"chunk_id": "class_Attention_d0036d25"
},
{
"content": "from typing import Any, Dict, Tuple",
"chunk_type": "import",
"name": "Any, Dict, Tuple",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, Tuple_80431a72"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_7f36154e"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_659e6c2d"
},
{
"content": "def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: Dict[int, Any], max_cond_frame_num: int):\n \"\"\"\n Select the closest conditioning frames to a given frame index.\n\n Args:\n frame_idx (int): Current frame index.\n cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.\n max_cond_frame_num (int): Maximum number of conditioning frames to select.\n\n Returns:\n selected_outputs (Dict[int, Any]): Selected items from cond_frame_outputs.\n unselected_outputs (Dict[int, Any]): Items not selected from cond_frame_outputs.\n\n Examples:\n >>> frame_idx = 5\n >>> cond_frame_outputs = {1: \"a\", 3: \"b\", 7: \"c\", 9: \"d\"}\n >>> max_cond_frame_num = 2\n >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)\n >>> print(selected)\n {3: 'b', 7: 'c'}\n >>> print(unselected)\n {1: 'a', 9: 'd'}\n \"\"\"\n if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:\n selected_outputs = cond_frame_outputs\n unselected_outputs = {}\n else:\n assert max_cond_frame_num >= 2, \"we should allow using 2+ conditioning frames\"\n selected_outputs = {}\n\n # The closest conditioning frame before `frame_idx` (if any)\n idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)\n if idx_before is not None:\n selected_outputs[idx_before] = cond_frame_outputs[idx_before]\n\n # The closest conditioning frame after `frame_idx` (if any)\n idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)\n if idx_after is not None:\n selected_outputs[idx_after] = cond_frame_outputs[idx_after]\n\n # Add other temporally closest conditioning frames until reaching a total\n # of `max_cond_frame_num` conditioning frames.\n num_remain = max_cond_frame_num - len(selected_outputs)\n inds_remain = sorted(\n (t for t in cond_frame_outputs if t not in selected_outputs),\n key=lambda x: abs(x - frame_idx),\n )[:num_remain]\n selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)\n unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}\n\n return selected_outputs, unselected_outputs",
"chunk_type": "function",
"name": "select_closest_cond_frames",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 9,
"end_line": 59,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Select the closest conditioning frames to a given frame index.\n\nArgs:\n frame_idx (int): Current frame index.\n cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.\n max_cond_frame_num (int): Maximum number of conditioning frames to select.\n\nReturns:\n selected_outputs (Dict[int, Any]): Selected items from cond_frame_outputs.\n unselected_outputs (Dict[int, Any]): Items not selected from cond_frame_outputs.\n\nExamples:\n >>> frame_idx = 5\n >>> cond_frame_outputs = {1: \"a\", 3: \"b\", 7: \"c\", 9: \"d\"}\n >>> max_cond_frame_num = 2\n >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)\n >>> print(selected)\n {3: 'b', 7: 'c'}\n >>> print(unselected)\n {1: 'a', 9: 'd'}",
"parameters": [
"frame_idx: int",
"cond_frame_outputs: Dict[int, Any]",
"max_cond_frame_num: int"
],
"return_type": null,
"decorators": [],
"complexity_score": 9,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_select_closest_cond_frames_f98c91cb"
},
{
"content": "def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):\n \"\"\"\n Generate 1D sinusoidal positional embeddings for given positions and dimensions.\n\n Args:\n pos_inds (torch.Tensor): Position indices for which to generate embeddings.\n dim (int): Dimension of the positional embeddings. Should be an even number.\n temperature (float, optional): Scaling factor for the frequency of the sinusoidal functions.\n\n Returns:\n (torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).\n\n Examples:\n >>> pos = torch.tensor([0, 1, 2, 3])\n >>> embeddings = get_1d_sine_pe(pos, 128)\n >>> embeddings.shape\n torch.Size([4, 128])\n \"\"\"\n pe_dim = dim // 2\n dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)\n dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)\n\n pos_embed = pos_inds.unsqueeze(-1) / dim_t\n pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)\n return pos_embed",
"chunk_type": "function",
"name": "get_1d_sine_pe",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 62,
"end_line": 86,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Generate 1D sinusoidal positional embeddings for given positions and dimensions.\n\nArgs:\n pos_inds (torch.Tensor): Position indices for which to generate embeddings.\n dim (int): Dimension of the positional embeddings. Should be an even number.\n temperature (float, optional): Scaling factor for the frequency of the sinusoidal functions.\n\nReturns:\n (torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).\n\nExamples:\n >>> pos = torch.tensor([0, 1, 2, 3])\n >>> embeddings = get_1d_sine_pe(pos, 128)\n >>> embeddings.shape\n torch.Size([4, 128])",
"parameters": [
"pos_inds: torch.Tensor",
"dim: int",
"temperature: float"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_get_1d_sine_pe_0de510ca"
},
{
"content": "def init_t_xy(end_x: int, end_y: int):\n \"\"\"\n Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.\n\n This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor\n and corresponding x and y coordinate tensors.\n\n Args:\n end_x (int): Width of the grid (number of columns).\n end_y (int): Height of the grid (number of rows).\n\n Returns:\n t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).\n t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).\n\n Examples:\n >>> t_x, t_y = init_t_xy(3, 2)\n >>> print(t_x)\n tensor([0., 1., 2., 0., 1., 2.])\n >>> print(t_y)\n tensor([0., 0., 0., 1., 1., 1.])\n \"\"\"\n t = torch.arange(end_x * end_y, dtype=torch.float32)\n t_x = (t % end_x).float()\n t_y = torch.div(t, end_x, rounding_mode=\"floor\").float()\n return t_x, t_y",
"chunk_type": "function",
"name": "init_t_xy",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 89,
"end_line": 114,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.\n\nThis function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor\nand corresponding x and y coordinate tensors.\n\nArgs:\n end_x (int): Width of the grid (number of columns).\n end_y (int): Height of the grid (number of rows).\n\nReturns:\n t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).\n t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).\n\nExamples:\n >>> t_x, t_y = init_t_xy(3, 2)\n >>> print(t_x)\n tensor([0., 1., 2., 0., 1., 2.])\n >>> print(t_y)\n tensor([0., 0., 0., 1., 1., 1.])",
"parameters": [
"end_x: int",
"end_y: int"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_init_t_xy_c89e4fba"
},
{
"content": "def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):\n \"\"\"\n Compute axial complex exponential positional encodings for 2D spatial positions in a grid.\n\n This function generates complex exponential positional encodings for a 2D grid of spatial positions,\n using separate frequency components for the x and y dimensions.\n\n Args:\n dim (int): Dimension of the positional encoding.\n end_x (int): Width of the 2D grid.\n end_y (int): Height of the 2D grid.\n theta (float, optional): Scaling factor for frequency computation.\n\n Returns:\n (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).\n\n Examples:\n >>> dim, end_x, end_y = 128, 8, 8\n >>> freqs_cis = compute_axial_cis(dim, end_x, end_y)\n >>> freqs_cis.shape\n torch.Size([64, 64])\n \"\"\"\n freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))\n freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))\n\n t_x, t_y = init_t_xy(end_x, end_y)\n freqs_x = torch.outer(t_x, freqs_x)\n freqs_y = torch.outer(t_y, freqs_y)\n freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)\n freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)\n return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)",
"chunk_type": "function",
"name": "compute_axial_cis",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 117,
"end_line": 147,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": "Compute axial complex exponential positional encodings for 2D spatial positions in a grid.\n\nThis function generates complex exponential positional encodings for a 2D grid of spatial positions,\nusing separate frequency components for the x and y dimensions.\n\nArgs:\n dim (int): Dimension of the positional encoding.\n end_x (int): Width of the 2D grid.\n end_y (int): Height of the 2D grid.\n theta (float, optional): Scaling factor for frequency computation.\n\nReturns:\n (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).\n\nExamples:\n >>> dim, end_x, end_y = 128, 8, 8\n >>> freqs_cis = compute_axial_cis(dim, end_x, end_y)\n >>> freqs_cis.shape\n torch.Size([64, 64])",
"parameters": [
"dim: int",
"end_x: int",
"end_y: int",
"theta: float"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_compute_axial_cis_723b9554"
},
{
"content": "def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n \"\"\"\n Reshape frequency tensor for broadcasting with input tensor.\n\n Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.\n This function is typically used in positional encoding operations.\n\n Args:\n freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.\n x (torch.Tensor): Input tensor to broadcast with.\n\n Returns:\n (torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.\n\n Raises:\n AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.\n \"\"\"\n ndim = x.ndim\n assert 0 <= 1 < ndim\n assert freqs_cis.shape == (x.shape[-2], x.shape[-1])\n shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]\n return freqs_cis.view(*shape)",
"chunk_type": "function",
"name": "reshape_for_broadcast",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 150,
"end_line": 171,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": "Reshape frequency tensor for broadcasting with input tensor.\n\nReshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.\nThis function is typically used in positional encoding operations.\n\nArgs:\n freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.\n x (torch.Tensor): Input tensor to broadcast with.\n\nReturns:\n (torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.\n\nRaises:\n AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.",
"parameters": [
"freqs_cis: torch.Tensor",
"x: torch.Tensor"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_reshape_for_broadcast_c523e5cb"
},
{
"content": "def apply_rotary_enc(\n xq: torch.Tensor,\n xk: torch.Tensor,\n freqs_cis: torch.Tensor,\n repeat_freqs_k: bool = False,\n):\n \"\"\"\n Apply rotary positional encoding to query and key tensors.\n\n This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency\n components. RoPE is a technique that injects relative position information into self-attention mechanisms.\n\n Args:\n xq (torch.Tensor): Query tensor to encode with positional information.\n xk (torch.Tensor): Key tensor to encode with positional information.\n freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the\n last two dimensions of xq.\n repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension\n to match key sequence length.\n\n Returns:\n xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.\n xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.\n\n Examples:\n >>> import torch\n >>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]\n >>> xk = torch.randn(2, 8, 16, 64)\n >>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64\n >>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)\n \"\"\"\n xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))\n xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None\n freqs_cis = reshape_for_broadcast(freqs_cis, xq_)\n xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)\n if xk_ is None:\n # No keys to rotate, due to dropout\n return xq_out.type_as(xq).to(xq.device), xk\n # Repeat freqs along seq_len dim to match k seq_len\n if repeat_freqs_k:\n r = xk_.shape[-2] // xq_.shape[-2]\n freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)\n xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)\n return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)",
"chunk_type": "function",
"name": "apply_rotary_enc",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 174,
"end_line": 217,
"start_col": 0,
"end_col": 77,
"parent_name": null,
"docstring": "Apply rotary positional encoding to query and key tensors.\n\nThis function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency\ncomponents. RoPE is a technique that injects relative position information into self-attention mechanisms.\n\nArgs:\n xq (torch.Tensor): Query tensor to encode with positional information.\n xk (torch.Tensor): Key tensor to encode with positional information.\n freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the\n last two dimensions of xq.\n repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension\n to match key sequence length.\n\nReturns:\n xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.\n xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.\n\nExamples:\n >>> import torch\n >>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]\n >>> xk = torch.randn(2, 8, 16, 64)\n >>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64\n >>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)",
"parameters": [
"xq: torch.Tensor",
"xk: torch.Tensor",
"freqs_cis: torch.Tensor",
"repeat_freqs_k: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_apply_rotary_enc_5a2215dd"
},
{
"content": "def window_partition(x: torch.Tensor, window_size: int):\n \"\"\"\n Partition input tensor into non-overlapping windows with padding if needed.\n\n Args:\n x (torch.Tensor): Input tensor with shape (B, H, W, C).\n window_size (int): Size of each window.\n\n Returns:\n windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).\n padded_h_w (Tuple[int, int]): Padded height and width before partition.\n\n Examples:\n >>> x = torch.randn(1, 16, 16, 3)\n >>> windows, (Hp, Wp) = window_partition(x, window_size=4)\n >>> print(windows.shape, Hp, Wp)\n torch.Size([16, 4, 4, 3]) 16 16\n \"\"\"\n B, H, W, C = x.shape\n\n pad_h = (window_size - H % window_size) % window_size\n pad_w = (window_size - W % window_size) % window_size\n if pad_h > 0 or pad_w > 0:\n x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))\n Hp, Wp = H + pad_h, W + pad_w\n\n x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)\n windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n return windows, (Hp, Wp)",
"chunk_type": "function",
"name": "window_partition",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 220,
"end_line": 248,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Partition input tensor into non-overlapping windows with padding if needed.\n\nArgs:\n x (torch.Tensor): Input tensor with shape (B, H, W, C).\n window_size (int): Size of each window.\n\nReturns:\n windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).\n padded_h_w (Tuple[int, int]): Padded height and width before partition.\n\nExamples:\n >>> x = torch.randn(1, 16, 16, 3)\n >>> windows, (Hp, Wp) = window_partition(x, window_size=4)\n >>> print(windows.shape, Hp, Wp)\n torch.Size([16, 4, 4, 3]) 16 16",
"parameters": [
"x: torch.Tensor",
"window_size: int"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_window_partition_c19df743"
},
{
"content": "def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]):\n \"\"\"\n Unpartition windowed sequences into original sequences and remove padding.\n\n This function reverses the windowing process, reconstructing the original input from windowed segments\n and removing any padding that was added during the windowing process.\n\n Args:\n windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,\n window_size, C), where B is the batch size, num_windows is the number of windows, window_size is\n the size of each window, and C is the number of channels.\n window_size (int): Size of each window.\n pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.\n hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.\n\n Returns:\n (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W\n are the original height and width, and C is the number of channels.\n\n Examples:\n >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels\n >>> pad_hw = (16, 16) # Padded height and width\n >>> hw = (15, 14) # Original height and width\n >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)\n >>> print(x.shape)\n torch.Size([1, 15, 14, 64])\n \"\"\"\n Hp, Wp = pad_hw\n H, W = hw\n B = windows.shape[0] // (Hp * Wp // window_size // window_size)\n x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)\n x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)\n\n if Hp > H or Wp > W:\n x = x[:, :H, :W, :].contiguous()\n return x",
"chunk_type": "function",
"name": "window_unpartition",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 251,
"end_line": 286,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Unpartition windowed sequences into original sequences and remove padding.\n\nThis function reverses the windowing process, reconstructing the original input from windowed segments\nand removing any padding that was added during the windowing process.\n\nArgs:\n windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,\n window_size, C), where B is the batch size, num_windows is the number of windows, window_size is\n the size of each window, and C is the number of channels.\n window_size (int): Size of each window.\n pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.\n hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.\n\nReturns:\n (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W\n are the original height and width, and C is the number of channels.\n\nExamples:\n >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels\n >>> pad_hw = (16, 16) # Padded height and width\n >>> hw = (15, 14) # Original height and width\n >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)\n >>> print(x.shape)\n torch.Size([1, 15, 14, 64])",
"parameters": [
"windows: torch.Tensor",
"window_size: int",
"pad_hw: Tuple[int, int]",
"hw: Tuple[int, int]"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_window_unpartition_c242b26b"
},
{
"content": "def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Extract relative positional embeddings based on query and key sizes.\n\n Args:\n q_size (int): Size of the query.\n k_size (int): Size of the key.\n rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative\n distance and C is the embedding dimension.\n\n Returns:\n (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,\n k_size, C).\n\n Examples:\n >>> q_size, k_size = 8, 16\n >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1\n >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)\n >>> print(extracted_pos.shape)\n torch.Size([8, 16, 64])\n \"\"\"\n max_rel_dist = int(2 * max(q_size, k_size) - 1)\n # Interpolate rel pos if needed.\n if rel_pos.shape[0] != max_rel_dist:\n # Interpolate rel pos.\n rel_pos_resized = F.interpolate(\n rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),\n size=max_rel_dist,\n mode=\"linear\",\n )\n rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)\n else:\n rel_pos_resized = rel_pos\n\n # Scale the coords with short length if shapes for q and k are different.\n q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)\n k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)\n relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)\n\n return rel_pos_resized[relative_coords.long()]",
"chunk_type": "function",
"name": "get_rel_pos",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 289,
"end_line": 328,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": "Extract relative positional embeddings based on query and key sizes.\n\nArgs:\n q_size (int): Size of the query.\n k_size (int): Size of the key.\n rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative\n distance and C is the embedding dimension.\n\nReturns:\n (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,\n k_size, C).\n\nExamples:\n >>> q_size, k_size = 8, 16\n >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1\n >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)\n >>> print(extracted_pos.shape)\n torch.Size([8, 16, 64])",
"parameters": [
"q_size: int",
"k_size: int",
"rel_pos: torch.Tensor"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_get_rel_pos_ab6a89b4"
},
{
"content": "def add_decomposed_rel_pos(\n attn: torch.Tensor,\n q: torch.Tensor,\n rel_pos_h: torch.Tensor,\n rel_pos_w: torch.Tensor,\n q_size: Tuple[int, int],\n k_size: Tuple[int, int],\n) -> torch.Tensor:\n \"\"\"\n Add decomposed Relative Positional Embeddings to the attention map.\n\n This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2\n paper. It enhances the attention mechanism by incorporating spatial relationships between query and key\n positions.\n\n Args:\n attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).\n q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).\n rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).\n rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).\n q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).\n k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).\n\n Returns:\n (torch.Tensor): Updated attention map with added relative positional embeddings, shape\n (B, q_h * q_w, k_h * k_w).\n\n Examples:\n >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8\n >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)\n >>> q = torch.rand(B, q_h * q_w, C)\n >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)\n >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)\n >>> q_size, k_size = (q_h, q_w), (k_h, k_w)\n >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)\n >>> print(updated_attn.shape)\n torch.Size([1, 64, 64])\n\n References:\n https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py\n \"\"\"\n q_h, q_w = q_size\n k_h, k_w = k_size\n Rh = get_rel_pos(q_h, k_h, rel_pos_h)\n Rw = get_rel_pos(q_w, k_w, rel_pos_w)\n\n B, _, dim = q.shape\n r_q = q.reshape(B, q_h, q_w, dim)\n rel_h = torch.einsum(\"bhwc,hkc->bhwk\", r_q, Rh)\n rel_w = torch.einsum(\"bhwc,wkc->bhwk\", r_q, Rw)\n\n attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(\n B, q_h * q_w, k_h * k_w\n )\n\n return attn",
"chunk_type": "function",
"name": "add_decomposed_rel_pos",
"file_path": "ultralytics\\ultralytics\\models\\sam\\modules\\utils.py",
"start_line": 331,
"end_line": 386,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Add decomposed Relative Positional Embeddings to the attention map.\n\nThis function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2\npaper. It enhances the attention mechanism by incorporating spatial relationships between query and key\npositions.\n\nArgs:\n attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).\n q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).\n rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).\n rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).\n q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).\n k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).\n\nReturns:\n (torch.Tensor): Updated attention map with added relative positional embeddings, shape\n (B, q_h * q_w, k_h * k_w).\n\nExamples:\n >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8\n >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)\n >>> q = torch.rand(B, q_h * q_w, C)\n >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)\n >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)\n >>> q_size, k_size = (q_h, q_w), (k_h, k_w)\n >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)\n >>> print(updated_attn.shape)\n torch.Size([1, 64, 64])\n\nReferences:\n https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py",
"parameters": [
"attn: torch.Tensor",
"q: torch.Tensor",
"rel_pos_h: torch.Tensor",
"rel_pos_w: torch.Tensor",
"q_size: Tuple[int, int]",
"k_size: Tuple[int, int]"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"typing.Any",
"typing.Dict",
"typing.Tuple",
"torch",
"torch.nn.functional"
],
"chunk_id": "function_add_decomposed_rel_pos_ea656b85"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_0028d373"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_1499b24c"
},
{
"content": "from PIL import Image",
"chunk_type": "import",
"name": "Image",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Image_5a9b3a61"
},
{
"content": "from ultralytics.data.augment import classify_transforms",
"chunk_type": "import",
"name": "classify_transforms",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_classify_transforms_47decfd8"
},
{
"content": "from ultralytics.engine.predictor import BasePredictor",
"chunk_type": "import",
"name": "BasePredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BasePredictor_637f1b05"
},
{
"content": "from ultralytics.engine.results import Results",
"chunk_type": "import",
"name": "Results",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Results_b1491de7"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, ops",
"chunk_type": "import",
"name": "DEFAULT_CFG, ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, ops_4cdc861c"
},
{
"content": "class ClassificationPredictor(BasePredictor):\n \"\"\"\n A class extending the BasePredictor class for prediction based on a classification model.\n\n This predictor handles the specific requirements of classification models, including preprocessing images\n and postprocessing predictions to generate classification results.\n\n Attributes:\n args (dict): Configuration arguments for the predictor.\n\n Methods:\n preprocess: Convert input images to model-compatible format.\n postprocess: Process model predictions into Results objects.\n\n Notes:\n - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.classify import ClassificationPredictor\n >>> args = dict(model=\"yolo11n-cls.pt\", source=ASSETS)\n >>> predictor = ClassificationPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.\n\n This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification\n tasks. It ensures the task is set to 'classify' regardless of input configuration.\n\n Args:\n cfg (dict): Default configuration dictionary containing prediction settings.\n overrides (dict, optional): Configuration overrides that take precedence over cfg.\n _callbacks (list, optional): List of callback functions to be executed during prediction.\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.args.task = \"classify\"\n\n def setup_source(self, source):\n \"\"\"Set up source and inference mode and classify transforms.\"\"\"\n super().setup_source(source)\n updated = (\n self.model.model.transforms.transforms[0].size != max(self.imgsz)\n if hasattr(self.model.model, \"transforms\") and hasattr(self.model.model.transforms.transforms[0], \"size\")\n else False\n )\n self.transforms = (\n classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms\n )\n\n def preprocess(self, img):\n \"\"\"Convert input images to model-compatible tensor format with appropriate normalization.\"\"\"\n if not isinstance(img, torch.Tensor):\n img = torch.stack(\n [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0\n )\n img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)\n return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Process predictions to return Results objects with classification probabilities.\n\n Args:\n preds (torch.Tensor): Raw predictions from the model.\n img (torch.Tensor): Input images after preprocessing.\n orig_imgs (List[np.ndarray] | torch.Tensor): Original images before preprocessing.\n\n Returns:\n (List[Results]): List of Results objects containing classification results for each image.\n \"\"\"\n if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list\n orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)\n\n preds = preds[0] if isinstance(preds, (list, tuple)) else preds\n return [\n Results(orig_img, path=img_path, names=self.model.names, probs=pred)\n for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])\n ]",
"chunk_type": "class",
"name": "ClassificationPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\predict.py",
"start_line": 13,
"end_line": 93,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "A class extending the BasePredictor class for prediction based on a classification model.\n\nThis predictor handles the specific requirements of classification models, including preprocessing images\nand postprocessing predictions to generate classification results.\n\nAttributes:\n args (dict): Configuration arguments for the predictor.\n\nMethods:\n preprocess: Convert input images to model-compatible format.\n postprocess: Process model predictions into Results objects.\n\nNotes:\n - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.classify import ClassificationPredictor\n >>> args = dict(model=\"yolo11n-cls.pt\", source=ASSETS)\n >>> predictor = ClassificationPredictor(overrides=args)\n >>> predictor.predict_cli()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"cv2",
"torch",
"PIL.Image",
"ultralytics.data.augment.classify_transforms",
"ultralytics.engine.predictor.BasePredictor",
"ultralytics.engine.results.Results",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.ops",
"BasePredictor"
],
"chunk_id": "class_ClassificationPredictor_2e4f6b41"
},
{
"content": "from copy import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_6f9c54a3"
},
{
"content": "from typing import Any, Dict, Optional",
"chunk_type": "import",
"name": "Any, Dict, Optional",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, Optional_433400f4"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_d70f3532"
},
{
"content": "from ultralytics.data import ClassificationDataset, build_dataloader",
"chunk_type": "import",
"name": "ClassificationDataset, build_dataloader",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 68,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassificationDataset, build_dataloader_2e41d386"
},
{
"content": "from ultralytics.engine.trainer import BaseTrainer",
"chunk_type": "import",
"name": "BaseTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseTrainer_3e625bf0"
},
{
"content": "from ultralytics.models import yolo",
"chunk_type": "import",
"name": "yolo",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_yolo_08fdf44e"
},
{
"content": "from ultralytics.nn.tasks import ClassificationModel",
"chunk_type": "import",
"name": "ClassificationModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassificationModel_f3fdd470"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK",
"chunk_type": "import",
"name": "DEFAULT_CFG, LOGGER, RANK",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, LOGGER, RANK_b6bdd43b"
},
{
"content": "from ultralytics.utils.plotting import plot_images, plot_results",
"chunk_type": "import",
"name": "plot_images, plot_results",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 64,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_plot_images, plot_results_5e0ad354"
},
{
"content": "from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first",
"chunk_type": "import",
"name": "is_parallel, strip_optimizer, torch_distributed_zero_first",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 100,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_is_parallel, strip_optimizer, torch_distributed_zero_first_058ff9e3"
},
{
"content": "class ClassificationTrainer(BaseTrainer):\n \"\"\"\n A trainer class extending BaseTrainer for training image classification models.\n\n This trainer handles the training process for image classification tasks, supporting both YOLO classification models\n and torchvision models with comprehensive dataset handling and validation.\n\n Attributes:\n model (ClassificationModel): The classification model to be trained.\n data (Dict[str, Any]): Dictionary containing dataset information including class names and number of classes.\n loss_names (List[str]): Names of the loss functions used during training.\n validator (ClassificationValidator): Validator instance for model evaluation.\n\n Methods:\n set_model_attributes: Set the model's class names from the loaded dataset.\n get_model: Return a modified PyTorch model configured for training.\n setup_model: Load, create or download model for classification.\n build_dataset: Create a ClassificationDataset instance.\n get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.\n preprocess_batch: Preprocess a batch of images and classes.\n progress_string: Return a formatted string showing training progress.\n get_validator: Return an instance of ClassificationValidator.\n label_loss_items: Return a loss dict with labelled training loss items.\n plot_metrics: Plot metrics from a CSV file.\n final_eval: Evaluate trained model and save validation results.\n plot_training_samples: Plot training samples with their annotations.\n\n Examples:\n Initialize and train a classification model\n >>> from ultralytics.models.yolo.classify import ClassificationTrainer\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\", epochs=3)\n >>> trainer = ClassificationTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):\n \"\"\"\n Initialize a ClassificationTrainer object.\n\n This constructor sets up a trainer for image classification tasks, configuring the task type and default\n image size if not specified.\n\n Args:\n cfg (Dict[str, Any], optional): Default configuration dictionary containing training parameters.\n overrides (Dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.\n _callbacks (List[Any], optional): List of callback functions to be executed during training.\n\n Examples:\n Create a trainer with custom configuration\n >>> from ultralytics.models.yolo.classify import ClassificationTrainer\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\", epochs=3)\n >>> trainer = ClassificationTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"task\"] = \"classify\"\n if overrides.get(\"imgsz\") is None:\n overrides[\"imgsz\"] = 224\n super().__init__(cfg, overrides, _callbacks)\n\n def set_model_attributes(self):\n \"\"\"Set the YOLO model's class names from the loaded dataset.\"\"\"\n self.model.names = self.data[\"names\"]\n\n def get_model(self, cfg=None, weights=None, verbose: bool = True):\n \"\"\"\n Return a modified PyTorch model configured for training YOLO classification.\n\n Args:\n cfg (Any, optional): Model configuration.\n weights (Any, optional): Pre-trained model weights.\n verbose (bool, optional): Whether to display model information.\n\n Returns:\n (ClassificationModel): Configured PyTorch model for classification.\n \"\"\"\n model = ClassificationModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n\n for m in model.modules():\n if not self.args.pretrained and hasattr(m, \"reset_parameters\"):\n m.reset_parameters()\n if isinstance(m, torch.nn.Dropout) and self.args.dropout:\n m.p = self.args.dropout # set dropout\n for p in model.parameters():\n p.requires_grad = True # for training\n return model\n\n def setup_model(self):\n \"\"\"\n Load, create or download model for classification tasks.\n\n Returns:\n (Any): Model checkpoint if applicable, otherwise None.\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n if str(self.model) in torchvision.models.__dict__:\n self.model = torchvision.models.__dict__[self.model](\n weights=\"IMAGENET1K_V1\" if self.args.pretrained else None\n )\n ckpt = None\n else:\n ckpt = super().setup_model()\n ClassificationModel.reshape_outputs(self.model, self.data[\"nc\"])\n return ckpt\n\n def build_dataset(self, img_path: str, mode: str = \"train\", batch=None):\n \"\"\"\n Create a ClassificationDataset instance given an image path and mode.\n\n Args:\n img_path (str): Path to the dataset images.\n mode (str, optional): Dataset mode ('train', 'val', or 'test').\n batch (Any, optional): Batch information (unused in this implementation).\n\n Returns:\n (ClassificationDataset): Dataset for the specified mode.\n \"\"\"\n return ClassificationDataset(root=img_path, args=self.args, augment=mode == \"train\", prefix=mode)\n\n def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = \"train\"):\n \"\"\"\n Return PyTorch DataLoader with transforms to preprocess images.\n\n Args:\n dataset_path (str): Path to the dataset.\n batch_size (int, optional): Number of images per batch.\n rank (int, optional): Process rank for distributed training.\n mode (str, optional): 'train', 'val', or 'test' mode.\n\n Returns:\n (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.\n \"\"\"\n with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP\n dataset = self.build_dataset(dataset_path, mode)\n\n loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)\n # Attach inference transforms\n if mode != \"train\":\n if is_parallel(self.model):\n self.model.module.transforms = loader.dataset.torch_transforms\n else:\n self.model.transforms = loader.dataset.torch_transforms\n return loader\n\n def preprocess_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n \"\"\"Preprocess a batch of images and classes.\"\"\"\n batch[\"img\"] = batch[\"img\"].to(self.device)\n batch[\"cls\"] = batch[\"cls\"].to(self.device)\n return batch\n\n def progress_string(self) -> str:\n \"\"\"Return a formatted string showing training progress.\"\"\"\n return (\"\\n\" + \"%11s\" * (4 + len(self.loss_names))) % (\n \"Epoch\",\n \"GPU_mem\",\n *self.loss_names,\n \"Instances\",\n \"Size\",\n )\n\n def get_validator(self):\n \"\"\"Return an instance of ClassificationValidator for validation.\"\"\"\n self.loss_names = [\"loss\"]\n return yolo.classify.ClassificationValidator(\n self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def label_loss_items(self, loss_items: Optional[torch.Tensor] = None, prefix: str = \"train\"):\n \"\"\"\n Return a loss dict with labelled training loss items tensor.\n\n Args:\n loss_items (torch.Tensor, optional): Loss tensor items.\n prefix (str, optional): Prefix to prepend to loss names.\n\n Returns:\n keys (List[str]): List of loss keys if loss_items is None.\n loss_dict (Dict[str, float]): Dictionary of loss items if loss_items is provided.\n \"\"\"\n keys = [f\"{prefix}/{x}\" for x in self.loss_names]\n if loss_items is None:\n return keys\n loss_items = [round(float(loss_items), 5)]\n return dict(zip(keys, loss_items))\n\n def plot_metrics(self):\n \"\"\"Plot metrics from a CSV file.\"\"\"\n plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png\n\n def final_eval(self):\n \"\"\"Evaluate trained model and save validation results.\"\"\"\n for f in self.last, self.best:\n if f.exists():\n strip_optimizer(f) # strip optimizers\n if f is self.best:\n LOGGER.info(f\"\\nValidating {f}...\")\n self.validator.args.data = self.args.data\n self.validator.args.plots = self.args.plots\n self.metrics = self.validator(model=f)\n self.metrics.pop(\"fitness\", None)\n self.run_callbacks(\"on_fit_epoch_end\")\n\n def plot_training_samples(self, batch: Dict[str, torch.Tensor], ni: int):\n \"\"\"\n Plot training samples with their annotations.\n\n Args:\n batch (Dict[str, torch.Tensor]): Batch containing images and class labels.\n ni (int): Number of iterations.\n \"\"\"\n batch[\"batch_idx\"] = torch.arange(len(batch[\"img\"])) # add batch index for plotting\n plot_images(\n labels=batch,\n fname=self.save_dir / f\"train_batch{ni}.jpg\",\n on_plot=self.on_plot,\n )",
"chunk_type": "class",
"name": "ClassificationTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\train.py",
"start_line": 17,
"end_line": 236,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "A trainer class extending BaseTrainer for training image classification models.\n\nThis trainer handles the training process for image classification tasks, supporting both YOLO classification models\nand torchvision models with comprehensive dataset handling and validation.\n\nAttributes:\n model (ClassificationModel): The classification model to be trained.\n data (Dict[str, Any]): Dictionary containing dataset information including class names and number of classes.\n loss_names (List[str]): Names of the loss functions used during training.\n validator (ClassificationValidator): Validator instance for model evaluation.\n\nMethods:\n set_model_attributes: Set the model's class names from the loaded dataset.\n get_model: Return a modified PyTorch model configured for training.\n setup_model: Load, create or download model for classification.\n build_dataset: Create a ClassificationDataset instance.\n get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.\n preprocess_batch: Preprocess a batch of images and classes.\n progress_string: Return a formatted string showing training progress.\n get_validator: Return an instance of ClassificationValidator.\n label_loss_items: Return a loss dict with labelled training loss items.\n plot_metrics: Plot metrics from a CSV file.\n final_eval: Evaluate trained model and save validation results.\n plot_training_samples: Plot training samples with their annotations.\n\nExamples:\n Initialize and train a classification model\n >>> from ultralytics.models.yolo.classify import ClassificationTrainer\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\", epochs=3)\n >>> trainer = ClassificationTrainer(overrides=args)\n >>> trainer.train()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"typing.Any",
"typing.Dict",
"typing.Optional",
"torch",
"ultralytics.data.ClassificationDataset",
"ultralytics.data.build_dataloader",
"ultralytics.engine.trainer.BaseTrainer",
"ultralytics.models.yolo",
"ultralytics.nn.tasks.ClassificationModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.plotting.plot_images",
"ultralytics.utils.plotting.plot_results",
"ultralytics.utils.torch_utils.is_parallel",
"ultralytics.utils.torch_utils.strip_optimizer",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"torchvision",
"BaseTrainer"
],
"chunk_id": "class_ClassificationTrainer_dd748521"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_3632a29c"
},
{
"content": "from typing import Any, Dict, List, Tuple, Union",
"chunk_type": "import",
"name": "Any, Dict, List, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Tuple, Union_63990ab4"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_3f12a6f1"
},
{
"content": "from ultralytics.data import ClassificationDataset, build_dataloader",
"chunk_type": "import",
"name": "ClassificationDataset, build_dataloader",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 68,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassificationDataset, build_dataloader_5bab6936"
},
{
"content": "from ultralytics.engine.validator import BaseValidator",
"chunk_type": "import",
"name": "BaseValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseValidator_fbeb8d27"
},
{
"content": "from ultralytics.utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_c5ddaab2"
},
{
"content": "from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix",
"chunk_type": "import",
"name": "ClassifyMetrics, ConfusionMatrix",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 70,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassifyMetrics, ConfusionMatrix_76591643"
},
{
"content": "from ultralytics.utils.plotting import plot_images",
"chunk_type": "import",
"name": "plot_images",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_plot_images_121c0ec7"
},
{
"content": "class ClassificationValidator(BaseValidator):\n \"\"\"\n A class extending the BaseValidator class for validation based on a classification model.\n\n This validator handles the validation process for classification models, including metrics calculation,\n confusion matrix generation, and visualization of results.\n\n Attributes:\n targets (List[torch.Tensor]): Ground truth class labels.\n pred (List[torch.Tensor]): Model predictions.\n metrics (ClassifyMetrics): Object to calculate and store classification metrics.\n names (dict): Mapping of class indices to class names.\n nc (int): Number of classes.\n confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.\n\n Methods:\n get_desc: Return a formatted string summarizing classification metrics.\n init_metrics: Initialize confusion matrix, class names, and tracking containers.\n preprocess: Preprocess input batch by moving data to device.\n update_metrics: Update running metrics with model predictions and batch targets.\n finalize_metrics: Finalize metrics including confusion matrix and processing speed.\n postprocess: Extract the primary prediction from model output.\n get_stats: Calculate and return a dictionary of metrics.\n build_dataset: Create a ClassificationDataset instance for validation.\n get_dataloader: Build and return a data loader for classification validation.\n print_results: Print evaluation metrics for the classification model.\n plot_val_samples: Plot validation image samples with their ground truth labels.\n plot_predictions: Plot images with their predicted class labels.\n\n Examples:\n >>> from ultralytics.models.yolo.classify import ClassificationValidator\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\")\n >>> validator = ClassificationValidator(args=args)\n >>> validator()\n\n Notes:\n Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize ClassificationValidator with dataloader, save directory, and other parameters.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.\n save_dir (str | Path, optional): Directory to save results.\n args (dict, optional): Arguments containing model and validation configuration.\n _callbacks (list, optional): List of callback functions to be called during validation.\n\n Examples:\n >>> from ultralytics.models.yolo.classify import ClassificationValidator\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\")\n >>> validator = ClassificationValidator(args=args)\n >>> validator()\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.targets = None\n self.pred = None\n self.args.task = \"classify\"\n self.metrics = ClassifyMetrics()\n\n def get_desc(self) -> str:\n \"\"\"Return a formatted string summarizing classification metrics.\"\"\"\n return (\"%22s\" + \"%11s\" * 2) % (\"classes\", \"top1_acc\", \"top5_acc\")\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"Initialize confusion matrix, class names, and tracking containers for predictions and targets.\"\"\"\n self.names = model.names\n self.nc = len(model.names)\n self.pred = []\n self.targets = []\n self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Preprocess input batch by moving data to device and converting to appropriate dtype.\"\"\"\n batch[\"img\"] = batch[\"img\"].to(self.device, non_blocking=True)\n batch[\"img\"] = batch[\"img\"].half() if self.args.half else batch[\"img\"].float()\n batch[\"cls\"] = batch[\"cls\"].to(self.device)\n return batch\n\n def update_metrics(self, preds: torch.Tensor, batch: Dict[str, Any]) -> None:\n \"\"\"\n Update running metrics with model predictions and batch targets.\n\n Args:\n preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.\n batch (dict): Batch data containing images and class labels.\n\n Notes:\n This method appends the top-N predictions (sorted by confidence in descending order) to the\n prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.\n \"\"\"\n n5 = min(len(self.names), 5)\n self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())\n self.targets.append(batch[\"cls\"].type(torch.int32).cpu())\n\n def finalize_metrics(self) -> None:\n \"\"\"\n Finalize metrics including confusion matrix and processing speed.\n\n Notes:\n This method processes the accumulated predictions and targets to generate the confusion matrix,\n optionally plots it, and updates the metrics object with speed information.\n\n Examples:\n >>> validator = ClassificationValidator()\n >>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample\n >>> validator.targets = [torch.tensor([0])] # Ground truth class\n >>> validator.finalize_metrics()\n >>> print(validator.metrics.confusion_matrix) # Access the confusion matrix\n \"\"\"\n self.confusion_matrix.process_cls_preds(self.pred, self.targets)\n if self.args.plots:\n for normalize in True, False:\n self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)\n self.metrics.speed = self.speed\n self.metrics.save_dir = self.save_dir\n self.metrics.confusion_matrix = self.confusion_matrix\n\n def postprocess(self, preds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> torch.Tensor:\n \"\"\"Extract the primary prediction from model output if it's in a list or tuple format.\"\"\"\n return preds[0] if isinstance(preds, (list, tuple)) else preds\n\n def get_stats(self) -> Dict[str, float]:\n \"\"\"Calculate and return a dictionary of metrics by processing targets and predictions.\"\"\"\n self.metrics.process(self.targets, self.pred)\n return self.metrics.results_dict\n\n def build_dataset(self, img_path: str) -> ClassificationDataset:\n \"\"\"Create a ClassificationDataset instance for validation.\"\"\"\n return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)\n\n def get_dataloader(self, dataset_path: Union[Path, str], batch_size: int) -> torch.utils.data.DataLoader:\n \"\"\"\n Build and return a data loader for classification validation.\n\n Args:\n dataset_path (str | Path): Path to the dataset directory.\n batch_size (int): Number of samples per batch.\n\n Returns:\n (torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.\n \"\"\"\n dataset = self.build_dataset(dataset_path)\n return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)\n\n def print_results(self) -> None:\n \"\"\"Print evaluation metrics for the classification model.\"\"\"\n pf = \"%22s\" + \"%11.3g\" * len(self.metrics.keys) # print format\n LOGGER.info(pf % (\"all\", self.metrics.top1, self.metrics.top5))\n\n def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:\n \"\"\"\n Plot validation image samples with their ground truth labels.\n\n Args:\n batch (Dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).\n ni (int): Batch index used for naming the output file.\n\n Examples:\n >>> validator = ClassificationValidator()\n >>> batch = {\"img\": torch.rand(16, 3, 224, 224), \"cls\": torch.randint(0, 10, (16,))}\n >>> validator.plot_val_samples(batch, 0)\n \"\"\"\n batch[\"batch_idx\"] = torch.arange(len(batch[\"img\"])) # add batch index for plotting\n plot_images(\n labels=batch,\n fname=self.save_dir / f\"val_batch{ni}_labels.jpg\",\n names=self.names,\n on_plot=self.on_plot,\n )\n\n def plot_predictions(self, batch: Dict[str, Any], preds: torch.Tensor, ni: int) -> None:\n \"\"\"\n Plot images with their predicted class labels and save the visualization.\n\n Args:\n batch (Dict[str, Any]): Batch data containing images and other information.\n preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).\n ni (int): Batch index used for naming the output file.\n\n Examples:\n >>> validator = ClassificationValidator()\n >>> batch = {\"img\": torch.rand(16, 3, 224, 224)}\n >>> preds = torch.rand(16, 10) # 16 images, 10 classes\n >>> validator.plot_predictions(batch, preds, 0)\n \"\"\"\n batched_preds = dict(\n img=batch[\"img\"],\n batch_idx=torch.arange(len(batch[\"img\"])),\n cls=torch.argmax(preds, dim=1),\n )\n plot_images(\n batched_preds,\n fname=self.save_dir / f\"val_batch{ni}_pred.jpg\",\n names=self.names,\n on_plot=self.on_plot,\n ) # pred",
"chunk_type": "class",
"name": "ClassificationValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\val.py",
"start_line": 15,
"end_line": 212,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "A class extending the BaseValidator class for validation based on a classification model.\n\nThis validator handles the validation process for classification models, including metrics calculation,\nconfusion matrix generation, and visualization of results.\n\nAttributes:\n targets (List[torch.Tensor]): Ground truth class labels.\n pred (List[torch.Tensor]): Model predictions.\n metrics (ClassifyMetrics): Object to calculate and store classification metrics.\n names (dict): Mapping of class indices to class names.\n nc (int): Number of classes.\n confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.\n\nMethods:\n get_desc: Return a formatted string summarizing classification metrics.\n init_metrics: Initialize confusion matrix, class names, and tracking containers.\n preprocess: Preprocess input batch by moving data to device.\n update_metrics: Update running metrics with model predictions and batch targets.\n finalize_metrics: Finalize metrics including confusion matrix and processing speed.\n postprocess: Extract the primary prediction from model output.\n get_stats: Calculate and return a dictionary of metrics.\n build_dataset: Create a ClassificationDataset instance for validation.\n get_dataloader: Build and return a data loader for classification validation.\n print_results: Print evaluation metrics for the classification model.\n plot_val_samples: Plot validation image samples with their ground truth labels.\n plot_predictions: Plot images with their predicted class labels.\n\nExamples:\n >>> from ultralytics.models.yolo.classify import ClassificationValidator\n >>> args = dict(model=\"yolo11n-cls.pt\", data=\"imagenet10\")\n >>> validator = ClassificationValidator(args=args)\n >>> validator()\n\nNotes:\n Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"torch",
"ultralytics.data.ClassificationDataset",
"ultralytics.data.build_dataloader",
"ultralytics.engine.validator.BaseValidator",
"ultralytics.utils.LOGGER",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.ConfusionMatrix",
"ultralytics.utils.plotting.plot_images",
"BaseValidator"
],
"chunk_id": "class_ClassificationValidator_e697968f"
},
{
"content": "from ultralytics.models.yolo.classify.predict import ClassificationPredictor",
"chunk_type": "import",
"name": "ClassificationPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 76,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassificationPredictor_fc0b1be3"
},
{
"content": "from ultralytics.models.yolo.classify.train import ClassificationTrainer",
"chunk_type": "import",
"name": "ClassificationTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassificationTrainer_b0059f59"
},
{
"content": "from ultralytics.models.yolo.classify.val import ClassificationValidator",
"chunk_type": "import",
"name": "ClassificationValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassificationValidator_463f349f"
},
{
"content": "__all__ = \"ClassificationPredictor\", \"ClassificationTrainer\", \"ClassificationValidator\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\classify\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 87,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___ab6e9af8"
},
{
"content": "from ultralytics.engine.predictor import BasePredictor",
"chunk_type": "import",
"name": "BasePredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BasePredictor_da3146f8"
},
{
"content": "from ultralytics.engine.results import Results",
"chunk_type": "import",
"name": "Results",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\predict.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Results_360d0d86"
},
{
"content": "from ultralytics.utils import ops",
"chunk_type": "import",
"name": "ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\predict.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ops_eb03b6c3"
},
{
"content": "class DetectionPredictor(BasePredictor):\n \"\"\"\n A class extending the BasePredictor class for prediction based on a detection model.\n\n This predictor specializes in object detection tasks, processing model outputs into meaningful detection results\n with bounding boxes and class predictions.\n\n Attributes:\n args (namespace): Configuration arguments for the predictor.\n model (nn.Module): The detection model used for inference.\n batch (list): Batch of images and metadata for processing.\n\n Methods:\n postprocess: Process raw model predictions into detection results.\n construct_results: Build Results objects from processed predictions.\n construct_result: Create a single Result object from a prediction.\n get_obj_feats: Extract object features from the feature maps.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.detect import DetectionPredictor\n >>> args = dict(model=\"yolo11n.pt\", source=ASSETS)\n >>> predictor = DetectionPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def postprocess(self, preds, img, orig_imgs, **kwargs):\n \"\"\"\n Post-process predictions and return a list of Results objects.\n\n This method applies non-maximum suppression to raw model predictions and prepares them for visualization and\n further analysis.\n\n Args:\n preds (torch.Tensor): Raw predictions from the model.\n img (torch.Tensor): Processed input image tensor in model input format.\n orig_imgs (torch.Tensor | list): Original input images before preprocessing.\n **kwargs (Any): Additional keyword arguments.\n\n Returns:\n (list): List of Results objects containing the post-processed predictions.\n\n Examples:\n >>> predictor = DetectionPredictor(overrides=dict(model=\"yolo11n.pt\"))\n >>> results = predictor.predict(\"path/to/image.jpg\")\n >>> processed_results = predictor.postprocess(preds, img, orig_imgs)\n \"\"\"\n save_feats = getattr(self, \"_feats\", None) is not None\n preds = ops.non_max_suppression(\n preds,\n self.args.conf,\n self.args.iou,\n self.args.classes,\n self.args.agnostic_nms,\n max_det=self.args.max_det,\n nc=0 if self.args.task == \"detect\" else len(self.model.names),\n end2end=getattr(self.model, \"end2end\", False),\n rotated=self.args.task == \"obb\",\n return_idxs=save_feats,\n )\n\n if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list\n orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)\n\n if save_feats:\n obj_feats = self.get_obj_feats(self._feats, preds[1])\n preds = preds[0]\n\n results = self.construct_results(preds, img, orig_imgs, **kwargs)\n\n if save_feats:\n for r, f in zip(results, obj_feats):\n r.feats = f # add object features to results\n\n return results\n\n def get_obj_feats(self, feat_maps, idxs):\n \"\"\"Extract object features from the feature maps.\"\"\"\n import torch\n\n s = min([x.shape[1] for x in feat_maps]) # find smallest vector length\n obj_feats = torch.cat(\n [x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1\n ) # mean reduce all vectors to same length\n return [feats[idx] if len(idx) else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch\n\n def construct_results(self, preds, img, orig_imgs):\n \"\"\"\n Construct a list of Results objects from model predictions.\n\n Args:\n preds (List[torch.Tensor]): List of predicted bounding boxes and scores for each image.\n img (torch.Tensor): Batch of preprocessed images used for inference.\n orig_imgs (List[np.ndarray]): List of original images before preprocessing.\n\n Returns:\n (List[Results]): List of Results objects containing detection information for each image.\n \"\"\"\n return [\n self.construct_result(pred, img, orig_img, img_path)\n for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])\n ]\n\n def construct_result(self, pred, img, orig_img, img_path):\n \"\"\"\n Construct a single Results object from one image prediction.\n\n Args:\n pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.\n img (torch.Tensor): Preprocessed image tensor used for inference.\n orig_img (np.ndarray): Original image before preprocessing.\n img_path (str): Path to the original image file.\n\n Returns:\n (Results): Results object containing the original image, image path, class names, and scaled bounding boxes.\n \"\"\"\n pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)\n return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])",
"chunk_type": "class",
"name": "DetectionPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\predict.py",
"start_line": 8,
"end_line": 125,
"start_col": 0,
"end_col": 90,
"parent_name": null,
"docstring": "A class extending the BasePredictor class for prediction based on a detection model.\n\nThis predictor specializes in object detection tasks, processing model outputs into meaningful detection results\nwith bounding boxes and class predictions.\n\nAttributes:\n args (namespace): Configuration arguments for the predictor.\n model (nn.Module): The detection model used for inference.\n batch (list): Batch of images and metadata for processing.\n\nMethods:\n postprocess: Process raw model predictions into detection results.\n construct_results: Build Results objects from processed predictions.\n construct_result: Create a single Result object from a prediction.\n get_obj_feats: Extract object features from the feature maps.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.detect import DetectionPredictor\n >>> args = dict(model=\"yolo11n.pt\", source=ASSETS)\n >>> predictor = DetectionPredictor(overrides=args)\n >>> predictor.predict_cli()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"ultralytics.engine.predictor.BasePredictor",
"ultralytics.engine.results.Results",
"ultralytics.utils.ops",
"torch",
"BasePredictor"
],
"chunk_id": "class_DetectionPredictor_3c9d40be"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_f0299af0"
},
{
"content": "import random",
"chunk_type": "import",
"name": "random",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_random_91b66083"
},
{
"content": "from copy import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_a93f4f48"
},
{
"content": "from typing import Any, Dict, List, Optional",
"chunk_type": "import",
"name": "Any, Dict, List, Optional",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional_cea6eb92"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_d326ad88"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_0a5f5c0d"
},
{
"content": "from ultralytics.data import build_dataloader, build_yolo_dataset",
"chunk_type": "import",
"name": "build_dataloader, build_yolo_dataset",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_build_dataloader, build_yolo_dataset_4e1469bb"
},
{
"content": "from ultralytics.engine.trainer import BaseTrainer",
"chunk_type": "import",
"name": "BaseTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseTrainer_4b4cf5ee"
},
{
"content": "from ultralytics.models import yolo",
"chunk_type": "import",
"name": "yolo",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_yolo_7f0d97b6"
},
{
"content": "from ultralytics.nn.tasks import DetectionModel",
"chunk_type": "import",
"name": "DetectionModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionModel_f9862d77"
},
{
"content": "from ultralytics.utils import LOGGER, RANK",
"chunk_type": "import",
"name": "LOGGER, RANK",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, RANK_5098cbb2"
},
{
"content": "from ultralytics.utils.patches import override_configs",
"chunk_type": "import",
"name": "override_configs",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_override_configs_db25ed44"
},
{
"content": "from ultralytics.utils.plotting import plot_images, plot_labels, plot_results",
"chunk_type": "import",
"name": "plot_images, plot_labels, plot_results",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 77,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_plot_images, plot_labels, plot_results_7ca697a2"
},
{
"content": "from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first",
"chunk_type": "import",
"name": "de_parallel, torch_distributed_zero_first",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 83,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_de_parallel, torch_distributed_zero_first_585c0ffc"
},
{
"content": "class DetectionTrainer(BaseTrainer):\n \"\"\"\n A class extending the BaseTrainer class for training based on a detection model.\n\n This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models\n for object detection including dataset building, data loading, preprocessing, and model configuration.\n\n Attributes:\n model (DetectionModel): The YOLO detection model being trained.\n data (Dict): Dictionary containing dataset information including class names and number of classes.\n loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).\n\n Methods:\n build_dataset: Build YOLO dataset for training or validation.\n get_dataloader: Construct and return dataloader for the specified mode.\n preprocess_batch: Preprocess a batch of images by scaling and converting to float.\n set_model_attributes: Set model attributes based on dataset information.\n get_model: Return a YOLO detection model.\n get_validator: Return a validator for model evaluation.\n label_loss_items: Return a loss dictionary with labeled training loss items.\n progress_string: Return a formatted string of training progress.\n plot_training_samples: Plot training samples with their annotations.\n plot_metrics: Plot metrics from a CSV file.\n plot_training_labels: Create a labeled training plot of the YOLO model.\n auto_batch: Calculate optimal batch size based on model memory requirements.\n\n Examples:\n >>> from ultralytics.models.yolo.detect import DetectionTrainer\n >>> args = dict(model=\"yolo11n.pt\", data=\"coco8.yaml\", epochs=3)\n >>> trainer = DetectionTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def build_dataset(self, img_path: str, mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset for training or validation.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.\n batch (int, optional): Size of batches, this is for 'rect' mode.\n\n Returns:\n (Dataset): YOLO dataset object configured for the specified mode.\n \"\"\"\n gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)\n return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == \"val\", stride=gs)\n\n def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = \"train\"):\n \"\"\"\n Construct and return dataloader for the specified mode.\n\n Args:\n dataset_path (str): Path to the dataset.\n batch_size (int): Number of images per batch.\n rank (int): Process rank for distributed training.\n mode (str): 'train' for training dataloader, 'val' for validation dataloader.\n\n Returns:\n (DataLoader): PyTorch dataloader object.\n \"\"\"\n assert mode in {\"train\", \"val\"}, f\"Mode must be 'train' or 'val', not {mode}.\"\n with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP\n dataset = self.build_dataset(dataset_path, mode, batch_size)\n shuffle = mode == \"train\"\n if getattr(dataset, \"rect\", False) and shuffle:\n LOGGER.warning(\"'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False\")\n shuffle = False\n workers = self.args.workers if mode == \"train\" else self.args.workers * 2\n return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader\n\n def preprocess_batch(self, batch: Dict) -> Dict:\n \"\"\"\n Preprocess a batch of images by scaling and converting to float.\n\n Args:\n batch (Dict): Dictionary containing batch data with 'img' tensor.\n\n Returns:\n (Dict): Preprocessed batch with normalized images.\n \"\"\"\n batch[\"img\"] = batch[\"img\"].to(self.device, non_blocking=True).float() / 255\n if self.args.multi_scale:\n imgs = batch[\"img\"]\n sz = (\n random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))\n // self.stride\n * self.stride\n ) # size\n sf = sz / max(imgs.shape[2:]) # scale factor\n if sf != 1:\n ns = [\n math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]\n ] # new shape (stretched to gs-multiple)\n imgs = nn.functional.interpolate(imgs, size=ns, mode=\"bilinear\", align_corners=False)\n batch[\"img\"] = imgs\n return batch\n\n def set_model_attributes(self):\n \"\"\"Set model attributes based on dataset information.\"\"\"\n # Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)\n # self.args.box *= 3 / nl # scale to layers\n # self.args.cls *= self.data[\"nc\"] / 80 * 3 / nl # scale to classes and layers\n # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers\n self.model.nc = self.data[\"nc\"] # attach number of classes to model\n self.model.names = self.data[\"names\"] # attach class names to model\n self.model.args = self.args # attach hyperparameters to model\n # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc\n\n def get_model(self, cfg: Optional[str] = None, weights: Optional[str] = None, verbose: bool = True):\n \"\"\"\n Return a YOLO detection model.\n\n Args:\n cfg (str, optional): Path to model configuration file.\n weights (str, optional): Path to model weights.\n verbose (bool): Whether to display model information.\n\n Returns:\n (DetectionModel): YOLO detection model.\n \"\"\"\n model = DetectionModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n return model\n\n def get_validator(self):\n \"\"\"Return a DetectionValidator for YOLO model validation.\"\"\"\n self.loss_names = \"box_loss\", \"cls_loss\", \"dfl_loss\"\n return yolo.detect.DetectionValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def label_loss_items(self, loss_items: Optional[List[float]] = None, prefix: str = \"train\"):\n \"\"\"\n Return a loss dict with labeled training loss items tensor.\n\n Args:\n loss_items (List[float], optional): List of loss values.\n prefix (str): Prefix for keys in the returned dictionary.\n\n Returns:\n (Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.\n \"\"\"\n keys = [f\"{prefix}/{x}\" for x in self.loss_names]\n if loss_items is not None:\n loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats\n return dict(zip(keys, loss_items))\n else:\n return keys\n\n def progress_string(self):\n \"\"\"Return a formatted string of training progress with epoch, GPU memory, loss, instances and size.\"\"\"\n return (\"\\n\" + \"%11s\" * (4 + len(self.loss_names))) % (\n \"Epoch\",\n \"GPU_mem\",\n *self.loss_names,\n \"Instances\",\n \"Size\",\n )\n\n def plot_training_samples(self, batch: Dict[str, Any], ni: int) -> None:\n \"\"\"\n Plot training samples with their annotations.\n\n Args:\n batch (Dict[str, Any]): Dictionary containing batch data.\n ni (int): Number of iterations.\n \"\"\"\n plot_images(\n labels=batch,\n paths=batch[\"im_file\"],\n fname=self.save_dir / f\"train_batch{ni}.jpg\",\n on_plot=self.on_plot,\n )\n\n def plot_metrics(self):\n \"\"\"Plot metrics from a CSV file.\"\"\"\n plot_results(file=self.csv, on_plot=self.on_plot) # save results.png\n\n def plot_training_labels(self):\n \"\"\"Create a labeled training plot of the YOLO model.\"\"\"\n boxes = np.concatenate([lb[\"bboxes\"] for lb in self.train_loader.dataset.labels], 0)\n cls = np.concatenate([lb[\"cls\"] for lb in self.train_loader.dataset.labels], 0)\n plot_labels(boxes, cls.squeeze(), names=self.data[\"names\"], save_dir=self.save_dir, on_plot=self.on_plot)\n\n def auto_batch(self):\n \"\"\"\n Get optimal batch size by calculating memory occupation of model.\n\n Returns:\n (int): Optimal batch size.\n \"\"\"\n with override_configs(self.args, overrides={\"cache\": False}) as self.args:\n train_dataset = self.build_dataset(self.data[\"train\"], mode=\"train\", batch=16)\n max_num_obj = max(len(label[\"cls\"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation\n del train_dataset # free memory\n return super().auto_batch(max_num_obj)",
"chunk_type": "class",
"name": "DetectionTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\train.py",
"start_line": 21,
"end_line": 218,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "A class extending the BaseTrainer class for training based on a detection model.\n\nThis trainer specializes in object detection tasks, handling the specific requirements for training YOLO models\nfor object detection including dataset building, data loading, preprocessing, and model configuration.\n\nAttributes:\n model (DetectionModel): The YOLO detection model being trained.\n data (Dict): Dictionary containing dataset information including class names and number of classes.\n loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).\n\nMethods:\n build_dataset: Build YOLO dataset for training or validation.\n get_dataloader: Construct and return dataloader for the specified mode.\n preprocess_batch: Preprocess a batch of images by scaling and converting to float.\n set_model_attributes: Set model attributes based on dataset information.\n get_model: Return a YOLO detection model.\n get_validator: Return a validator for model evaluation.\n label_loss_items: Return a loss dictionary with labeled training loss items.\n progress_string: Return a formatted string of training progress.\n plot_training_samples: Plot training samples with their annotations.\n plot_metrics: Plot metrics from a CSV file.\n plot_training_labels: Create a labeled training plot of the YOLO model.\n auto_batch: Calculate optimal batch size based on model memory requirements.\n\nExamples:\n >>> from ultralytics.models.yolo.detect import DetectionTrainer\n >>> args = dict(model=\"yolo11n.pt\", data=\"coco8.yaml\", epochs=3)\n >>> trainer = DetectionTrainer(overrides=args)\n >>> trainer.train()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"random",
"copy.copy",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"numpy",
"torch.nn",
"ultralytics.data.build_dataloader",
"ultralytics.data.build_yolo_dataset",
"ultralytics.engine.trainer.BaseTrainer",
"ultralytics.models.yolo",
"ultralytics.nn.tasks.DetectionModel",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.patches.override_configs",
"ultralytics.utils.plotting.plot_images",
"ultralytics.utils.plotting.plot_labels",
"ultralytics.utils.plotting.plot_results",
"ultralytics.utils.torch_utils.de_parallel",
"ultralytics.utils.torch_utils.torch_distributed_zero_first",
"BaseTrainer"
],
"chunk_id": "class_DetectionTrainer_63774454"
},
{
"content": "import os",
"chunk_type": "import",
"name": "os",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_os_37352805"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_41a2c22a"
},
{
"content": "from typing import Any, Dict, List, Optional, Tuple, Union",
"chunk_type": "import",
"name": "Any, Dict, List, Optional, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 58,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional, Tuple, Union_f7770b17"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_4c28ded0"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_8baa87eb"
},
{
"content": "from ultralytics.data import build_dataloader, build_yolo_dataset, converter",
"chunk_type": "import",
"name": "build_dataloader, build_yolo_dataset, converter",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 76,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_build_dataloader, build_yolo_dataset, converter_105d54f6"
},
{
"content": "from ultralytics.engine.validator import BaseValidator",
"chunk_type": "import",
"name": "BaseValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_BaseValidator_10db63cf"
},
{
"content": "from ultralytics.utils import LOGGER, ops",
"chunk_type": "import",
"name": "LOGGER, ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, ops_7a47d0e3"
},
{
"content": "from ultralytics.utils.checks import check_requirements",
"chunk_type": "import",
"name": "check_requirements",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_requirements_ba0bad5b"
},
{
"content": "from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou",
"chunk_type": "import",
"name": "ConfusionMatrix, DetMetrics, box_iou",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 74,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ConfusionMatrix, DetMetrics, box_iou_e52987ac"
},
{
"content": "from ultralytics.utils.plotting import plot_images",
"chunk_type": "import",
"name": "plot_images",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_plot_images_e7d8b65c"
},
{
"content": "class DetectionValidator(BaseValidator):\n \"\"\"\n A class extending the BaseValidator class for validation based on a detection model.\n\n This class implements validation functionality specific to object detection tasks, including metrics calculation,\n prediction processing, and visualization of results.\n\n Attributes:\n is_coco (bool): Whether the dataset is COCO.\n is_lvis (bool): Whether the dataset is LVIS.\n class_map (List[int]): Mapping from model class indices to dataset class indices.\n metrics (DetMetrics): Object detection metrics calculator.\n iouv (torch.Tensor): IoU thresholds for mAP calculation.\n niou (int): Number of IoU thresholds.\n lb (List[Any]): List for storing ground truth labels for hybrid saving.\n jdict (List[Dict[str, Any]]): List for storing JSON detection results.\n stats (Dict[str, List[torch.Tensor]]): Dictionary for storing statistics during validation.\n\n Examples:\n >>> from ultralytics.models.yolo.detect import DetectionValidator\n >>> args = dict(model=\"yolo11n.pt\", data=\"coco8.yaml\")\n >>> validator = DetectionValidator(args=args)\n >>> validator()\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize detection validator with necessary variables and settings.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.\n save_dir (Path, optional): Directory to save results.\n args (Dict[str, Any], optional): Arguments for the validator.\n _callbacks (List[Any], optional): List of callback functions.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.is_coco = False\n self.is_lvis = False\n self.class_map = None\n self.args.task = \"detect\"\n self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95\n self.niou = self.iouv.numel()\n self.metrics = DetMetrics()\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Preprocess batch of images for YOLO validation.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Preprocessed batch.\n \"\"\"\n batch[\"img\"] = batch[\"img\"].to(self.device, non_blocking=True)\n batch[\"img\"] = (batch[\"img\"].half() if self.args.half else batch[\"img\"].float()) / 255\n for k in {\"batch_idx\", \"cls\", \"bboxes\"}:\n batch[k] = batch[k].to(self.device)\n\n return batch\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"\n Initialize evaluation metrics for YOLO detection validation.\n\n Args:\n model (torch.nn.Module): Model to validate.\n \"\"\"\n val = self.data.get(self.args.split, \"\") # validation path\n self.is_coco = (\n isinstance(val, str)\n and \"coco\" in val\n and (val.endswith(f\"{os.sep}val2017.txt\") or val.endswith(f\"{os.sep}test-dev2017.txt\"))\n ) # is COCO\n self.is_lvis = isinstance(val, str) and \"lvis\" in val and not self.is_coco # is LVIS\n self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))\n self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val\n self.names = model.names\n self.nc = len(model.names)\n self.end2end = getattr(model, \"end2end\", False)\n self.seen = 0\n self.jdict = []\n self.metrics.names = self.names\n self.confusion_matrix = ConfusionMatrix(names=list(model.names.values()))\n\n def get_desc(self) -> str:\n \"\"\"Return a formatted string summarizing class metrics of YOLO model.\"\"\"\n return (\"%22s\" + \"%11s\" * 6) % (\"Class\", \"Images\", \"Instances\", \"Box(P\", \"R\", \"mAP50\", \"mAP50-95)\")\n\n def postprocess(self, preds: torch.Tensor) -> List[Dict[str, torch.Tensor]]:\n \"\"\"\n Apply Non-maximum suppression to prediction outputs.\n\n Args:\n preds (torch.Tensor): Raw predictions from the model.\n\n Returns:\n (List[Dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains\n 'bboxes', 'conf', 'cls', and 'extra' tensors.\n \"\"\"\n outputs = ops.non_max_suppression(\n preds,\n self.args.conf,\n self.args.iou,\n nc=0 if self.args.task == \"detect\" else self.nc,\n multi_label=True,\n agnostic=self.args.single_cls or self.args.agnostic_nms,\n max_det=self.args.max_det,\n end2end=self.end2end,\n rotated=self.args.task == \"obb\",\n )\n return [{\"bboxes\": x[:, :4], \"conf\": x[:, 4], \"cls\": x[:, 5], \"extra\": x[:, 6:]} for x in outputs]\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare a batch of images and annotations for validation.\n\n Args:\n si (int): Batch index.\n batch (Dict[str, Any]): Batch data containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Prepared batch with processed annotations.\n \"\"\"\n idx = batch[\"batch_idx\"] == si\n cls = batch[\"cls\"][idx].squeeze(-1)\n bbox = batch[\"bboxes\"][idx]\n ori_shape = batch[\"ori_shape\"][si]\n imgsz = batch[\"img\"].shape[2:]\n ratio_pad = batch[\"ratio_pad\"][si]\n if len(cls):\n bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes\n ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels\n return {\"cls\": cls, \"bboxes\": bbox, \"ori_shape\": ori_shape, \"imgsz\": imgsz, \"ratio_pad\": ratio_pad}\n\n def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n \"\"\"\n Prepare predictions for evaluation against ground truth.\n\n Args:\n pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.\n pbatch (Dict[str, Any]): Prepared batch information.\n\n Returns:\n (Dict[str, torch.Tensor]): Prepared predictions in native space.\n \"\"\"\n cls = pred[\"cls\"]\n if self.args.single_cls:\n cls *= 0\n # predn = pred.clone()\n bboxes = ops.scale_boxes(\n pbatch[\"imgsz\"], pred[\"bboxes\"].clone(), pbatch[\"ori_shape\"], ratio_pad=pbatch[\"ratio_pad\"]\n ) # native-space pred\n return {\"bboxes\": bboxes, \"conf\": pred[\"conf\"], \"cls\": cls}\n\n def update_metrics(self, preds: List[Dict[str, torch.Tensor]], batch: Dict[str, Any]) -> None:\n \"\"\"\n Update metrics with new predictions and ground truth.\n\n Args:\n preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.\n batch (Dict[str, Any]): Batch data containing ground truth.\n \"\"\"\n for si, pred in enumerate(preds):\n self.seen += 1\n pbatch = self._prepare_batch(si, batch)\n predn = self._prepare_pred(pred, pbatch)\n\n cls = pbatch[\"cls\"].cpu().numpy()\n no_pred = len(predn[\"cls\"]) == 0\n self.metrics.update_stats(\n {\n **self._process_batch(predn, pbatch),\n \"target_cls\": cls,\n \"target_img\": np.unique(cls),\n \"conf\": np.zeros(0) if no_pred else predn[\"conf\"].cpu().numpy(),\n \"pred_cls\": np.zeros(0) if no_pred else predn[\"cls\"].cpu().numpy(),\n }\n )\n # Evaluate\n if self.args.plots:\n self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)\n\n if no_pred:\n continue\n\n # Save\n if self.args.save_json:\n self.pred_to_json(predn, batch[\"im_file\"][si])\n if self.args.save_txt:\n self.save_one_txt(\n predn,\n self.args.save_conf,\n pbatch[\"ori_shape\"],\n self.save_dir / \"labels\" / f\"{Path(batch['im_file'][si]).stem}.txt\",\n )\n\n def finalize_metrics(self) -> None:\n \"\"\"Set final values for metrics speed and confusion matrix.\"\"\"\n if self.args.plots:\n for normalize in True, False:\n self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)\n self.metrics.speed = self.speed\n self.metrics.confusion_matrix = self.confusion_matrix\n self.metrics.save_dir = self.save_dir\n\n def get_stats(self) -> Dict[str, Any]:\n \"\"\"\n Calculate and return metrics statistics.\n\n Returns:\n (Dict[str, Any]): Dictionary containing metrics results.\n \"\"\"\n self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)\n self.metrics.clear_stats()\n return self.metrics.results_dict\n\n def print_results(self) -> None:\n \"\"\"Print training/validation set metrics per class.\"\"\"\n pf = \"%22s\" + \"%11i\" * 2 + \"%11.3g\" * len(self.metrics.keys) # print format\n LOGGER.info(pf % (\"all\", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))\n if self.metrics.nt_per_class.sum() == 0:\n LOGGER.warning(f\"no labels found in {self.args.task} set, can not compute metrics without labels\")\n\n # Print results per class\n if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):\n for i, c in enumerate(self.metrics.ap_class_index):\n LOGGER.info(\n pf\n % (\n self.names[c],\n self.metrics.nt_per_image[c],\n self.metrics.nt_per_class[c],\n *self.metrics.class_result(i),\n )\n )\n\n def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:\n \"\"\"\n Return correct prediction matrix.\n\n Args:\n preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.\n batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.\n \"\"\"\n if len(batch[\"cls\"]) == 0 or len(preds[\"cls\"]) == 0:\n return {\"tp\": np.zeros((len(preds[\"cls\"]), self.niou), dtype=bool)}\n iou = box_iou(batch[\"bboxes\"], preds[\"bboxes\"])\n return {\"tp\": self.match_predictions(preds[\"cls\"], batch[\"cls\"], iou).cpu().numpy()}\n\n def build_dataset(self, img_path: str, mode: str = \"val\", batch: Optional[int] = None) -> torch.utils.data.Dataset:\n \"\"\"\n Build YOLO Dataset.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.\n batch (int, optional): Size of batches, this is for `rect`.\n\n Returns:\n (Dataset): YOLO dataset.\n \"\"\"\n return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)\n\n def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:\n \"\"\"\n Construct and return dataloader.\n\n Args:\n dataset_path (str): Path to the dataset.\n batch_size (int): Size of each batch.\n\n Returns:\n (torch.utils.data.DataLoader): Dataloader for validation.\n \"\"\"\n dataset = self.build_dataset(dataset_path, batch=batch_size, mode=\"val\")\n return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader\n\n def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:\n \"\"\"\n Plot validation image samples.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n ni (int): Batch index.\n \"\"\"\n plot_images(\n labels=batch,\n paths=batch[\"im_file\"],\n fname=self.save_dir / f\"val_batch{ni}_labels.jpg\",\n names=self.names,\n on_plot=self.on_plot,\n )\n\n def plot_predictions(\n self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int, max_det: Optional[int] = None\n ) -> None:\n \"\"\"\n Plot predicted bounding boxes on input images and save the result.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.\n ni (int): Batch index.\n max_det (Optional[int]): Maximum number of detections to plot.\n \"\"\"\n # TODO: optimize this\n for i, pred in enumerate(preds):\n pred[\"batch_idx\"] = torch.ones_like(pred[\"conf\"]) * i # add batch index to predictions\n keys = preds[0].keys()\n max_det = max_det or self.args.max_det\n batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}\n # TODO: fix this\n batched_preds[\"bboxes\"][:, :4] = ops.xyxy2xywh(batched_preds[\"bboxes\"][:, :4]) # convert to xywh format\n plot_images(\n images=batch[\"img\"],\n labels=batched_preds,\n paths=batch[\"im_file\"],\n fname=self.save_dir / f\"val_batch{ni}_pred.jpg\",\n names=self.names,\n on_plot=self.on_plot,\n ) # pred\n\n def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:\n \"\"\"\n Save YOLO detections to a txt file in normalized coordinates in a specific format.\n\n Args:\n predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.\n save_conf (bool): Whether to save confidence scores.\n shape (Tuple[int, int]): Shape of the original image (height, width).\n file (Path): File path to save the detections.\n \"\"\"\n from ultralytics.engine.results import Results\n\n Results(\n np.zeros((shape[0], shape[1]), dtype=np.uint8),\n path=None,\n names=self.names,\n boxes=torch.cat([predn[\"bboxes\"], predn[\"conf\"].unsqueeze(-1), predn[\"cls\"].unsqueeze(-1)], dim=1),\n ).save_txt(file, save_conf=save_conf)\n\n def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:\n \"\"\"\n Serialize YOLO predictions to COCO json format.\n\n Args:\n predn (Dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys\n with bounding box coordinates, confidence scores, and class predictions.\n filename (str): Image filename.\n \"\"\"\n stem = Path(filename).stem\n image_id = int(stem) if stem.isnumeric() else stem\n box = ops.xyxy2xywh(predn[\"bboxes\"]) # xywh\n box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner\n for b, s, c in zip(box.tolist(), predn[\"conf\"].tolist(), predn[\"cls\"].tolist()):\n self.jdict.append(\n {\n \"image_id\": image_id,\n \"category_id\": self.class_map[int(c)],\n \"bbox\": [round(x, 3) for x in b],\n \"score\": round(s, 5),\n }\n )\n\n def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Evaluate YOLO output in JSON format and return performance statistics.\n\n Args:\n stats (Dict[str, Any]): Current statistics dictionary.\n\n Returns:\n (Dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.\n \"\"\"\n pred_json = self.save_dir / \"predictions.json\" # predictions\n anno_json = (\n self.data[\"path\"]\n / \"annotations\"\n / (\"instances_val2017.json\" if self.is_coco else f\"lvis_v1_{self.args.split}.json\")\n ) # annotations\n return self.coco_evaluate(stats, pred_json, anno_json)\n\n def coco_evaluate(\n self,\n stats: Dict[str, Any],\n pred_json: str,\n anno_json: str,\n iou_types: Union[str, List[str]] = \"bbox\",\n suffix: Union[str, List[str]] = \"Box\",\n ) -> Dict[str, Any]:\n \"\"\"\n Evaluate COCO/LVIS metrics using faster-coco-eval library.\n\n Performs evaluation using the faster-coco-eval library to compute mAP metrics\n for object detection. Updates the provided stats dictionary with computed metrics\n including mAP50, mAP50-95, and LVIS-specific metrics if applicable.\n\n Args:\n stats (Dict[str, Any]): Dictionary to store computed metrics and statistics.\n pred_json (str | Path]): Path to JSON file containing predictions in COCO format.\n anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.\n iou_types (str | List[str]]): IoU type(s) for evaluation. Can be single string or list of strings.\n Common values include \"bbox\", \"segm\", \"keypoints\". Defaults to \"bbox\".\n suffix (str | List[str]]): Suffix to append to metric names in stats dictionary. Should correspond\n to iou_types if multiple types provided. Defaults to \"Box\".\n\n Returns:\n (Dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.\n \"\"\"\n if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):\n LOGGER.info(f\"\\nEvaluating faster-coco-eval mAP using {pred_json} and {anno_json}...\")\n try:\n for x in pred_json, anno_json:\n assert x.is_file(), f\"{x} file not found\"\n iou_types = [iou_types] if isinstance(iou_types, str) else iou_types\n suffix = [suffix] if isinstance(suffix, str) else suffix\n check_requirements(\"faster-coco-eval>=1.6.7\")\n from faster_coco_eval import COCO, COCOeval_faster\n\n anno = COCO(anno_json)\n pred = anno.loadRes(pred_json)\n for i, iou_type in enumerate(iou_types):\n val = COCOeval_faster(\n anno, pred, iouType=iou_type, lvis_style=self.is_lvis, print_function=LOGGER.info\n )\n val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval\n val.evaluate()\n val.accumulate()\n val.summarize()\n\n # update mAP50-95 and mAP50\n stats[f\"metrics/mAP50({suffix[i][0]})\"] = val.stats_as_dict[\"AP_50\"]\n stats[f\"metrics/mAP50-95({suffix[i][0]})\"] = val.stats_as_dict[\"AP_all\"]\n\n if self.is_lvis:\n stats[f\"metrics/APr({suffix[i][0]})\"] = val.stats_as_dict[\"APr\"]\n stats[f\"metrics/APc({suffix[i][0]})\"] = val.stats_as_dict[\"APc\"]\n stats[f\"metrics/APf({suffix[i][0]})\"] = val.stats_as_dict[\"APf\"]\n\n if self.is_lvis:\n stats[\"fitness\"] = stats[\"metrics/mAP50-95(B)\"] # always use box mAP50-95 for fitness\n except Exception as e:\n LOGGER.warning(f\"faster-coco-eval unable to run: {e}\")\n return stats",
"chunk_type": "class",
"name": "DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\val.py",
"start_line": 18,
"end_line": 465,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "A class extending the BaseValidator class for validation based on a detection model.\n\nThis class implements validation functionality specific to object detection tasks, including metrics calculation,\nprediction processing, and visualization of results.\n\nAttributes:\n is_coco (bool): Whether the dataset is COCO.\n is_lvis (bool): Whether the dataset is LVIS.\n class_map (List[int]): Mapping from model class indices to dataset class indices.\n metrics (DetMetrics): Object detection metrics calculator.\n iouv (torch.Tensor): IoU thresholds for mAP calculation.\n niou (int): Number of IoU thresholds.\n lb (List[Any]): List for storing ground truth labels for hybrid saving.\n jdict (List[Dict[str, Any]]): List for storing JSON detection results.\n stats (Dict[str, List[torch.Tensor]]): Dictionary for storing statistics during validation.\n\nExamples:\n >>> from ultralytics.models.yolo.detect import DetectionValidator\n >>> args = dict(model=\"yolo11n.pt\", data=\"coco8.yaml\")\n >>> validator = DetectionValidator(args=args)\n >>> validator()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"os",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.data.build_dataloader",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.converter",
"ultralytics.engine.validator.BaseValidator",
"ultralytics.utils.LOGGER",
"ultralytics.utils.ops",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.metrics.ConfusionMatrix",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.box_iou",
"ultralytics.utils.plotting.plot_images",
"ultralytics.engine.results.Results",
"faster_coco_eval.COCO",
"faster_coco_eval.COCOeval_faster",
"BaseValidator"
],
"chunk_id": "class_DetectionValidator_602f977a"
},
{
"content": "from .predict import DetectionPredictor",
"chunk_type": "import",
"name": "DetectionPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionPredictor_5b7f284f"
},
{
"content": "from .train import DetectionTrainer",
"chunk_type": "import",
"name": "DetectionTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionTrainer_1e8eb3ba"
},
{
"content": "from .val import DetectionValidator",
"chunk_type": "import",
"name": "DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionValidator_a53db3c2"
},
{
"content": "__all__ = \"DetectionPredictor\", \"DetectionTrainer\", \"DetectionValidator\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\detect\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___f412584c"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_aeec3423"
},
{
"content": "from ultralytics.engine.results import Results",
"chunk_type": "import",
"name": "Results",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Results_2191c58a"
},
{
"content": "from ultralytics.models.yolo.detect.predict import DetectionPredictor",
"chunk_type": "import",
"name": "DetectionPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 69,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionPredictor_ced97923"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, ops",
"chunk_type": "import",
"name": "DEFAULT_CFG, ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, ops_6bc7b389"
},
{
"content": "class OBBPredictor(DetectionPredictor):\n \"\"\"\n A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.\n\n This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated\n bounding boxes.\n\n Attributes:\n args (namespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO OBB model.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.obb import OBBPredictor\n >>> args = dict(model=\"yolo11n-obb.pt\", source=ASSETS)\n >>> predictor = OBBPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize OBBPredictor with optional model and data configuration overrides.\n\n Args:\n cfg (dict, optional): Default configuration for the predictor.\n overrides (dict, optional): Configuration overrides that take precedence over the default config.\n _callbacks (list, optional): List of callback functions to be invoked during prediction.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.obb import OBBPredictor\n >>> args = dict(model=\"yolo11n-obb.pt\", source=ASSETS)\n >>> predictor = OBBPredictor(overrides=args)\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.args.task = \"obb\"\n\n def construct_result(self, pred, img, orig_img, img_path):\n \"\"\"\n Construct the result object from the prediction.\n\n Args:\n pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where\n the last dimension contains [x, y, w, h, confidence, class_id, angle].\n img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).\n orig_img (np.ndarray): The original image before preprocessing.\n img_path (str): The path to the original image.\n\n Returns:\n (Results): The result object containing the original image, image path, class names, and oriented bounding\n boxes.\n \"\"\"\n rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))\n rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)\n obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)\n return Results(orig_img, path=img_path, names=self.model.names, obb=obb)",
"chunk_type": "class",
"name": "OBBPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\predict.py",
"start_line": 10,
"end_line": 65,
"start_col": 0,
"end_col": 80,
"parent_name": null,
"docstring": "A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.\n\nThis predictor handles oriented bounding box detection tasks, processing images and returning results with rotated\nbounding boxes.\n\nAttributes:\n args (namespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO OBB model.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.obb import OBBPredictor\n >>> args = dict(model=\"yolo11n-obb.pt\", source=ASSETS)\n >>> predictor = OBBPredictor(overrides=args)\n >>> predictor.predict_cli()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"torch",
"ultralytics.engine.results.Results",
"ultralytics.models.yolo.detect.predict.DetectionPredictor",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.ops",
"DetectionPredictor"
],
"chunk_id": "class_OBBPredictor_0d773c5c"
},
{
"content": "from copy import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_570af0db"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_8caf5746"
},
{
"content": "from typing import Any, List, Optional, Union",
"chunk_type": "import",
"name": "Any, List, Optional, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, List, Optional, Union_5cedf4ff"
},
{
"content": "from ultralytics.models import yolo",
"chunk_type": "import",
"name": "yolo",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_yolo_7cdc48ea"
},
{
"content": "from ultralytics.nn.tasks import OBBModel",
"chunk_type": "import",
"name": "OBBModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OBBModel_cab655dc"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, RANK",
"chunk_type": "import",
"name": "DEFAULT_CFG, RANK",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, RANK_23eb45fd"
},
{
"content": "class OBBTrainer(yolo.detect.DetectionTrainer):\n \"\"\"\n A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.\n\n This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for\n detecting objects at arbitrary angles rather than just axis-aligned rectangles.\n\n Attributes:\n loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,\n and dfl_loss.\n\n Methods:\n get_model: Return OBBModel initialized with specified config and weights.\n get_validator: Return an instance of OBBValidator for validation of YOLO model.\n\n Examples:\n >>> from ultralytics.models.yolo.obb import OBBTrainer\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\", epochs=3)\n >>> trainer = OBBTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[dict] = None, _callbacks: Optional[List[Any]] = None):\n \"\"\"\n Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.\n\n This trainer extends the DetectionTrainer class to specialize in training models that detect oriented\n bounding boxes. It automatically sets the task to 'obb' in the configuration.\n\n Args:\n cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and\n model configuration.\n overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here\n will take precedence over those in cfg.\n _callbacks (List[Any], optional): List of callback functions to be invoked during training.\n\n Examples:\n >>> from ultralytics.models.yolo.obb import OBBTrainer\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\", epochs=3)\n >>> trainer = OBBTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"task\"] = \"obb\"\n super().__init__(cfg, overrides, _callbacks)\n\n def get_model(\n self, cfg: Optional[Union[str, dict]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True\n ) -> OBBModel:\n \"\"\"\n Return OBBModel initialized with specified config and weights.\n\n Args:\n cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary\n containing configuration parameters, or None to use default configuration.\n weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.\n verbose (bool): Whether to display model information during initialization.\n\n Returns:\n (OBBModel): Initialized OBBModel with the specified configuration and weights.\n\n Examples:\n >>> trainer = OBBTrainer()\n >>> model = trainer.get_model(cfg=\"yolo11n-obb.yaml\", weights=\"yolo11n-obb.pt\")\n \"\"\"\n model = OBBModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n\n return model\n\n def get_validator(self):\n \"\"\"Return an instance of OBBValidator for validation of YOLO model.\"\"\"\n self.loss_names = \"box_loss\", \"cls_loss\", \"dfl_loss\"\n return yolo.obb.OBBValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )",
"chunk_type": "class",
"name": "OBBTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\train.py",
"start_line": 12,
"end_line": 89,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.\n\nThis trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for\ndetecting objects at arbitrary angles rather than just axis-aligned rectangles.\n\nAttributes:\n loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,\n and dfl_loss.\n\nMethods:\n get_model: Return OBBModel initialized with specified config and weights.\n get_validator: Return an instance of OBBValidator for validation of YOLO model.\n\nExamples:\n >>> from ultralytics.models.yolo.obb import OBBTrainer\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\", epochs=3)\n >>> trainer = OBBTrainer(overrides=args)\n >>> trainer.train()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.List",
"typing.Optional",
"typing.Union",
"ultralytics.models.yolo",
"ultralytics.nn.tasks.OBBModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.RANK",
"yolo.detect.DetectionTrainer"
],
"chunk_id": "class_OBBTrainer_efd0a4ee"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_a1b3e038"
},
{
"content": "from typing import Any, Dict, List, Tuple, Union",
"chunk_type": "import",
"name": "Any, Dict, List, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Tuple, Union_ebd9e595"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_e2f852ec"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_c4d8299d"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionValidator",
"chunk_type": "import",
"name": "DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionValidator_2c9c1948"
},
{
"content": "from ultralytics.utils import LOGGER, ops",
"chunk_type": "import",
"name": "LOGGER, ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, ops_d88f0de0"
},
{
"content": "from ultralytics.utils.metrics import OBBMetrics, batch_probiou",
"chunk_type": "import",
"name": "OBBMetrics, batch_probiou",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OBBMetrics, batch_probiou_2ab63936"
},
{
"content": "class OBBValidator(DetectionValidator):\n \"\"\"\n A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.\n\n This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and\n satellite imagery where objects can appear at various orientations.\n\n Attributes:\n args (dict): Configuration arguments for the validator.\n metrics (OBBMetrics): Metrics object for evaluating OBB model performance.\n is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.\n\n Methods:\n init_metrics: Initialize evaluation metrics for YOLO.\n _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.\n _prepare_batch: Prepare batch data for OBB validation.\n _prepare_pred: Prepare predictions with scaled and padded bounding boxes.\n plot_predictions: Plot predicted bounding boxes on input images.\n pred_to_json: Serialize YOLO predictions to COCO json format.\n save_one_txt: Save YOLO detections to a txt file in normalized coordinates.\n eval_json: Evaluate YOLO output in JSON format and return performance statistics.\n\n Examples:\n >>> from ultralytics.models.yolo.obb import OBBValidator\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\")\n >>> validator = OBBValidator(args=args)\n >>> validator(model=args[\"model\"])\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.\n\n This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.\n It extends the DetectionValidator class and configures it specifically for the OBB task.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.\n save_dir (str | Path, optional): Directory to save results.\n args (dict | SimpleNamespace, optional): Arguments containing validation parameters.\n _callbacks (list, optional): List of callback functions to be called during validation.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.args.task = \"obb\"\n self.metrics = OBBMetrics()\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"\n Initialize evaluation metrics for YOLO obb validation.\n\n Args:\n model (torch.nn.Module): Model to validate.\n \"\"\"\n super().init_metrics(model)\n val = self.data.get(self.args.split, \"\") # validation path\n self.is_dota = isinstance(val, str) and \"DOTA\" in val # check if dataset is DOTA format\n\n def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:\n \"\"\"\n Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.\n\n Args:\n preds (Dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected\n class labels and bounding boxes.\n batch (Dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth\n class labels and bounding boxes.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy\n array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy\n of predictions compared to the ground truth.\n\n Examples:\n >>> detections = torch.rand(100, 7) # 100 sample detections\n >>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes\n >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels\n >>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)\n \"\"\"\n if len(batch[\"cls\"]) == 0 or len(preds[\"cls\"]) == 0:\n return {\"tp\": np.zeros((len(preds[\"cls\"]), self.niou), dtype=bool)}\n iou = batch_probiou(batch[\"bboxes\"], preds[\"bboxes\"])\n return {\"tp\": self.match_predictions(preds[\"cls\"], batch[\"cls\"], iou).cpu().numpy()}\n\n def postprocess(self, preds: torch.Tensor) -> List[Dict[str, torch.Tensor]]:\n \"\"\"\n Args:\n preds (torch.Tensor): Raw predictions from the model.\n\n Returns:\n (List[Dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.\n \"\"\"\n preds = super().postprocess(preds)\n for pred in preds:\n pred[\"bboxes\"] = torch.cat([pred[\"bboxes\"], pred.pop(\"extra\")], dim=-1) # concatenate angle\n return preds\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare batch data for OBB validation with proper scaling and formatting.\n\n Args:\n si (int): Batch index to process.\n batch (Dict[str, Any]): Dictionary containing batch data with keys:\n - batch_idx: Tensor of batch indices\n - cls: Tensor of class labels\n - bboxes: Tensor of bounding boxes\n - ori_shape: Original image shapes\n - img: Batch of images\n - ratio_pad: Ratio and padding information\n\n Returns:\n (Dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.\n \"\"\"\n idx = batch[\"batch_idx\"] == si\n cls = batch[\"cls\"][idx].squeeze(-1)\n bbox = batch[\"bboxes\"][idx]\n ori_shape = batch[\"ori_shape\"][si]\n imgsz = batch[\"img\"].shape[2:]\n ratio_pad = batch[\"ratio_pad\"][si]\n if len(cls):\n bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes\n ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels\n return {\"cls\": cls, \"bboxes\": bbox, \"ori_shape\": ori_shape, \"imgsz\": imgsz, \"ratio_pad\": ratio_pad}\n\n def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n \"\"\"\n Prepare predictions by scaling bounding boxes to original image dimensions.\n\n This method takes prediction tensors containing bounding box coordinates and scales them from the model's\n input dimensions to the original image dimensions using the provided batch information.\n\n Args:\n pred (Dict[str, torch.Tensor]): Prediction dictionary containing bounding box coordinates and other information.\n pbatch (Dict[str, Any]): Dictionary containing batch information with keys:\n - imgsz (tuple): Model input image size.\n - ori_shape (tuple): Original image shape.\n - ratio_pad (tuple): Ratio and padding information for scaling.\n\n Returns:\n (Dict[str, torch.Tensor]): Scaled prediction dictionary with bounding boxes in original image dimensions.\n \"\"\"\n cls = pred[\"cls\"]\n if self.args.single_cls:\n cls *= 0\n bboxes = ops.scale_boxes(\n pbatch[\"imgsz\"], pred[\"bboxes\"].clone(), pbatch[\"ori_shape\"], ratio_pad=pbatch[\"ratio_pad\"], xywh=True\n ) # native-space pred\n return {\"bboxes\": bboxes, \"conf\": pred[\"conf\"], \"cls\": cls}\n\n def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:\n \"\"\"\n Plot predicted bounding boxes on input images and save the result.\n\n Args:\n batch (Dict[str, Any]): Batch data containing images, file paths, and other metadata.\n preds (List[torch.Tensor]): List of prediction tensors for each image in the batch.\n ni (int): Batch index used for naming the output file.\n\n Examples:\n >>> validator = OBBValidator()\n >>> batch = {\"img\": images, \"im_file\": paths}\n >>> preds = [torch.rand(10, 7)] # Example predictions for one image\n >>> validator.plot_predictions(batch, preds, 0)\n \"\"\"\n for p in preds:\n # TODO: fix this duplicated `xywh2xyxy`\n p[\"bboxes\"][:, :4] = ops.xywh2xyxy(p[\"bboxes\"][:, :4]) # convert to xyxy format for plotting\n super().plot_predictions(batch, preds, ni) # plot bboxes\n\n def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: Union[str, Path]) -> None:\n \"\"\"\n Convert YOLO predictions to COCO JSON format with rotated bounding box information.\n\n Args:\n predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys\n with bounding box coordinates, confidence scores, and class predictions.\n filename (str | Path): Path to the image file for which predictions are being processed.\n\n Notes:\n This method processes rotated bounding box predictions and converts them to both rbox format\n (x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them\n to the JSON dictionary.\n \"\"\"\n stem = Path(filename).stem\n image_id = int(stem) if stem.isnumeric() else stem\n rbox = predn[\"bboxes\"]\n poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)\n for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn[\"conf\"].tolist(), predn[\"cls\"].tolist()):\n self.jdict.append(\n {\n \"image_id\": image_id,\n \"category_id\": self.class_map[int(c)],\n \"score\": round(s, 5),\n \"rbox\": [round(x, 3) for x in r],\n \"poly\": [round(x, 3) for x in b],\n }\n )\n\n def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:\n \"\"\"\n Save YOLO OBB detections to a text file in normalized coordinates.\n\n Args:\n predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,\n class predictions, and angles in format (x, y, w, h, conf, cls, angle).\n save_conf (bool): Whether to save confidence scores in the text file.\n shape (Tuple[int, int]): Original image shape in format (height, width).\n file (Path): Output file path to save detections.\n\n Examples:\n >>> validator = OBBValidator()\n >>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle\n >>> validator.save_one_txt(predn, True, (640, 480), \"detection.txt\")\n \"\"\"\n import numpy as np\n\n from ultralytics.engine.results import Results\n\n Results(\n np.zeros((shape[0], shape[1]), dtype=np.uint8),\n path=None,\n names=self.names,\n obb=torch.cat([predn[\"bboxes\"], predn[\"conf\"].unsqueeze(-1), predn[\"cls\"].unsqueeze(-1)], dim=1),\n ).save_txt(file, save_conf=save_conf)\n\n def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Evaluate YOLO output in JSON format and save predictions in DOTA format.\n\n Args:\n stats (Dict[str, Any]): Performance statistics dictionary.\n\n Returns:\n (Dict[str, Any]): Updated performance statistics.\n \"\"\"\n if self.args.save_json and self.is_dota and len(self.jdict):\n import json\n import re\n from collections import defaultdict\n\n pred_json = self.save_dir / \"predictions.json\" # predictions\n pred_txt = self.save_dir / \"predictions_txt\" # predictions\n pred_txt.mkdir(parents=True, exist_ok=True)\n data = json.load(open(pred_json))\n # Save split results\n LOGGER.info(f\"Saving predictions with DOTA format to {pred_txt}...\")\n for d in data:\n image_id = d[\"image_id\"]\n score = d[\"score\"]\n classname = self.names[d[\"category_id\"] - 1].replace(\" \", \"-\")\n p = d[\"poly\"]\n\n with open(f\"{pred_txt / f'Task1_{classname}'}.txt\", \"a\", encoding=\"utf-8\") as f:\n f.writelines(f\"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\\n\")\n # Save merged results, this could result slightly lower map than using official merging script,\n # because of the probiou calculation.\n pred_merged_txt = self.save_dir / \"predictions_merged_txt\" # predictions\n pred_merged_txt.mkdir(parents=True, exist_ok=True)\n merged_results = defaultdict(list)\n LOGGER.info(f\"Saving merged predictions with DOTA format to {pred_merged_txt}...\")\n for d in data:\n image_id = d[\"image_id\"].split(\"__\", 1)[0]\n pattern = re.compile(r\"\\d+___\\d+\")\n x, y = (int(c) for c in re.findall(pattern, d[\"image_id\"])[0].split(\"___\"))\n bbox, score, cls = d[\"rbox\"], d[\"score\"], d[\"category_id\"] - 1\n bbox[0] += x\n bbox[1] += y\n bbox.extend([score, cls])\n merged_results[image_id].append(bbox)\n for image_id, bbox in merged_results.items():\n bbox = torch.tensor(bbox)\n max_wh = torch.max(bbox[:, :2]).item() * 2\n c = bbox[:, 6:7] * max_wh # classes\n scores = bbox[:, 5] # scores\n b = bbox[:, :5].clone()\n b[:, :2] += c\n # 0.3 could get results close to the ones from official merging script, even slightly better.\n i = ops.nms_rotated(b, scores, 0.3)\n bbox = bbox[i]\n\n b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)\n for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():\n classname = self.names[int(x[-1])].replace(\" \", \"-\")\n p = [round(i, 3) for i in x[:-2]] # poly\n score = round(x[-2], 3)\n\n with open(f\"{pred_merged_txt / f'Task1_{classname}'}.txt\", \"a\", encoding=\"utf-8\") as f:\n f.writelines(f\"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\\n\")\n\n return stats",
"chunk_type": "class",
"name": "OBBValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\val.py",
"start_line": 14,
"end_line": 303,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.\n\nThis validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and\nsatellite imagery where objects can appear at various orientations.\n\nAttributes:\n args (dict): Configuration arguments for the validator.\n metrics (OBBMetrics): Metrics object for evaluating OBB model performance.\n is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.\n\nMethods:\n init_metrics: Initialize evaluation metrics for YOLO.\n _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.\n _prepare_batch: Prepare batch data for OBB validation.\n _prepare_pred: Prepare predictions with scaled and padded bounding boxes.\n plot_predictions: Plot predicted bounding boxes on input images.\n pred_to_json: Serialize YOLO predictions to COCO json format.\n save_one_txt: Save YOLO detections to a txt file in normalized coordinates.\n eval_json: Evaluate YOLO output in JSON format and return performance statistics.\n\nExamples:\n >>> from ultralytics.models.yolo.obb import OBBValidator\n >>> args = dict(model=\"yolo11n-obb.pt\", data=\"dota8.yaml\")\n >>> validator = OBBValidator(args=args)\n >>> validator(model=args[\"model\"])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"typing.Union",
"numpy",
"torch",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.utils.LOGGER",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.batch_probiou",
"numpy",
"ultralytics.engine.results.Results",
"json",
"re",
"collections.defaultdict",
"DetectionValidator"
],
"chunk_id": "class_OBBValidator_afeefb6c"
},
{
"content": "from .predict import OBBPredictor",
"chunk_type": "import",
"name": "OBBPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OBBPredictor_6b8ebe96"
},
{
"content": "from .train import OBBTrainer",
"chunk_type": "import",
"name": "OBBTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OBBTrainer_ac16a03d"
},
{
"content": "from .val import OBBValidator",
"chunk_type": "import",
"name": "OBBValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OBBValidator_6794c3f5"
},
{
"content": "__all__ = \"OBBPredictor\", \"OBBTrainer\", \"OBBValidator\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\obb\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___512c9b5a"
},
{
"content": "from ultralytics.models.yolo.detect.predict import DetectionPredictor",
"chunk_type": "import",
"name": "DetectionPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 69,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionPredictor_083c2bc6"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, ops",
"chunk_type": "import",
"name": "DEFAULT_CFG, LOGGER, ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\predict.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, LOGGER, ops_d2bc05fe"
},
{
"content": "class PosePredictor(DetectionPredictor):\n \"\"\"\n A class extending the DetectionPredictor class for prediction based on a pose model.\n\n This class specializes in pose estimation, handling keypoints detection alongside standard object detection\n capabilities inherited from DetectionPredictor.\n\n Attributes:\n args (namespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.\n\n Methods:\n construct_result: Construct the result object from the prediction, including keypoints.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.pose import PosePredictor\n >>> args = dict(model=\"yolo11n-pose.pt\", source=ASSETS)\n >>> predictor = PosePredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize PosePredictor for pose estimation tasks.\n\n Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific\n warnings for Apple MPS.\n\n Args:\n cfg (Any): Configuration for the predictor.\n overrides (dict, optional): Configuration overrides that take precedence over cfg.\n _callbacks (list, optional): List of callback functions to be invoked during prediction.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.pose import PosePredictor\n >>> args = dict(model=\"yolo11n-pose.pt\", source=ASSETS)\n >>> predictor = PosePredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.args.task = \"pose\"\n if isinstance(self.args.device, str) and self.args.device.lower() == \"mps\":\n LOGGER.warning(\n \"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. \"\n \"See https://github.com/ultralytics/ultralytics/issues/4031.\"\n )\n\n def construct_result(self, pred, img, orig_img, img_path):\n \"\"\"\n Construct the result object from the prediction, including keypoints.\n\n Extends the parent class implementation by extracting keypoint data from predictions and adding them to the\n result object.\n\n Args:\n pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is\n the number of detections, K is the number of keypoints, and D is the keypoint dimension.\n img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).\n orig_img (np.ndarray): The original unprocessed image as a numpy array.\n img_path (str): The path to the original image file.\n\n Returns:\n (Results): The result object containing the original image, image path, class names, bounding boxes, and\n keypoints.\n \"\"\"\n result = super().construct_result(pred, img, orig_img, img_path)\n # Extract keypoints from prediction and reshape according to model's keypoint shape\n pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape)\n # Scale keypoints coordinates to match the original image dimensions\n pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)\n result.update(keypoints=pred_kpts)\n return result",
"chunk_type": "class",
"name": "PosePredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\predict.py",
"start_line": 7,
"end_line": 80,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": "A class extending the DetectionPredictor class for prediction based on a pose model.\n\nThis class specializes in pose estimation, handling keypoints detection alongside standard object detection\ncapabilities inherited from DetectionPredictor.\n\nAttributes:\n args (namespace): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.\n\nMethods:\n construct_result: Construct the result object from the prediction, including keypoints.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.pose import PosePredictor\n >>> args = dict(model=\"yolo11n-pose.pt\", source=ASSETS)\n >>> predictor = PosePredictor(overrides=args)\n >>> predictor.predict_cli()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"ultralytics.models.yolo.detect.predict.DetectionPredictor",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.ops",
"DetectionPredictor"
],
"chunk_id": "class_PosePredictor_1d8d9515"
},
{
"content": "from copy import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_cf2de10a"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_2aaaa099"
},
{
"content": "from typing import Any, Dict, Optional, Union",
"chunk_type": "import",
"name": "Any, Dict, Optional, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, Optional, Union_42f903a3"
},
{
"content": "from ultralytics.models import yolo",
"chunk_type": "import",
"name": "yolo",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_yolo_66810175"
},
{
"content": "from ultralytics.nn.tasks import PoseModel",
"chunk_type": "import",
"name": "PoseModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_PoseModel_5bf44d06"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, LOGGER",
"chunk_type": "import",
"name": "DEFAULT_CFG, LOGGER",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, LOGGER_8a6b8b43"
},
{
"content": "from ultralytics.utils.plotting import plot_results",
"chunk_type": "import",
"name": "plot_results",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_plot_results_7414fa96"
},
{
"content": "class PoseTrainer(yolo.detect.DetectionTrainer):\n \"\"\"\n A class extending the DetectionTrainer class for training YOLO pose estimation models.\n\n This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization\n of pose keypoints alongside bounding boxes.\n\n Attributes:\n args (dict): Configuration arguments for training.\n model (PoseModel): The pose estimation model being trained.\n data (dict): Dataset configuration including keypoint shape information.\n loss_names (tuple): Names of the loss components used in training.\n\n Methods:\n get_model: Retrieve a pose estimation model with specified configuration.\n set_model_attributes: Set keypoints shape attribute on the model.\n get_validator: Create a validator instance for model evaluation.\n plot_training_samples: Visualize training samples with keypoints.\n plot_metrics: Generate and save training/validation metric plots.\n get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.\n\n Examples:\n >>> from ultralytics.models.yolo.pose import PoseTrainer\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\", epochs=3)\n >>> trainer = PoseTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):\n \"\"\"\n Initialize a PoseTrainer object for training YOLO pose estimation models.\n\n This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and\n handling specific configurations needed for keypoint detection models.\n\n Args:\n cfg (dict, optional): Default configuration dictionary containing training parameters.\n overrides (dict, optional): Dictionary of parameter overrides for the default configuration.\n _callbacks (list, optional): List of callback functions to be executed during training.\n\n Notes:\n This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.\n A warning is issued when using Apple MPS device due to known bugs with pose models.\n\n Examples:\n >>> from ultralytics.models.yolo.pose import PoseTrainer\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\", epochs=3)\n >>> trainer = PoseTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"task\"] = \"pose\"\n super().__init__(cfg, overrides, _callbacks)\n\n if isinstance(self.args.device, str) and self.args.device.lower() == \"mps\":\n LOGGER.warning(\n \"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. \"\n \"See https://github.com/ultralytics/ultralytics/issues/4031.\"\n )\n\n def get_model(\n self,\n cfg: Optional[Union[str, Path, Dict[str, Any]]] = None,\n weights: Optional[Union[str, Path]] = None,\n verbose: bool = True,\n ) -> PoseModel:\n \"\"\"\n Get pose estimation model with specified configuration and weights.\n\n Args:\n cfg (str | Path | dict, optional): Model configuration file path or dictionary.\n weights (str | Path, optional): Path to the model weights file.\n verbose (bool): Whether to display model information.\n\n Returns:\n (PoseModel): Initialized pose estimation model.\n \"\"\"\n model = PoseModel(\n cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], data_kpt_shape=self.data[\"kpt_shape\"], verbose=verbose\n )\n if weights:\n model.load(weights)\n\n return model\n\n def set_model_attributes(self):\n \"\"\"Set keypoints shape attribute of PoseModel.\"\"\"\n super().set_model_attributes()\n self.model.kpt_shape = self.data[\"kpt_shape\"]\n\n def get_validator(self):\n \"\"\"Return an instance of the PoseValidator class for validation.\"\"\"\n self.loss_names = \"box_loss\", \"pose_loss\", \"kobj_loss\", \"cls_loss\", \"dfl_loss\"\n return yolo.pose.PoseValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def plot_metrics(self):\n \"\"\"Plot training/validation metrics.\"\"\"\n plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png\n\n def get_dataset(self) -> Dict[str, Any]:\n \"\"\"\n Retrieve the dataset and ensure it contains the required `kpt_shape` key.\n\n Returns:\n (dict): A dictionary containing the training/validation/test dataset and category names.\n\n Raises:\n KeyError: If the `kpt_shape` key is not present in the dataset.\n \"\"\"\n data = super().get_dataset()\n if \"kpt_shape\" not in data:\n raise KeyError(f\"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/\")\n return data",
"chunk_type": "class",
"name": "PoseTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\train.py",
"start_line": 13,
"end_line": 128,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "A class extending the DetectionTrainer class for training YOLO pose estimation models.\n\nThis trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization\nof pose keypoints alongside bounding boxes.\n\nAttributes:\n args (dict): Configuration arguments for training.\n model (PoseModel): The pose estimation model being trained.\n data (dict): Dataset configuration including keypoint shape information.\n loss_names (tuple): Names of the loss components used in training.\n\nMethods:\n get_model: Retrieve a pose estimation model with specified configuration.\n set_model_attributes: Set keypoints shape attribute on the model.\n get_validator: Create a validator instance for model evaluation.\n plot_training_samples: Visualize training samples with keypoints.\n plot_metrics: Generate and save training/validation metric plots.\n get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.\n\nExamples:\n >>> from ultralytics.models.yolo.pose import PoseTrainer\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\", epochs=3)\n >>> trainer = PoseTrainer(overrides=args)\n >>> trainer.train()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Optional",
"typing.Union",
"ultralytics.models.yolo",
"ultralytics.nn.tasks.PoseModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.plotting.plot_results",
"yolo.detect.DetectionTrainer"
],
"chunk_id": "class_PoseTrainer_75b00860"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_5106f23b"
},
{
"content": "from typing import Any, Dict, Tuple",
"chunk_type": "import",
"name": "Any, Dict, Tuple",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, Tuple_2193a97b"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_21b0730b"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_35e58519"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionValidator",
"chunk_type": "import",
"name": "DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionValidator_d1720e2b"
},
{
"content": "from ultralytics.utils import LOGGER, ops",
"chunk_type": "import",
"name": "LOGGER, ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, ops_5b4a1692"
},
{
"content": "from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou",
"chunk_type": "import",
"name": "OKS_SIGMA, PoseMetrics, kpt_iou",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 69,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OKS_SIGMA, PoseMetrics, kpt_iou_0354a920"
},
{
"content": "class PoseValidator(DetectionValidator):\n \"\"\"\n A class extending the DetectionValidator class for validation based on a pose model.\n\n This validator is specifically designed for pose estimation tasks, handling keypoints and implementing\n specialized metrics for pose evaluation.\n\n Attributes:\n sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.\n kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.\n args (dict): Arguments for the validator including task set to \"pose\".\n metrics (PoseMetrics): Metrics object for pose evaluation.\n\n Methods:\n preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.\n get_desc: Return description of evaluation metrics in string format.\n init_metrics: Initialize pose estimation metrics for YOLO model.\n _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original\n dimensions.\n _prepare_pred: Prepare and scale keypoints in predictions for pose processing.\n _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between\n detections and ground truth.\n plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.\n plot_predictions: Plot and save model predictions with bounding boxes and keypoints.\n save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.\n pred_to_json: Convert YOLO predictions to COCO JSON format.\n eval_json: Evaluate object detection model using COCO JSON format.\n\n Examples:\n >>> from ultralytics.models.yolo.pose import PoseValidator\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\")\n >>> validator = PoseValidator(args=args)\n >>> validator()\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize a PoseValidator object for pose estimation validation.\n\n This validator is specifically designed for pose estimation tasks, handling keypoints and implementing\n specialized metrics for pose evaluation.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.\n save_dir (Path | str, optional): Directory to save results.\n args (dict, optional): Arguments for the validator including task set to \"pose\".\n _callbacks (list, optional): List of callback functions to be executed during validation.\n\n Examples:\n >>> from ultralytics.models.yolo.pose import PoseValidator\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\")\n >>> validator = PoseValidator(args=args)\n >>> validator()\n\n Notes:\n This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values\n for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS\n due to a known bug with pose models.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.sigma = None\n self.kpt_shape = None\n self.args.task = \"pose\"\n self.metrics = PoseMetrics()\n if isinstance(self.args.device, str) and self.args.device.lower() == \"mps\":\n LOGGER.warning(\n \"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. \"\n \"See https://github.com/ultralytics/ultralytics/issues/4031.\"\n )\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Preprocess batch by converting keypoints data to float and moving it to the device.\"\"\"\n batch = super().preprocess(batch)\n batch[\"keypoints\"] = batch[\"keypoints\"].to(self.device).float()\n return batch\n\n def get_desc(self) -> str:\n \"\"\"Return description of evaluation metrics in string format.\"\"\"\n return (\"%22s\" + \"%11s\" * 10) % (\n \"Class\",\n \"Images\",\n \"Instances\",\n \"Box(P\",\n \"R\",\n \"mAP50\",\n \"mAP50-95)\",\n \"Pose(P\",\n \"R\",\n \"mAP50\",\n \"mAP50-95)\",\n )\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"\n Initialize evaluation metrics for YOLO pose validation.\n\n Args:\n model (torch.nn.Module): Model to validate.\n \"\"\"\n super().init_metrics(model)\n self.kpt_shape = self.data[\"kpt_shape\"]\n is_pose = self.kpt_shape == [17, 3]\n nkpt = self.kpt_shape[0]\n self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt\n\n def postprocess(self, preds: torch.Tensor) -> Dict[str, torch.Tensor]:\n \"\"\"\n Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.\n\n This method extends the parent class postprocessing by extracting keypoints from the 'extra'\n field of predictions and reshaping them according to the keypoint shape configuration.\n The keypoints are reshaped from a flattened format to the proper dimensional structure\n (typically [N, 17, 3] for COCO pose format).\n\n Args:\n preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing\n bounding boxes, confidence scores, class predictions, and keypoint data.\n\n Returns:\n (Dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:\n - 'bboxes': Bounding box coordinates\n - 'conf': Confidence scores\n - 'cls': Class predictions\n - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)\n\n Note:\n If no keypoints are present in a prediction (empty keypoints), that prediction\n is skipped and continues to the next one. The keypoints are extracted from the\n 'extra' field which contains additional task-specific data beyond basic detection.\n \"\"\"\n preds = super().postprocess(preds)\n for pred in preds:\n pred[\"keypoints\"] = pred.pop(\"extra\").view(-1, *self.kpt_shape) # remove extra if exists\n return preds\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.\n\n Args:\n si (int): Batch index.\n batch (Dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.\n\n Returns:\n (Dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.\n\n Notes:\n This method extends the parent class's _prepare_batch method by adding keypoint processing.\n Keypoints are scaled from normalized coordinates to original image dimensions.\n \"\"\"\n pbatch = super()._prepare_batch(si, batch)\n kpts = batch[\"keypoints\"][batch[\"batch_idx\"] == si]\n h, w = pbatch[\"imgsz\"]\n kpts = kpts.clone()\n kpts[..., 0] *= w\n kpts[..., 1] *= h\n kpts = ops.scale_coords(pbatch[\"imgsz\"], kpts, pbatch[\"ori_shape\"], ratio_pad=pbatch[\"ratio_pad\"])\n pbatch[\"keypoints\"] = kpts\n return pbatch\n\n def _prepare_pred(self, pred: Dict[str, Any], pbatch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare and scale keypoints in predictions for pose processing.\n\n This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls\n the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates\n to match the original image dimensions.\n\n Args:\n pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.\n pbatch (Dict[str, Any]): Processed batch dictionary containing image information including:\n - imgsz: Image size used for inference\n - ori_shape: Original image shape\n - ratio_pad: Ratio and padding information for coordinate scaling\n\n Returns:\n (Dict[str, Any]): Processed prediction dictionary with keypoints scaled to original image dimensions.\n \"\"\"\n predn = super()._prepare_pred(pred, pbatch)\n predn[\"keypoints\"] = ops.scale_coords(\n pbatch[\"imgsz\"], pred.get(\"keypoints\").clone(), pbatch[\"ori_shape\"], ratio_pad=pbatch[\"ratio_pad\"]\n )\n return predn\n\n def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:\n \"\"\"\n Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.\n\n Args:\n preds (Dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions\n and 'keypoints' for keypoint predictions.\n batch (Dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,\n 'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.\n\n Returns:\n (Dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose\n true positives across 10 IoU levels.\n\n Notes:\n `0.53` scale factor used in area computation is referenced from\n https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.\n \"\"\"\n tp = super()._process_batch(preds, batch)\n gt_cls = batch[\"cls\"]\n if len(gt_cls) == 0 or len(preds[\"cls\"]) == 0:\n tp_p = np.zeros((len(preds[\"cls\"]), self.niou), dtype=bool)\n else:\n # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384\n area = ops.xyxy2xywh(batch[\"bboxes\"])[:, 2:].prod(1) * 0.53\n iou = kpt_iou(batch[\"keypoints\"], preds[\"keypoints\"], sigma=self.sigma, area=area)\n tp_p = self.match_predictions(preds[\"cls\"], gt_cls, iou).cpu().numpy()\n tp.update({\"tp_p\": tp_p}) # update tp with kpts IoU\n return tp\n\n def save_one_txt(self, predn: Dict[str, torch.Tensor], save_conf: bool, shape: Tuple[int, int], file: Path) -> None:\n \"\"\"\n Save YOLO pose detections to a text file in normalized coordinates.\n\n Args:\n predn (Dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.\n save_conf (bool): Whether to save confidence scores.\n shape (Tuple[int, int]): Shape of the original image (height, width).\n file (Path): Output file path to save detections.\n\n Notes:\n The output format is: class_id x_center y_center width height confidence keypoints where keypoints are\n normalized (x, y, visibility) values for each point.\n \"\"\"\n from ultralytics.engine.results import Results\n\n Results(\n np.zeros((shape[0], shape[1]), dtype=np.uint8),\n path=None,\n names=self.names,\n boxes=torch.cat([predn[\"bboxes\"], predn[\"conf\"].unsqueeze(-1), predn[\"cls\"].unsqueeze(-1)], dim=1),\n keypoints=predn[\"keypoints\"],\n ).save_txt(file, save_conf=save_conf)\n\n def pred_to_json(self, predn: Dict[str, torch.Tensor], filename: str) -> None:\n \"\"\"\n Convert YOLO predictions to COCO JSON format.\n\n This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format\n to COCO format, and appends the results to the internal JSON dictionary (self.jdict).\n\n Args:\n predn (Dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',\n and 'keypoints' tensors.\n filename (str): Path to the image file for which predictions are being processed.\n\n Notes:\n The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),\n converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner\n before saving to the JSON dictionary.\n \"\"\"\n stem = Path(filename).stem\n image_id = int(stem) if stem.isnumeric() else stem\n box = ops.xyxy2xywh(predn[\"bboxes\"]) # xywh\n box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner\n for b, s, c, k in zip(\n box.tolist(),\n predn[\"conf\"].tolist(),\n predn[\"cls\"].tolist(),\n predn[\"keypoints\"].flatten(1, 2).tolist(),\n ):\n self.jdict.append(\n {\n \"image_id\": image_id,\n \"category_id\": self.class_map[int(c)],\n \"bbox\": [round(x, 3) for x in b],\n \"keypoints\": k,\n \"score\": round(s, 5),\n }\n )\n\n def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Evaluate object detection model using COCO JSON format.\"\"\"\n anno_json = self.data[\"path\"] / \"annotations/person_keypoints_val2017.json\" # annotations\n pred_json = self.save_dir / \"predictions.json\" # predictions\n return super().coco_evaluate(stats, pred_json, anno_json, [\"bbox\", \"keypoints\"], suffix=[\"Box\", \"Pose\"])",
"chunk_type": "class",
"name": "PoseValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\val.py",
"start_line": 14,
"end_line": 293,
"start_col": 0,
"end_col": 112,
"parent_name": null,
"docstring": "A class extending the DetectionValidator class for validation based on a pose model.\n\nThis validator is specifically designed for pose estimation tasks, handling keypoints and implementing\nspecialized metrics for pose evaluation.\n\nAttributes:\n sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.\n kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.\n args (dict): Arguments for the validator including task set to \"pose\".\n metrics (PoseMetrics): Metrics object for pose evaluation.\n\nMethods:\n preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.\n get_desc: Return description of evaluation metrics in string format.\n init_metrics: Initialize pose estimation metrics for YOLO model.\n _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original\n dimensions.\n _prepare_pred: Prepare and scale keypoints in predictions for pose processing.\n _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between\n detections and ground truth.\n plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.\n plot_predictions: Plot and save model predictions with bounding boxes and keypoints.\n save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.\n pred_to_json: Convert YOLO predictions to COCO JSON format.\n eval_json: Evaluate object detection model using COCO JSON format.\n\nExamples:\n >>> from ultralytics.models.yolo.pose import PoseValidator\n >>> args = dict(model=\"yolo11n-pose.pt\", data=\"coco8-pose.yaml\")\n >>> validator = PoseValidator(args=args)\n >>> validator()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Tuple",
"numpy",
"torch",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.utils.LOGGER",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.OKS_SIGMA",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.kpt_iou",
"ultralytics.engine.results.Results",
"DetectionValidator"
],
"chunk_id": "class_PoseValidator_b7ce6f5d"
},
{
"content": "from .predict import PosePredictor",
"chunk_type": "import",
"name": "PosePredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_PosePredictor_c2d3074f"
},
{
"content": "from .train import PoseTrainer",
"chunk_type": "import",
"name": "PoseTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_PoseTrainer_84be1717"
},
{
"content": "from .val import PoseValidator",
"chunk_type": "import",
"name": "PoseValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_PoseValidator_f05dff60"
},
{
"content": "__all__ = \"PoseTrainer\", \"PoseValidator\", \"PosePredictor\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\pose\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 57,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___a8d8fd20"
},
{
"content": "from ultralytics.engine.results import Results",
"chunk_type": "import",
"name": "Results",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Results_1f6f7c59"
},
{
"content": "from ultralytics.models.yolo.detect.predict import DetectionPredictor",
"chunk_type": "import",
"name": "DetectionPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\predict.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 69,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionPredictor_6d55ce8c"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, ops",
"chunk_type": "import",
"name": "DEFAULT_CFG, ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\predict.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, ops_c9819379"
},
{
"content": "class SegmentationPredictor(DetectionPredictor):\n \"\"\"\n A class extending the DetectionPredictor class for prediction based on a segmentation model.\n\n This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the\n prediction results.\n\n Attributes:\n args (dict): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO segmentation model.\n batch (list): Current batch of images being processed.\n\n Methods:\n postprocess: Apply non-max suppression and process segmentation detections.\n construct_results: Construct a list of result objects from predictions.\n construct_result: Construct a single result object from a prediction.\n\n Examples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.segment import SegmentationPredictor\n >>> args = dict(model=\"yolo11n-seg.pt\", source=ASSETS)\n >>> predictor = SegmentationPredictor(overrides=args)\n >>> predictor.predict_cli()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize the SegmentationPredictor with configuration, overrides, and callbacks.\n\n This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the\n prediction results.\n\n Args:\n cfg (dict): Configuration for the predictor.\n overrides (dict, optional): Configuration overrides that take precedence over cfg.\n _callbacks (list, optional): List of callback functions to be invoked during prediction.\n \"\"\"\n super().__init__(cfg, overrides, _callbacks)\n self.args.task = \"segment\"\n\n def postprocess(self, preds, img, orig_imgs):\n \"\"\"\n Apply non-max suppression and process segmentation detections for each image in the input batch.\n\n Args:\n preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.\n img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).\n orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.\n\n Returns:\n (list): List of Results objects containing the segmentation predictions for each image in the batch.\n Each Results object includes both bounding boxes and segmentation masks.\n\n Examples:\n >>> predictor = SegmentationPredictor(overrides=dict(model=\"yolo11n-seg.pt\"))\n >>> results = predictor.postprocess(preds, img, orig_img)\n \"\"\"\n # Extract protos - tuple if PyTorch model or array if exported\n protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]\n return super().postprocess(preds[0], img, orig_imgs, protos=protos)\n\n def construct_results(self, preds, img, orig_imgs, protos):\n \"\"\"\n Construct a list of result objects from the predictions.\n\n Args:\n preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.\n img (torch.Tensor): The image after preprocessing.\n orig_imgs (List[np.ndarray]): List of original images before preprocessing.\n protos (List[torch.Tensor]): List of prototype masks.\n\n Returns:\n (List[Results]): List of result objects containing the original images, image paths, class names,\n bounding boxes, and masks.\n \"\"\"\n return [\n self.construct_result(pred, img, orig_img, img_path, proto)\n for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)\n ]\n\n def construct_result(self, pred, img, orig_img, img_path, proto):\n \"\"\"\n Construct a single result object from the prediction.\n\n Args:\n pred (np.ndarray): The predicted bounding boxes, scores, and masks.\n img (torch.Tensor): The image after preprocessing.\n orig_img (np.ndarray): The original image before preprocessing.\n img_path (str): The path to the original image.\n proto (torch.Tensor): The prototype masks.\n\n Returns:\n (Results): Result object containing the original image, image path, class names, bounding boxes, and masks.\n \"\"\"\n if not len(pred): # save empty boxes\n masks = None\n elif self.args.retina_masks:\n pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)\n masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC\n else:\n masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC\n pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)\n if masks is not None:\n keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks\n pred, masks = pred[keep], masks[keep]\n return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)",
"chunk_type": "class",
"name": "SegmentationPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\predict.py",
"start_line": 8,
"end_line": 113,
"start_col": 0,
"end_col": 103,
"parent_name": null,
"docstring": "A class extending the DetectionPredictor class for prediction based on a segmentation model.\n\nThis class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the\nprediction results.\n\nAttributes:\n args (dict): Configuration arguments for the predictor.\n model (torch.nn.Module): The loaded YOLO segmentation model.\n batch (list): Current batch of images being processed.\n\nMethods:\n postprocess: Apply non-max suppression and process segmentation detections.\n construct_results: Construct a list of result objects from predictions.\n construct_result: Construct a single result object from a prediction.\n\nExamples:\n >>> from ultralytics.utils import ASSETS\n >>> from ultralytics.models.yolo.segment import SegmentationPredictor\n >>> args = dict(model=\"yolo11n-seg.pt\", source=ASSETS)\n >>> predictor = SegmentationPredictor(overrides=args)\n >>> predictor.predict_cli()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"ultralytics.engine.results.Results",
"ultralytics.models.yolo.detect.predict.DetectionPredictor",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.ops",
"DetectionPredictor"
],
"chunk_id": "class_SegmentationPredictor_bd7c5e5e"
},
{
"content": "from copy import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_caaa5ce4"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_dbbdf7a2"
},
{
"content": "from typing import Dict, Optional, Union",
"chunk_type": "import",
"name": "Dict, Optional, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Dict, Optional, Union_bbea9a10"
},
{
"content": "from ultralytics.models import yolo",
"chunk_type": "import",
"name": "yolo",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_yolo_310640d9"
},
{
"content": "from ultralytics.nn.tasks import SegmentationModel",
"chunk_type": "import",
"name": "SegmentationModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationModel_54884489"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, RANK",
"chunk_type": "import",
"name": "DEFAULT_CFG, RANK",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, RANK_62644128"
},
{
"content": "from ultralytics.utils.plotting import plot_results",
"chunk_type": "import",
"name": "plot_results",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_plot_results_70bc1128"
},
{
"content": "class SegmentationTrainer(yolo.detect.DetectionTrainer):\n \"\"\"\n A class extending the DetectionTrainer class for training based on a segmentation model.\n\n This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific\n functionality including model initialization, validation, and visualization.\n\n Attributes:\n loss_names (Tuple[str]): Names of the loss components used during training.\n\n Examples:\n >>> from ultralytics.models.yolo.segment import SegmentationTrainer\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\", epochs=3)\n >>> trainer = SegmentationTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict] = None, _callbacks=None):\n \"\"\"\n Initialize a SegmentationTrainer object.\n\n This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific\n functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.\n\n Args:\n cfg (dict): Configuration dictionary with default training settings.\n overrides (dict, optional): Dictionary of parameter overrides for the default configuration.\n _callbacks (list, optional): List of callback functions to be executed during training.\n\n Examples:\n >>> from ultralytics.models.yolo.segment import SegmentationTrainer\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\", epochs=3)\n >>> trainer = SegmentationTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"task\"] = \"segment\"\n super().__init__(cfg, overrides, _callbacks)\n\n def get_model(\n self, cfg: Optional[Union[Dict, str]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True\n ):\n \"\"\"\n Initialize and return a SegmentationModel with specified configuration and weights.\n\n Args:\n cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.\n weights (str | Path, optional): Path to pretrained weights file.\n verbose (bool): Whether to display model information during initialization.\n\n Returns:\n (SegmentationModel): Initialized segmentation model with loaded weights if specified.\n\n Examples:\n >>> trainer = SegmentationTrainer()\n >>> model = trainer.get_model(cfg=\"yolo11n-seg.yaml\")\n >>> model = trainer.get_model(weights=\"yolo11n-seg.pt\", verbose=False)\n \"\"\"\n model = SegmentationModel(cfg, nc=self.data[\"nc\"], ch=self.data[\"channels\"], verbose=verbose and RANK == -1)\n if weights:\n model.load(weights)\n\n return model\n\n def get_validator(self):\n \"\"\"Return an instance of SegmentationValidator for validation of YOLO model.\"\"\"\n self.loss_names = \"box_loss\", \"seg_loss\", \"cls_loss\", \"dfl_loss\"\n return yolo.segment.SegmentationValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def plot_metrics(self):\n \"\"\"Plot training/validation metrics.\"\"\"\n plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png",
"chunk_type": "class",
"name": "SegmentationTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\train.py",
"start_line": 13,
"end_line": 87,
"start_col": 0,
"end_col": 71,
"parent_name": null,
"docstring": "A class extending the DetectionTrainer class for training based on a segmentation model.\n\nThis trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific\nfunctionality including model initialization, validation, and visualization.\n\nAttributes:\n loss_names (Tuple[str]): Names of the loss components used during training.\n\nExamples:\n >>> from ultralytics.models.yolo.segment import SegmentationTrainer\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\", epochs=3)\n >>> trainer = SegmentationTrainer(overrides=args)\n >>> trainer.train()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"pathlib.Path",
"typing.Dict",
"typing.Optional",
"typing.Union",
"ultralytics.models.yolo",
"ultralytics.nn.tasks.SegmentationModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.RANK",
"ultralytics.utils.plotting.plot_results",
"yolo.detect.DetectionTrainer"
],
"chunk_id": "class_SegmentationTrainer_cbc8899e"
},
{
"content": "from multiprocessing.pool import ThreadPool",
"chunk_type": "import",
"name": "ThreadPool",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ThreadPool_06f926f7"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_d8db4424"
},
{
"content": "from typing import Any, Dict, List, Tuple",
"chunk_type": "import",
"name": "Any, Dict, List, Tuple",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Tuple_48b48602"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_f1554a43"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_b3575745"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_660eeb7e"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionValidator",
"chunk_type": "import",
"name": "DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionValidator_eddf39dd"
},
{
"content": "from ultralytics.utils import LOGGER, NUM_THREADS, ops",
"chunk_type": "import",
"name": "LOGGER, NUM_THREADS, ops",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, NUM_THREADS, ops_5c13f41b"
},
{
"content": "from ultralytics.utils.checks import check_requirements",
"chunk_type": "import",
"name": "check_requirements",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_requirements_40a3b414"
},
{
"content": "from ultralytics.utils.metrics import SegmentMetrics, mask_iou",
"chunk_type": "import",
"name": "SegmentMetrics, mask_iou",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 62,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentMetrics, mask_iou_743469aa"
},
{
"content": "class SegmentationValidator(DetectionValidator):\n \"\"\"\n A class extending the DetectionValidator class for validation based on a segmentation model.\n\n This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions\n to compute metrics such as mAP for both detection and segmentation tasks.\n\n Attributes:\n plot_masks (list): List to store masks for plotting.\n process (callable): Function to process masks based on save_json and save_txt flags.\n args (namespace): Arguments for the validator.\n metrics (SegmentMetrics): Metrics calculator for segmentation tasks.\n stats (dict): Dictionary to store statistics during validation.\n\n Examples:\n >>> from ultralytics.models.yolo.segment import SegmentationValidator\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\")\n >>> validator = SegmentationValidator(args=args)\n >>> validator()\n \"\"\"\n\n def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:\n \"\"\"\n Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.\n\n Args:\n dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.\n save_dir (Path, optional): Directory to save results.\n args (namespace, optional): Arguments for the validator.\n _callbacks (list, optional): List of callback functions.\n \"\"\"\n super().__init__(dataloader, save_dir, args, _callbacks)\n self.process = None\n self.args.task = \"segment\"\n self.metrics = SegmentMetrics()\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Preprocess batch of images for YOLO segmentation validation.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Preprocessed batch.\n \"\"\"\n batch = super().preprocess(batch)\n batch[\"masks\"] = batch[\"masks\"].to(self.device).float()\n return batch\n\n def init_metrics(self, model: torch.nn.Module) -> None:\n \"\"\"\n Initialize metrics and select mask processing function based on save_json flag.\n\n Args:\n model (torch.nn.Module): Model to validate.\n \"\"\"\n super().init_metrics(model)\n if self.args.save_json:\n check_requirements(\"faster-coco-eval>=1.6.7\")\n # More accurate vs faster\n self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask\n\n def get_desc(self) -> str:\n \"\"\"Return a formatted description of evaluation metrics.\"\"\"\n return (\"%22s\" + \"%11s\" * 10) % (\n \"Class\",\n \"Images\",\n \"Instances\",\n \"Box(P\",\n \"R\",\n \"mAP50\",\n \"mAP50-95)\",\n \"Mask(P\",\n \"R\",\n \"mAP50\",\n \"mAP50-95)\",\n )\n\n def postprocess(self, preds: List[torch.Tensor]) -> List[Dict[str, torch.Tensor]]:\n \"\"\"\n Post-process YOLO predictions and return output detections with proto.\n\n Args:\n preds (List[torch.Tensor]): Raw predictions from the model.\n\n Returns:\n List[Dict[str, torch.Tensor]]: Processed detection predictions with masks.\n \"\"\"\n proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported\n preds = super().postprocess(preds[0])\n imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto\n for i, pred in enumerate(preds):\n coefficient = pred.pop(\"extra\")\n pred[\"masks\"] = (\n self.process(proto[i], coefficient, pred[\"bboxes\"], shape=imgsz)\n if len(coefficient)\n else torch.zeros(\n (0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),\n dtype=torch.uint8,\n device=pred[\"bboxes\"].device,\n )\n )\n return preds\n\n def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"\n Prepare a batch for training or inference by processing images and targets.\n\n Args:\n si (int): Batch index.\n batch (Dict[str, Any]): Batch data containing images and annotations.\n\n Returns:\n (Dict[str, Any]): Prepared batch with processed annotations.\n \"\"\"\n prepared_batch = super()._prepare_batch(si, batch)\n midx = [si] if self.args.overlap_mask else batch[\"batch_idx\"] == si\n prepared_batch[\"masks\"] = batch[\"masks\"][midx]\n return prepared_batch\n\n def _prepare_pred(self, pred: Dict[str, torch.Tensor], pbatch: Dict[str, Any]) -> Dict[str, torch.Tensor]:\n \"\"\"\n Prepare predictions for evaluation by processing bounding boxes and masks.\n\n Args:\n pred (Dict[str, torch.Tensor]): Post-processed predictions from the model.\n pbatch (Dict[str, Any]): Prepared batch information.\n\n Returns:\n Dict[str, torch.Tensor]: Processed bounding box predictions.\n \"\"\"\n predn = super()._prepare_pred(pred, pbatch)\n predn[\"masks\"] = pred[\"masks\"]\n if self.args.save_json and len(predn[\"masks\"]):\n coco_masks = torch.as_tensor(pred[\"masks\"], dtype=torch.uint8)\n coco_masks = ops.scale_image(\n coco_masks.permute(1, 2, 0).contiguous().cpu().numpy(),\n pbatch[\"ori_shape\"],\n ratio_pad=pbatch[\"ratio_pad\"],\n )\n predn[\"coco_masks\"] = coco_masks\n return predn\n\n def _process_batch(self, preds: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, np.ndarray]:\n \"\"\"\n Compute correct prediction matrix for a batch based on bounding boxes and optional masks.\n\n Args:\n preds (Dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.\n batch (Dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.\n\n Returns:\n (Dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.\n\n Notes:\n - If `masks` is True, the function computes IoU between predicted and ground truth masks.\n - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.\n\n Examples:\n >>> preds = {\"cls\": torch.tensor([1, 0]), \"masks\": torch.rand(2, 640, 640), \"bboxes\": torch.rand(2, 4)}\n >>> batch = {\"cls\": torch.tensor([1, 0]), \"masks\": torch.rand(2, 640, 640), \"bboxes\": torch.rand(2, 4)}\n >>> correct_preds = validator._process_batch(preds, batch)\n \"\"\"\n tp = super()._process_batch(preds, batch)\n gt_cls, gt_masks = batch[\"cls\"], batch[\"masks\"]\n if len(gt_cls) == 0 or len(preds[\"cls\"]) == 0:\n tp_m = np.zeros((len(preds[\"cls\"]), self.niou), dtype=bool)\n else:\n pred_masks = preds[\"masks\"]\n if self.args.overlap_mask:\n nl = len(gt_cls)\n index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1\n gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)\n gt_masks = torch.where(gt_masks == index, 1.0, 0.0)\n if gt_masks.shape[1:] != pred_masks.shape[1:]:\n gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode=\"bilinear\", align_corners=False)[0]\n gt_masks = gt_masks.gt_(0.5)\n iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))\n tp_m = self.match_predictions(preds[\"cls\"], gt_cls, iou).cpu().numpy()\n tp.update({\"tp_m\": tp_m}) # update tp with mask IoU\n return tp\n\n def plot_predictions(self, batch: Dict[str, Any], preds: List[Dict[str, torch.Tensor]], ni: int) -> None:\n \"\"\"\n Plot batch predictions with masks and bounding boxes.\n\n Args:\n batch (Dict[str, Any]): Batch containing images and annotations.\n preds (List[Dict[str, torch.Tensor]]): List of predictions from the model.\n ni (int): Batch index.\n \"\"\"\n for p in preds:\n masks = p[\"masks\"]\n if masks.shape[0] > 50:\n LOGGER.warning(\"Limiting validation plots to first 50 items per image for speed...\")\n p[\"masks\"] = torch.as_tensor(masks[:50], dtype=torch.uint8).cpu()\n super().plot_predictions(batch, preds, ni, max_det=50) # plot bboxes\n\n def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path) -> None:\n \"\"\"\n Save YOLO detections to a txt file in normalized coordinates in a specific format.\n\n Args:\n predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).\n save_conf (bool): Whether to save confidence scores.\n shape (Tuple[int, int]): Shape of the original image.\n file (Path): File path to save the detections.\n \"\"\"\n from ultralytics.engine.results import Results\n\n Results(\n np.zeros((shape[0], shape[1]), dtype=np.uint8),\n path=None,\n names=self.names,\n boxes=torch.cat([predn[\"bboxes\"], predn[\"conf\"].unsqueeze(-1), predn[\"cls\"].unsqueeze(-1)], dim=1),\n masks=torch.as_tensor(predn[\"masks\"], dtype=torch.uint8),\n ).save_txt(file, save_conf=save_conf)\n\n def pred_to_json(self, predn: torch.Tensor, filename: str) -> None:\n \"\"\"\n Save one JSON result for COCO evaluation.\n\n Args:\n predn (Dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.\n filename (str): Image filename.\n\n Examples:\n >>> result = {\"image_id\": 42, \"category_id\": 18, \"bbox\": [258.15, 41.29, 348.26, 243.78], \"score\": 0.236}\n \"\"\"\n from faster_coco_eval.core.mask import encode # noqa\n\n def single_encode(x):\n \"\"\"Encode predicted masks as RLE and append results to jdict.\"\"\"\n rle = encode(np.asarray(x[:, :, None], order=\"F\", dtype=\"uint8\"))[0]\n rle[\"counts\"] = rle[\"counts\"].decode(\"utf-8\")\n return rle\n\n stem = Path(filename).stem\n image_id = int(stem) if stem.isnumeric() else stem\n box = ops.xyxy2xywh(predn[\"bboxes\"]) # xywh\n box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner\n pred_masks = np.transpose(predn[\"coco_masks\"], (2, 0, 1))\n with ThreadPool(NUM_THREADS) as pool:\n rles = pool.map(single_encode, pred_masks)\n for i, (b, s, c) in enumerate(zip(box.tolist(), predn[\"conf\"].tolist(), predn[\"cls\"].tolist())):\n self.jdict.append(\n {\n \"image_id\": image_id,\n \"category_id\": self.class_map[int(c)],\n \"bbox\": [round(x, 3) for x in b],\n \"score\": round(s, 5),\n \"segmentation\": rles[i],\n }\n )\n\n def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Return COCO-style instance segmentation evaluation metrics.\"\"\"\n pred_json = self.save_dir / \"predictions.json\" # predictions\n anno_json = (\n self.data[\"path\"]\n / \"annotations\"\n / (\"instances_val2017.json\" if self.is_coco else f\"lvis_v1_{self.args.split}.json\")\n ) # annotations\n return super().coco_evaluate(stats, pred_json, anno_json, [\"bbox\", \"segm\"], suffix=[\"Box\", \"Mask\"])",
"chunk_type": "class",
"name": "SegmentationValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\val.py",
"start_line": 17,
"end_line": 281,
"start_col": 0,
"end_col": 107,
"parent_name": null,
"docstring": "A class extending the DetectionValidator class for validation based on a segmentation model.\n\nThis validator handles the evaluation of segmentation models, processing both bounding box and mask predictions\nto compute metrics such as mAP for both detection and segmentation tasks.\n\nAttributes:\n plot_masks (list): List to store masks for plotting.\n process (callable): Function to process masks based on save_json and save_txt flags.\n args (namespace): Arguments for the validator.\n metrics (SegmentMetrics): Metrics calculator for segmentation tasks.\n stats (dict): Dictionary to store statistics during validation.\n\nExamples:\n >>> from ultralytics.models.yolo.segment import SegmentationValidator\n >>> args = dict(model=\"yolo11n-seg.pt\", data=\"coco8-seg.yaml\")\n >>> validator = SegmentationValidator(args=args)\n >>> validator()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"multiprocessing.pool.ThreadPool",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Tuple",
"numpy",
"torch",
"torch.nn.functional",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.utils.LOGGER",
"ultralytics.utils.NUM_THREADS",
"ultralytics.utils.ops",
"ultralytics.utils.checks.check_requirements",
"ultralytics.utils.metrics.SegmentMetrics",
"ultralytics.utils.metrics.mask_iou",
"ultralytics.engine.results.Results",
"faster_coco_eval.core.mask.encode",
"DetectionValidator"
],
"chunk_id": "class_SegmentationValidator_6dc7ef2b"
},
{
"content": "from .predict import SegmentationPredictor",
"chunk_type": "import",
"name": "SegmentationPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationPredictor_d11aa346"
},
{
"content": "from .train import SegmentationTrainer",
"chunk_type": "import",
"name": "SegmentationTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationTrainer_0c904de2"
},
{
"content": "from .val import SegmentationValidator",
"chunk_type": "import",
"name": "SegmentationValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationValidator_2396777f"
},
{
"content": "__all__ = \"SegmentationPredictor\", \"SegmentationTrainer\", \"SegmentationValidator\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\segment\\__init__.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 81,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___02a783e6"
},
{
"content": "import itertools",
"chunk_type": "import",
"name": "itertools",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_itertools_b4d994a3"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_6c0c7cc5"
},
{
"content": "from typing import Any, Dict, List, Optional",
"chunk_type": "import",
"name": "Any, Dict, List, Optional",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, List, Optional_bdf62c65"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_7ad04b71"
},
{
"content": "from ultralytics.data import build_yolo_dataset",
"chunk_type": "import",
"name": "build_yolo_dataset",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_build_yolo_dataset_fc6bae89"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionTrainer",
"chunk_type": "import",
"name": "DetectionTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionTrainer_91e8b8a8"
},
{
"content": "from ultralytics.nn.tasks import WorldModel",
"chunk_type": "import",
"name": "WorldModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_WorldModel_ba3f49ae"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK",
"chunk_type": "import",
"name": "DEFAULT_CFG, LOGGER, RANK",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, LOGGER, RANK_52b50460"
},
{
"content": "from ultralytics.utils.torch_utils import de_parallel",
"chunk_type": "import",
"name": "de_parallel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_de_parallel_ed6180ba"
},
{
"content": "def on_pretrain_routine_end(trainer) -> None:\n \"\"\"Set up model classes and text encoder at the end of the pretrain routine.\"\"\"\n if RANK in {-1, 0}:\n # Set class names for evaluation\n names = [name.split(\"/\", 1)[0] for name in list(trainer.test_loader.dataset.data[\"names\"].values())]\n de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)",
"chunk_type": "function",
"name": "on_pretrain_routine_end",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 16,
"end_line": 21,
"start_col": 0,
"end_col": 79,
"parent_name": null,
"docstring": "Set up model classes and text encoder at the end of the pretrain routine.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"itertools",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"torch",
"ultralytics.data.build_yolo_dataset",
"ultralytics.models.yolo.detect.DetectionTrainer",
"ultralytics.nn.tasks.WorldModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.torch_utils.de_parallel"
],
"chunk_id": "function_on_pretrain_routine_end_a2916ba6"
},
{
"content": "class WorldTrainer(DetectionTrainer):\n \"\"\"\n A trainer class for fine-tuning YOLO World models on close-set datasets.\n\n This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual\n features for improved object detection and understanding. It handles text embedding generation and caching to\n accelerate training with multi-modal data.\n\n Attributes:\n text_embeddings (Dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate\n training.\n model (WorldModel): The YOLO World model being trained.\n data (Dict[str, Any]): Dataset configuration containing class information.\n args (Any): Training arguments and configuration.\n\n Methods:\n get_model: Return WorldModel initialized with specified config and weights.\n build_dataset: Build YOLO Dataset for training or validation.\n set_text_embeddings: Set text embeddings for datasets to accelerate training.\n generate_text_embeddings: Generate text embeddings for a list of text samples.\n preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.\n\n Examples:\n Initialize and train a YOLO World model\n >>> from ultralytics.models.yolo.world import WorldTrainer\n >>> args = dict(model=\"yolov8s-world.pt\", data=\"coco8.yaml\", epochs=3)\n >>> trainer = WorldTrainer(overrides=args)\n >>> trainer.train()\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):\n \"\"\"\n Initialize a WorldTrainer object with given arguments.\n\n Args:\n cfg (Dict[str, Any]): Configuration for the trainer.\n overrides (Dict[str, Any], optional): Configuration overrides.\n _callbacks (List[Any], optional): List of callback functions.\n \"\"\"\n if overrides is None:\n overrides = {}\n super().__init__(cfg, overrides, _callbacks)\n self.text_embeddings = None\n\n def get_model(self, cfg=None, weights: Optional[str] = None, verbose: bool = True) -> WorldModel:\n \"\"\"\n Return WorldModel initialized with specified config and weights.\n\n Args:\n cfg (Dict[str, Any] | str, optional): Model configuration.\n weights (str, optional): Path to pretrained weights.\n verbose (bool): Whether to display model info.\n\n Returns:\n (WorldModel): Initialized WorldModel.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = WorldModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=min(self.data[\"nc\"], 80),\n verbose=verbose and RANK == -1,\n )\n if weights:\n model.load(weights)\n self.add_callback(\"on_pretrain_routine_end\", on_pretrain_routine_end)\n\n return model\n\n def build_dataset(self, img_path: str, mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset for training or validation.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.\n batch (int, optional): Size of batches, this is for `rect`.\n\n Returns:\n (Any): YOLO dataset configured for training or validation.\n \"\"\"\n gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)\n dataset = build_yolo_dataset(\n self.args, img_path, batch, self.data, mode=mode, rect=mode == \"val\", stride=gs, multi_modal=mode == \"train\"\n )\n if mode == \"train\":\n self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training\n return dataset\n\n def set_text_embeddings(self, datasets: List[Any], batch: Optional[int]) -> None:\n \"\"\"\n Set text embeddings for datasets to accelerate training by caching category names.\n\n This method collects unique category names from all datasets, then generates and caches text embeddings\n for these categories to improve training efficiency.\n\n Args:\n datasets (List[Any]): List of datasets from which to extract category names.\n batch (int | None): Batch size used for processing.\n\n Notes:\n This method collects category names from datasets that have the 'category_names' attribute,\n then uses the first dataset's image path to determine where to cache the generated text embeddings.\n \"\"\"\n text_embeddings = {}\n for dataset in datasets:\n if not hasattr(dataset, \"category_names\"):\n continue\n text_embeddings.update(\n self.generate_text_embeddings(\n list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent\n )\n )\n self.text_embeddings = text_embeddings\n\n def generate_text_embeddings(self, texts: List[str], batch: int, cache_dir: Path) -> Dict[str, torch.Tensor]:\n \"\"\"\n Generate text embeddings for a list of text samples.\n\n Args:\n texts (List[str]): List of text samples to encode.\n batch (int): Batch size for processing.\n cache_dir (Path): Directory to save/load cached embeddings.\n\n Returns:\n (Dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.\n \"\"\"\n model = \"clip:ViT-B/32\"\n cache_path = cache_dir / f\"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt\"\n if cache_path.exists():\n LOGGER.info(f\"Reading existed cache from '{cache_path}'\")\n txt_map = torch.load(cache_path, map_location=self.device)\n if sorted(txt_map.keys()) == sorted(texts):\n return txt_map\n LOGGER.info(f\"Caching text embeddings to '{cache_path}'\")\n assert self.model is not None\n txt_feats = de_parallel(self.model).get_text_pe(texts, batch, cache_clip_model=False)\n txt_map = dict(zip(texts, txt_feats.squeeze(0)))\n torch.save(txt_map, cache_path)\n return txt_map\n\n def preprocess_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Preprocess a batch of images and text for YOLOWorld training.\"\"\"\n batch = DetectionTrainer.preprocess_batch(self, batch)\n\n # Add text features\n texts = list(itertools.chain(*batch[\"texts\"]))\n txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)\n txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)\n batch[\"txt_feats\"] = txt_feats.reshape(len(batch[\"texts\"]), -1, txt_feats.shape[-1])\n return batch",
"chunk_type": "class",
"name": "WorldTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train.py",
"start_line": 24,
"end_line": 175,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "A trainer class for fine-tuning YOLO World models on close-set datasets.\n\nThis trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual\nfeatures for improved object detection and understanding. It handles text embedding generation and caching to\naccelerate training with multi-modal data.\n\nAttributes:\n text_embeddings (Dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate\n training.\n model (WorldModel): The YOLO World model being trained.\n data (Dict[str, Any]): Dataset configuration containing class information.\n args (Any): Training arguments and configuration.\n\nMethods:\n get_model: Return WorldModel initialized with specified config and weights.\n build_dataset: Build YOLO Dataset for training or validation.\n set_text_embeddings: Set text embeddings for datasets to accelerate training.\n generate_text_embeddings: Generate text embeddings for a list of text samples.\n preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.\n\nExamples:\n Initialize and train a YOLO World model\n >>> from ultralytics.models.yolo.world import WorldTrainer\n >>> args = dict(model=\"yolov8s-world.pt\", data=\"coco8.yaml\", epochs=3)\n >>> trainer = WorldTrainer(overrides=args)\n >>> trainer.train()",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.List",
"typing.Optional",
"torch",
"ultralytics.data.build_yolo_dataset",
"ultralytics.models.yolo.detect.DetectionTrainer",
"ultralytics.nn.tasks.WorldModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.torch_utils.de_parallel",
"DetectionTrainer"
],
"chunk_id": "class_WorldTrainer_dced3b02"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_cebd899c"
},
{
"content": "from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset",
"chunk_type": "import",
"name": "YOLOConcatDataset, build_grounding, build_yolo_dataset",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 83,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOConcatDataset, build_grounding, build_yolo_dataset_cf0238a8"
},
{
"content": "from ultralytics.data.utils import check_det_dataset",
"chunk_type": "import",
"name": "check_det_dataset",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_det_dataset_581e0267"
},
{
"content": "from ultralytics.models.yolo.world import WorldTrainer",
"chunk_type": "import",
"name": "WorldTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 54,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_WorldTrainer_afd3fd51"
},
{
"content": "from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER",
"chunk_type": "import",
"name": "DATASETS_DIR, DEFAULT_CFG, LOGGER",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DATASETS_DIR, DEFAULT_CFG, LOGGER_29233daf"
},
{
"content": "from ultralytics.utils.torch_utils import de_parallel",
"chunk_type": "import",
"name": "de_parallel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_de_parallel_c46092b2"
},
{
"content": "class WorldTrainerFromScratch(WorldTrainer):\n \"\"\"\n A class extending the WorldTrainer for training a world model from scratch on open-set datasets.\n\n This trainer specializes in handling mixed datasets including both object detection and grounding datasets,\n supporting training YOLO-World models with combined vision-language capabilities.\n\n Attributes:\n cfg (dict): Configuration dictionary with default parameters for model training.\n overrides (dict): Dictionary of parameter overrides to customize the configuration.\n _callbacks (list): List of callback functions to be executed during different stages of training.\n data (dict): Final processed data configuration containing train/val paths and metadata.\n training_data (dict): Dictionary mapping training dataset paths to their configurations.\n\n Methods:\n build_dataset: Build YOLO Dataset for training or validation with mixed dataset support.\n get_dataset: Get train and validation paths from data dictionary.\n plot_training_labels: Skip label plotting for YOLO-World training.\n final_eval: Perform final evaluation and validation for the YOLO-World model.\n\n Examples:\n >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch\n >>> from ultralytics import YOLOWorld\n >>> data = dict(\n ... train=dict(\n ... yolo_data=[\"Objects365.yaml\"],\n ... grounding_data=[\n ... dict(\n ... img_path=\"flickr30k/images\",\n ... json_file=\"flickr30k/final_flickr_separateGT_train.json\",\n ... ),\n ... dict(\n ... img_path=\"GQA/images\",\n ... json_file=\"GQA/final_mixed_train_no_coco.json\",\n ... ),\n ... ],\n ... ),\n ... val=dict(yolo_data=[\"lvis.yaml\"]),\n ... )\n >>> model = YOLOWorld(\"yolov8s-worldv2.yaml\")\n >>> model.train(data=data, trainer=WorldTrainerFromScratch)\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):\n \"\"\"\n Initialize a WorldTrainerFromScratch object.\n\n This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both\n object detection and grounding datasets for vision-language capabilities.\n\n Args:\n cfg (dict): Configuration dictionary with default parameters for model training.\n overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.\n _callbacks (list, optional): List of callback functions to be executed during different stages of training.\n\n Examples:\n >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch\n >>> from ultralytics import YOLOWorld\n >>> data = dict(\n ... train=dict(\n ... yolo_data=[\"Objects365.yaml\"],\n ... grounding_data=[\n ... dict(\n ... img_path=\"flickr30k/images\",\n ... json_file=\"flickr30k/final_flickr_separateGT_train.json\",\n ... ),\n ... ],\n ... ),\n ... val=dict(yolo_data=[\"lvis.yaml\"]),\n ... )\n >>> model = YOLOWorld(\"yolov8s-worldv2.yaml\")\n >>> model.train(data=data, trainer=WorldTrainerFromScratch)\n \"\"\"\n if overrides is None:\n overrides = {}\n super().__init__(cfg, overrides, _callbacks)\n\n def build_dataset(self, img_path, mode=\"train\", batch=None):\n \"\"\"\n Build YOLO Dataset for training or validation.\n\n This method constructs appropriate datasets based on the mode and input paths, handling both\n standard YOLO datasets and grounding datasets with different formats.\n\n Args:\n img_path (List[str] | str): Path to the folder containing images or list of paths.\n mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.\n batch (int, optional): Size of batches, used for rectangular training/validation.\n\n Returns:\n (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.\n \"\"\"\n gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)\n if mode != \"train\":\n return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=gs)\n datasets = [\n build_yolo_dataset(self.args, im_path, batch, self.training_data[im_path], stride=gs, multi_modal=True)\n if isinstance(im_path, str)\n else build_grounding(\n # assign `nc` from validation set to max number of text samples for training consistency\n self.args,\n im_path[\"img_path\"],\n im_path[\"json_file\"],\n batch,\n stride=gs,\n max_samples=self.data[\"nc\"],\n )\n for im_path in img_path\n ]\n self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training\n return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]\n\n def get_dataset(self):\n \"\"\"\n Get train and validation paths from data dictionary.\n\n Processes the data configuration to extract paths for training and validation datasets,\n handling both YOLO detection datasets and grounding datasets.\n\n Returns:\n train_path (str): Train dataset path.\n val_path (str): Validation dataset path.\n\n Raises:\n AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.\n \"\"\"\n final_data = {}\n data_yaml = self.args.data\n assert data_yaml.get(\"train\", False), \"train dataset not found\" # object365.yaml\n assert data_yaml.get(\"val\", False), \"validation dataset not found\" # lvis.yaml\n data = {k: [check_det_dataset(d) for d in v.get(\"yolo_data\", [])] for k, v in data_yaml.items()}\n assert len(data[\"val\"]) == 1, f\"Only support validating on 1 dataset for now, but got {len(data['val'])}.\"\n val_split = \"minival\" if \"lvis\" in data[\"val\"][0][\"val\"] else \"val\"\n for d in data[\"val\"]:\n if d.get(\"minival\") is None: # for lvis dataset\n continue\n d[\"minival\"] = str(d[\"path\"] / d[\"minival\"])\n for s in {\"train\", \"val\"}:\n final_data[s] = [d[\"train\" if s == \"train\" else val_split] for d in data[s]]\n # save grounding data if there's one\n grounding_data = data_yaml[s].get(\"grounding_data\")\n if grounding_data is None:\n continue\n grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]\n for g in grounding_data:\n assert isinstance(g, dict), f\"Grounding data should be provided in dict format, but got {type(g)}\"\n for k in {\"img_path\", \"json_file\"}:\n path = Path(g[k])\n if not path.exists() and not path.is_absolute():\n g[k] = str((DATASETS_DIR / g[k]).resolve()) # path relative to DATASETS_DIR\n final_data[s] += grounding_data\n # assign the first val dataset as currently only one validation set is supported\n data[\"val\"] = data[\"val\"][0]\n final_data[\"val\"] = final_data[\"val\"][0]\n # NOTE: to make training work properly, set `nc` and `names`\n final_data[\"nc\"] = data[\"val\"][\"nc\"]\n final_data[\"names\"] = data[\"val\"][\"names\"]\n # NOTE: add path with lvis path\n final_data[\"path\"] = data[\"val\"][\"path\"]\n final_data[\"channels\"] = data[\"val\"][\"channels\"]\n self.data = final_data\n if self.args.single_cls: # consistent with base trainer\n LOGGER.info(\"Overriding class names with single class.\")\n self.data[\"names\"] = {0: \"object\"}\n self.data[\"nc\"] = 1\n self.training_data = {}\n for d in data[\"train\"]:\n if self.args.single_cls:\n d[\"names\"] = {0: \"object\"}\n d[\"nc\"] = 1\n self.training_data[d[\"train\"]] = d\n return final_data\n\n def plot_training_labels(self):\n \"\"\"Skip label plotting for YOLO-World training.\"\"\"\n pass\n\n def final_eval(self):\n \"\"\"\n Perform final evaluation and validation for the YOLO-World model.\n\n Configures the validator with appropriate dataset and split information before running evaluation.\n\n Returns:\n (dict): Dictionary containing evaluation metrics and results.\n \"\"\"\n val = self.args.data[\"val\"][\"yolo_data\"][0]\n self.validator.args.data = val\n self.validator.args.split = \"minival\" if isinstance(val, str) and \"lvis\" in val else \"val\"\n return super().final_eval()",
"chunk_type": "class",
"name": "WorldTrainerFromScratch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\train_world.py",
"start_line": 12,
"end_line": 201,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": "A class extending the WorldTrainer for training a world model from scratch on open-set datasets.\n\nThis trainer specializes in handling mixed datasets including both object detection and grounding datasets,\nsupporting training YOLO-World models with combined vision-language capabilities.\n\nAttributes:\n cfg (dict): Configuration dictionary with default parameters for model training.\n overrides (dict): Dictionary of parameter overrides to customize the configuration.\n _callbacks (list): List of callback functions to be executed during different stages of training.\n data (dict): Final processed data configuration containing train/val paths and metadata.\n training_data (dict): Dictionary mapping training dataset paths to their configurations.\n\nMethods:\n build_dataset: Build YOLO Dataset for training or validation with mixed dataset support.\n get_dataset: Get train and validation paths from data dictionary.\n plot_training_labels: Skip label plotting for YOLO-World training.\n final_eval: Perform final evaluation and validation for the YOLO-World model.\n\nExamples:\n >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch\n >>> from ultralytics import YOLOWorld\n >>> data = dict(\n ... train=dict(\n ... yolo_data=[\"Objects365.yaml\"],\n ... grounding_data=[\n ... dict(\n ... img_path=\"flickr30k/images\",\n ... json_file=\"flickr30k/final_flickr_separateGT_train.json\",\n ... ),\n ... dict(\n ... img_path=\"GQA/images\",\n ... json_file=\"GQA/final_mixed_train_no_coco.json\",\n ... ),\n ... ],\n ... ),\n ... val=dict(yolo_data=[\"lvis.yaml\"]),\n ... )\n >>> model = YOLOWorld(\"yolov8s-worldv2.yaml\")\n >>> model.train(data=data, trainer=WorldTrainerFromScratch)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"pathlib.Path",
"ultralytics.data.YOLOConcatDataset",
"ultralytics.data.build_grounding",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.utils.check_det_dataset",
"ultralytics.models.yolo.world.WorldTrainer",
"ultralytics.utils.DATASETS_DIR",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.torch_utils.de_parallel",
"WorldTrainer"
],
"chunk_id": "class_WorldTrainerFromScratch_777a0ffc"
},
{
"content": "from .train import WorldTrainer",
"chunk_type": "import",
"name": "WorldTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_WorldTrainer_2412de85"
},
{
"content": "__all__ = [\"WorldTrainer\"]",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\world\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___aac26038"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_492ba178"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_cae74323"
},
{
"content": "from ultralytics.data.augment import LoadVisualPrompt",
"chunk_type": "import",
"name": "LoadVisualPrompt",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LoadVisualPrompt_636b68f2"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionPredictor",
"chunk_type": "import",
"name": "DetectionPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionPredictor_52754e1d"
},
{
"content": "from ultralytics.models.yolo.segment import SegmentationPredictor",
"chunk_type": "import",
"name": "SegmentationPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationPredictor_504c32b1"
},
{
"content": "class YOLOEVPDetectPredictor(DetectionPredictor):\n \"\"\"\n A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.\n\n This mixin provides common functionality for YOLO models that use visual prompting, including\n model setup, prompt handling, and preprocessing transformations.\n\n Attributes:\n model (torch.nn.Module): The YOLO model for inference.\n device (torch.device): Device to run the model on (CPU or CUDA).\n prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.\n\n Methods:\n setup_model: Initialize the YOLO model and set it to evaluation mode.\n set_prompts: Set the visual prompts for the model.\n pre_transform: Preprocess images and prompts before inference.\n inference: Run inference with visual prompts.\n get_vpe: Process source to get visual prompt embeddings.\n \"\"\"\n\n def setup_model(self, model, verbose: bool = True):\n \"\"\"\n Set up the model for prediction.\n\n Args:\n model (torch.nn.Module): Model to load or use.\n verbose (bool, optional): If True, provides detailed logging.\n \"\"\"\n super().setup_model(model, verbose=verbose)\n self.done_warmup = True\n\n def set_prompts(self, prompts):\n \"\"\"\n Set the visual prompts for the model.\n\n Args:\n prompts (dict): Dictionary containing class indices and bounding boxes or masks.\n Must include a 'cls' key with class indices.\n \"\"\"\n self.prompts = prompts\n\n def pre_transform(self, im):\n \"\"\"\n Preprocess images and prompts before inference.\n\n This method applies letterboxing to the input image and transforms the visual prompts\n (bounding boxes or masks) accordingly.\n\n Args:\n im (list): List containing a single input image.\n\n Returns:\n (list): Preprocessed image ready for model inference.\n\n Raises:\n ValueError: If neither valid bounding boxes nor masks are provided in the prompts.\n \"\"\"\n img = super().pre_transform(im)\n bboxes = self.prompts.pop(\"bboxes\", None)\n masks = self.prompts.pop(\"masks\", None)\n category = self.prompts[\"cls\"]\n if len(img) == 1:\n visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)\n self.prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)\n else:\n # NOTE: only supports bboxes as prompts for now\n assert bboxes is not None, f\"Expected bboxes, but got {bboxes}!\"\n # NOTE: needs List[np.ndarray]\n assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (\n f\"Expected List[np.ndarray], but got {bboxes}!\"\n )\n assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (\n f\"Expected List[np.ndarray], but got {category}!\"\n )\n assert len(im) == len(category) == len(bboxes), (\n f\"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!\"\n )\n visuals = [\n self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])\n for i in range(len(img))\n ]\n self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)\n\n return img\n\n def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):\n \"\"\"\n Process a single image by resizing bounding boxes or masks and generating visuals.\n\n Args:\n dst_shape (tuple): The target shape (height, width) of the image.\n src_shape (tuple): The original shape (height, width) of the image.\n category (str): The category of the image for visual prompts.\n bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].\n masks (np.ndarray, optional): A list of masks corresponding to the image.\n\n Returns:\n (torch.Tensor): The processed visuals for the image.\n\n Raises:\n ValueError: If neither `bboxes` nor `masks` are provided.\n \"\"\"\n if bboxes is not None and len(bboxes):\n bboxes = np.array(bboxes, dtype=np.float32)\n if bboxes.ndim == 1:\n bboxes = bboxes[None, :]\n # Calculate scaling factor and adjust bounding boxes\n gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new\n bboxes *= gain\n bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)\n bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)\n elif masks is not None:\n # Resize and process masks\n resized_masks = super().pre_transform(masks)\n masks = np.stack(resized_masks) # (N, H, W)\n masks[masks == 114] = 0 # Reset padding values to 0\n else:\n raise ValueError(\"Please provide valid bboxes or masks\")\n\n # Generate visuals using the visual prompt loader\n return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)\n\n def inference(self, im, *args, **kwargs):\n \"\"\"\n Run inference with visual prompts.\n\n Args:\n im (torch.Tensor): Input image tensor.\n *args (Any): Variable length argument list.\n **kwargs (Any): Arbitrary keyword arguments.\n\n Returns:\n (torch.Tensor): Model prediction results.\n \"\"\"\n return super().inference(im, vpe=self.prompts, *args, **kwargs)\n\n def get_vpe(self, source):\n \"\"\"\n Process the source to get the visual prompt embeddings (VPE).\n\n Args:\n source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source\n of the image to make predictions on. Accepts various types including file paths, URLs, PIL\n images, numpy arrays, and torch tensors.\n\n Returns:\n (torch.Tensor): The visual prompt embeddings (VPE) from the model.\n \"\"\"\n self.setup_source(source)\n assert len(self.dataset) == 1, \"get_vpe only supports one image!\"\n for _, im0s, _ in self.dataset:\n im = self.preprocess(im0s)\n return self.model(im, vpe=self.prompts, return_vpe=True)",
"chunk_type": "class",
"name": "YOLOEVPDetectPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py",
"start_line": 11,
"end_line": 163,
"start_col": 0,
"end_col": 68,
"parent_name": null,
"docstring": "A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.\n\nThis mixin provides common functionality for YOLO models that use visual prompting, including\nmodel setup, prompt handling, and preprocessing transformations.\n\nAttributes:\n model (torch.nn.Module): The YOLO model for inference.\n device (torch.device): Device to run the model on (CPU or CUDA).\n prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.\n\nMethods:\n setup_model: Initialize the YOLO model and set it to evaluation mode.\n set_prompts: Set the visual prompts for the model.\n pre_transform: Preprocess images and prompts before inference.\n inference: Run inference with visual prompts.\n get_vpe: Process source to get visual prompt embeddings.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"numpy",
"torch",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.models.yolo.detect.DetectionPredictor",
"ultralytics.models.yolo.segment.SegmentationPredictor",
"DetectionPredictor"
],
"chunk_id": "class_YOLOEVPDetectPredictor_e62fb0e7"
},
{
"content": "class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):\n \"\"\"Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities.\"\"\"\n\n pass",
"chunk_type": "class",
"name": "YOLOEVPSegPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\predict.py",
"start_line": 166,
"end_line": 169,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"numpy",
"torch",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.models.yolo.detect.DetectionPredictor",
"ultralytics.models.yolo.segment.SegmentationPredictor",
"YOLOEVPDetectPredictor",
"SegmentationPredictor"
],
"chunk_id": "class_YOLOEVPSegPredictor_169e6011"
},
{
"content": "import itertools",
"chunk_type": "import",
"name": "itertools",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_itertools_83951bc0"
},
{
"content": "from copy import copy, deepcopy",
"chunk_type": "import",
"name": "copy, deepcopy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy, deepcopy_a95ce580"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_a1b0f115"
},
{
"content": "from typing import Dict, List, Optional, Union",
"chunk_type": "import",
"name": "Dict, List, Optional, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Dict, List, Optional, Union_a4a23fb8"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_6310badd"
},
{
"content": "from ultralytics.data import YOLOConcatDataset, build_yolo_dataset",
"chunk_type": "import",
"name": "YOLOConcatDataset, build_yolo_dataset",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 66,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOConcatDataset, build_yolo_dataset_80e4499d"
},
{
"content": "from ultralytics.data.augment import LoadVisualPrompt",
"chunk_type": "import",
"name": "LoadVisualPrompt",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LoadVisualPrompt_6cfad980"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator",
"chunk_type": "import",
"name": "DetectionTrainer, DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 79,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionTrainer, DetectionValidator_cd51ace3"
},
{
"content": "from ultralytics.nn.tasks import YOLOEModel",
"chunk_type": "import",
"name": "YOLOEModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOEModel_feeb20fc"
},
{
"content": "from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK",
"chunk_type": "import",
"name": "DEFAULT_CFG, LOGGER, RANK",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DEFAULT_CFG, LOGGER, RANK_f06b862c"
},
{
"content": "from ultralytics.utils.torch_utils import de_parallel",
"chunk_type": "import",
"name": "de_parallel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_de_parallel_f537c039"
},
{
"content": "from ..world.train_world import WorldTrainerFromScratch",
"chunk_type": "import",
"name": "WorldTrainerFromScratch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_WorldTrainerFromScratch_d87e7e7f"
},
{
"content": "from .val import YOLOEDetectValidator",
"chunk_type": "import",
"name": "YOLOEDetectValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOEDetectValidator_390eccc6"
},
{
"content": "class YOLOETrainer(DetectionTrainer):\n \"\"\"\n A trainer class for YOLOE object detection models.\n\n This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,\n including custom model initialization, validation, and dataset building with multi-modal support.\n\n Attributes:\n loss_names (tuple): Names of loss components used during training.\n\n Methods:\n get_model: Initialize and return a YOLOEModel with specified configuration.\n get_validator: Return a YOLOEDetectValidator for model validation.\n build_dataset: Build YOLO dataset with multi-modal support for training.\n \"\"\"\n\n def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict] = None, _callbacks=None):\n \"\"\"\n Initialize the YOLOE Trainer with specified configurations.\n\n This method sets up the YOLOE trainer with the provided configuration and overrides, initializing\n the training environment, model, and callbacks for YOLOE object detection training.\n\n Args:\n cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.\n overrides (dict, optional): Dictionary of parameter overrides for the default configuration.\n _callbacks (list, optional): List of callback functions to be applied during training.\n \"\"\"\n if overrides is None:\n overrides = {}\n overrides[\"overlap_mask\"] = False\n super().__init__(cfg, overrides, _callbacks)\n\n def get_model(self, cfg=None, weights=None, verbose: bool = True):\n \"\"\"\n Return a YOLOEModel initialized with the specified configuration and weights.\n\n Args:\n cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,\n a direct path to a YAML file, or None to use default configuration.\n weights (str | Path, optional): Path to pretrained weights file to load into the model.\n verbose (bool): Whether to display model information during initialization.\n\n Returns:\n (YOLOEModel): The initialized YOLOE model.\n\n Notes:\n - The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.\n - The nc parameter here represents the maximum number of different text samples in one image,\n rather than the actual number of classes.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = YOLOEModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=min(self.data[\"nc\"], 80),\n verbose=verbose and RANK == -1,\n )\n if weights:\n model.load(weights)\n\n return model\n\n def get_validator(self):\n \"\"\"Return a YOLOEDetectValidator for YOLOE model validation.\"\"\"\n self.loss_names = \"box\", \"cls\", \"dfl\"\n return YOLOEDetectValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def build_dataset(self, img_path: str, mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset.\n\n Args:\n img_path (str): Path to the folder containing images.\n mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.\n batch (int, optional): Size of batches, this is for rectangular training.\n\n Returns:\n (Dataset): YOLO dataset configured for training or validation.\n \"\"\"\n gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)\n return build_yolo_dataset(\n self.args, img_path, batch, self.data, mode=mode, rect=mode == \"val\", stride=gs, multi_modal=mode == \"train\"\n )",
"chunk_type": "class",
"name": "YOLOETrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 21,
"end_line": 107,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "A trainer class for YOLOE object detection models.\n\nThis class extends DetectionTrainer to provide specialized training functionality for YOLOE models,\nincluding custom model initialization, validation, and dataset building with multi-modal support.\n\nAttributes:\n loss_names (tuple): Names of loss components used during training.\n\nMethods:\n get_model: Initialize and return a YOLOEModel with specified configuration.\n get_validator: Return a YOLOEDetectValidator for model validation.\n build_dataset: Build YOLO dataset with multi-modal support for training.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"copy.copy",
"copy.deepcopy",
"pathlib.Path",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"torch",
"ultralytics.data.YOLOConcatDataset",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.models.yolo.detect.DetectionTrainer",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.torch_utils.de_parallel",
"world.train_world.WorldTrainerFromScratch",
"val.YOLOEDetectValidator",
"DetectionTrainer"
],
"chunk_id": "class_YOLOETrainer_876d5b58"
},
{
"content": "class YOLOEPETrainer(DetectionTrainer):\n \"\"\"\n Fine-tune YOLOE model using linear probing approach.\n\n This trainer freezes most model layers and only trains specific projection layers for efficient\n fine-tuning on new datasets while preserving pretrained features.\n\n Methods:\n get_model: Initialize YOLOEModel with frozen layers except projection layers.\n \"\"\"\n\n def get_model(self, cfg=None, weights=None, verbose: bool = True):\n \"\"\"\n Return YOLOEModel initialized with specified config and weights.\n\n Args:\n cfg (dict | str, optional): Model configuration.\n weights (str, optional): Path to pretrained weights.\n verbose (bool): Whether to display model information.\n\n Returns:\n (YOLOEModel): Initialized model with frozen layers except for specific projection layers.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = YOLOEModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=self.data[\"nc\"],\n verbose=verbose and RANK == -1,\n )\n\n del model.model[-1].savpe\n\n assert weights is not None, \"Pretrained weights must be provided for linear probing.\"\n if weights:\n model.load(weights)\n\n model.eval()\n names = list(self.data[\"names\"].values())\n # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,\n # it'd get correct results as long as loading proper pretrained weights.\n tpe = model.get_text_pe(names)\n model.set_classes(names, tpe)\n model.model[-1].fuse(model.pe) # fuse text embeddings to classify head\n model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)\n model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)\n model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)\n del model.pe\n model.train()\n\n return model",
"chunk_type": "class",
"name": "YOLOEPETrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 110,
"end_line": 161,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Fine-tune YOLOE model using linear probing approach.\n\nThis trainer freezes most model layers and only trains specific projection layers for efficient\nfine-tuning on new datasets while preserving pretrained features.\n\nMethods:\n get_model: Initialize YOLOEModel with frozen layers except projection layers.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"copy.copy",
"copy.deepcopy",
"pathlib.Path",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"torch",
"ultralytics.data.YOLOConcatDataset",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.models.yolo.detect.DetectionTrainer",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.torch_utils.de_parallel",
"world.train_world.WorldTrainerFromScratch",
"val.YOLOEDetectValidator",
"DetectionTrainer"
],
"chunk_id": "class_YOLOEPETrainer_12abeb65"
},
{
"content": "class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):\n \"\"\"\n Train YOLOE models from scratch with text embedding support.\n\n This trainer combines YOLOE training capabilities with world training features, enabling\n training from scratch with text embeddings and grounding datasets.\n\n Methods:\n build_dataset: Build datasets for training with grounding support.\n preprocess_batch: Process batches with text features.\n generate_text_embeddings: Generate and cache text embeddings for training.\n \"\"\"\n\n def build_dataset(self, img_path: Union[List[str], str], mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset for training or validation.\n\n This method constructs appropriate datasets based on the mode and input paths, handling both\n standard YOLO datasets and grounding datasets with different formats.\n\n Args:\n img_path (List[str] | str): Path to the folder containing images or list of paths.\n mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.\n batch (int, optional): Size of batches, used for rectangular training/validation.\n\n Returns:\n (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.\n \"\"\"\n return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)\n\n def preprocess_batch(self, batch):\n \"\"\"Process batch for training, moving text features to the appropriate device.\"\"\"\n batch = DetectionTrainer.preprocess_batch(self, batch)\n\n texts = list(itertools.chain(*batch[\"texts\"]))\n txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)\n txt_feats = txt_feats.reshape(len(batch[\"texts\"]), -1, txt_feats.shape[-1])\n batch[\"txt_feats\"] = txt_feats\n return batch\n\n def generate_text_embeddings(self, texts: List[str], batch: int, cache_dir: Path):\n \"\"\"\n Generate text embeddings for a list of text samples.\n\n Args:\n texts (List[str]): List of text samples to encode.\n batch (int): Batch size for processing.\n cache_dir (Path): Directory to save/load cached embeddings.\n\n Returns:\n (dict): Dictionary mapping text samples to their embeddings.\n \"\"\"\n model = \"mobileclip:blt\"\n cache_path = cache_dir / f\"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt\"\n if cache_path.exists():\n LOGGER.info(f\"Reading existed cache from '{cache_path}'\")\n txt_map = torch.load(cache_path, map_location=self.device)\n if sorted(txt_map.keys()) == sorted(texts):\n return txt_map\n LOGGER.info(f\"Caching text embeddings to '{cache_path}'\")\n assert self.model is not None\n txt_feats = de_parallel(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)\n txt_map = dict(zip(texts, txt_feats.squeeze(0)))\n torch.save(txt_map, cache_path)\n return txt_map",
"chunk_type": "class",
"name": "YOLOETrainerFromScratch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 164,
"end_line": 228,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Train YOLOE models from scratch with text embedding support.\n\nThis trainer combines YOLOE training capabilities with world training features, enabling\ntraining from scratch with text embeddings and grounding datasets.\n\nMethods:\n build_dataset: Build datasets for training with grounding support.\n preprocess_batch: Process batches with text features.\n generate_text_embeddings: Generate and cache text embeddings for training.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"copy.copy",
"copy.deepcopy",
"pathlib.Path",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"torch",
"ultralytics.data.YOLOConcatDataset",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.models.yolo.detect.DetectionTrainer",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.torch_utils.de_parallel",
"world.train_world.WorldTrainerFromScratch",
"val.YOLOEDetectValidator",
"YOLOETrainer",
"WorldTrainerFromScratch"
],
"chunk_id": "class_YOLOETrainerFromScratch_2a19afb9"
},
{
"content": "class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):\n \"\"\"\n Train prompt-free YOLOE model.\n\n This trainer combines linear probing capabilities with from-scratch training for prompt-free\n YOLOE models that don't require text prompts during inference.\n\n Methods:\n get_validator: Return standard DetectionValidator for validation.\n preprocess_batch: Preprocess batches without text features.\n set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).\n \"\"\"\n\n def get_validator(self):\n \"\"\"Return a DetectionValidator for YOLO model validation.\"\"\"\n self.loss_names = \"box\", \"cls\", \"dfl\"\n return DetectionValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )\n\n def preprocess_batch(self, batch):\n \"\"\"Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed.\"\"\"\n batch = DetectionTrainer.preprocess_batch(self, batch)\n return batch\n\n def set_text_embeddings(self, datasets, batch: int):\n \"\"\"\n Set text embeddings for datasets to accelerate training by caching category names.\n\n This method collects unique category names from all datasets, generates text embeddings for them,\n and caches these embeddings to improve training efficiency. The embeddings are stored in a file\n in the parent directory of the first dataset's image path.\n\n Args:\n datasets (List[Dataset]): List of datasets containing category names to process.\n batch (int): Batch size for processing text embeddings.\n\n Notes:\n The method creates a dictionary mapping text samples to their embeddings and stores it\n at the path specified by 'cache_path'. If the cache file already exists, it will be loaded\n instead of regenerating the embeddings.\n \"\"\"\n pass",
"chunk_type": "class",
"name": "YOLOEPEFreeTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 231,
"end_line": 273,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Train prompt-free YOLOE model.\n\nThis trainer combines linear probing capabilities with from-scratch training for prompt-free\nYOLOE models that don't require text prompts during inference.\n\nMethods:\n get_validator: Return standard DetectionValidator for validation.\n preprocess_batch: Preprocess batches without text features.\n set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"copy.copy",
"copy.deepcopy",
"pathlib.Path",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"torch",
"ultralytics.data.YOLOConcatDataset",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.models.yolo.detect.DetectionTrainer",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.torch_utils.de_parallel",
"world.train_world.WorldTrainerFromScratch",
"val.YOLOEDetectValidator",
"YOLOEPETrainer",
"YOLOETrainerFromScratch"
],
"chunk_id": "class_YOLOEPEFreeTrainer_a87d96eb"
},
{
"content": "class YOLOEVPTrainer(YOLOETrainerFromScratch):\n \"\"\"\n Train YOLOE model with visual prompts.\n\n This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,\n where visual cues are provided alongside images to guide the detection process.\n\n Methods:\n build_dataset: Build dataset with visual prompt loading transforms.\n preprocess_batch: Preprocess batches with visual prompts.\n \"\"\"\n\n def build_dataset(self, img_path: Union[List[str], str], mode: str = \"train\", batch: Optional[int] = None):\n \"\"\"\n Build YOLO Dataset for training or validation with visual prompts.\n\n Args:\n img_path (List[str] | str): Path to the folder containing images or list of paths.\n mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.\n batch (int, optional): Size of batches, used for rectangular training/validation.\n\n Returns:\n (Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.\n \"\"\"\n dataset = super().build_dataset(img_path, mode, batch)\n if isinstance(dataset, YOLOConcatDataset):\n for d in dataset.datasets:\n d.transforms.append(LoadVisualPrompt())\n else:\n dataset.transforms.append(LoadVisualPrompt())\n return dataset\n\n def _close_dataloader_mosaic(self):\n \"\"\"Close mosaic augmentation and add visual prompt loading to the training dataset.\"\"\"\n super()._close_dataloader_mosaic()\n if isinstance(self.train_loader.dataset, YOLOConcatDataset):\n for d in self.train_loader.dataset.datasets:\n d.transforms.append(LoadVisualPrompt())\n else:\n self.train_loader.dataset.transforms.append(LoadVisualPrompt())\n\n def preprocess_batch(self, batch):\n \"\"\"Preprocess a batch of images for YOLOE training, moving visual prompts to the appropriate device.\"\"\"\n batch = super().preprocess_batch(batch)\n batch[\"visuals\"] = batch[\"visuals\"].to(self.device)\n return batch",
"chunk_type": "class",
"name": "YOLOEVPTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train.py",
"start_line": 276,
"end_line": 321,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Train YOLOE model with visual prompts.\n\nThis trainer extends YOLOETrainerFromScratch to support visual prompt-based training,\nwhere visual cues are provided alongside images to guide the detection process.\n\nMethods:\n build_dataset: Build dataset with visual prompt loading transforms.\n preprocess_batch: Preprocess batches with visual prompts.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"itertools",
"copy.copy",
"copy.deepcopy",
"pathlib.Path",
"typing.Dict",
"typing.List",
"typing.Optional",
"typing.Union",
"torch",
"ultralytics.data.YOLOConcatDataset",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.models.yolo.detect.DetectionTrainer",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.utils.DEFAULT_CFG",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.torch_utils.de_parallel",
"world.train_world.WorldTrainerFromScratch",
"val.YOLOEDetectValidator",
"YOLOETrainerFromScratch"
],
"chunk_id": "class_YOLOEVPTrainer_21b21ae1"
},
{
"content": "from copy import copy, deepcopy",
"chunk_type": "import",
"name": "copy, deepcopy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy, deepcopy_14f8fd7e"
},
{
"content": "from ultralytics.models.yolo.segment import SegmentationTrainer",
"chunk_type": "import",
"name": "SegmentationTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 63,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationTrainer_7b8fa4f0"
},
{
"content": "from ultralytics.nn.tasks import YOLOESegModel",
"chunk_type": "import",
"name": "YOLOESegModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOESegModel_223f18e7"
},
{
"content": "from ultralytics.utils import RANK",
"chunk_type": "import",
"name": "RANK",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_RANK_a43de328"
},
{
"content": "from .train import YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer",
"chunk_type": "import",
"name": "YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer_48db9426"
},
{
"content": "from .val import YOLOESegValidator",
"chunk_type": "import",
"name": "YOLOESegValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOESegValidator_73974eef"
},
{
"content": "class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):\n \"\"\"\n Trainer class for YOLOE segmentation models.\n\n This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE\n segmentation models, enabling both object detection and instance segmentation capabilities.\n\n Attributes:\n cfg (dict): Configuration dictionary with training parameters.\n overrides (dict): Dictionary with parameter overrides.\n _callbacks (list): List of callback functions for training events.\n \"\"\"\n\n def get_model(self, cfg=None, weights=None, verbose=True):\n \"\"\"\n Return YOLOESegModel initialized with specified config and weights.\n\n Args:\n cfg (dict | str, optional): Model configuration dictionary or YAML file path.\n weights (str, optional): Path to pretrained weights file.\n verbose (bool): Whether to display model information.\n\n Returns:\n (YOLOESegModel): Initialized YOLOE segmentation model.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = YOLOESegModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=min(self.data[\"nc\"], 80),\n verbose=verbose and RANK == -1,\n )\n if weights:\n model.load(weights)\n\n return model\n\n def get_validator(self):\n \"\"\"\n Create and return a validator for YOLOE segmentation model evaluation.\n\n Returns:\n (YOLOESegValidator): Validator for YOLOE segmentation models.\n \"\"\"\n self.loss_names = \"box\", \"seg\", \"cls\", \"dfl\"\n return YOLOESegValidator(\n self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks\n )",
"chunk_type": "class",
"name": "YOLOESegTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 13,
"end_line": 61,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Trainer class for YOLOE segmentation models.\n\nThis class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE\nsegmentation models, enabling both object detection and instance segmentation capabilities.\n\nAttributes:\n cfg (dict): Configuration dictionary with training parameters.\n overrides (dict): Dictionary with parameter overrides.\n _callbacks (list): List of callback functions for training events.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"copy.deepcopy",
"ultralytics.models.yolo.segment.SegmentationTrainer",
"ultralytics.nn.tasks.YOLOESegModel",
"ultralytics.utils.RANK",
"train.YOLOETrainer",
"train.YOLOETrainerFromScratch",
"train.YOLOEVPTrainer",
"val.YOLOESegValidator",
"YOLOETrainer",
"SegmentationTrainer"
],
"chunk_id": "class_YOLOESegTrainer_902592da"
},
{
"content": "class YOLOEPESegTrainer(SegmentationTrainer):\n \"\"\"\n Fine-tune YOLOESeg model in linear probing way.\n\n This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing\n most of the model and only training specific layers for efficient adaptation to new tasks.\n\n Attributes:\n data (dict): Dataset configuration containing channels, class names, and number of classes.\n \"\"\"\n\n def get_model(self, cfg=None, weights=None, verbose=True):\n \"\"\"\n Return YOLOESegModel initialized with specified config and weights for linear probing.\n\n Args:\n cfg (dict | str, optional): Model configuration dictionary or YAML file path.\n weights (str, optional): Path to pretrained weights file.\n verbose (bool): Whether to display model information.\n\n Returns:\n (YOLOESegModel): Initialized YOLOE segmentation model configured for linear probing.\n \"\"\"\n # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.\n # NOTE: Following the official config, nc hard-coded to 80 for now.\n model = YOLOESegModel(\n cfg[\"yaml_file\"] if isinstance(cfg, dict) else cfg,\n ch=self.data[\"channels\"],\n nc=self.data[\"nc\"],\n verbose=verbose and RANK == -1,\n )\n\n del model.model[-1].savpe\n\n assert weights is not None, \"Pretrained weights must be provided for linear probing.\"\n if weights:\n model.load(weights)\n\n model.eval()\n names = list(self.data[\"names\"].values())\n # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,\n # it'd get correct results as long as loading proper pretrained weights.\n tpe = model.get_text_pe(names)\n model.set_classes(names, tpe)\n model.model[-1].fuse(model.pe)\n model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)\n model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)\n model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)\n del model.pe\n model.train()\n\n return model",
"chunk_type": "class",
"name": "YOLOEPESegTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 64,
"end_line": 115,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "Fine-tune YOLOESeg model in linear probing way.\n\nThis trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing\nmost of the model and only training specific layers for efficient adaptation to new tasks.\n\nAttributes:\n data (dict): Dataset configuration containing channels, class names, and number of classes.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"copy.deepcopy",
"ultralytics.models.yolo.segment.SegmentationTrainer",
"ultralytics.nn.tasks.YOLOESegModel",
"ultralytics.utils.RANK",
"train.YOLOETrainer",
"train.YOLOETrainerFromScratch",
"train.YOLOEVPTrainer",
"val.YOLOESegValidator",
"SegmentationTrainer"
],
"chunk_id": "class_YOLOEPESegTrainer_87305228"
},
{
"content": "class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):\n \"\"\"Trainer for YOLOE segmentation models trained from scratch without pretrained weights.\"\"\"\n\n pass",
"chunk_type": "class",
"name": "YOLOESegTrainerFromScratch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 118,
"end_line": 121,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Trainer for YOLOE segmentation models trained from scratch without pretrained weights.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"copy.deepcopy",
"ultralytics.models.yolo.segment.SegmentationTrainer",
"ultralytics.nn.tasks.YOLOESegModel",
"ultralytics.utils.RANK",
"train.YOLOETrainer",
"train.YOLOETrainerFromScratch",
"train.YOLOEVPTrainer",
"val.YOLOESegValidator",
"YOLOETrainerFromScratch",
"YOLOESegTrainer"
],
"chunk_id": "class_YOLOESegTrainerFromScratch_ac610604"
},
{
"content": "class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):\n \"\"\"Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities.\"\"\"\n\n pass",
"chunk_type": "class",
"name": "YOLOESegVPTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\train_seg.py",
"start_line": 124,
"end_line": 127,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.copy",
"copy.deepcopy",
"ultralytics.models.yolo.segment.SegmentationTrainer",
"ultralytics.nn.tasks.YOLOESegModel",
"ultralytics.utils.RANK",
"train.YOLOETrainer",
"train.YOLOETrainerFromScratch",
"train.YOLOEVPTrainer",
"val.YOLOESegValidator",
"YOLOEVPTrainer",
"YOLOESegTrainerFromScratch"
],
"chunk_id": "class_YOLOESegVPTrainer_848dc5cf"
},
{
"content": "from copy import deepcopy",
"chunk_type": "import",
"name": "deepcopy",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_deepcopy_f9e40614"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_9f78b6b9"
},
{
"content": "from typing import Any, Dict, Optional, Union",
"chunk_type": "import",
"name": "Any, Dict, Optional, Union",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, Dict, Optional, Union_05bd7c6f"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_34502e9c"
},
{
"content": "from torch.nn import functional as F",
"chunk_type": "import",
"name": "functional",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_functional_5511c916"
},
{
"content": "from ultralytics.data import YOLOConcatDataset, build_dataloader, build_yolo_dataset",
"chunk_type": "import",
"name": "YOLOConcatDataset, build_dataloader, build_yolo_dataset",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 84,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOConcatDataset, build_dataloader, build_yolo_dataset_3e8eb75d"
},
{
"content": "from ultralytics.data.augment import LoadVisualPrompt",
"chunk_type": "import",
"name": "LoadVisualPrompt",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LoadVisualPrompt_7f66e959"
},
{
"content": "from ultralytics.data.utils import check_det_dataset",
"chunk_type": "import",
"name": "check_det_dataset",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_check_det_dataset_76dcfbb2"
},
{
"content": "from ultralytics.models.yolo.detect import DetectionValidator",
"chunk_type": "import",
"name": "DetectionValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DetectionValidator_60711022"
},
{
"content": "from ultralytics.models.yolo.segment import SegmentationValidator",
"chunk_type": "import",
"name": "SegmentationValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SegmentationValidator_c1531403"
},
{
"content": "from ultralytics.nn.modules.head import YOLOEDetect",
"chunk_type": "import",
"name": "YOLOEDetect",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 15,
"end_line": 15,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOEDetect_bcda8aa3"
},
{
"content": "from ultralytics.nn.tasks import YOLOEModel",
"chunk_type": "import",
"name": "YOLOEModel",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOEModel_589c4e76"
},
{
"content": "from ultralytics.utils import LOGGER, TQDM",
"chunk_type": "import",
"name": "LOGGER, TQDM",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, TQDM_dde9a176"
},
{
"content": "from ultralytics.utils.torch_utils import select_device, smart_inference_mode",
"chunk_type": "import",
"name": "select_device, smart_inference_mode",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 77,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_select_device, smart_inference_mode_c3de7d32"
},
{
"content": "class YOLOEDetectValidator(DetectionValidator):\n \"\"\"\n A validator class for YOLOE detection models that handles both text and visual prompt embeddings.\n\n This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.\n It supports validation using either text prompts or visual prompt embeddings extracted from training samples,\n enabling flexible evaluation strategies for prompt-based object detection.\n\n Attributes:\n device (torch.device): The device on which validation is performed.\n args (namespace): Configuration arguments for validation.\n dataloader (DataLoader): DataLoader for validation data.\n\n Methods:\n get_visual_pe: Extract visual prompt embeddings from training samples.\n preprocess: Preprocess batch data ensuring visuals are on the same device as images.\n get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.\n __call__: Run validation using either text or visual prompt embeddings.\n\n Examples:\n Validate with text prompts\n >>> validator = YOLOEDetectValidator()\n >>> stats = validator(model=model, load_vp=False)\n\n Validate with visual prompts\n >>> stats = validator(model=model, refer_data=\"path/to/data.yaml\", load_vp=True)\n \"\"\"\n\n @smart_inference_mode()\n def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:\n \"\"\"\n Extract visual prompt embeddings from training samples.\n\n This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.\n It normalizes the embeddings and handles cases where no samples exist for a class by setting their\n embeddings to zero.\n\n Args:\n dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.\n model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.\n\n Returns:\n (torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).\n \"\"\"\n assert isinstance(model, YOLOEModel)\n names = [name.split(\"/\", 1)[0] for name in list(dataloader.dataset.data[\"names\"].values())]\n visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)\n cls_visual_num = torch.zeros(len(names))\n\n desc = \"Get visual prompt embeddings from samples\"\n\n # Count samples per class\n for batch in dataloader:\n cls = batch[\"cls\"].squeeze(-1).to(torch.int).unique()\n count = torch.bincount(cls, minlength=len(names))\n cls_visual_num += count\n\n cls_visual_num = cls_visual_num.to(self.device)\n\n # Extract visual prompt embeddings\n pbar = TQDM(dataloader, total=len(dataloader), desc=desc)\n for batch in pbar:\n batch = self.preprocess(batch)\n preds = model.get_visual_pe(batch[\"img\"], visual=batch[\"visuals\"]) # (B, max_n, embed_dim)\n\n batch_idx = batch[\"batch_idx\"]\n for i in range(preds.shape[0]):\n cls = batch[\"cls\"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)\n pad_cls = torch.ones(preds.shape[1], device=self.device) * -1\n pad_cls[: len(cls)] = cls\n for c in cls:\n visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]\n\n # Normalize embeddings for classes with samples, set others to zero\n visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)\n visual_pe[cls_visual_num == 0] = 0\n return visual_pe.unsqueeze(0)\n\n def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n \"\"\"Preprocess batch data, ensuring visuals are on the same device as images.\"\"\"\n batch = super().preprocess(batch)\n if \"visuals\" in batch:\n batch[\"visuals\"] = batch[\"visuals\"].to(batch[\"img\"].device)\n return batch\n\n def get_vpe_dataloader(self, data: Dict[str, Any]) -> torch.utils.data.DataLoader:\n \"\"\"\n Create a dataloader for LVIS training visual prompt samples.\n\n This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.\n It applies necessary transformations including LoadVisualPrompt and configurations to the dataset\n for validation purposes.\n\n Args:\n data (dict): Dataset configuration dictionary containing paths and settings.\n\n Returns:\n (torch.utils.data.DataLoader): The dataloader for visual prompt samples.\n \"\"\"\n dataset = build_yolo_dataset(\n self.args,\n data.get(self.args.split, data.get(\"val\")),\n self.args.batch,\n data,\n mode=\"val\",\n rect=False,\n )\n if isinstance(dataset, YOLOConcatDataset):\n for d in dataset.datasets:\n d.transforms.append(LoadVisualPrompt())\n else:\n dataset.transforms.append(LoadVisualPrompt())\n return build_dataloader(\n dataset,\n self.args.batch,\n self.args.workers,\n shuffle=False,\n rank=-1,\n )\n\n @smart_inference_mode()\n def __call__(\n self,\n trainer: Optional[Any] = None,\n model: Optional[Union[YOLOEModel, str]] = None,\n refer_data: Optional[str] = None,\n load_vp: bool = False,\n ) -> Dict[str, Any]:\n \"\"\"\n Run validation on the model using either text or visual prompt embeddings.\n\n This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.\n It supports validation during training (using a trainer object) or standalone validation with a provided\n model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.\n\n Args:\n trainer (object, optional): Trainer object containing the model and device.\n model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.\n refer_data (str, optional): Path to reference data for visual prompts.\n load_vp (bool): Whether to load visual prompts. If False, text prompts are used.\n\n Returns:\n (dict): Validation statistics containing metrics computed during validation.\n \"\"\"\n if trainer is not None:\n self.device = trainer.device\n model = trainer.ema.ema\n names = [name.split(\"/\", 1)[0] for name in list(self.dataloader.dataset.data[\"names\"].values())]\n\n if load_vp:\n LOGGER.info(\"Validate using the visual prompt.\")\n self.args.half = False\n # Directly use the same dataloader for visual embeddings extracted during training\n vpe = self.get_visual_pe(self.dataloader, model)\n model.set_classes(names, vpe)\n else:\n LOGGER.info(\"Validate using the text prompt.\")\n tpe = model.get_text_pe(names)\n model.set_classes(names, tpe)\n stats = super().__call__(trainer, model)\n else:\n if refer_data is not None:\n assert load_vp, \"Refer data is only used for visual prompt validation.\"\n self.device = select_device(self.args.device)\n\n if isinstance(model, (str, Path)):\n from ultralytics.nn.tasks import attempt_load_weights\n\n model = attempt_load_weights(model, device=self.device, inplace=True)\n model.eval().to(self.device)\n data = check_det_dataset(refer_data or self.args.data)\n names = [name.split(\"/\", 1)[0] for name in list(data[\"names\"].values())]\n\n if load_vp:\n LOGGER.info(\"Validate using the visual prompt.\")\n self.args.half = False\n # TODO: need to check if the names from refer data is consistent with the evaluated dataset\n # could use same dataset or refer to extract visual prompt embeddings\n dataloader = self.get_vpe_dataloader(data)\n vpe = self.get_visual_pe(dataloader, model)\n model.set_classes(names, vpe)\n stats = super().__call__(model=deepcopy(model))\n elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], \"lrpc\"): # prompt-free\n return super().__call__(trainer, model)\n else:\n LOGGER.info(\"Validate using the text prompt.\")\n tpe = model.get_text_pe(names)\n model.set_classes(names, tpe)\n stats = super().__call__(model=deepcopy(model))\n return stats",
"chunk_type": "class",
"name": "YOLOEDetectValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 21,
"end_line": 210,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "A validator class for YOLOE detection models that handles both text and visual prompt embeddings.\n\nThis class extends DetectionValidator to provide specialized validation functionality for YOLOE models.\nIt supports validation using either text prompts or visual prompt embeddings extracted from training samples,\nenabling flexible evaluation strategies for prompt-based object detection.\n\nAttributes:\n device (torch.device): The device on which validation is performed.\n args (namespace): Configuration arguments for validation.\n dataloader (DataLoader): DataLoader for validation data.\n\nMethods:\n get_visual_pe: Extract visual prompt embeddings from training samples.\n preprocess: Preprocess batch data ensuring visuals are on the same device as images.\n get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.\n __call__: Run validation using either text or visual prompt embeddings.\n\nExamples:\n Validate with text prompts\n >>> validator = YOLOEDetectValidator()\n >>> stats = validator(model=model, load_vp=False)\n\n Validate with visual prompts\n >>> stats = validator(model=model, refer_data=\"path/to/data.yaml\", load_vp=True)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.deepcopy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Optional",
"typing.Union",
"torch",
"torch.nn.functional",
"ultralytics.data.YOLOConcatDataset",
"ultralytics.data.build_dataloader",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.data.utils.check_det_dataset",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.models.yolo.segment.SegmentationValidator",
"ultralytics.nn.modules.head.YOLOEDetect",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.torch_utils.select_device",
"ultralytics.utils.torch_utils.smart_inference_mode",
"ultralytics.nn.tasks.attempt_load_weights",
"DetectionValidator"
],
"chunk_id": "class_YOLOEDetectValidator_f61b65ec"
},
{
"content": "class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):\n \"\"\"YOLOE segmentation validator that supports both text and visual prompt embeddings.\"\"\"\n\n pass",
"chunk_type": "class",
"name": "YOLOESegValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\val.py",
"start_line": 213,
"end_line": 216,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "YOLOE segmentation validator that supports both text and visual prompt embeddings.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy.deepcopy",
"pathlib.Path",
"typing.Any",
"typing.Dict",
"typing.Optional",
"typing.Union",
"torch",
"torch.nn.functional",
"ultralytics.data.YOLOConcatDataset",
"ultralytics.data.build_dataloader",
"ultralytics.data.build_yolo_dataset",
"ultralytics.data.augment.LoadVisualPrompt",
"ultralytics.data.utils.check_det_dataset",
"ultralytics.models.yolo.detect.DetectionValidator",
"ultralytics.models.yolo.segment.SegmentationValidator",
"ultralytics.nn.modules.head.YOLOEDetect",
"ultralytics.nn.tasks.YOLOEModel",
"ultralytics.utils.LOGGER",
"ultralytics.utils.TQDM",
"ultralytics.utils.torch_utils.select_device",
"ultralytics.utils.torch_utils.smart_inference_mode",
"ultralytics.nn.tasks.attempt_load_weights",
"YOLOEDetectValidator",
"SegmentationValidator"
],
"chunk_id": "class_YOLOESegValidator_bbed2d6f"
},
{
"content": "from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor",
"chunk_type": "import",
"name": "YOLOEVPDetectPredictor, YOLOEVPSegPredictor",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 64,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOEVPDetectPredictor, YOLOEVPSegPredictor_08ad326a"
},
{
"content": "from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer",
"chunk_type": "import",
"name": "YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 108,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer_7dca38c0"
},
{
"content": "from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer",
"chunk_type": "import",
"name": "YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 104,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer_d356eaa9"
},
{
"content": "from .val import YOLOEDetectValidator, YOLOESegValidator",
"chunk_type": "import",
"name": "YOLOEDetectValidator, YOLOESegValidator",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_YOLOEDetectValidator, YOLOESegValidator_ba883fe7"
},
{
"content": "__all__ = [\n \"YOLOETrainer\",\n \"YOLOEPETrainer\",\n \"YOLOESegTrainer\",\n \"YOLOEDetectValidator\",\n \"YOLOESegValidator\",\n \"YOLOEPESegTrainer\",\n \"YOLOESegTrainerFromScratch\",\n \"YOLOESegVPTrainer\",\n \"YOLOEVPTrainer\",\n \"YOLOEPEFreeTrainer\",\n \"YOLOEVPDetectPredictor\",\n \"YOLOEVPSegPredictor\",\n \"YOLOETrainerFromScratch\",\n]",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\models\\yolo\\yoloe\\__init__.py",
"start_line": 8,
"end_line": 22,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___7ce3e0c1"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\activation.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_5eac896e"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\activation.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_cf4e57aa"
},
{
"content": "class AGLU(nn.Module):\n \"\"\"\n Unified activation function module from AGLU.\n\n This class implements a parameterized activation function with learnable parameters lambda and kappa, based on the\n AGLU (Adaptive Gated Linear Unit) approach.\n\n Attributes:\n act (nn.Softplus): Softplus activation function with negative beta.\n lambd (nn.Parameter): Learnable lambda parameter initialized with uniform distribution.\n kappa (nn.Parameter): Learnable kappa parameter initialized with uniform distribution.\n\n Methods:\n forward: Compute the forward pass of the Unified activation function.\n\n Examples:\n >>> import torch\n >>> m = AGLU()\n >>> input = torch.randn(2)\n >>> output = m(input)\n >>> print(output.shape)\n torch.Size([2])\n\n References:\n https://github.com/kostas1515/AGLU\n \"\"\"\n\n def __init__(self, device=None, dtype=None) -> None:\n \"\"\"Initialize the Unified activation function with learnable parameters.\"\"\"\n super().__init__()\n self.act = nn.Softplus(beta=-1.0)\n self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter\n self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply the Adaptive Gated Linear Unit (AGLU) activation function.\n\n This forward method implements the AGLU activation function with learnable parameters lambda and kappa.\n The function applies a transformation that adaptively combines linear and non-linear components.\n\n Args:\n x (torch.Tensor): Input tensor to apply the activation function to.\n\n Returns:\n (torch.Tensor): Output tensor after applying the AGLU activation function, with the same shape as the input.\n \"\"\"\n lam = torch.clamp(self.lambd, min=0.0001) # Clamp lambda to avoid division by zero\n return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))",
"chunk_type": "class",
"name": "AGLU",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\activation.py",
"start_line": 8,
"end_line": 56,
"start_col": 0,
"end_col": 81,
"parent_name": null,
"docstring": "Unified activation function module from AGLU.\n\nThis class implements a parameterized activation function with learnable parameters lambda and kappa, based on the\nAGLU (Adaptive Gated Linear Unit) approach.\n\nAttributes:\n act (nn.Softplus): Softplus activation function with negative beta.\n lambd (nn.Parameter): Learnable lambda parameter initialized with uniform distribution.\n kappa (nn.Parameter): Learnable kappa parameter initialized with uniform distribution.\n\nMethods:\n forward: Compute the forward pass of the Unified activation function.\n\nExamples:\n >>> import torch\n >>> m = AGLU()\n >>> input = torch.randn(2)\n >>> output = m(input)\n >>> print(output.shape)\n torch.Size([2])\n\nReferences:\n https://github.com/kostas1515/AGLU",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_AGLU_856d70d7"
},
{
"content": "from typing import List, Optional, Tuple",
"chunk_type": "import",
"name": "List, Optional, Tuple",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional, Tuple_3be83d74"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_8d4ac29b"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_f1a1c3a8"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_9c8fb0f7"
},
{
"content": "from ultralytics.utils.torch_utils import fuse_conv_and_bn",
"chunk_type": "import",
"name": "fuse_conv_and_bn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 58,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_fuse_conv_and_bn_5c231990"
},
{
"content": "from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad",
"chunk_type": "import",
"name": "Conv, DWConv, GhostConv, LightConv, RepConv, autopad",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 70,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Conv, DWConv, GhostConv, LightConv, RepConv, autopad_4db33685"
},
{
"content": "from .transformer import TransformerBlock",
"chunk_type": "import",
"name": "TransformerBlock",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TransformerBlock_5f97c6ea"
},
{
"content": "__all__ = (\n \"DFL\",\n \"HGBlock\",\n \"HGStem\",\n \"SPP\",\n \"SPPF\",\n \"C1\",\n \"C2\",\n \"C3\",\n \"C2f\",\n \"C2fAttn\",\n \"ImagePoolingAttn\",\n \"ContrastiveHead\",\n \"BNContrastiveHead\",\n \"C3x\",\n \"C3TR\",\n \"C3Ghost\",\n \"GhostBottleneck\",\n \"Bottleneck\",\n \"BottleneckCSP\",\n \"Proto\",\n \"RepC3\",\n \"ResNetLayer\",\n \"RepNCSPELAN4\",\n \"ELAN1\",\n \"ADown\",\n \"AConv\",\n \"SPPELAN\",\n \"CBFuse\",\n \"CBLinear\",\n \"C3k2\",\n \"C2fPSA\",\n \"C2PSA\",\n \"RepVGGDW\",\n \"CIB\",\n \"C2fCIB\",\n \"Attention\",\n \"PSA\",\n \"SCDown\",\n \"TorchVision\",\n)",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 15,
"end_line": 55,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___2c799b66"
},
{
"content": "class DFL(nn.Module):\n \"\"\"\n Integral module of Distribution Focal Loss (DFL).\n\n Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391\n \"\"\"\n\n def __init__(self, c1: int = 16):\n \"\"\"\n Initialize a convolutional layer with a given number of input channels.\n\n Args:\n c1 (int): Number of input channels.\n \"\"\"\n super().__init__()\n self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)\n x = torch.arange(c1, dtype=torch.float)\n self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))\n self.c1 = c1\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply the DFL module to input tensor and return transformed output.\"\"\"\n b, _, a = x.shape # batch, channels, anchors\n return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)",
"chunk_type": "class",
"name": "DFL",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 58,
"end_line": 81,
"start_col": 0,
"end_col": 91,
"parent_name": null,
"docstring": "Integral module of Distribution Focal Loss (DFL).\n\nProposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_DFL_1fd1ef42"
},
{
"content": "class Proto(nn.Module):\n \"\"\"Ultralytics YOLO models mask Proto module for segmentation models.\"\"\"\n\n def __init__(self, c1: int, c_: int = 256, c2: int = 32):\n \"\"\"\n Initialize the Ultralytics YOLO models mask Proto module with specified number of protos and masks.\n\n Args:\n c1 (int): Input channels.\n c_ (int): Intermediate channels.\n c2 (int): Output channels (number of protos).\n \"\"\"\n super().__init__()\n self.cv1 = Conv(c1, c_, k=3)\n self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')\n self.cv2 = Conv(c_, c_, k=3)\n self.cv3 = Conv(c_, c2)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Perform a forward pass through layers using an upsampled input image.\"\"\"\n return self.cv3(self.cv2(self.upsample(self.cv1(x))))",
"chunk_type": "class",
"name": "Proto",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 85,
"end_line": 105,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": "Ultralytics YOLO models mask Proto module for segmentation models.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_Proto_d5d6f109"
},
{
"content": "class HGStem(nn.Module):\n \"\"\"\n StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.\n\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py\n \"\"\"\n\n def __init__(self, c1: int, cm: int, c2: int):\n \"\"\"\n Initialize the StemBlock of PPHGNetV2.\n\n Args:\n c1 (int): Input channels.\n cm (int): Middle channels.\n c2 (int): Output channels.\n \"\"\"\n super().__init__()\n self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())\n self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())\n self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())\n self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())\n self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())\n self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass of a PPHGNetV2 backbone layer.\"\"\"\n x = self.stem1(x)\n x = F.pad(x, [0, 1, 0, 1])\n x2 = self.stem2a(x)\n x2 = F.pad(x2, [0, 1, 0, 1])\n x2 = self.stem2b(x2)\n x1 = self.pool(x)\n x = torch.cat([x1, x2], dim=1)\n x = self.stem3(x)\n x = self.stem4(x)\n return x",
"chunk_type": "class",
"name": "HGStem",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 108,
"end_line": 143,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.\n\nhttps://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_HGStem_49c65caf"
},
{
"content": "class HGBlock(nn.Module):\n \"\"\"\n HG_Block of PPHGNetV2 with 2 convolutions and LightConv.\n\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py\n \"\"\"\n\n def __init__(\n self,\n c1: int,\n cm: int,\n c2: int,\n k: int = 3,\n n: int = 6,\n lightconv: bool = False,\n shortcut: bool = False,\n act: nn.Module = nn.ReLU(),\n ):\n \"\"\"\n Initialize HGBlock with specified parameters.\n\n Args:\n c1 (int): Input channels.\n cm (int): Middle channels.\n c2 (int): Output channels.\n k (int): Kernel size.\n n (int): Number of LightConv or Conv blocks.\n lightconv (bool): Whether to use LightConv.\n shortcut (bool): Whether to use shortcut connection.\n act (nn.Module): Activation function.\n \"\"\"\n super().__init__()\n block = LightConv if lightconv else Conv\n self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))\n self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv\n self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv\n self.add = shortcut and c1 == c2\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass of a PPHGNetV2 backbone layer.\"\"\"\n y = [x]\n y.extend(m(y[-1]) for m in self.m)\n y = self.ec(self.sc(torch.cat(y, 1)))\n return y + x if self.add else y",
"chunk_type": "class",
"name": "HGBlock",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 146,
"end_line": 189,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": "HG_Block of PPHGNetV2 with 2 convolutions and LightConv.\n\nhttps://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_HGBlock_934eb761"
},
{
"content": "class SPP(nn.Module):\n \"\"\"Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729.\"\"\"\n\n def __init__(self, c1: int, c2: int, k: Tuple[int, ...] = (5, 9, 13)):\n \"\"\"\n Initialize the SPP layer with input/output channels and pooling kernel sizes.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n k (tuple): Kernel sizes for max pooling.\n \"\"\"\n super().__init__()\n c_ = c1 // 2 # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)\n self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass of the SPP layer, performing spatial pyramid pooling.\"\"\"\n x = self.cv1(x)\n return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))",
"chunk_type": "class",
"name": "SPP",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 192,
"end_line": 213,
"start_col": 0,
"end_col": 67,
"parent_name": null,
"docstring": "Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_SPP_936a261d"
},
{
"content": "class SPPF(nn.Module):\n \"\"\"Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.\"\"\"\n\n def __init__(self, c1: int, c2: int, k: int = 5):\n \"\"\"\n Initialize the SPPF layer with given input/output channels and kernel size.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n k (int): Kernel size.\n\n Notes:\n This module is equivalent to SPP(k=(5, 9, 13)).\n \"\"\"\n super().__init__()\n c_ = c1 // 2 # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c_ * 4, c2, 1, 1)\n self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply sequential pooling operations to input and return concatenated feature maps.\"\"\"\n y = [self.cv1(x)]\n y.extend(self.m(y[-1]) for _ in range(3))\n return self.cv2(torch.cat(y, 1))",
"chunk_type": "class",
"name": "SPPF",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 216,
"end_line": 241,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_SPPF_4ada2457"
},
{
"content": "class C1(nn.Module):\n \"\"\"CSP Bottleneck with 1 convolution.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1):\n \"\"\"\n Initialize the CSP Bottleneck with 1 convolution.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of convolutions.\n \"\"\"\n super().__init__()\n self.cv1 = Conv(c1, c2, 1, 1)\n self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply convolution and residual connection to input tensor.\"\"\"\n y = self.cv1(x)\n return self.m(y) + y",
"chunk_type": "class",
"name": "C1",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 244,
"end_line": 263,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "CSP Bottleneck with 1 convolution.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_C1_39b5aedd"
},
{
"content": "class C2(nn.Module):\n \"\"\"CSP Bottleneck with 2 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize a CSP Bottleneck with 2 convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n self.c = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)\n # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()\n self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through the CSP bottleneck with 2 convolutions.\"\"\"\n a, b = self.cv1(x).chunk(2, 1)\n return self.cv2(torch.cat((self.m(a), b), 1))",
"chunk_type": "class",
"name": "C2",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 266,
"end_line": 291,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": "CSP Bottleneck with 2 convolutions.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_C2_cbb49698"
},
{
"content": "class C2f(nn.Module):\n \"\"\"Faster Implementation of CSP Bottleneck with 2 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize a CSP bottleneck with 2 convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n self.c = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)\n self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through C2f layer.\"\"\"\n y = list(self.cv1(x).chunk(2, 1))\n y.extend(m(y[-1]) for m in self.m)\n return self.cv2(torch.cat(y, 1))\n\n def forward_split(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass using split() instead of chunk().\"\"\"\n y = self.cv1(x).split((self.c, self.c), 1)\n y = [y[0], y[1]]\n y.extend(m(y[-1]) for m in self.m)\n return self.cv2(torch.cat(y, 1))",
"chunk_type": "class",
"name": "C2f",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 294,
"end_line": 326,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Faster Implementation of CSP Bottleneck with 2 convolutions.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_C2f_947106e9"
},
{
"content": "class C3(nn.Module):\n \"\"\"CSP Bottleneck with 3 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize the CSP Bottleneck with 3 convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c1, c_, 1, 1)\n self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)\n self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through the CSP bottleneck with 3 convolutions.\"\"\"\n return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))",
"chunk_type": "class",
"name": "C3",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 329,
"end_line": 353,
"start_col": 0,
"end_col": 73,
"parent_name": null,
"docstring": "CSP Bottleneck with 3 convolutions.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_C3_eccd69e4"
},
{
"content": "class C3x(C3):\n \"\"\"C3 module with cross-convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C3 module with cross-convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n self.c_ = int(c2 * e)\n self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))",
"chunk_type": "class",
"name": "C3x",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 356,
"end_line": 373,
"start_col": 0,
"end_col": 119,
"parent_name": null,
"docstring": "C3 module with cross-convolutions.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"C3"
],
"chunk_id": "class_C3x_a86a4bf9"
},
{
"content": "class RepC3(nn.Module):\n \"\"\"Rep C3.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 3, e: float = 1.0):\n \"\"\"\n Initialize CSP Bottleneck with a single convolution.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of RepConv blocks.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c1, c_, 1, 1)\n self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])\n self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass of RepC3 module.\"\"\"\n return self.cv3(self.m(self.cv1(x)) + self.cv2(x))",
"chunk_type": "class",
"name": "RepC3",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 376,
"end_line": 398,
"start_col": 0,
"end_col": 58,
"parent_name": null,
"docstring": "Rep C3.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_RepC3_a0dfcb95"
},
{
"content": "class C3TR(C3):\n \"\"\"C3 module with TransformerBlock().\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C3 module with TransformerBlock.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Transformer blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n c_ = int(c2 * e)\n self.m = TransformerBlock(c_, c_, 4, n)",
"chunk_type": "class",
"name": "C3TR",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 401,
"end_line": 418,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "C3 module with TransformerBlock().",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"C3"
],
"chunk_id": "class_C3TR_fe347d69"
},
{
"content": "class C3Ghost(C3):\n \"\"\"C3 module with GhostBottleneck().\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C3 module with GhostBottleneck.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Ghost bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n c_ = int(c2 * e) # hidden channels\n self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))",
"chunk_type": "class",
"name": "C3Ghost",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 421,
"end_line": 438,
"start_col": 0,
"end_col": 76,
"parent_name": null,
"docstring": "C3 module with GhostBottleneck().",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"C3"
],
"chunk_id": "class_C3Ghost_42be1069"
},
{
"content": "class GhostBottleneck(nn.Module):\n \"\"\"Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones.\"\"\"\n\n def __init__(self, c1: int, c2: int, k: int = 3, s: int = 1):\n \"\"\"\n Initialize Ghost Bottleneck module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n k (int): Kernel size.\n s (int): Stride.\n \"\"\"\n super().__init__()\n c_ = c2 // 2\n self.conv = nn.Sequential(\n GhostConv(c1, c_, 1, 1), # pw\n DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw\n GhostConv(c_, c2, 1, 1, act=False), # pw-linear\n )\n self.shortcut = (\n nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply skip connection and concatenation to input tensor.\"\"\"\n return self.conv(x) + self.shortcut(x)",
"chunk_type": "class",
"name": "GhostBottleneck",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 441,
"end_line": 467,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_GhostBottleneck_87d94d85"
},
{
"content": "class Bottleneck(nn.Module):\n \"\"\"Standard bottleneck.\"\"\"\n\n def __init__(\n self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5\n ):\n \"\"\"\n Initialize a standard bottleneck module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n shortcut (bool): Whether to use shortcut connection.\n g (int): Groups for convolutions.\n k (tuple): Kernel sizes for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, k[0], 1)\n self.cv2 = Conv(c_, c2, k[1], 1, g=g)\n self.add = shortcut and c1 == c2\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply bottleneck with optional shortcut connection.\"\"\"\n return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))",
"chunk_type": "class",
"name": "Bottleneck",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 470,
"end_line": 495,
"start_col": 0,
"end_col": 79,
"parent_name": null,
"docstring": "Standard bottleneck.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_Bottleneck_b6c60a65"
},
{
"content": "class BottleneckCSP(nn.Module):\n \"\"\"CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize CSP Bottleneck.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)\n self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)\n self.cv4 = Conv(2 * c_, c2, 1, 1)\n self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)\n self.act = nn.SiLU()\n self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply CSP bottleneck with 3 convolutions.\"\"\"\n y1 = self.cv3(self.m(self.cv1(x)))\n y2 = self.cv2(x)\n return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))",
"chunk_type": "class",
"name": "BottleneckCSP",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 498,
"end_line": 527,
"start_col": 0,
"end_col": 66,
"parent_name": null,
"docstring": "CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_BottleneckCSP_d85467ed"
},
{
"content": "class ResNetBlock(nn.Module):\n \"\"\"ResNet block with standard convolution layers.\"\"\"\n\n def __init__(self, c1: int, c2: int, s: int = 1, e: int = 4):\n \"\"\"\n Initialize ResNet block.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n s (int): Stride.\n e (int): Expansion ratio.\n \"\"\"\n super().__init__()\n c3 = e * c2\n self.cv1 = Conv(c1, c2, k=1, s=1, act=True)\n self.cv2 = Conv(c2, c2, k=3, s=s, p=1, act=True)\n self.cv3 = Conv(c2, c3, k=1, act=False)\n self.shortcut = nn.Sequential(Conv(c1, c3, k=1, s=s, act=False)) if s != 1 or c1 != c3 else nn.Identity()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through the ResNet block.\"\"\"\n return F.relu(self.cv3(self.cv2(self.cv1(x))) + self.shortcut(x))",
"chunk_type": "class",
"name": "ResNetBlock",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 530,
"end_line": 552,
"start_col": 0,
"end_col": 73,
"parent_name": null,
"docstring": "ResNet block with standard convolution layers.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_ResNetBlock_0bc30168"
},
{
"content": "class ResNetLayer(nn.Module):\n \"\"\"ResNet layer with multiple ResNet blocks.\"\"\"\n\n def __init__(self, c1: int, c2: int, s: int = 1, is_first: bool = False, n: int = 1, e: int = 4):\n \"\"\"\n Initialize ResNet layer.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n s (int): Stride.\n is_first (bool): Whether this is the first layer.\n n (int): Number of ResNet blocks.\n e (int): Expansion ratio.\n \"\"\"\n super().__init__()\n self.is_first = is_first\n\n if self.is_first:\n self.layer = nn.Sequential(\n Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n )\n else:\n blocks = [ResNetBlock(c1, c2, s, e=e)]\n blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])\n self.layer = nn.Sequential(*blocks)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through the ResNet layer.\"\"\"\n return self.layer(x)",
"chunk_type": "class",
"name": "ResNetLayer",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 555,
"end_line": 584,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "ResNet layer with multiple ResNet blocks.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_ResNetLayer_07394af0"
},
{
"content": "class MaxSigmoidAttnBlock(nn.Module):\n \"\"\"Max Sigmoid attention block.\"\"\"\n\n def __init__(self, c1: int, c2: int, nh: int = 1, ec: int = 128, gc: int = 512, scale: bool = False):\n \"\"\"\n Initialize MaxSigmoidAttnBlock.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n nh (int): Number of heads.\n ec (int): Embedding channels.\n gc (int): Guide channels.\n scale (bool): Whether to use learnable scale parameter.\n \"\"\"\n super().__init__()\n self.nh = nh\n self.hc = c2 // nh\n self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None\n self.gl = nn.Linear(gc, ec)\n self.bias = nn.Parameter(torch.zeros(nh))\n self.proj_conv = Conv(c1, c2, k=3, s=1, act=False)\n self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0\n\n def forward(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass of MaxSigmoidAttnBlock.\n\n Args:\n x (torch.Tensor): Input tensor.\n guide (torch.Tensor): Guide tensor.\n\n Returns:\n (torch.Tensor): Output tensor after attention.\n \"\"\"\n bs, _, h, w = x.shape\n\n guide = self.gl(guide)\n guide = guide.view(bs, guide.shape[1], self.nh, self.hc)\n embed = self.ec(x) if self.ec is not None else x\n embed = embed.view(bs, self.nh, self.hc, h, w)\n\n aw = torch.einsum(\"bmchw,bnmc->bmhwn\", embed, guide)\n aw = aw.max(dim=-1)[0]\n aw = aw / (self.hc**0.5)\n aw = aw + self.bias[None, :, None, None]\n aw = aw.sigmoid() * self.scale\n\n x = self.proj_conv(x)\n x = x.view(bs, self.nh, -1, h, w)\n x = x * aw.unsqueeze(2)\n return x.view(bs, -1, h, w)",
"chunk_type": "class",
"name": "MaxSigmoidAttnBlock",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 587,
"end_line": 638,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": "Max Sigmoid attention block.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_MaxSigmoidAttnBlock_e6b84d8e"
},
{
"content": "class C2fAttn(nn.Module):\n \"\"\"C2f module with an additional attn module.\"\"\"\n\n def __init__(\n self,\n c1: int,\n c2: int,\n n: int = 1,\n ec: int = 128,\n nh: int = 1,\n gc: int = 512,\n shortcut: bool = False,\n g: int = 1,\n e: float = 0.5,\n ):\n \"\"\"\n Initialize C2f module with attention mechanism.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n ec (int): Embedding channels for attention.\n nh (int): Number of heads for attention.\n gc (int): Guide channels for attention.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n self.c = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2)\n self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))\n self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh)\n\n def forward(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass through C2f layer with attention.\n\n Args:\n x (torch.Tensor): Input tensor.\n guide (torch.Tensor): Guide tensor for attention.\n\n Returns:\n (torch.Tensor): Output tensor after processing.\n \"\"\"\n y = list(self.cv1(x).chunk(2, 1))\n y.extend(m(y[-1]) for m in self.m)\n y.append(self.attn(y[-1], guide))\n return self.cv2(torch.cat(y, 1))\n\n def forward_split(self, x: torch.Tensor, guide: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass using split() instead of chunk().\n\n Args:\n x (torch.Tensor): Input tensor.\n guide (torch.Tensor): Guide tensor for attention.\n\n Returns:\n (torch.Tensor): Output tensor after processing.\n \"\"\"\n y = list(self.cv1(x).split((self.c, self.c), 1))\n y.extend(m(y[-1]) for m in self.m)\n y.append(self.attn(y[-1], guide))\n return self.cv2(torch.cat(y, 1))",
"chunk_type": "class",
"name": "C2fAttn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 641,
"end_line": 707,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "C2f module with an additional attn module.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_C2fAttn_77c66cdf"
},
{
"content": "class ImagePoolingAttn(nn.Module):\n \"\"\"ImagePoolingAttn: Enhance the text embeddings with image-aware information.\"\"\"\n\n def __init__(\n self, ec: int = 256, ch: Tuple[int, ...] = (), ct: int = 512, nh: int = 8, k: int = 3, scale: bool = False\n ):\n \"\"\"\n Initialize ImagePoolingAttn module.\n\n Args:\n ec (int): Embedding channels.\n ch (tuple): Channel dimensions for feature maps.\n ct (int): Channel dimension for text embeddings.\n nh (int): Number of attention heads.\n k (int): Kernel size for pooling.\n scale (bool): Whether to use learnable scale parameter.\n \"\"\"\n super().__init__()\n\n nf = len(ch)\n self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec))\n self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))\n self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))\n self.proj = nn.Linear(ec, ct)\n self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0\n self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch])\n self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)])\n self.ec = ec\n self.nh = nh\n self.nf = nf\n self.hc = ec // nh\n self.k = k\n\n def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass of ImagePoolingAttn.\n\n Args:\n x (List[torch.Tensor]): List of input feature maps.\n text (torch.Tensor): Text embeddings.\n\n Returns:\n (torch.Tensor): Enhanced text embeddings.\n \"\"\"\n bs = x[0].shape[0]\n assert len(x) == self.nf\n num_patches = self.k**2\n x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)]\n x = torch.cat(x, dim=-1).transpose(1, 2)\n q = self.query(text)\n k = self.key(x)\n v = self.value(x)\n\n # q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1)\n q = q.reshape(bs, -1, self.nh, self.hc)\n k = k.reshape(bs, -1, self.nh, self.hc)\n v = v.reshape(bs, -1, self.nh, self.hc)\n\n aw = torch.einsum(\"bnmc,bkmc->bmnk\", q, k)\n aw = aw / (self.hc**0.5)\n aw = F.softmax(aw, dim=-1)\n\n x = torch.einsum(\"bmnk,bkmc->bnmc\", aw, v)\n x = self.proj(x.reshape(bs, -1, self.ec))\n return x * self.scale + text",
"chunk_type": "class",
"name": "ImagePoolingAttn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 710,
"end_line": 774,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": "ImagePoolingAttn: Enhance the text embeddings with image-aware information.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_ImagePoolingAttn_7cec88fc"
},
{
"content": "class ContrastiveHead(nn.Module):\n \"\"\"Implements contrastive learning head for region-text similarity in vision-language models.\"\"\"\n\n def __init__(self):\n \"\"\"Initialize ContrastiveHead with region-text similarity parameters.\"\"\"\n super().__init__()\n # NOTE: use -10.0 to keep the init cls loss consistency with other losses\n self.bias = nn.Parameter(torch.tensor([-10.0]))\n self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())\n\n def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward function of contrastive learning.\n\n Args:\n x (torch.Tensor): Image features.\n w (torch.Tensor): Text features.\n\n Returns:\n (torch.Tensor): Similarity scores.\n \"\"\"\n x = F.normalize(x, dim=1, p=2)\n w = F.normalize(w, dim=-1, p=2)\n x = torch.einsum(\"bchw,bkc->bkhw\", x, w)\n return x * self.logit_scale.exp() + self.bias",
"chunk_type": "class",
"name": "ContrastiveHead",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 777,
"end_line": 801,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": "Implements contrastive learning head for region-text similarity in vision-language models.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_ContrastiveHead_c63f9e5c"
},
{
"content": "class BNContrastiveHead(nn.Module):\n \"\"\"\n Batch Norm Contrastive Head using batch norm instead of l2-normalization.\n\n Args:\n embed_dims (int): Embed dimensions of text and image features.\n \"\"\"\n\n def __init__(self, embed_dims: int):\n \"\"\"\n Initialize BNContrastiveHead.\n\n Args:\n embed_dims (int): Embedding dimensions for features.\n \"\"\"\n super().__init__()\n self.norm = nn.BatchNorm2d(embed_dims)\n # NOTE: use -10.0 to keep the init cls loss consistency with other losses\n self.bias = nn.Parameter(torch.tensor([-10.0]))\n # use -1.0 is more stable\n self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))\n\n def fuse(self):\n \"\"\"Fuse the batch normalization layer in the BNContrastiveHead module.\"\"\"\n del self.norm\n del self.bias\n del self.logit_scale\n self.forward = self.forward_fuse\n\n def forward_fuse(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n \"\"\"Passes input out unchanged.\"\"\"\n return x\n\n def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward function of contrastive learning with batch normalization.\n\n Args:\n x (torch.Tensor): Image features.\n w (torch.Tensor): Text features.\n\n Returns:\n (torch.Tensor): Similarity scores.\n \"\"\"\n x = self.norm(x)\n w = F.normalize(w, dim=-1, p=2)\n\n x = torch.einsum(\"bchw,bkc->bkhw\", x, w)\n return x * self.logit_scale.exp() + self.bias",
"chunk_type": "class",
"name": "BNContrastiveHead",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 804,
"end_line": 852,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": "Batch Norm Contrastive Head using batch norm instead of l2-normalization.\n\nArgs:\n embed_dims (int): Embed dimensions of text and image features.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_BNContrastiveHead_523689c4"
},
{
"content": "class RepBottleneck(Bottleneck):\n \"\"\"Rep bottleneck.\"\"\"\n\n def __init__(\n self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: Tuple[int, int] = (3, 3), e: float = 0.5\n ):\n \"\"\"\n Initialize RepBottleneck.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n shortcut (bool): Whether to use shortcut connection.\n g (int): Groups for convolutions.\n k (tuple): Kernel sizes for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, shortcut, g, k, e)\n c_ = int(c2 * e) # hidden channels\n self.cv1 = RepConv(c1, c_, k[0], 1)",
"chunk_type": "class",
"name": "RepBottleneck",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 855,
"end_line": 874,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": "Rep bottleneck.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"Bottleneck"
],
"chunk_id": "class_RepBottleneck_872e277d"
},
{
"content": "class RepCSP(C3):\n \"\"\"Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize RepCSP layer.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of RepBottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n c_ = int(c2 * e) # hidden channels\n self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))",
"chunk_type": "class",
"name": "RepCSP",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 877,
"end_line": 894,
"start_col": 0,
"end_col": 94,
"parent_name": null,
"docstring": "Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"C3"
],
"chunk_id": "class_RepCSP_44069c52"
},
{
"content": "class RepNCSPELAN4(nn.Module):\n \"\"\"CSP-ELAN.\"\"\"\n\n def __init__(self, c1: int, c2: int, c3: int, c4: int, n: int = 1):\n \"\"\"\n Initialize CSP-ELAN layer.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n c3 (int): Intermediate channels.\n c4 (int): Intermediate channels for RepCSP.\n n (int): Number of RepCSP blocks.\n \"\"\"\n super().__init__()\n self.c = c3 // 2\n self.cv1 = Conv(c1, c3, 1, 1)\n self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1))\n self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1))\n self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through RepNCSPELAN4 layer.\"\"\"\n y = list(self.cv1(x).chunk(2, 1))\n y.extend((m(y[-1])) for m in [self.cv2, self.cv3])\n return self.cv4(torch.cat(y, 1))\n\n def forward_split(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass using split() instead of chunk().\"\"\"\n y = list(self.cv1(x).split((self.c, self.c), 1))\n y.extend(m(y[-1]) for m in [self.cv2, self.cv3])\n return self.cv4(torch.cat(y, 1))",
"chunk_type": "class",
"name": "RepNCSPELAN4",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 897,
"end_line": 928,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "CSP-ELAN.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_RepNCSPELAN4_6f8bb196"
},
{
"content": "class ELAN1(RepNCSPELAN4):\n \"\"\"ELAN1 module with 4 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, c3: int, c4: int):\n \"\"\"\n Initialize ELAN1 layer.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n c3 (int): Intermediate channels.\n c4 (int): Intermediate channels for convolutions.\n \"\"\"\n super().__init__(c1, c2, c3, c4)\n self.c = c3 // 2\n self.cv1 = Conv(c1, c3, 1, 1)\n self.cv2 = Conv(c3 // 2, c4, 3, 1)\n self.cv3 = Conv(c4, c4, 3, 1)\n self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)",
"chunk_type": "class",
"name": "ELAN1",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 931,
"end_line": 949,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": "ELAN1 module with 4 convolutions.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"RepNCSPELAN4"
],
"chunk_id": "class_ELAN1_988f4e6d"
},
{
"content": "class AConv(nn.Module):\n \"\"\"AConv.\"\"\"\n\n def __init__(self, c1: int, c2: int):\n \"\"\"\n Initialize AConv module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n \"\"\"\n super().__init__()\n self.cv1 = Conv(c1, c2, 3, 2, 1)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through AConv layer.\"\"\"\n x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)\n return self.cv1(x)",
"chunk_type": "class",
"name": "AConv",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 952,
"end_line": 969,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": "AConv.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_AConv_6ac2b074"
},
{
"content": "class ADown(nn.Module):\n \"\"\"ADown.\"\"\"\n\n def __init__(self, c1: int, c2: int):\n \"\"\"\n Initialize ADown module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n \"\"\"\n super().__init__()\n self.c = c2 // 2\n self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)\n self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through ADown layer.\"\"\"\n x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)\n x1, x2 = x.chunk(2, 1)\n x1 = self.cv1(x1)\n x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)\n x2 = self.cv2(x2)\n return torch.cat((x1, x2), 1)",
"chunk_type": "class",
"name": "ADown",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 972,
"end_line": 995,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": "ADown.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_ADown_d418e88d"
},
{
"content": "class SPPELAN(nn.Module):\n \"\"\"SPP-ELAN.\"\"\"\n\n def __init__(self, c1: int, c2: int, c3: int, k: int = 5):\n \"\"\"\n Initialize SPP-ELAN block.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n c3 (int): Intermediate channels.\n k (int): Kernel size for max pooling.\n \"\"\"\n super().__init__()\n self.c = c3\n self.cv1 = Conv(c1, c3, 1, 1)\n self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)\n self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)\n self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)\n self.cv5 = Conv(4 * c3, c2, 1, 1)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through SPPELAN layer.\"\"\"\n y = [self.cv1(x)]\n y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])\n return self.cv5(torch.cat(y, 1))",
"chunk_type": "class",
"name": "SPPELAN",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 998,
"end_line": 1023,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "SPP-ELAN.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_SPPELAN_54c70ab7"
},
{
"content": "class CBLinear(nn.Module):\n \"\"\"CBLinear.\"\"\"\n\n def __init__(self, c1: int, c2s: List[int], k: int = 1, s: int = 1, p: Optional[int] = None, g: int = 1):\n \"\"\"\n Initialize CBLinear module.\n\n Args:\n c1 (int): Input channels.\n c2s (List[int]): List of output channel sizes.\n k (int): Kernel size.\n s (int): Stride.\n p (int | None): Padding.\n g (int): Groups.\n \"\"\"\n super().__init__()\n self.c2s = c2s\n self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)\n\n def forward(self, x: torch.Tensor) -> List[torch.Tensor]:\n \"\"\"Forward pass through CBLinear layer.\"\"\"\n return self.conv(x).split(self.c2s, dim=1)",
"chunk_type": "class",
"name": "CBLinear",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1026,
"end_line": 1047,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": "CBLinear.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_CBLinear_8170ffd3"
},
{
"content": "class CBFuse(nn.Module):\n \"\"\"CBFuse.\"\"\"\n\n def __init__(self, idx: List[int]):\n \"\"\"\n Initialize CBFuse module.\n\n Args:\n idx (List[int]): Indices for feature selection.\n \"\"\"\n super().__init__()\n self.idx = idx\n\n def forward(self, xs: List[torch.Tensor]) -> torch.Tensor:\n \"\"\"\n Forward pass through CBFuse layer.\n\n Args:\n xs (List[torch.Tensor]): List of input tensors.\n\n Returns:\n (torch.Tensor): Fused output tensor.\n \"\"\"\n target_size = xs[-1].shape[2:]\n res = [F.interpolate(x[self.idx[i]], size=target_size, mode=\"nearest\") for i, x in enumerate(xs[:-1])]\n return torch.sum(torch.stack(res + xs[-1:]), dim=0)",
"chunk_type": "class",
"name": "CBFuse",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1050,
"end_line": 1075,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": "CBFuse.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_CBFuse_36c0c305"
},
{
"content": "class C3f(nn.Module):\n \"\"\"Faster Implementation of CSP Bottleneck with 2 convolutions.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = False, g: int = 1, e: float = 0.5):\n \"\"\"\n Initialize CSP bottleneck layer with two convolutions.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv(c1, c_, 1, 1)\n self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)\n self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Forward pass through C3f layer.\"\"\"\n y = [self.cv2(x), self.cv1(x)]\n y.extend(m(y[-1]) for m in self.m)\n return self.cv3(torch.cat(y, 1))",
"chunk_type": "class",
"name": "C3f",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1078,
"end_line": 1104,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Faster Implementation of CSP Bottleneck with 2 convolutions.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_C3f_4070fa33"
},
{
"content": "class C3k2(C2f):\n \"\"\"Faster Implementation of CSP Bottleneck with 2 convolutions.\"\"\"\n\n def __init__(\n self, c1: int, c2: int, n: int = 1, c3k: bool = False, e: float = 0.5, g: int = 1, shortcut: bool = True\n ):\n \"\"\"\n Initialize C3k2 module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of blocks.\n c3k (bool): Whether to use C3k blocks.\n e (float): Expansion ratio.\n g (int): Groups for convolutions.\n shortcut (bool): Whether to use shortcut connections.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n self.m = nn.ModuleList(\n C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)\n )",
"chunk_type": "class",
"name": "C3k2",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1107,
"end_line": 1128,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Faster Implementation of CSP Bottleneck with 2 convolutions.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"C2f"
],
"chunk_id": "class_C3k2_37260c67"
},
{
"content": "class C3k(C3):\n \"\"\"C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks.\"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, shortcut: bool = True, g: int = 1, e: float = 0.5, k: int = 3):\n \"\"\"\n Initialize C3k module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of Bottleneck blocks.\n shortcut (bool): Whether to use shortcut connections.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n k (int): Kernel size.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n c_ = int(c2 * e) # hidden channels\n # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))\n self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))",
"chunk_type": "class",
"name": "C3k",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1131,
"end_line": 1150,
"start_col": 0,
"end_col": 101,
"parent_name": null,
"docstring": "C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"C3"
],
"chunk_id": "class_C3k_f04eef9c"
},
{
"content": "class RepVGGDW(torch.nn.Module):\n \"\"\"RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture.\"\"\"\n\n def __init__(self, ed: int) -> None:\n \"\"\"\n Initialize RepVGGDW module.\n\n Args:\n ed (int): Input and output channels.\n \"\"\"\n super().__init__()\n self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)\n self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)\n self.dim = ed\n self.act = nn.SiLU()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform a forward pass of the RepVGGDW block.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after applying the depth wise separable convolution.\n \"\"\"\n return self.act(self.conv(x) + self.conv1(x))\n\n def forward_fuse(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform a forward pass of the RepVGGDW block without fusing the convolutions.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after applying the depth wise separable convolution.\n \"\"\"\n return self.act(self.conv(x))\n\n @torch.no_grad()\n def fuse(self):\n \"\"\"\n Fuse the convolutional layers in the RepVGGDW block.\n\n This method fuses the convolutional layers and updates the weights and biases accordingly.\n \"\"\"\n conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)\n conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)\n\n conv_w = conv.weight\n conv_b = conv.bias\n conv1_w = conv1.weight\n conv1_b = conv1.bias\n\n conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])\n\n final_conv_w = conv_w + conv1_w\n final_conv_b = conv_b + conv1_b\n\n conv.weight.data.copy_(final_conv_w)\n conv.bias.data.copy_(final_conv_b)\n\n self.conv = conv\n del self.conv1",
"chunk_type": "class",
"name": "RepVGGDW",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1153,
"end_line": 1217,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"torch.nn.Module"
],
"chunk_id": "class_RepVGGDW_87a6072f"
},
{
"content": "class CIB(nn.Module):\n \"\"\"\n Conditional Identity Block (CIB) module.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.\n e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.\n lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.\n \"\"\"\n\n def __init__(self, c1: int, c2: int, shortcut: bool = True, e: float = 0.5, lk: bool = False):\n \"\"\"\n Initialize the CIB module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n shortcut (bool): Whether to use shortcut connection.\n e (float): Expansion ratio.\n lk (bool): Whether to use RepVGGDW.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n self.cv1 = nn.Sequential(\n Conv(c1, c1, 3, g=c1),\n Conv(c1, 2 * c_, 1),\n RepVGGDW(2 * c_) if lk else Conv(2 * c_, 2 * c_, 3, g=2 * c_),\n Conv(2 * c_, c2, 1),\n Conv(c2, c2, 3, g=c2),\n )\n\n self.add = shortcut and c1 == c2\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass of the CIB module.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return x + self.cv1(x) if self.add else self.cv1(x)",
"chunk_type": "class",
"name": "CIB",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1220,
"end_line": 1265,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": "Conditional Identity Block (CIB) module.\n\nArgs:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.\n e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.\n lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_CIB_348da215"
},
{
"content": "class C2fCIB(C2f):\n \"\"\"\n C2fCIB class represents a convolutional block with C2f and CIB modules.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n n (int, optional): Number of CIB modules to stack. Defaults to 1.\n shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.\n lk (bool, optional): Whether to use local key connection. Defaults to False.\n g (int, optional): Number of groups for grouped convolution. Defaults to 1.\n e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.\n \"\"\"\n\n def __init__(\n self, c1: int, c2: int, n: int = 1, shortcut: bool = False, lk: bool = False, g: int = 1, e: float = 0.5\n ):\n \"\"\"\n Initialize C2fCIB module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of CIB modules.\n shortcut (bool): Whether to use shortcut connection.\n lk (bool): Whether to use local key connection.\n g (int): Groups for convolutions.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__(c1, c2, n, shortcut, g, e)\n self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))",
"chunk_type": "class",
"name": "C2fCIB",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1268,
"end_line": 1298,
"start_col": 0,
"end_col": 93,
"parent_name": null,
"docstring": "C2fCIB class represents a convolutional block with C2f and CIB modules.\n\nArgs:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n n (int, optional): Number of CIB modules to stack. Defaults to 1.\n shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.\n lk (bool, optional): Whether to use local key connection. Defaults to False.\n g (int, optional): Number of groups for grouped convolution. Defaults to 1.\n e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"C2f"
],
"chunk_id": "class_C2fCIB_ddc7b517"
},
{
"content": "class Attention(nn.Module):\n \"\"\"\n Attention module that performs self-attention on the input tensor.\n\n Args:\n dim (int): The input tensor dimension.\n num_heads (int): The number of attention heads.\n attn_ratio (float): The ratio of the attention key dimension to the head dimension.\n\n Attributes:\n num_heads (int): The number of attention heads.\n head_dim (int): The dimension of each attention head.\n key_dim (int): The dimension of the attention key.\n scale (float): The scaling factor for the attention scores.\n qkv (Conv): Convolutional layer for computing the query, key, and value.\n proj (Conv): Convolutional layer for projecting the attended values.\n pe (Conv): Convolutional layer for positional encoding.\n \"\"\"\n\n def __init__(self, dim: int, num_heads: int = 8, attn_ratio: float = 0.5):\n \"\"\"\n Initialize multi-head attention module.\n\n Args:\n dim (int): Input dimension.\n num_heads (int): Number of attention heads.\n attn_ratio (float): Attention ratio for key dimension.\n \"\"\"\n super().__init__()\n self.num_heads = num_heads\n self.head_dim = dim // num_heads\n self.key_dim = int(self.head_dim * attn_ratio)\n self.scale = self.key_dim**-0.5\n nh_kd = self.key_dim * num_heads\n h = dim + nh_kd * 2\n self.qkv = Conv(dim, h, 1, act=False)\n self.proj = Conv(dim, dim, 1, act=False)\n self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass of the Attention module.\n\n Args:\n x (torch.Tensor): The input tensor.\n\n Returns:\n (torch.Tensor): The output tensor after self-attention.\n \"\"\"\n B, C, H, W = x.shape\n N = H * W\n qkv = self.qkv(x)\n q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(\n [self.key_dim, self.key_dim, self.head_dim], dim=2\n )\n\n attn = (q.transpose(-2, -1) @ k) * self.scale\n attn = attn.softmax(dim=-1)\n x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))\n x = self.proj(x)\n return x",
"chunk_type": "class",
"name": "Attention",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1301,
"end_line": 1361,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Attention module that performs self-attention on the input tensor.\n\nArgs:\n dim (int): The input tensor dimension.\n num_heads (int): The number of attention heads.\n attn_ratio (float): The ratio of the attention key dimension to the head dimension.\n\nAttributes:\n num_heads (int): The number of attention heads.\n head_dim (int): The dimension of each attention head.\n key_dim (int): The dimension of the attention key.\n scale (float): The scaling factor for the attention scores.\n qkv (Conv): Convolutional layer for computing the query, key, and value.\n proj (Conv): Convolutional layer for projecting the attended values.\n pe (Conv): Convolutional layer for positional encoding.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_Attention_0b70783b"
},
{
"content": "class PSABlock(nn.Module):\n \"\"\"\n PSABlock class implementing a Position-Sensitive Attention block for neural networks.\n\n This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers\n with optional shortcut connections.\n\n Attributes:\n attn (Attention): Multi-head attention module.\n ffn (nn.Sequential): Feed-forward neural network module.\n add (bool): Flag indicating whether to add shortcut connections.\n\n Methods:\n forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.\n\n Examples:\n Create a PSABlock and perform a forward pass\n >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)\n >>> input_tensor = torch.randn(1, 128, 32, 32)\n >>> output_tensor = psablock(input_tensor)\n \"\"\"\n\n def __init__(self, c: int, attn_ratio: float = 0.5, num_heads: int = 4, shortcut: bool = True) -> None:\n \"\"\"\n Initialize the PSABlock.\n\n Args:\n c (int): Input and output channels.\n attn_ratio (float): Attention ratio for key dimension.\n num_heads (int): Number of attention heads.\n shortcut (bool): Whether to use shortcut connections.\n \"\"\"\n super().__init__()\n\n self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)\n self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))\n self.add = shortcut\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Execute a forward pass through PSABlock.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after attention and feed-forward processing.\n \"\"\"\n x = x + self.attn(x) if self.add else self.attn(x)\n x = x + self.ffn(x) if self.add else self.ffn(x)\n return x",
"chunk_type": "class",
"name": "PSABlock",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1364,
"end_line": 1414,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "PSABlock class implementing a Position-Sensitive Attention block for neural networks.\n\nThis class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers\nwith optional shortcut connections.\n\nAttributes:\n attn (Attention): Multi-head attention module.\n ffn (nn.Sequential): Feed-forward neural network module.\n add (bool): Flag indicating whether to add shortcut connections.\n\nMethods:\n forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.\n\nExamples:\n Create a PSABlock and perform a forward pass\n >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)\n >>> input_tensor = torch.randn(1, 128, 32, 32)\n >>> output_tensor = psablock(input_tensor)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_PSABlock_04e7ec49"
},
{
"content": "class PSA(nn.Module):\n \"\"\"\n PSA class for implementing Position-Sensitive Attention in neural networks.\n\n This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to\n input tensors, enhancing feature extraction and processing capabilities.\n\n Attributes:\n c (int): Number of hidden channels after applying the initial convolution.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n attn (Attention): Attention module for position-sensitive attention.\n ffn (nn.Sequential): Feed-forward network for further processing.\n\n Methods:\n forward: Applies position-sensitive attention and feed-forward network to the input tensor.\n\n Examples:\n Create a PSA module and apply it to an input tensor\n >>> psa = PSA(c1=128, c2=128, e=0.5)\n >>> input_tensor = torch.randn(1, 128, 64, 64)\n >>> output_tensor = psa.forward(input_tensor)\n \"\"\"\n\n def __init__(self, c1: int, c2: int, e: float = 0.5):\n \"\"\"\n Initialize PSA module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n assert c1 == c2\n self.c = int(c1 * e)\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv(2 * self.c, c1, 1)\n\n self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)\n self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Execute forward pass in PSA module.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after attention and feed-forward processing.\n \"\"\"\n a, b = self.cv1(x).split((self.c, self.c), dim=1)\n b = b + self.attn(b)\n b = b + self.ffn(b)\n return self.cv2(torch.cat((a, b), 1))",
"chunk_type": "class",
"name": "PSA",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1417,
"end_line": 1472,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": "PSA class for implementing Position-Sensitive Attention in neural networks.\n\nThis class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to\ninput tensors, enhancing feature extraction and processing capabilities.\n\nAttributes:\n c (int): Number of hidden channels after applying the initial convolution.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n attn (Attention): Attention module for position-sensitive attention.\n ffn (nn.Sequential): Feed-forward network for further processing.\n\nMethods:\n forward: Applies position-sensitive attention and feed-forward network to the input tensor.\n\nExamples:\n Create a PSA module and apply it to an input tensor\n >>> psa = PSA(c1=128, c2=128, e=0.5)\n >>> input_tensor = torch.randn(1, 128, 64, 64)\n >>> output_tensor = psa.forward(input_tensor)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_PSA_92a2933c"
},
{
"content": "class C2PSA(nn.Module):\n \"\"\"\n C2PSA module with attention mechanism for enhanced feature extraction and processing.\n\n This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing\n capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.\n\n Attributes:\n c (int): Number of hidden channels.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.\n\n Methods:\n forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.\n\n Notes:\n This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.\n\n Examples:\n >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)\n >>> input_tensor = torch.randn(1, 256, 64, 64)\n >>> output_tensor = c2psa(input_tensor)\n \"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C2PSA module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of PSABlock modules.\n e (float): Expansion ratio.\n \"\"\"\n super().__init__()\n assert c1 == c2\n self.c = int(c1 * e)\n self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n self.cv2 = Conv(2 * self.c, c1, 1)\n\n self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Process the input tensor through a series of PSA blocks.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after processing.\n \"\"\"\n a, b = self.cv1(x).split((self.c, self.c), dim=1)\n b = self.m(b)\n return self.cv2(torch.cat((a, b), 1))",
"chunk_type": "class",
"name": "C2PSA",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1475,
"end_line": 1530,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": "C2PSA module with attention mechanism for enhanced feature extraction and processing.\n\nThis module implements a convolutional block with attention mechanisms to enhance feature extraction and processing\ncapabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.\n\nAttributes:\n c (int): Number of hidden channels.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.\n\nMethods:\n forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.\n\nNotes:\n This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.\n\nExamples:\n >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)\n >>> input_tensor = torch.randn(1, 256, 64, 64)\n >>> output_tensor = c2psa(input_tensor)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_C2PSA_f5974732"
},
{
"content": "class C2fPSA(C2f):\n \"\"\"\n C2fPSA module with enhanced feature extraction using PSA blocks.\n\n This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.\n\n Attributes:\n c (int): Number of hidden channels.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n m (nn.ModuleList): List of PSA blocks for feature extraction.\n\n Methods:\n forward: Performs a forward pass through the C2fPSA module.\n forward_split: Performs a forward pass using split() instead of chunk().\n\n Examples:\n >>> import torch\n >>> from ultralytics.models.common import C2fPSA\n >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)\n >>> x = torch.randn(1, 64, 128, 128)\n >>> output = model(x)\n >>> print(output.shape)\n \"\"\"\n\n def __init__(self, c1: int, c2: int, n: int = 1, e: float = 0.5):\n \"\"\"\n Initialize C2fPSA module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n n (int): Number of PSABlock modules.\n e (float): Expansion ratio.\n \"\"\"\n assert c1 == c2\n super().__init__(c1, c2, n=n, e=e)\n self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))",
"chunk_type": "class",
"name": "C2fPSA",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1533,
"end_line": 1570,
"start_col": 0,
"end_col": 106,
"parent_name": null,
"docstring": "C2fPSA module with enhanced feature extraction using PSA blocks.\n\nThis class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.\n\nAttributes:\n c (int): Number of hidden channels.\n cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.\n cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.\n m (nn.ModuleList): List of PSA blocks for feature extraction.\n\nMethods:\n forward: Performs a forward pass through the C2fPSA module.\n forward_split: Performs a forward pass using split() instead of chunk().\n\nExamples:\n >>> import torch\n >>> from ultralytics.models.common import C2fPSA\n >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)\n >>> x = torch.randn(1, 64, 128, 128)\n >>> output = model(x)\n >>> print(output.shape)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"C2f"
],
"chunk_id": "class_C2fPSA_8f7445a7"
},
{
"content": "class SCDown(nn.Module):\n \"\"\"\n SCDown module for downsampling with separable convolutions.\n\n This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in\n efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.\n\n Attributes:\n cv1 (Conv): Pointwise convolution layer that reduces the number of channels.\n cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.\n\n Methods:\n forward: Applies the SCDown module to the input tensor.\n\n Examples:\n >>> import torch\n >>> from ultralytics import SCDown\n >>> model = SCDown(c1=64, c2=128, k=3, s=2)\n >>> x = torch.randn(1, 64, 128, 128)\n >>> y = model(x)\n >>> print(y.shape)\n torch.Size([1, 128, 64, 64])\n \"\"\"\n\n def __init__(self, c1: int, c2: int, k: int, s: int):\n \"\"\"\n Initialize SCDown module.\n\n Args:\n c1 (int): Input channels.\n c2 (int): Output channels.\n k (int): Kernel size.\n s (int): Stride.\n \"\"\"\n super().__init__()\n self.cv1 = Conv(c1, c2, 1, 1)\n self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply convolution and downsampling to the input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Downsampled output tensor.\n \"\"\"\n return self.cv2(self.cv1(x))",
"chunk_type": "class",
"name": "SCDown",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1573,
"end_line": 1621,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": "SCDown module for downsampling with separable convolutions.\n\nThis module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in\nefficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.\n\nAttributes:\n cv1 (Conv): Pointwise convolution layer that reduces the number of channels.\n cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.\n\nMethods:\n forward: Applies the SCDown module to the input tensor.\n\nExamples:\n >>> import torch\n >>> from ultralytics import SCDown\n >>> model = SCDown(c1=64, c2=128, k=3, s=2)\n >>> x = torch.randn(1, 64, 128, 128)\n >>> y = model(x)\n >>> print(y.shape)\n torch.Size([1, 128, 64, 64])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_SCDown_49a4a000"
},
{
"content": "class TorchVision(nn.Module):\n \"\"\"\n TorchVision module to allow loading any torchvision model.\n\n This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.\n\n Attributes:\n m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.\n\n Args:\n model (str): Name of the torchvision model to load.\n weights (str, optional): Pre-trained weights to load. Default is \"DEFAULT\".\n unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.\n truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.\n split (bool, optional): Returns output from intermediate child modules as list. Default is False.\n \"\"\"\n\n def __init__(\n self, model: str, weights: str = \"DEFAULT\", unwrap: bool = True, truncate: int = 2, split: bool = False\n ):\n \"\"\"\n Load the model and weights from torchvision.\n\n Args:\n model (str): Name of the torchvision model to load.\n weights (str): Pre-trained weights to load.\n unwrap (bool): Whether to unwrap the model.\n truncate (int): Number of layers to truncate.\n split (bool): Whether to split the output.\n \"\"\"\n import torchvision # scope for faster 'import ultralytics'\n\n super().__init__()\n if hasattr(torchvision.models, \"get_model\"):\n self.m = torchvision.models.get_model(model, weights=weights)\n else:\n self.m = torchvision.models.__dict__[model](pretrained=bool(weights))\n if unwrap:\n layers = list(self.m.children())\n if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin\n layers = [*list(layers[0].children()), *layers[1:]]\n self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers))\n self.split = split\n else:\n self.split = False\n self.m.head = self.m.heads = nn.Identity()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass through the model.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor | List[torch.Tensor]): Output tensor or list of tensors.\n \"\"\"\n if self.split:\n y = [x]\n y.extend(m(y[-1]) for m in self.m)\n else:\n y = self.m(x)\n return y",
"chunk_type": "class",
"name": "TorchVision",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1624,
"end_line": 1686,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "TorchVision module to allow loading any torchvision model.\n\nThis class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.\n\nAttributes:\n m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.\n\nArgs:\n model (str): Name of the torchvision model to load.\n weights (str, optional): Pre-trained weights to load. Default is \"DEFAULT\".\n unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.\n truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.\n split (bool, optional): Returns output from intermediate child modules as list. Default is False.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_TorchVision_6883286a"
},
{
"content": "class AAttn(nn.Module):\n \"\"\"\n Area-attention module for YOLO models, providing efficient attention mechanisms.\n\n This module implements an area-based attention mechanism that processes input features in a spatially-aware manner,\n making it particularly effective for object detection tasks.\n\n Attributes:\n area (int): Number of areas the feature map is divided.\n num_heads (int): Number of heads into which the attention mechanism is divided.\n head_dim (int): Dimension of each attention head.\n qkv (Conv): Convolution layer for computing query, key and value tensors.\n proj (Conv): Projection convolution layer.\n pe (Conv): Position encoding convolution layer.\n\n Methods:\n forward: Applies area-attention to input tensor.\n\n Examples:\n >>> attn = AAttn(dim=256, num_heads=8, area=4)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = attn(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])\n \"\"\"\n\n def __init__(self, dim: int, num_heads: int, area: int = 1):\n \"\"\"\n Initialize an Area-attention module for YOLO models.\n\n Args:\n dim (int): Number of hidden channels.\n num_heads (int): Number of heads into which the attention mechanism is divided.\n area (int): Number of areas the feature map is divided.\n \"\"\"\n super().__init__()\n self.area = area\n\n self.num_heads = num_heads\n self.head_dim = head_dim = dim // num_heads\n all_head_dim = head_dim * self.num_heads\n\n self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)\n self.proj = Conv(all_head_dim, dim, 1, act=False)\n self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Process the input tensor through the area-attention.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after area-attention.\n \"\"\"\n B, C, H, W = x.shape\n N = H * W\n\n qkv = self.qkv(x).flatten(2).transpose(1, 2)\n if self.area > 1:\n qkv = qkv.reshape(B * self.area, N // self.area, C * 3)\n B, N, _ = qkv.shape\n q, k, v = (\n qkv.view(B, N, self.num_heads, self.head_dim * 3)\n .permute(0, 2, 3, 1)\n .split([self.head_dim, self.head_dim, self.head_dim], dim=2)\n )\n attn = (q.transpose(-2, -1) @ k) * (self.head_dim**-0.5)\n attn = attn.softmax(dim=-1)\n x = v @ attn.transpose(-2, -1)\n x = x.permute(0, 3, 1, 2)\n v = v.permute(0, 3, 1, 2)\n\n if self.area > 1:\n x = x.reshape(B // self.area, N * self.area, C)\n v = v.reshape(B // self.area, N * self.area, C)\n B, N, _ = x.shape\n\n x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()\n v = v.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()\n\n x = x + self.pe(v)\n return self.proj(x)",
"chunk_type": "class",
"name": "AAttn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1689,
"end_line": 1772,
"start_col": 0,
"end_col": 27,
"parent_name": null,
"docstring": "Area-attention module for YOLO models, providing efficient attention mechanisms.\n\nThis module implements an area-based attention mechanism that processes input features in a spatially-aware manner,\nmaking it particularly effective for object detection tasks.\n\nAttributes:\n area (int): Number of areas the feature map is divided.\n num_heads (int): Number of heads into which the attention mechanism is divided.\n head_dim (int): Dimension of each attention head.\n qkv (Conv): Convolution layer for computing query, key and value tensors.\n proj (Conv): Projection convolution layer.\n pe (Conv): Position encoding convolution layer.\n\nMethods:\n forward: Applies area-attention to input tensor.\n\nExamples:\n >>> attn = AAttn(dim=256, num_heads=8, area=4)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = attn(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_AAttn_f769942c"
},
{
"content": "class ABlock(nn.Module):\n \"\"\"\n Area-attention block module for efficient feature extraction in YOLO models.\n\n This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.\n It uses a novel area-based attention approach that is more efficient than traditional self-attention while\n maintaining effectiveness.\n\n Attributes:\n attn (AAttn): Area-attention module for processing spatial features.\n mlp (nn.Sequential): Multi-layer perceptron for feature transformation.\n\n Methods:\n _init_weights: Initializes module weights using truncated normal distribution.\n forward: Applies area-attention and feed-forward processing to input tensor.\n\n Examples:\n >>> block = ABlock(dim=256, num_heads=8, mlp_ratio=1.2, area=1)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])\n \"\"\"\n\n def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 1.2, area: int = 1):\n \"\"\"\n Initialize an Area-attention block module.\n\n Args:\n dim (int): Number of input channels.\n num_heads (int): Number of heads into which the attention mechanism is divided.\n mlp_ratio (float): Expansion ratio for MLP hidden dimension.\n area (int): Number of areas the feature map is divided.\n \"\"\"\n super().__init__()\n\n self.attn = AAttn(dim, num_heads=num_heads, area=area)\n mlp_hidden_dim = int(dim * mlp_ratio)\n self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))\n\n self.apply(self._init_weights)\n\n def _init_weights(self, m: nn.Module):\n \"\"\"\n Initialize weights using a truncated normal distribution.\n\n Args:\n m (nn.Module): Module to initialize.\n \"\"\"\n if isinstance(m, nn.Conv2d):\n nn.init.trunc_normal_(m.weight, std=0.02)\n if m.bias is not None:\n nn.init.constant_(m.bias, 0)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass through ABlock.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after area-attention and feed-forward processing.\n \"\"\"\n x = x + self.attn(x)\n return x + self.mlp(x)",
"chunk_type": "class",
"name": "ABlock",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1775,
"end_line": 1840,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": "Area-attention block module for efficient feature extraction in YOLO models.\n\nThis module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.\nIt uses a novel area-based attention approach that is more efficient than traditional self-attention while\nmaintaining effectiveness.\n\nAttributes:\n attn (AAttn): Area-attention module for processing spatial features.\n mlp (nn.Sequential): Multi-layer perceptron for feature transformation.\n\nMethods:\n _init_weights: Initializes module weights using truncated normal distribution.\n forward: Applies area-attention and feed-forward processing to input tensor.\n\nExamples:\n >>> block = ABlock(dim=256, num_heads=8, mlp_ratio=1.2, area=1)\n >>> x = torch.randn(1, 256, 32, 32)\n >>> output = block(x)\n >>> print(output.shape)\n torch.Size([1, 256, 32, 32])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_ABlock_d1e5ffc8"
},
{
"content": "class A2C2f(nn.Module):\n \"\"\"\n Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.\n\n This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature\n processing. It supports both area-attention and standard convolution modes.\n\n Attributes:\n cv1 (Conv): Initial 1x1 convolution layer that reduces input channels to hidden channels.\n cv2 (Conv): Final 1x1 convolution layer that processes concatenated features.\n gamma (nn.Parameter | None): Learnable parameter for residual scaling when using area attention.\n m (nn.ModuleList): List of either ABlock or C3k modules for feature processing.\n\n Methods:\n forward: Processes input through area-attention or standard convolution pathway.\n\n Examples:\n >>> m = A2C2f(512, 512, n=1, a2=True, area=1)\n >>> x = torch.randn(1, 512, 32, 32)\n >>> output = m(x)\n >>> print(output.shape)\n torch.Size([1, 512, 32, 32])\n \"\"\"\n\n def __init__(\n self,\n c1: int,\n c2: int,\n n: int = 1,\n a2: bool = True,\n area: int = 1,\n residual: bool = False,\n mlp_ratio: float = 2.0,\n e: float = 0.5,\n g: int = 1,\n shortcut: bool = True,\n ):\n \"\"\"\n Initialize Area-Attention C2f module.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n n (int): Number of ABlock or C3k modules to stack.\n a2 (bool): Whether to use area attention blocks. If False, uses C3k blocks instead.\n area (int): Number of areas the feature map is divided.\n residual (bool): Whether to use residual connections with learnable gamma parameter.\n mlp_ratio (float): Expansion ratio for MLP hidden dimension.\n e (float): Channel expansion ratio for hidden channels.\n g (int): Number of groups for grouped convolutions.\n shortcut (bool): Whether to use shortcut connections in C3k blocks.\n \"\"\"\n super().__init__()\n c_ = int(c2 * e) # hidden channels\n assert c_ % 32 == 0, \"Dimension of ABlock be a multiple of 32.\"\n\n self.cv1 = Conv(c1, c_, 1, 1)\n self.cv2 = Conv((1 + n) * c_, c2, 1)\n\n self.gamma = nn.Parameter(0.01 * torch.ones(c2), requires_grad=True) if a2 and residual else None\n self.m = nn.ModuleList(\n nn.Sequential(*(ABlock(c_, c_ // 32, mlp_ratio, area) for _ in range(2)))\n if a2\n else C3k(c_, c_, 2, shortcut, g)\n for _ in range(n)\n )\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass through A2C2f layer.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after processing.\n \"\"\"\n y = [self.cv1(x)]\n y.extend(m(y[-1]) for m in self.m)\n y = self.cv2(torch.cat(y, 1))\n if self.gamma is not None:\n return x + self.gamma.view(-1, len(self.gamma), 1, 1) * y\n return y",
"chunk_type": "class",
"name": "A2C2f",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1843,
"end_line": 1925,
"start_col": 0,
"end_col": 16,
"parent_name": null,
"docstring": "Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.\n\nThis module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature\nprocessing. It supports both area-attention and standard convolution modes.\n\nAttributes:\n cv1 (Conv): Initial 1x1 convolution layer that reduces input channels to hidden channels.\n cv2 (Conv): Final 1x1 convolution layer that processes concatenated features.\n gamma (nn.Parameter | None): Learnable parameter for residual scaling when using area attention.\n m (nn.ModuleList): List of either ABlock or C3k modules for feature processing.\n\nMethods:\n forward: Processes input through area-attention or standard convolution pathway.\n\nExamples:\n >>> m = A2C2f(512, 512, n=1, a2=True, area=1)\n >>> x = torch.randn(1, 512, 32, 32)\n >>> output = m(x)\n >>> print(output.shape)\n torch.Size([1, 512, 32, 32])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_A2C2f_da7e61c0"
},
{
"content": "class SwiGLUFFN(nn.Module):\n \"\"\"SwiGLU Feed-Forward Network for transformer-based architectures.\"\"\"\n\n def __init__(self, gc: int, ec: int, e: int = 4) -> None:\n \"\"\"\n Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor.\n\n Args:\n gc (int): Guide channels.\n ec (int): Embedding channels.\n e (int): Expansion factor.\n \"\"\"\n super().__init__()\n self.w12 = nn.Linear(gc, e * ec)\n self.w3 = nn.Linear(e * ec // 2, ec)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply SwiGLU transformation to input features.\"\"\"\n x12 = self.w12(x)\n x1, x2 = x12.chunk(2, dim=-1)\n hidden = F.silu(x1) * x2\n return self.w3(hidden)",
"chunk_type": "class",
"name": "SwiGLUFFN",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1928,
"end_line": 1949,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": "SwiGLU Feed-Forward Network for transformer-based architectures.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_SwiGLUFFN_2e5d045a"
},
{
"content": "class Residual(nn.Module):\n \"\"\"Residual connection wrapper for neural network modules.\"\"\"\n\n def __init__(self, m: nn.Module) -> None:\n \"\"\"\n Initialize residual module with the wrapped module.\n\n Args:\n m (nn.Module): Module to wrap with residual connection.\n \"\"\"\n super().__init__()\n self.m = m\n nn.init.zeros_(self.m.w3.bias)\n # For models with l scale, please change the initialization to\n # nn.init.constant_(self.m.w3.weight, 1e-6)\n nn.init.zeros_(self.m.w3.weight)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"Apply residual connection to input features.\"\"\"\n return x + self.m(x)",
"chunk_type": "class",
"name": "Residual",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1952,
"end_line": 1971,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Residual connection wrapper for neural network modules.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_Residual_4db16cd0"
},
{
"content": "class SAVPE(nn.Module):\n \"\"\"Spatial-Aware Visual Prompt Embedding module for feature enhancement.\"\"\"\n\n def __init__(self, ch: List[int], c3: int, embed: int):\n \"\"\"\n Initialize SAVPE module with channels, intermediate channels, and embedding dimension.\n\n Args:\n ch (List[int]): List of input channel dimensions.\n c3 (int): Intermediate channels.\n embed (int): Embedding dimension.\n \"\"\"\n super().__init__()\n self.cv1 = nn.ModuleList(\n nn.Sequential(\n Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()\n )\n for i, x in enumerate(ch)\n )\n\n self.cv2 = nn.ModuleList(\n nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())\n for i, x in enumerate(ch)\n )\n\n self.c = 16\n self.cv3 = nn.Conv2d(3 * c3, embed, 1)\n self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)\n self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)\n self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))\n\n def forward(self, x: List[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:\n \"\"\"Process input features and visual prompts to generate enhanced embeddings.\"\"\"\n y = [self.cv2[i](xi) for i, xi in enumerate(x)]\n y = self.cv4(torch.cat(y, dim=1))\n\n x = [self.cv1[i](xi) for i, xi in enumerate(x)]\n x = self.cv3(torch.cat(x, dim=1))\n\n B, C, H, W = x.shape\n\n Q = vp.shape[1]\n\n x = x.view(B, C, -1)\n\n y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)\n vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)\n\n y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))\n\n y = y.reshape(B, Q, self.c, -1)\n vp = vp.reshape(B, Q, 1, -1)\n\n score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min\n\n score = F.softmax(score, dim=-1, dtype=torch.float).to(score.dtype)\n\n aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)\n\n return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)",
"chunk_type": "class",
"name": "SAVPE",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\block.py",
"start_line": 1974,
"end_line": 2033,
"start_col": 0,
"end_col": 87,
"parent_name": null,
"docstring": "Spatial-Aware Visual Prompt Embedding module for feature enhancement.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"typing.List",
"typing.Optional",
"typing.Tuple",
"torch",
"torch.nn",
"torch.nn.functional",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"conv.Conv",
"conv.DWConv",
"conv.GhostConv",
"conv.LightConv",
"conv.RepConv",
"conv.autopad",
"transformer.TransformerBlock",
"torchvision",
"nn.Module"
],
"chunk_id": "class_SAVPE_0cc8314c"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_197fdc46"
},
{
"content": "from typing import List",
"chunk_type": "import",
"name": "List",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List_aa56bce6"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_4e2a9c1b"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_905ccb83"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_ee8edbaf"
},
{
"content": "__all__ = (\n \"Conv\",\n \"Conv2\",\n \"LightConv\",\n \"DWConv\",\n \"DWConvTranspose2d\",\n \"ConvTranspose\",\n \"Focus\",\n \"GhostConv\",\n \"ChannelAttention\",\n \"SpatialAttention\",\n \"CBAM\",\n \"Concat\",\n \"RepConv\",\n \"Index\",\n)",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 11,
"end_line": 26,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___baa4b10a"
},
{
"content": "def autopad(k, p=None, d=1): # kernel, padding, dilation\n \"\"\"Pad to 'same' shape outputs.\"\"\"\n if d > 1:\n k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size\n if p is None:\n p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad\n return p",
"chunk_type": "function",
"name": "autopad",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 29,
"end_line": 35,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": "Pad to 'same' shape outputs.",
"parameters": [
"k",
"p",
"d"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn"
],
"chunk_id": "function_autopad_28ae3a8c"
},
{
"content": "class Conv(nn.Module):\n \"\"\"\n Standard convolution module with batch normalization and activation.\n\n Attributes:\n conv (nn.Conv2d): Convolutional layer.\n bn (nn.BatchNorm2d): Batch normalization layer.\n act (nn.Module): Activation function layer.\n default_act (nn.Module): Default activation function (SiLU).\n \"\"\"\n\n default_act = nn.SiLU() # default activation\n\n def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):\n \"\"\"\n Initialize Conv layer with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int, optional): Padding.\n g (int): Groups.\n d (int): Dilation.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)\n self.bn = nn.BatchNorm2d(c2)\n self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()\n\n def forward(self, x):\n \"\"\"\n Apply convolution, batch normalization and activation to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.bn(self.conv(x)))\n\n def forward_fuse(self, x):\n \"\"\"\n Apply convolution and activation without batch normalization.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.conv(x))",
"chunk_type": "class",
"name": "Conv",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 38,
"end_line": 92,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": "Standard convolution module with batch normalization and activation.\n\nAttributes:\n conv (nn.Conv2d): Convolutional layer.\n bn (nn.BatchNorm2d): Batch normalization layer.\n act (nn.Module): Activation function layer.\n default_act (nn.Module): Default activation function (SiLU).",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_Conv_1bc938d1"
},
{
"content": "class Conv2(Conv):\n \"\"\"\n Simplified RepConv module with Conv fusing.\n\n Attributes:\n conv (nn.Conv2d): Main 3x3 convolutional layer.\n cv2 (nn.Conv2d): Additional 1x1 convolutional layer.\n bn (nn.BatchNorm2d): Batch normalization layer.\n act (nn.Module): Activation function layer.\n \"\"\"\n\n def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):\n \"\"\"\n Initialize Conv2 layer with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int, optional): Padding.\n g (int): Groups.\n d (int): Dilation.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)\n self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv\n\n def forward(self, x):\n \"\"\"\n Apply convolution, batch normalization and activation to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.bn(self.conv(x) + self.cv2(x)))\n\n def forward_fuse(self, x):\n \"\"\"\n Apply fused convolution, batch normalization and activation to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.bn(self.conv(x)))\n\n def fuse_convs(self):\n \"\"\"Fuse parallel convolutions.\"\"\"\n w = torch.zeros_like(self.conv.weight.data)\n i = [x // 2 for x in w.shape[2:]]\n w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone()\n self.conv.weight.data += w\n self.__delattr__(\"cv2\")\n self.forward = self.forward_fuse",
"chunk_type": "class",
"name": "Conv2",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 95,
"end_line": 154,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Simplified RepConv module with Conv fusing.\n\nAttributes:\n conv (nn.Conv2d): Main 3x3 convolutional layer.\n cv2 (nn.Conv2d): Additional 1x1 convolutional layer.\n bn (nn.BatchNorm2d): Batch normalization layer.\n act (nn.Module): Activation function layer.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"Conv"
],
"chunk_id": "class_Conv2_f2d55d14"
},
{
"content": "class LightConv(nn.Module):\n \"\"\"\n Light convolution module with 1x1 and depthwise convolutions.\n\n This implementation is based on the PaddleDetection HGNetV2 backbone.\n\n Attributes:\n conv1 (Conv): 1x1 convolution layer.\n conv2 (DWConv): Depthwise convolution layer.\n \"\"\"\n\n def __init__(self, c1, c2, k=1, act=nn.ReLU()):\n \"\"\"\n Initialize LightConv layer with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size for depthwise convolution.\n act (nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.conv1 = Conv(c1, c2, 1, act=False)\n self.conv2 = DWConv(c2, c2, k, act=act)\n\n def forward(self, x):\n \"\"\"\n Apply 2 convolutions to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.conv2(self.conv1(x))",
"chunk_type": "class",
"name": "LightConv",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 157,
"end_line": 192,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Light convolution module with 1x1 and depthwise convolutions.\n\nThis implementation is based on the PaddleDetection HGNetV2 backbone.\n\nAttributes:\n conv1 (Conv): 1x1 convolution layer.\n conv2 (DWConv): Depthwise convolution layer.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_LightConv_f71a4c9f"
},
{
"content": "class DWConv(Conv):\n \"\"\"Depth-wise convolution module.\"\"\"\n\n def __init__(self, c1, c2, k=1, s=1, d=1, act=True):\n \"\"\"\n Initialize depth-wise convolution with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n d (int): Dilation.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)",
"chunk_type": "class",
"name": "DWConv",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 195,
"end_line": 210,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": "Depth-wise convolution module.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"Conv"
],
"chunk_id": "class_DWConv_e26260ac"
},
{
"content": "class DWConvTranspose2d(nn.ConvTranspose2d):\n \"\"\"Depth-wise transpose convolution module.\"\"\"\n\n def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):\n \"\"\"\n Initialize depth-wise transpose convolution with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p1 (int): Padding.\n p2 (int): Output padding.\n \"\"\"\n super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))",
"chunk_type": "class",
"name": "DWConvTranspose2d",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 213,
"end_line": 228,
"start_col": 0,
"end_col": 71,
"parent_name": null,
"docstring": "Depth-wise transpose convolution module.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.ConvTranspose2d"
],
"chunk_id": "class_DWConvTranspose2d_b683d7f6"
},
{
"content": "class ConvTranspose(nn.Module):\n \"\"\"\n Convolution transpose module with optional batch normalization and activation.\n\n Attributes:\n conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.\n bn (nn.BatchNorm2d | nn.Identity): Batch normalization layer.\n act (nn.Module): Activation function layer.\n default_act (nn.Module): Default activation function (SiLU).\n \"\"\"\n\n default_act = nn.SiLU() # default activation\n\n def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):\n \"\"\"\n Initialize ConvTranspose layer with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int): Padding.\n bn (bool): Use batch normalization.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)\n self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()\n self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()\n\n def forward(self, x):\n \"\"\"\n Apply transposed convolution, batch normalization and activation to input.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.bn(self.conv_transpose(x)))\n\n def forward_fuse(self, x):\n \"\"\"\n Apply activation and convolution transpose operation to input.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.conv_transpose(x))",
"chunk_type": "class",
"name": "ConvTranspose",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 231,
"end_line": 284,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Convolution transpose module with optional batch normalization and activation.\n\nAttributes:\n conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.\n bn (nn.BatchNorm2d | nn.Identity): Batch normalization layer.\n act (nn.Module): Activation function layer.\n default_act (nn.Module): Default activation function (SiLU).",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_ConvTranspose_8c6636c1"
},
{
"content": "class Focus(nn.Module):\n \"\"\"\n Focus module for concentrating feature information.\n\n Slices input tensor into 4 parts and concatenates them in the channel dimension.\n\n Attributes:\n conv (Conv): Convolution layer.\n \"\"\"\n\n def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):\n \"\"\"\n Initialize Focus module with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int, optional): Padding.\n g (int): Groups.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)\n # self.contract = Contract(gain=2)\n\n def forward(self, x):\n \"\"\"\n Apply Focus operation and convolution to input tensor.\n\n Input shape is (B, C, W, H) and output shape is (B, 4C, W/2, H/2).\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))",
"chunk_type": "class",
"name": "Focus",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 287,
"end_line": 326,
"start_col": 0,
"end_col": 116,
"parent_name": null,
"docstring": "Focus module for concentrating feature information.\n\nSlices input tensor into 4 parts and concatenates them in the channel dimension.\n\nAttributes:\n conv (Conv): Convolution layer.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_Focus_023c9a71"
},
{
"content": "class GhostConv(nn.Module):\n \"\"\"\n Ghost Convolution module.\n\n Generates more features with fewer parameters by using cheap operations.\n\n Attributes:\n cv1 (Conv): Primary convolution.\n cv2 (Conv): Cheap operation convolution.\n\n References:\n https://github.com/huawei-noah/Efficient-AI-Backbones\n \"\"\"\n\n def __init__(self, c1, c2, k=1, s=1, g=1, act=True):\n \"\"\"\n Initialize Ghost Convolution module with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n g (int): Groups.\n act (bool | nn.Module): Activation function.\n \"\"\"\n super().__init__()\n c_ = c2 // 2 # hidden channels\n self.cv1 = Conv(c1, c_, k, s, None, g, act=act)\n self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)\n\n def forward(self, x):\n \"\"\"\n Apply Ghost Convolution to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor with concatenated features.\n \"\"\"\n y = self.cv1(x)\n return torch.cat((y, self.cv2(y)), 1)",
"chunk_type": "class",
"name": "GhostConv",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 330,
"end_line": 372,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": "Ghost Convolution module.\n\nGenerates more features with fewer parameters by using cheap operations.\n\nAttributes:\n cv1 (Conv): Primary convolution.\n cv2 (Conv): Cheap operation convolution.\n\nReferences:\n https://github.com/huawei-noah/Efficient-AI-Backbones",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_GhostConv_f69fae21"
},
{
"content": "class RepConv(nn.Module):\n \"\"\"\n RepConv module with training and deploy modes.\n\n This module is used in RT-DETR and can fuse convolutions during inference for efficiency.\n\n Attributes:\n conv1 (Conv): 3x3 convolution.\n conv2 (Conv): 1x1 convolution.\n bn (nn.BatchNorm2d, optional): Batch normalization for identity branch.\n act (nn.Module): Activation function.\n default_act (nn.Module): Default activation function (SiLU).\n\n References:\n https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py\n \"\"\"\n\n default_act = nn.SiLU() # default activation\n\n def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):\n \"\"\"\n Initialize RepConv module with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output channels.\n k (int): Kernel size.\n s (int): Stride.\n p (int): Padding.\n g (int): Groups.\n d (int): Dilation.\n act (bool | nn.Module): Activation function.\n bn (bool): Use batch normalization for identity branch.\n deploy (bool): Deploy mode for inference.\n \"\"\"\n super().__init__()\n assert k == 3 and p == 1\n self.g = g\n self.c1 = c1\n self.c2 = c2\n self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()\n\n self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None\n self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)\n self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)\n\n def forward_fuse(self, x):\n \"\"\"\n Forward pass for deploy mode.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n return self.act(self.conv(x))\n\n def forward(self, x):\n \"\"\"\n Forward pass for training mode.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor.\n \"\"\"\n id_out = 0 if self.bn is None else self.bn(x)\n return self.act(self.conv1(x) + self.conv2(x) + id_out)\n\n def get_equivalent_kernel_bias(self):\n \"\"\"\n Calculate equivalent kernel and bias by fusing convolutions.\n\n Returns:\n (torch.Tensor): Equivalent kernel\n (torch.Tensor): Equivalent bias\n \"\"\"\n kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)\n kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)\n kernelid, biasid = self._fuse_bn_tensor(self.bn)\n return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid\n\n @staticmethod\n def _pad_1x1_to_3x3_tensor(kernel1x1):\n \"\"\"\n Pad a 1x1 kernel to 3x3 size.\n\n Args:\n kernel1x1 (torch.Tensor): 1x1 convolution kernel.\n\n Returns:\n (torch.Tensor): Padded 3x3 kernel.\n \"\"\"\n if kernel1x1 is None:\n return 0\n else:\n return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])\n\n def _fuse_bn_tensor(self, branch):\n \"\"\"\n Fuse batch normalization with convolution weights.\n\n Args:\n branch (Conv | nn.BatchNorm2d | None): Branch to fuse.\n\n Returns:\n kernel (torch.Tensor): Fused kernel.\n bias (torch.Tensor): Fused bias.\n \"\"\"\n if branch is None:\n return 0, 0\n if isinstance(branch, Conv):\n kernel = branch.conv.weight\n running_mean = branch.bn.running_mean\n running_var = branch.bn.running_var\n gamma = branch.bn.weight\n beta = branch.bn.bias\n eps = branch.bn.eps\n elif isinstance(branch, nn.BatchNorm2d):\n if not hasattr(self, \"id_tensor\"):\n input_dim = self.c1 // self.g\n kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)\n for i in range(self.c1):\n kernel_value[i, i % input_dim, 1, 1] = 1\n self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)\n kernel = self.id_tensor\n running_mean = branch.running_mean\n running_var = branch.running_var\n gamma = branch.weight\n beta = branch.bias\n eps = branch.eps\n std = (running_var + eps).sqrt()\n t = (gamma / std).reshape(-1, 1, 1, 1)\n return kernel * t, beta - running_mean * gamma / std\n\n def fuse_convs(self):\n \"\"\"Fuse convolutions for inference by creating a single equivalent convolution.\"\"\"\n if hasattr(self, \"conv\"):\n return\n kernel, bias = self.get_equivalent_kernel_bias()\n self.conv = nn.Conv2d(\n in_channels=self.conv1.conv.in_channels,\n out_channels=self.conv1.conv.out_channels,\n kernel_size=self.conv1.conv.kernel_size,\n stride=self.conv1.conv.stride,\n padding=self.conv1.conv.padding,\n dilation=self.conv1.conv.dilation,\n groups=self.conv1.conv.groups,\n bias=True,\n ).requires_grad_(False)\n self.conv.weight.data = kernel\n self.conv.bias.data = bias\n for para in self.parameters():\n para.detach_()\n self.__delattr__(\"conv1\")\n self.__delattr__(\"conv2\")\n if hasattr(self, \"nm\"):\n self.__delattr__(\"nm\")\n if hasattr(self, \"bn\"):\n self.__delattr__(\"bn\")\n if hasattr(self, \"id_tensor\"):\n self.__delattr__(\"id_tensor\")",
"chunk_type": "class",
"name": "RepConv",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 375,
"end_line": 538,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": "RepConv module with training and deploy modes.\n\nThis module is used in RT-DETR and can fuse convolutions during inference for efficiency.\n\nAttributes:\n conv1 (Conv): 3x3 convolution.\n conv2 (Conv): 1x1 convolution.\n bn (nn.BatchNorm2d, optional): Batch normalization for identity branch.\n act (nn.Module): Activation function.\n default_act (nn.Module): Default activation function (SiLU).\n\nReferences:\n https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_RepConv_d1e80245"
},
{
"content": "class ChannelAttention(nn.Module):\n \"\"\"\n Channel-attention module for feature recalibration.\n\n Applies attention weights to channels based on global average pooling.\n\n Attributes:\n pool (nn.AdaptiveAvgPool2d): Global average pooling.\n fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution.\n act (nn.Sigmoid): Sigmoid activation for attention weights.\n\n References:\n https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet\n \"\"\"\n\n def __init__(self, channels: int) -> None:\n \"\"\"\n Initialize Channel-attention module.\n\n Args:\n channels (int): Number of input channels.\n \"\"\"\n super().__init__()\n self.pool = nn.AdaptiveAvgPool2d(1)\n self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)\n self.act = nn.Sigmoid()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply channel attention to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Channel-attended output tensor.\n \"\"\"\n return x * self.act(self.fc(self.pool(x)))",
"chunk_type": "class",
"name": "ChannelAttention",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 541,
"end_line": 578,
"start_col": 0,
"end_col": 50,
"parent_name": null,
"docstring": "Channel-attention module for feature recalibration.\n\nApplies attention weights to channels based on global average pooling.\n\nAttributes:\n pool (nn.AdaptiveAvgPool2d): Global average pooling.\n fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution.\n act (nn.Sigmoid): Sigmoid activation for attention weights.\n\nReferences:\n https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_ChannelAttention_0d43909d"
},
{
"content": "class SpatialAttention(nn.Module):\n \"\"\"\n Spatial-attention module for feature recalibration.\n\n Applies attention weights to spatial dimensions based on channel statistics.\n\n Attributes:\n cv1 (nn.Conv2d): Convolution layer for spatial attention.\n act (nn.Sigmoid): Sigmoid activation for attention weights.\n \"\"\"\n\n def __init__(self, kernel_size=7):\n \"\"\"\n Initialize Spatial-attention module.\n\n Args:\n kernel_size (int): Size of the convolutional kernel (3 or 7).\n \"\"\"\n super().__init__()\n assert kernel_size in {3, 7}, \"kernel size must be 3 or 7\"\n padding = 3 if kernel_size == 7 else 1\n self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)\n self.act = nn.Sigmoid()\n\n def forward(self, x):\n \"\"\"\n Apply spatial attention to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Spatial-attended output tensor.\n \"\"\"\n return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))",
"chunk_type": "class",
"name": "SpatialAttention",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 581,
"end_line": 615,
"start_col": 0,
"end_col": 119,
"parent_name": null,
"docstring": "Spatial-attention module for feature recalibration.\n\nApplies attention weights to spatial dimensions based on channel statistics.\n\nAttributes:\n cv1 (nn.Conv2d): Convolution layer for spatial attention.\n act (nn.Sigmoid): Sigmoid activation for attention weights.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_SpatialAttention_57c216fb"
},
{
"content": "class CBAM(nn.Module):\n \"\"\"\n Convolutional Block Attention Module.\n\n Combines channel and spatial attention mechanisms for comprehensive feature refinement.\n\n Attributes:\n channel_attention (ChannelAttention): Channel attention module.\n spatial_attention (SpatialAttention): Spatial attention module.\n \"\"\"\n\n def __init__(self, c1, kernel_size=7):\n \"\"\"\n Initialize CBAM with given parameters.\n\n Args:\n c1 (int): Number of input channels.\n kernel_size (int): Size of the convolutional kernel for spatial attention.\n \"\"\"\n super().__init__()\n self.channel_attention = ChannelAttention(c1)\n self.spatial_attention = SpatialAttention(kernel_size)\n\n def forward(self, x):\n \"\"\"\n Apply channel and spatial attention sequentially to input tensor.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Attended output tensor.\n \"\"\"\n return self.spatial_attention(self.channel_attention(x))",
"chunk_type": "class",
"name": "CBAM",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 618,
"end_line": 651,
"start_col": 0,
"end_col": 64,
"parent_name": null,
"docstring": "Convolutional Block Attention Module.\n\nCombines channel and spatial attention mechanisms for comprehensive feature refinement.\n\nAttributes:\n channel_attention (ChannelAttention): Channel attention module.\n spatial_attention (SpatialAttention): Spatial attention module.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_CBAM_174201a4"
},
{
"content": "class Concat(nn.Module):\n \"\"\"\n Concatenate a list of tensors along specified dimension.\n\n Attributes:\n d (int): Dimension along which to concatenate tensors.\n \"\"\"\n\n def __init__(self, dimension=1):\n \"\"\"\n Initialize Concat module.\n\n Args:\n dimension (int): Dimension along which to concatenate tensors.\n \"\"\"\n super().__init__()\n self.d = dimension\n\n def forward(self, x: List[torch.Tensor]):\n \"\"\"\n Concatenate input tensors along specified dimension.\n\n Args:\n x (List[torch.Tensor]): List of input tensors.\n\n Returns:\n (torch.Tensor): Concatenated tensor.\n \"\"\"\n return torch.cat(x, self.d)",
"chunk_type": "class",
"name": "Concat",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 654,
"end_line": 682,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": "Concatenate a list of tensors along specified dimension.\n\nAttributes:\n d (int): Dimension along which to concatenate tensors.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_Concat_89f2ccf2"
},
{
"content": "class Index(nn.Module):\n \"\"\"\n Returns a particular index of the input.\n\n Attributes:\n index (int): Index to select from input.\n \"\"\"\n\n def __init__(self, index=0):\n \"\"\"\n Initialize Index module.\n\n Args:\n index (int): Index to select from input.\n \"\"\"\n super().__init__()\n self.index = index\n\n def forward(self, x: List[torch.Tensor]):\n \"\"\"\n Select and return a particular index from input.\n\n Args:\n x (List[torch.Tensor]): List of input tensors.\n\n Returns:\n (torch.Tensor): Selected tensor.\n \"\"\"\n return x[self.index]",
"chunk_type": "class",
"name": "Index",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\conv.py",
"start_line": 685,
"end_line": 713,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Returns a particular index of the input.\n\nAttributes:\n index (int): Index to select from input.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"numpy",
"torch",
"torch.nn",
"nn.Module"
],
"chunk_id": "class_Index_72f4bfca"
},
{
"content": "import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_db5e079d"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_3d4ed46a"
},
{
"content": "from typing import List, Optional, Tuple, Union",
"chunk_type": "import",
"name": "List, Optional, Tuple, Union",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional, Tuple, Union_7be0b002"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_2a9f0f53"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_51b0a839"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_a2c8732e"
},
{
"content": "from torch.nn.init import constant_, xavier_uniform_",
"chunk_type": "import",
"name": "constant_, xavier_uniform_",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_constant_, xavier_uniform__dd452c3d"
},
{
"content": "from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors",
"chunk_type": "import",
"name": "TORCH_1_10, dist2bbox, dist2rbox, make_anchors",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 80,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_TORCH_1_10, dist2bbox, dist2rbox, make_anchors_8e915633"
},
{
"content": "from ultralytics.utils.torch_utils import fuse_conv_and_bn, smart_inference_mode",
"chunk_type": "import",
"name": "fuse_conv_and_bn, smart_inference_mode",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 14,
"end_line": 14,
"start_col": 0,
"end_col": 80,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_fuse_conv_and_bn, smart_inference_mode_5e4099c9"
},
{
"content": "from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN",
"chunk_type": "import",
"name": "DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 16,
"end_line": 16,
"start_col": 0,
"end_col": 93,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN_7e717db0"
},
{
"content": "from .conv import Conv, DWConv",
"chunk_type": "import",
"name": "Conv, DWConv",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 17,
"end_line": 17,
"start_col": 0,
"end_col": 30,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Conv, DWConv_163d563d"
},
{
"content": "from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer",
"chunk_type": "import",
"name": "MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 18,
"end_line": 18,
"start_col": 0,
"end_col": 93,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer_e8a853e2"
},
{
"content": "from .utils import bias_init_with_prob, linear_init",
"chunk_type": "import",
"name": "bias_init_with_prob, linear_init",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 19,
"end_line": 19,
"start_col": 0,
"end_col": 51,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_bias_init_with_prob, linear_init_f82d045a"
},
{
"content": "__all__ = \"Detect\", \"Segment\", \"Pose\", \"Classify\", \"OBB\", \"RTDETRDecoder\", \"v10Detect\", \"YOLOEDetect\", \"YOLOESegment\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 21,
"end_line": 21,
"start_col": 0,
"end_col": 117,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___04011789"
},
{
"content": "class Detect(nn.Module):\n \"\"\"\n YOLO Detect head for object detection models.\n\n This class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.\n It supports both training and inference modes, with optional end-to-end detection capabilities.\n\n Attributes:\n dynamic (bool): Force grid reconstruction.\n export (bool): Export mode flag.\n format (str): Export format.\n end2end (bool): End-to-end detection mode.\n max_det (int): Maximum detections per image.\n shape (tuple): Input shape.\n anchors (torch.Tensor): Anchor points.\n strides (torch.Tensor): Feature map strides.\n legacy (bool): Backward compatibility for v3/v5/v8/v9 models.\n xyxy (bool): Output format, xyxy or xywh.\n nc (int): Number of classes.\n nl (int): Number of detection layers.\n reg_max (int): DFL channels.\n no (int): Number of outputs per anchor.\n stride (torch.Tensor): Strides computed during build.\n cv2 (nn.ModuleList): Convolution layers for box regression.\n cv3 (nn.ModuleList): Convolution layers for classification.\n dfl (nn.Module): Distribution Focal Loss layer.\n one2one_cv2 (nn.ModuleList): One-to-one convolution layers for box regression.\n one2one_cv3 (nn.ModuleList): One-to-one convolution layers for classification.\n\n Methods:\n forward: Perform forward pass and return predictions.\n forward_end2end: Perform forward pass for end-to-end detection.\n bias_init: Initialize detection head biases.\n decode_bboxes: Decode bounding boxes from predictions.\n postprocess: Post-process model predictions.\n\n Examples:\n Create a detection head for 80 classes\n >>> detect = Detect(nc=80, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = detect(x)\n \"\"\"\n\n dynamic = False # force grid reconstruction\n export = False # export mode\n format = None # export format\n end2end = False # end2end\n max_det = 300 # max_det\n shape = None\n anchors = torch.empty(0) # init\n strides = torch.empty(0) # init\n legacy = False # backward compatibility for v3/v5/v8/v9 models\n xyxy = False # xyxy or xywh output\n\n def __init__(self, nc: int = 80, ch: Tuple = ()):\n \"\"\"\n Initialize the YOLO detection layer with specified number of classes and channels.\n\n Args:\n nc (int): Number of classes.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__()\n self.nc = nc # number of classes\n self.nl = len(ch) # number of detection layers\n self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)\n self.no = nc + self.reg_max * 4 # number of outputs per anchor\n self.stride = torch.zeros(self.nl) # strides computed during build\n c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels\n self.cv2 = nn.ModuleList(\n nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch\n )\n self.cv3 = (\n nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)\n if self.legacy\n else nn.ModuleList(\n nn.Sequential(\n nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),\n nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),\n nn.Conv2d(c3, self.nc, 1),\n )\n for x in ch\n )\n )\n self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()\n\n if self.end2end:\n self.one2one_cv2 = copy.deepcopy(self.cv2)\n self.one2one_cv3 = copy.deepcopy(self.cv3)\n\n def forward(self, x: List[torch.Tensor]) -> Union[List[torch.Tensor], Tuple]:\n \"\"\"Concatenate and return predicted bounding boxes and class probabilities.\"\"\"\n if self.end2end:\n return self.forward_end2end(x)\n\n for i in range(self.nl):\n x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)\n if self.training: # Training path\n return x\n y = self._inference(x)\n return y if self.export else (y, x)\n\n def forward_end2end(self, x: List[torch.Tensor]) -> Union[dict, Tuple]:\n \"\"\"\n Perform forward pass of the v10Detect module.\n\n Args:\n x (List[torch.Tensor]): Input feature maps from different levels.\n\n Returns:\n outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs.\n Inference mode returns processed detections or tuple with detections and raw outputs.\n \"\"\"\n x_detach = [xi.detach() for xi in x]\n one2one = [\n torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)\n ]\n for i in range(self.nl):\n x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)\n if self.training: # Training path\n return {\"one2many\": x, \"one2one\": one2one}\n\n y = self._inference(one2one)\n y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)\n return y if self.export else (y, {\"one2many\": x, \"one2one\": one2one})\n\n def _inference(self, x: List[torch.Tensor]) -> torch.Tensor:\n \"\"\"\n Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.\n\n Args:\n x (List[torch.Tensor]): List of feature maps from different detection layers.\n\n Returns:\n (torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.\n \"\"\"\n # Inference path\n shape = x[0].shape # BCHW\n x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)\n if self.format != \"imx\" and (self.dynamic or self.shape != shape):\n self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))\n self.shape = shape\n\n if self.export and self.format in {\"saved_model\", \"pb\", \"tflite\", \"edgetpu\", \"tfjs\"}: # avoid TF FlexSplitV ops\n box = x_cat[:, : self.reg_max * 4]\n cls = x_cat[:, self.reg_max * 4 :]\n else:\n box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)\n\n if self.export and self.format in {\"tflite\", \"edgetpu\"}:\n # Precompute normalization factor to increase numerical stability\n # See https://github.com/ultralytics/ultralytics/issues/7371\n grid_h = shape[2]\n grid_w = shape[3]\n grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)\n norm = self.strides / (self.stride[0] * grid_size)\n dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])\n else:\n dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides\n if self.export and self.format == \"imx\":\n return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)\n return torch.cat((dbox, cls.sigmoid()), 1)\n\n def bias_init(self):\n \"\"\"Initialize Detect() biases, WARNING: requires stride availability.\"\"\"\n m = self # self.model[-1] # Detect() module\n # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1\n # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency\n for a, b, s in zip(m.cv2, m.cv3, m.stride): # from\n a[-1].bias.data[:] = 1.0 # box\n b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)\n if self.end2end:\n for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from\n a[-1].bias.data[:] = 1.0 # box\n b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)\n\n def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:\n \"\"\"Decode bounding boxes from predictions.\"\"\"\n return dist2bbox(bboxes, anchors, xywh=xywh and not (self.end2end or self.xyxy), dim=1)\n\n @staticmethod\n def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:\n \"\"\"\n Post-process YOLO model predictions.\n\n Args:\n preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension\n format [x, y, w, h, class_probs].\n max_det (int): Maximum detections per image.\n nc (int, optional): Number of classes.\n\n Returns:\n (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last\n dimension format [x, y, w, h, max_class_prob, class_index].\n \"\"\"\n batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)\n boxes, scores = preds.split([4, nc], dim=-1)\n index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)\n boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))\n scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))\n scores, index = scores.flatten(1).topk(min(max_det, anchors))\n i = torch.arange(batch_size)[..., None] # batch indices\n return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)",
"chunk_type": "class",
"name": "Detect",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 24,
"end_line": 226,
"start_col": 0,
"end_col": 109,
"parent_name": null,
"docstring": "YOLO Detect head for object detection models.\n\nThis class implements the detection head used in YOLO models for predicting bounding boxes and class probabilities.\nIt supports both training and inference modes, with optional end-to-end detection capabilities.\n\nAttributes:\n dynamic (bool): Force grid reconstruction.\n export (bool): Export mode flag.\n format (str): Export format.\n end2end (bool): End-to-end detection mode.\n max_det (int): Maximum detections per image.\n shape (tuple): Input shape.\n anchors (torch.Tensor): Anchor points.\n strides (torch.Tensor): Feature map strides.\n legacy (bool): Backward compatibility for v3/v5/v8/v9 models.\n xyxy (bool): Output format, xyxy or xywh.\n nc (int): Number of classes.\n nl (int): Number of detection layers.\n reg_max (int): DFL channels.\n no (int): Number of outputs per anchor.\n stride (torch.Tensor): Strides computed during build.\n cv2 (nn.ModuleList): Convolution layers for box regression.\n cv3 (nn.ModuleList): Convolution layers for classification.\n dfl (nn.Module): Distribution Focal Loss layer.\n one2one_cv2 (nn.ModuleList): One-to-one convolution layers for box regression.\n one2one_cv3 (nn.ModuleList): One-to-one convolution layers for classification.\n\nMethods:\n forward: Perform forward pass and return predictions.\n forward_end2end: Perform forward pass for end-to-end detection.\n bias_init: Initialize detection head biases.\n decode_bboxes: Decode bounding boxes from predictions.\n postprocess: Post-process model predictions.\n\nExamples:\n Create a detection head for 80 classes\n >>> detect = Detect(nc=80, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = detect(x)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"nn.Module"
],
"chunk_id": "class_Detect_72afa236"
},
{
"content": "class Segment(Detect):\n \"\"\"\n YOLO Segment head for segmentation models.\n\n This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.\n\n Attributes:\n nm (int): Number of masks.\n npr (int): Number of protos.\n proto (Proto): Prototype generation module.\n cv4 (nn.ModuleList): Convolution layers for mask coefficients.\n\n Methods:\n forward: Return model outputs and mask coefficients.\n\n Examples:\n Create a segmentation head\n >>> segment = Segment(nc=80, nm=32, npr=256, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = segment(x)\n \"\"\"\n\n def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: Tuple = ()):\n \"\"\"\n Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.\n\n Args:\n nc (int): Number of classes.\n nm (int): Number of masks.\n npr (int): Number of protos.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n self.nm = nm # number of masks\n self.npr = npr # number of protos\n self.proto = Proto(ch[0], self.npr, self.nm) # protos\n\n c4 = max(ch[0] // 4, self.nm)\n self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)\n\n def forward(self, x: List[torch.Tensor]) -> Union[Tuple, List[torch.Tensor]]:\n \"\"\"Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.\"\"\"\n p = self.proto(x[0]) # mask protos\n bs = p.shape[0] # batch size\n\n mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients\n x = Detect.forward(self, x)\n if self.training:\n return x, mc, p\n return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))",
"chunk_type": "class",
"name": "Segment",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 229,
"end_line": 278,
"start_col": 0,
"end_col": 103,
"parent_name": null,
"docstring": "YOLO Segment head for segmentation models.\n\nThis class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.\n\nAttributes:\n nm (int): Number of masks.\n npr (int): Number of protos.\n proto (Proto): Prototype generation module.\n cv4 (nn.ModuleList): Convolution layers for mask coefficients.\n\nMethods:\n forward: Return model outputs and mask coefficients.\n\nExamples:\n Create a segmentation head\n >>> segment = Segment(nc=80, nm=32, npr=256, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = segment(x)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"Detect"
],
"chunk_id": "class_Segment_477a5588"
},
{
"content": "class OBB(Detect):\n \"\"\"\n YOLO OBB detection head for detection with rotation models.\n\n This class extends the Detect head to include oriented bounding box prediction with rotation angles.\n\n Attributes:\n ne (int): Number of extra parameters.\n cv4 (nn.ModuleList): Convolution layers for angle prediction.\n angle (torch.Tensor): Predicted rotation angles.\n\n Methods:\n forward: Concatenate and return predicted bounding boxes and class probabilities.\n decode_bboxes: Decode rotated bounding boxes.\n\n Examples:\n Create an OBB detection head\n >>> obb = OBB(nc=80, ne=1, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = obb(x)\n \"\"\"\n\n def __init__(self, nc: int = 80, ne: int = 1, ch: Tuple = ()):\n \"\"\"\n Initialize OBB with number of classes `nc` and layer channels `ch`.\n\n Args:\n nc (int): Number of classes.\n ne (int): Number of extra parameters.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n self.ne = ne # number of extra parameters\n\n c4 = max(ch[0] // 4, self.ne)\n self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)\n\n def forward(self, x: List[torch.Tensor]) -> Union[torch.Tensor, Tuple]:\n \"\"\"Concatenate and return predicted bounding boxes and class probabilities.\"\"\"\n bs = x[0].shape[0] # batch size\n angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits\n # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.\n angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]\n # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]\n if not self.training:\n self.angle = angle\n x = Detect.forward(self, x)\n if self.training:\n return x, angle\n return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))\n\n def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:\n \"\"\"Decode rotated bounding boxes.\"\"\"\n return dist2rbox(bboxes, self.angle, anchors, dim=1)",
"chunk_type": "class",
"name": "OBB",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 281,
"end_line": 334,
"start_col": 0,
"end_col": 60,
"parent_name": null,
"docstring": "YOLO OBB detection head for detection with rotation models.\n\nThis class extends the Detect head to include oriented bounding box prediction with rotation angles.\n\nAttributes:\n ne (int): Number of extra parameters.\n cv4 (nn.ModuleList): Convolution layers for angle prediction.\n angle (torch.Tensor): Predicted rotation angles.\n\nMethods:\n forward: Concatenate and return predicted bounding boxes and class probabilities.\n decode_bboxes: Decode rotated bounding boxes.\n\nExamples:\n Create an OBB detection head\n >>> obb = OBB(nc=80, ne=1, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = obb(x)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"Detect"
],
"chunk_id": "class_OBB_85f3b3c3"
},
{
"content": "class Pose(Detect):\n \"\"\"\n YOLO Pose head for keypoints models.\n\n This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.\n\n Attributes:\n kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).\n nk (int): Total number of keypoint values.\n cv4 (nn.ModuleList): Convolution layers for keypoint prediction.\n\n Methods:\n forward: Perform forward pass through YOLO model and return predictions.\n kpts_decode: Decode keypoints from predictions.\n\n Examples:\n Create a pose detection head\n >>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = pose(x)\n \"\"\"\n\n def __init__(self, nc: int = 80, kpt_shape: Tuple = (17, 3), ch: Tuple = ()):\n \"\"\"\n Initialize YOLO network with default parameters and Convolutional Layers.\n\n Args:\n nc (int): Number of classes.\n kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)\n self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total\n\n c4 = max(ch[0] // 4, self.nk)\n self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)\n\n def forward(self, x: List[torch.Tensor]) -> Union[torch.Tensor, Tuple]:\n \"\"\"Perform forward pass through YOLO model and return predictions.\"\"\"\n bs = x[0].shape[0] # batch size\n kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)\n x = Detect.forward(self, x)\n if self.training:\n return x, kpt\n pred_kpt = self.kpts_decode(bs, kpt)\n if self.export and self.format == \"imx\":\n return (*x, pred_kpt.permute(0, 2, 1))\n return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))\n\n def kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:\n \"\"\"Decode keypoints from predictions.\"\"\"\n ndim = self.kpt_shape[1]\n if self.export:\n if self.format in {\n \"tflite\",\n \"edgetpu\",\n }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug\n # Precompute normalization factor to increase numerical stability\n y = kpts.view(bs, *self.kpt_shape, -1)\n grid_h, grid_w = self.shape[2], self.shape[3]\n grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)\n norm = self.strides / (self.stride[0] * grid_size)\n a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm\n else:\n # NCNN fix\n y = kpts.view(bs, *self.kpt_shape, -1)\n a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides\n if ndim == 3:\n a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)\n return a.view(bs, self.nk, -1)\n else:\n y = kpts.clone()\n if ndim == 3:\n y[:, 2::ndim] = y[:, 2::ndim].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)\n y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides\n y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides\n return y",
"chunk_type": "class",
"name": "Pose",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 337,
"end_line": 414,
"start_col": 0,
"end_col": 20,
"parent_name": null,
"docstring": "YOLO Pose head for keypoints models.\n\nThis class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.\n\nAttributes:\n kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).\n nk (int): Total number of keypoint values.\n cv4 (nn.ModuleList): Convolution layers for keypoint prediction.\n\nMethods:\n forward: Perform forward pass through YOLO model and return predictions.\n kpts_decode: Decode keypoints from predictions.\n\nExamples:\n Create a pose detection head\n >>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = pose(x)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"Detect"
],
"chunk_id": "class_Pose_852b4e1a"
},
{
"content": "class Classify(nn.Module):\n \"\"\"\n YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).\n\n This class implements a classification head that transforms feature maps into class predictions.\n\n Attributes:\n export (bool): Export mode flag.\n conv (Conv): Convolutional layer for feature transformation.\n pool (nn.AdaptiveAvgPool2d): Global average pooling layer.\n drop (nn.Dropout): Dropout layer for regularization.\n linear (nn.Linear): Linear layer for final classification.\n\n Methods:\n forward: Perform forward pass of the YOLO model on input image data.\n\n Examples:\n Create a classification head\n >>> classify = Classify(c1=1024, c2=1000)\n >>> x = torch.randn(1, 1024, 20, 20)\n >>> output = classify(x)\n \"\"\"\n\n export = False # export mode\n\n def __init__(self, c1: int, c2: int, k: int = 1, s: int = 1, p: Optional[int] = None, g: int = 1):\n \"\"\"\n Initialize YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.\n\n Args:\n c1 (int): Number of input channels.\n c2 (int): Number of output classes.\n k (int, optional): Kernel size.\n s (int, optional): Stride.\n p (int, optional): Padding.\n g (int, optional): Groups.\n \"\"\"\n super().__init__()\n c_ = 1280 # efficientnet_b0 size\n self.conv = Conv(c1, c_, k, s, p, g)\n self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)\n self.drop = nn.Dropout(p=0.0, inplace=True)\n self.linear = nn.Linear(c_, c2) # to x(b,c2)\n\n def forward(self, x: Union[List[torch.Tensor], torch.Tensor]) -> Union[torch.Tensor, Tuple]:\n \"\"\"Perform forward pass of the YOLO model on input image data.\"\"\"\n if isinstance(x, list):\n x = torch.cat(x, 1)\n x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))\n if self.training:\n return x\n y = x.softmax(1) # get final output\n return y if self.export else (y, x)",
"chunk_type": "class",
"name": "Classify",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 417,
"end_line": 469,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": "YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).\n\nThis class implements a classification head that transforms feature maps into class predictions.\n\nAttributes:\n export (bool): Export mode flag.\n conv (Conv): Convolutional layer for feature transformation.\n pool (nn.AdaptiveAvgPool2d): Global average pooling layer.\n drop (nn.Dropout): Dropout layer for regularization.\n linear (nn.Linear): Linear layer for final classification.\n\nMethods:\n forward: Perform forward pass of the YOLO model on input image data.\n\nExamples:\n Create a classification head\n >>> classify = Classify(c1=1024, c2=1000)\n >>> x = torch.randn(1, 1024, 20, 20)\n >>> output = classify(x)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"nn.Module"
],
"chunk_id": "class_Classify_2697ab2b"
},
{
"content": "class WorldDetect(Detect):\n \"\"\"\n Head for integrating YOLO detection models with semantic understanding from text embeddings.\n\n This class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding\n in object detection tasks.\n\n Attributes:\n cv3 (nn.ModuleList): Convolution layers for embedding features.\n cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.\n\n Methods:\n forward: Concatenate and return predicted bounding boxes and class probabilities.\n bias_init: Initialize detection head biases.\n\n Examples:\n Create a WorldDetect head\n >>> world_detect = WorldDetect(nc=80, embed=512, with_bn=False, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> text = torch.randn(1, 80, 512)\n >>> outputs = world_detect(x, text)\n \"\"\"\n\n def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: Tuple = ()):\n \"\"\"\n Initialize YOLO detection layer with nc classes and layer channels ch.\n\n Args:\n nc (int): Number of classes.\n embed (int): Embedding dimension.\n with_bn (bool): Whether to use batch normalization in contrastive head.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n c3 = max(ch[0], min(self.nc, 100))\n self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)\n self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)\n\n def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[List[torch.Tensor], Tuple]:\n \"\"\"Concatenate and return predicted bounding boxes and class probabilities.\"\"\"\n for i in range(self.nl):\n x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)\n if self.training:\n return x\n self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts\n y = self._inference(x)\n return y if self.export else (y, x)\n\n def bias_init(self):\n \"\"\"Initialize Detect() biases, WARNING: requires stride availability.\"\"\"\n m = self # self.model[-1] # Detect() module\n # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1\n # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency\n for a, b, s in zip(m.cv2, m.cv3, m.stride): # from\n a[-1].bias.data[:] = 1.0 # box",
"chunk_type": "class",
"name": "WorldDetect",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 472,
"end_line": 526,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": "Head for integrating YOLO detection models with semantic understanding from text embeddings.\n\nThis class extends the standard Detect head to incorporate text embeddings for enhanced semantic understanding\nin object detection tasks.\n\nAttributes:\n cv3 (nn.ModuleList): Convolution layers for embedding features.\n cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.\n\nMethods:\n forward: Concatenate and return predicted bounding boxes and class probabilities.\n bias_init: Initialize detection head biases.\n\nExamples:\n Create a WorldDetect head\n >>> world_detect = WorldDetect(nc=80, embed=512, with_bn=False, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> text = torch.randn(1, 80, 512)\n >>> outputs = world_detect(x, text)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"Detect"
],
"chunk_id": "class_WorldDetect_4c307729"
},
{
"content": "class LRPCHead(nn.Module):\n \"\"\"\n Lightweight Region Proposal and Classification Head for efficient object detection.\n\n This head combines region proposal filtering with classification to enable efficient detection with\n dynamic vocabulary support.\n\n Attributes:\n vocab (nn.Module): Vocabulary/classification layer.\n pf (nn.Module): Proposal filter module.\n loc (nn.Module): Localization module.\n enabled (bool): Whether the head is enabled.\n\n Methods:\n conv2linear: Convert a 1x1 convolutional layer to a linear layer.\n forward: Process classification and localization features to generate detection proposals.\n\n Examples:\n Create an LRPC head\n >>> vocab = nn.Conv2d(256, 80, 1)\n >>> pf = nn.Conv2d(256, 1, 1)\n >>> loc = nn.Conv2d(256, 4, 1)\n >>> head = LRPCHead(vocab, pf, loc, enabled=True)\n \"\"\"\n\n def __init__(self, vocab: nn.Module, pf: nn.Module, loc: nn.Module, enabled: bool = True):\n \"\"\"\n Initialize LRPCHead with vocabulary, proposal filter, and localization components.\n\n Args:\n vocab (nn.Module): Vocabulary/classification module.\n pf (nn.Module): Proposal filter module.\n loc (nn.Module): Localization module.\n enabled (bool): Whether to enable the head functionality.\n \"\"\"\n super().__init__()\n self.vocab = self.conv2linear(vocab) if enabled else vocab\n self.pf = pf\n self.loc = loc\n self.enabled = enabled\n\n def conv2linear(self, conv: nn.Conv2d) -> nn.Linear:\n \"\"\"Convert a 1x1 convolutional layer to a linear layer.\"\"\"\n assert isinstance(conv, nn.Conv2d) and conv.kernel_size == (1, 1)\n linear = nn.Linear(conv.in_channels, conv.out_channels)\n linear.weight.data = conv.weight.view(conv.out_channels, -1).data\n linear.bias.data = conv.bias.data\n return linear\n\n def forward(self, cls_feat: torch.Tensor, loc_feat: torch.Tensor, conf: float) -> Tuple[Tuple, torch.Tensor]:\n \"\"\"Process classification and localization features to generate detection proposals.\"\"\"\n if self.enabled:\n pf_score = self.pf(cls_feat)[0, 0].flatten(0)\n mask = pf_score.sigmoid() > conf\n cls_feat = cls_feat.flatten(2).transpose(-1, -2)\n cls_feat = self.vocab(cls_feat[:, mask] if conf else cls_feat * mask.unsqueeze(-1).int())\n return (self.loc(loc_feat), cls_feat.transpose(-1, -2)), mask\n else:\n cls_feat = self.vocab(cls_feat)\n loc_feat = self.loc(loc_feat)\n return (loc_feat, cls_feat.flatten(2)), torch.ones(\n cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool\n )",
"chunk_type": "class",
"name": "LRPCHead",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 530,
"end_line": 592,
"start_col": 0,
"end_col": 13,
"parent_name": null,
"docstring": "Lightweight Region Proposal and Classification Head for efficient object detection.\n\nThis head combines region proposal filtering with classification to enable efficient detection with\ndynamic vocabulary support.\n\nAttributes:\n vocab (nn.Module): Vocabulary/classification layer.\n pf (nn.Module): Proposal filter module.\n loc (nn.Module): Localization module.\n enabled (bool): Whether the head is enabled.\n\nMethods:\n conv2linear: Convert a 1x1 convolutional layer to a linear layer.\n forward: Process classification and localization features to generate detection proposals.\n\nExamples:\n Create an LRPC head\n >>> vocab = nn.Conv2d(256, 80, 1)\n >>> pf = nn.Conv2d(256, 1, 1)\n >>> loc = nn.Conv2d(256, 4, 1)\n >>> head = LRPCHead(vocab, pf, loc, enabled=True)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"nn.Module"
],
"chunk_id": "class_LRPCHead_4219ae44"
},
{
"content": "class YOLOEDetect(Detect):\n \"\"\"\n Head for integrating YOLO detection models with semantic understanding from text embeddings.\n\n This class extends the standard Detect head to support text-guided detection with enhanced semantic understanding\n through text embeddings and visual prompt embeddings.\n\n Attributes:\n is_fused (bool): Whether the model is fused for inference.\n cv3 (nn.ModuleList): Convolution layers for embedding features.\n cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.\n reprta (Residual): Residual block for text prompt embeddings.\n savpe (SAVPE): Spatial-aware visual prompt embeddings module.\n embed (int): Embedding dimension.\n\n Methods:\n fuse: Fuse text features with model weights for efficient inference.\n get_tpe: Get text prompt embeddings with normalization.\n get_vpe: Get visual prompt embeddings with spatial awareness.\n forward_lrpc: Process features with fused text embeddings for prompt-free model.\n forward: Process features with class prompt embeddings to generate detections.\n bias_init: Initialize biases for detection heads.\n\n Examples:\n Create a YOLOEDetect head\n >>> yoloe_detect = YOLOEDetect(nc=80, embed=512, with_bn=True, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> cls_pe = torch.randn(1, 80, 512)\n >>> outputs = yoloe_detect(x, cls_pe)\n \"\"\"\n\n is_fused = False\n\n def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: Tuple = ()):\n \"\"\"\n Initialize YOLO detection layer with nc classes and layer channels ch.\n\n Args:\n nc (int): Number of classes.\n embed (int): Embedding dimension.\n with_bn (bool): Whether to use batch normalization in contrastive head.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n c3 = max(ch[0], min(self.nc, 100))\n assert c3 <= embed\n assert with_bn is True\n self.cv3 = (\n nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)\n if self.legacy\n else nn.ModuleList(\n nn.Sequential(\n nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),\n nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),\n nn.Conv2d(c3, embed, 1),\n )\n for x in ch\n )\n )\n\n self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)\n\n self.reprta = Residual(SwiGLUFFN(embed, embed))\n self.savpe = SAVPE(ch, c3, embed)\n self.embed = embed\n\n @smart_inference_mode()\n def fuse(self, txt_feats: torch.Tensor):\n \"\"\"Fuse text features with model weights for efficient inference.\"\"\"\n if self.is_fused:\n return\n\n assert not self.training\n txt_feats = txt_feats.to(torch.float32).squeeze(0)\n for cls_head, bn_head in zip(self.cv3, self.cv4):\n assert isinstance(cls_head, nn.Sequential)\n assert isinstance(bn_head, BNContrastiveHead)\n conv = cls_head[-1]\n assert isinstance(conv, nn.Conv2d)\n logit_scale = bn_head.logit_scale\n bias = bn_head.bias\n norm = bn_head.norm\n\n t = txt_feats * logit_scale.exp()\n conv: nn.Conv2d = fuse_conv_and_bn(conv, norm)\n\n w = conv.weight.data.squeeze(-1).squeeze(-1)\n b = conv.bias.data\n\n w = t @ w\n b1 = (t @ b.reshape(-1).unsqueeze(-1)).squeeze(-1)\n b2 = torch.ones_like(b1) * bias\n\n conv = (\n nn.Conv2d(\n conv.in_channels,\n w.shape[0],\n kernel_size=1,\n )\n .requires_grad_(False)\n .to(conv.weight.device)\n )\n\n conv.weight.data.copy_(w.unsqueeze(-1).unsqueeze(-1))\n conv.bias.data.copy_(b1 + b2)\n cls_head[-1] = conv\n\n bn_head.fuse()\n\n del self.reprta\n self.reprta = nn.Identity()\n self.is_fused = True\n\n def get_tpe(self, tpe: Optional[torch.Tensor]) -> Optional[torch.Tensor]:\n \"\"\"Get text prompt embeddings with normalization.\"\"\"\n return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)\n\n def get_vpe(self, x: List[torch.Tensor], vpe: torch.Tensor) -> torch.Tensor:\n \"\"\"Get visual prompt embeddings with spatial awareness.\"\"\"\n if vpe.shape[1] == 0: # no visual prompt embeddings\n return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)\n if vpe.ndim == 4: # (B, N, H, W)\n vpe = self.savpe(x, vpe)\n assert vpe.ndim == 3 # (B, N, D)\n return vpe\n\n def forward_lrpc(self, x: List[torch.Tensor], return_mask: bool = False) -> Union[torch.Tensor, Tuple]:\n \"\"\"Process features with fused text embeddings to generate detections for prompt-free model.\"\"\"\n masks = []\n assert self.is_fused, \"Prompt-free inference requires model to be fused!\"\n for i in range(self.nl):\n cls_feat = self.cv3[i](x[i])\n loc_feat = self.cv2[i](x[i])\n assert isinstance(self.lrpc[i], LRPCHead)\n x[i], mask = self.lrpc[i](\n cls_feat, loc_feat, 0 if self.export and not self.dynamic else getattr(self, \"conf\", 0.001)\n )\n masks.append(mask)\n shape = x[0][0].shape\n if self.dynamic or self.shape != shape:\n self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors([b[0] for b in x], self.stride, 0.5))\n self.shape = shape\n box = torch.cat([xi[0].view(shape[0], self.reg_max * 4, -1) for xi in x], 2)\n cls = torch.cat([xi[1] for xi in x], 2)\n\n if self.export and self.format in {\"tflite\", \"edgetpu\"}:\n # Precompute normalization factor to increase numerical stability\n # See https://github.com/ultralytics/ultralytics/issues/7371\n grid_h = shape[2]\n grid_w = shape[3]\n grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)\n norm = self.strides / (self.stride[0] * grid_size)\n dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])\n else:\n dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides\n\n mask = torch.cat(masks)\n y = torch.cat((dbox if self.export and not self.dynamic else dbox[..., mask], cls.sigmoid()), 1)\n\n if return_mask:\n return (y, mask) if self.export else ((y, x), mask)\n else:\n return y if self.export else (y, x)\n\n def forward(\n self, x: List[torch.Tensor], cls_pe: torch.Tensor, return_mask: bool = False\n ) -> Union[torch.Tensor, Tuple]:\n \"\"\"Process features with class prompt embeddings to generate detections.\"\"\"\n if hasattr(self, \"lrpc\"): # for prompt-free inference\n return self.forward_lrpc(x, return_mask)\n for i in range(self.nl):\n x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)\n if self.training:\n return x\n self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts\n y = self._inference(x)\n return y if self.export else (y, x)\n\n def bias_init(self):\n \"\"\"Initialize biases for detection heads.\"\"\"\n m = self # self.model[-1] # Detect() module\n # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1\n # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency\n for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from\n a[-1].bias.data[:] = 1.0 # box\n # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)\n b[-1].bias.data[:] = 0.0\n c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)",
"chunk_type": "class",
"name": "YOLOEDetect",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 595,
"end_line": 782,
"start_col": 0,
"end_col": 64,
"parent_name": null,
"docstring": "Head for integrating YOLO detection models with semantic understanding from text embeddings.\n\nThis class extends the standard Detect head to support text-guided detection with enhanced semantic understanding\nthrough text embeddings and visual prompt embeddings.\n\nAttributes:\n is_fused (bool): Whether the model is fused for inference.\n cv3 (nn.ModuleList): Convolution layers for embedding features.\n cv4 (nn.ModuleList): Contrastive head layers for text-vision alignment.\n reprta (Residual): Residual block for text prompt embeddings.\n savpe (SAVPE): Spatial-aware visual prompt embeddings module.\n embed (int): Embedding dimension.\n\nMethods:\n fuse: Fuse text features with model weights for efficient inference.\n get_tpe: Get text prompt embeddings with normalization.\n get_vpe: Get visual prompt embeddings with spatial awareness.\n forward_lrpc: Process features with fused text embeddings for prompt-free model.\n forward: Process features with class prompt embeddings to generate detections.\n bias_init: Initialize biases for detection heads.\n\nExamples:\n Create a YOLOEDetect head\n >>> yoloe_detect = YOLOEDetect(nc=80, embed=512, with_bn=True, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> cls_pe = torch.randn(1, 80, 512)\n >>> outputs = yoloe_detect(x, cls_pe)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"Detect"
],
"chunk_id": "class_YOLOEDetect_9945901b"
},
{
"content": "class YOLOESegment(YOLOEDetect):\n \"\"\"\n YOLO segmentation head with text embedding capabilities.\n\n This class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks\n with text-guided semantic understanding.\n\n Attributes:\n nm (int): Number of masks.\n npr (int): Number of protos.\n proto (Proto): Prototype generation module.\n cv5 (nn.ModuleList): Convolution layers for mask coefficients.\n\n Methods:\n forward: Return model outputs and mask coefficients.\n\n Examples:\n Create a YOLOESegment head\n >>> yoloe_segment = YOLOESegment(nc=80, nm=32, npr=256, embed=512, with_bn=True, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> text = torch.randn(1, 80, 512)\n >>> outputs = yoloe_segment(x, text)\n \"\"\"\n\n def __init__(\n self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: Tuple = ()\n ):\n \"\"\"\n Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.\n\n Args:\n nc (int): Number of classes.\n nm (int): Number of masks.\n npr (int): Number of protos.\n embed (int): Embedding dimension.\n with_bn (bool): Whether to use batch normalization in contrastive head.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, embed, with_bn, ch)\n self.nm = nm\n self.npr = npr\n self.proto = Proto(ch[0], self.npr, self.nm)\n\n c5 = max(ch[0] // 4, self.nm)\n self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)\n\n def forward(self, x: List[torch.Tensor], text: torch.Tensor) -> Union[Tuple, torch.Tensor]:\n \"\"\"Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.\"\"\"\n p = self.proto(x[0]) # mask protos\n bs = p.shape[0] # batch size\n\n mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients\n has_lrpc = hasattr(self, \"lrpc\")\n\n if not has_lrpc:\n x = YOLOEDetect.forward(self, x, text)\n else:\n x, mask = YOLOEDetect.forward(self, x, text, return_mask=True)\n\n if self.training:\n return x, mc, p\n\n if has_lrpc:\n mc = (mc * mask.int()) if self.export and not self.dynamic else mc[..., mask]\n\n return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))",
"chunk_type": "class",
"name": "YOLOESegment",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 785,
"end_line": 850,
"start_col": 0,
"end_col": 103,
"parent_name": null,
"docstring": "YOLO segmentation head with text embedding capabilities.\n\nThis class extends YOLOEDetect to include mask prediction capabilities for instance segmentation tasks\nwith text-guided semantic understanding.\n\nAttributes:\n nm (int): Number of masks.\n npr (int): Number of protos.\n proto (Proto): Prototype generation module.\n cv5 (nn.ModuleList): Convolution layers for mask coefficients.\n\nMethods:\n forward: Return model outputs and mask coefficients.\n\nExamples:\n Create a YOLOESegment head\n >>> yoloe_segment = YOLOESegment(nc=80, nm=32, npr=256, embed=512, with_bn=True, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> text = torch.randn(1, 80, 512)\n >>> outputs = yoloe_segment(x, text)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"YOLOEDetect"
],
"chunk_id": "class_YOLOESegment_8116a2f2"
},
{
"content": "class RTDETRDecoder(nn.Module):\n \"\"\"\n Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.\n\n This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes\n and class labels for objects in an image. It integrates features from multiple layers and runs through a series of\n Transformer decoder layers to output the final predictions.\n\n Attributes:\n export (bool): Export mode flag.\n hidden_dim (int): Dimension of hidden layers.\n nhead (int): Number of heads in multi-head attention.\n nl (int): Number of feature levels.\n nc (int): Number of classes.\n num_queries (int): Number of query points.\n num_decoder_layers (int): Number of decoder layers.\n input_proj (nn.ModuleList): Input projection layers for backbone features.\n decoder (DeformableTransformerDecoder): Transformer decoder module.\n denoising_class_embed (nn.Embedding): Class embeddings for denoising.\n num_denoising (int): Number of denoising queries.\n label_noise_ratio (float): Label noise ratio for training.\n box_noise_scale (float): Box noise scale for training.\n learnt_init_query (bool): Whether to learn initial query embeddings.\n tgt_embed (nn.Embedding): Target embeddings for queries.\n query_pos_head (MLP): Query position head.\n enc_output (nn.Sequential): Encoder output layers.\n enc_score_head (nn.Linear): Encoder score prediction head.\n enc_bbox_head (MLP): Encoder bbox prediction head.\n dec_score_head (nn.ModuleList): Decoder score prediction heads.\n dec_bbox_head (nn.ModuleList): Decoder bbox prediction heads.\n\n Methods:\n forward: Run forward pass and return bounding box and classification scores.\n\n Examples:\n Create an RTDETRDecoder\n >>> decoder = RTDETRDecoder(nc=80, ch=(512, 1024, 2048), hd=256, nq=300)\n >>> x = [torch.randn(1, 512, 64, 64), torch.randn(1, 1024, 32, 32), torch.randn(1, 2048, 16, 16)]\n >>> outputs = decoder(x)\n \"\"\"\n\n export = False # export mode\n\n def __init__(\n self,\n nc: int = 80,\n ch: Tuple = (512, 1024, 2048),\n hd: int = 256, # hidden dim\n nq: int = 300, # num queries\n ndp: int = 4, # num decoder points\n nh: int = 8, # num head\n ndl: int = 6, # num decoder layers\n d_ffn: int = 1024, # dim of feedforward\n dropout: float = 0.0,\n act: nn.Module = nn.ReLU(),\n eval_idx: int = -1,\n # Training args\n nd: int = 100, # num denoising\n label_noise_ratio: float = 0.5,\n box_noise_scale: float = 1.0,\n learnt_init_query: bool = False,\n ):\n \"\"\"\n Initialize the RTDETRDecoder module with the given parameters.\n\n Args:\n nc (int): Number of classes.\n ch (tuple): Channels in the backbone feature maps.\n hd (int): Dimension of hidden layers.\n nq (int): Number of query points.\n ndp (int): Number of decoder points.\n nh (int): Number of heads in multi-head attention.\n ndl (int): Number of decoder layers.\n d_ffn (int): Dimension of the feed-forward networks.\n dropout (float): Dropout rate.\n act (nn.Module): Activation function.\n eval_idx (int): Evaluation index.\n nd (int): Number of denoising.\n label_noise_ratio (float): Label noise ratio.\n box_noise_scale (float): Box noise scale.\n learnt_init_query (bool): Whether to learn initial query embeddings.\n \"\"\"\n super().__init__()\n self.hidden_dim = hd\n self.nhead = nh\n self.nl = len(ch) # num level\n self.nc = nc\n self.num_queries = nq\n self.num_decoder_layers = ndl\n\n # Backbone feature projection\n self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)\n # NOTE: simplified version but it's not consistent with .pt weights.\n # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)\n\n # Transformer module\n decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)\n self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)\n\n # Denoising part\n self.denoising_class_embed = nn.Embedding(nc, hd)\n self.num_denoising = nd\n self.label_noise_ratio = label_noise_ratio\n self.box_noise_scale = box_noise_scale\n\n # Decoder embedding\n self.learnt_init_query = learnt_init_query\n if learnt_init_query:\n self.tgt_embed = nn.Embedding(nq, hd)\n self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)\n\n # Encoder head\n self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))\n self.enc_score_head = nn.Linear(hd, nc)\n self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)\n\n # Decoder head\n self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])\n self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])\n\n self._reset_parameters()\n\n def forward(self, x: List[torch.Tensor], batch: Optional[dict] = None) -> Union[Tuple, torch.Tensor]:\n \"\"\"\n Run the forward pass of the module, returning bounding box and classification scores for the input.\n\n Args:\n x (List[torch.Tensor]): List of feature maps from the backbone.\n batch (dict, optional): Batch information for training.\n\n Returns:\n outputs (tuple | torch.Tensor): During training, returns a tuple of bounding boxes, scores, and other\n metadata. During inference, returns a tensor of shape (bs, 300, 4+nc) containing bounding boxes and\n class scores.\n \"\"\"\n from ultralytics.models.utils.ops import get_cdn_group\n\n # Input projection and embedding\n feats, shapes = self._get_encoder_input(x)\n\n # Prepare denoising training\n dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(\n batch,\n self.nc,\n self.num_queries,\n self.denoising_class_embed.weight,\n self.num_denoising,\n self.label_noise_ratio,\n self.box_noise_scale,\n self.training,\n )\n\n embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)\n\n # Decoder\n dec_bboxes, dec_scores = self.decoder(\n embed,\n refer_bbox,\n feats,\n shapes,\n self.dec_bbox_head,\n self.dec_score_head,\n self.query_pos_head,\n attn_mask=attn_mask,\n )\n x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta\n if self.training:\n return x\n # (bs, 300, 4+nc)\n y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)\n return y if self.export else (y, x)\n\n def _generate_anchors(\n self,\n shapes: List[List[int]],\n grid_size: float = 0.05,\n dtype: torch.dtype = torch.float32,\n device: str = \"cpu\",\n eps: float = 1e-2,\n ) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n Generate anchor bounding boxes for given shapes with specific grid size and validate them.\n\n Args:\n shapes (list): List of feature map shapes.\n grid_size (float, optional): Base size of grid cells.\n dtype (torch.dtype, optional): Data type for tensors.\n device (str, optional): Device to create tensors on.\n eps (float, optional): Small value for numerical stability.\n\n Returns:\n anchors (torch.Tensor): Generated anchor boxes.\n valid_mask (torch.Tensor): Valid mask for anchors.\n \"\"\"\n anchors = []\n for i, (h, w) in enumerate(shapes):\n sy = torch.arange(end=h, dtype=dtype, device=device)\n sx = torch.arange(end=w, dtype=dtype, device=device)\n grid_y, grid_x = torch.meshgrid(sy, sx, indexing=\"ij\") if TORCH_1_10 else torch.meshgrid(sy, sx)\n grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)\n\n valid_WH = torch.tensor([w, h], dtype=dtype, device=device)\n grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)\n wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)\n anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)\n\n anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)\n valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1\n anchors = torch.log(anchors / (1 - anchors))\n anchors = anchors.masked_fill(~valid_mask, float(\"inf\"))\n return anchors, valid_mask\n\n def _get_encoder_input(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor, List[List[int]]]:\n \"\"\"\n Process and return encoder inputs by getting projection features from input and concatenating them.\n\n Args:\n x (List[torch.Tensor]): List of feature maps from the backbone.\n\n Returns:\n feats (torch.Tensor): Processed features.\n shapes (list): List of feature map shapes.\n \"\"\"\n # Get projection features\n x = [self.input_proj[i](feat) for i, feat in enumerate(x)]\n # Get encoder inputs\n feats = []\n shapes = []\n for feat in x:\n h, w = feat.shape[2:]\n # [b, c, h, w] -> [b, h*w, c]\n feats.append(feat.flatten(2).permute(0, 2, 1))\n # [nl, 2]\n shapes.append([h, w])\n\n # [b, h*w, c]\n feats = torch.cat(feats, 1)\n return feats, shapes\n\n def _get_decoder_input(\n self,\n feats: torch.Tensor,\n shapes: List[List[int]],\n dn_embed: Optional[torch.Tensor] = None,\n dn_bbox: Optional[torch.Tensor] = None,\n ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n \"\"\"\n Generate and prepare the input required for the decoder from the provided features and shapes.\n\n Args:\n feats (torch.Tensor): Processed features from encoder.\n shapes (list): List of feature map shapes.\n dn_embed (torch.Tensor, optional): Denoising embeddings.\n dn_bbox (torch.Tensor, optional): Denoising bounding boxes.\n\n Returns:\n embeddings (torch.Tensor): Query embeddings for decoder.\n refer_bbox (torch.Tensor): Reference bounding boxes.\n enc_bboxes (torch.Tensor): Encoded bounding boxes.\n enc_scores (torch.Tensor): Encoded scores.\n \"\"\"\n bs = feats.shape[0]\n # Prepare input for decoder\n anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)\n features = self.enc_output(valid_mask * feats) # bs, h*w, 256\n\n enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)\n\n # Query selection\n # (bs, num_queries)\n topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)\n # (bs, num_queries)\n batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)\n\n # (bs, num_queries, 256)\n top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)\n # (bs, num_queries, 4)\n top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)\n\n # Dynamic anchors + static content\n refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors\n\n enc_bboxes = refer_bbox.sigmoid()\n if dn_bbox is not None:\n refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)\n enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)\n\n embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features\n if self.training:\n refer_bbox = refer_bbox.detach()\n if not self.learnt_init_query:\n embeddings = embeddings.detach()\n if dn_embed is not None:\n embeddings = torch.cat([dn_embed, embeddings], 1)\n\n return embeddings, refer_bbox, enc_bboxes, enc_scores\n\n def _reset_parameters(self):\n \"\"\"Initialize or reset the parameters of the model's various components with predefined weights and biases.\"\"\"\n # Class and bbox head init\n bias_cls = bias_init_with_prob(0.01) / 80 * self.nc\n # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.\n # linear_init(self.enc_score_head)\n constant_(self.enc_score_head.bias, bias_cls)\n constant_(self.enc_bbox_head.layers[-1].weight, 0.0)\n constant_(self.enc_bbox_head.layers[-1].bias, 0.0)\n for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):\n # linear_init(cls_)\n constant_(cls_.bias, bias_cls)\n constant_(reg_.layers[-1].weight, 0.0)\n constant_(reg_.layers[-1].bias, 0.0)\n\n linear_init(self.enc_output[0])\n xavier_uniform_(self.enc_output[0].weight)\n if self.learnt_init_query:\n xavier_uniform_(self.tgt_embed.weight)\n xavier_uniform_(self.query_pos_head.layers[0].weight)\n xavier_uniform_(self.query_pos_head.layers[1].weight)\n for layer in self.input_proj:\n xavier_uniform_(layer[0].weight)",
"chunk_type": "class",
"name": "RTDETRDecoder",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 853,
"end_line": 1172,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.\n\nThis decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes\nand class labels for objects in an image. It integrates features from multiple layers and runs through a series of\nTransformer decoder layers to output the final predictions.\n\nAttributes:\n export (bool): Export mode flag.\n hidden_dim (int): Dimension of hidden layers.\n nhead (int): Number of heads in multi-head attention.\n nl (int): Number of feature levels.\n nc (int): Number of classes.\n num_queries (int): Number of query points.\n num_decoder_layers (int): Number of decoder layers.\n input_proj (nn.ModuleList): Input projection layers for backbone features.\n decoder (DeformableTransformerDecoder): Transformer decoder module.\n denoising_class_embed (nn.Embedding): Class embeddings for denoising.\n num_denoising (int): Number of denoising queries.\n label_noise_ratio (float): Label noise ratio for training.\n box_noise_scale (float): Box noise scale for training.\n learnt_init_query (bool): Whether to learn initial query embeddings.\n tgt_embed (nn.Embedding): Target embeddings for queries.\n query_pos_head (MLP): Query position head.\n enc_output (nn.Sequential): Encoder output layers.\n enc_score_head (nn.Linear): Encoder score prediction head.\n enc_bbox_head (MLP): Encoder bbox prediction head.\n dec_score_head (nn.ModuleList): Decoder score prediction heads.\n dec_bbox_head (nn.ModuleList): Decoder bbox prediction heads.\n\nMethods:\n forward: Run forward pass and return bounding box and classification scores.\n\nExamples:\n Create an RTDETRDecoder\n >>> decoder = RTDETRDecoder(nc=80, ch=(512, 1024, 2048), hd=256, nq=300)\n >>> x = [torch.randn(1, 512, 64, 64), torch.randn(1, 1024, 32, 32), torch.randn(1, 2048, 16, 16)]\n >>> outputs = decoder(x)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"nn.Module"
],
"chunk_id": "class_RTDETRDecoder_7ce5df20"
},
{
"content": "class v10Detect(Detect):\n \"\"\"\n v10 Detection head from https://arxiv.org/pdf/2405.14458.\n\n This class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions\n for improved efficiency and performance.\n\n Attributes:\n end2end (bool): End-to-end detection mode.\n max_det (int): Maximum number of detections.\n cv3 (nn.ModuleList): Light classification head layers.\n one2one_cv3 (nn.ModuleList): One-to-one classification head layers.\n\n Methods:\n __init__: Initialize the v10Detect object with specified number of classes and input channels.\n forward: Perform forward pass of the v10Detect module.\n bias_init: Initialize biases of the Detect module.\n fuse: Remove the one2many head for inference optimization.\n\n Examples:\n Create a v10Detect head\n >>> v10_detect = v10Detect(nc=80, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = v10_detect(x)\n \"\"\"\n\n end2end = True\n\n def __init__(self, nc: int = 80, ch: Tuple = ()):\n \"\"\"\n Initialize the v10Detect object with the specified number of classes and input channels.\n\n Args:\n nc (int): Number of classes.\n ch (tuple): Tuple of channel sizes from backbone feature maps.\n \"\"\"\n super().__init__(nc, ch)\n c3 = max(ch[0], min(self.nc, 100)) # channels\n # Light cls head\n self.cv3 = nn.ModuleList(\n nn.Sequential(\n nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)),\n nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)),\n nn.Conv2d(c3, self.nc, 1),\n )\n for x in ch\n )\n self.one2one_cv3 = copy.deepcopy(self.cv3)\n\n def fuse(self):\n \"\"\"Remove the one2many head for inference optimization.\"\"\"\n self.cv2 = self.cv3 = nn.ModuleList([nn.Identity()] * self.nl)",
"chunk_type": "class",
"name": "v10Detect",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\head.py",
"start_line": 1175,
"end_line": 1226,
"start_col": 0,
"end_col": 70,
"parent_name": null,
"docstring": "v10 Detection head from https://arxiv.org/pdf/2405.14458.\n\nThis class implements the YOLOv10 detection head with dual-assignment training and consistent dual predictions\nfor improved efficiency and performance.\n\nAttributes:\n end2end (bool): End-to-end detection mode.\n max_det (int): Maximum number of detections.\n cv3 (nn.ModuleList): Light classification head layers.\n one2one_cv3 (nn.ModuleList): One-to-one classification head layers.\n\nMethods:\n __init__: Initialize the v10Detect object with specified number of classes and input channels.\n forward: Perform forward pass of the v10Detect module.\n bias_init: Initialize biases of the Detect module.\n fuse: Remove the one2many head for inference optimization.\n\nExamples:\n Create a v10Detect head\n >>> v10_detect = v10Detect(nc=80, ch=(256, 512, 1024))\n >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]\n >>> outputs = v10_detect(x)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"math",
"typing.List",
"typing.Optional",
"typing.Tuple",
"typing.Union",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"ultralytics.utils.tal.TORCH_1_10",
"ultralytics.utils.tal.dist2bbox",
"ultralytics.utils.tal.dist2rbox",
"ultralytics.utils.tal.make_anchors",
"ultralytics.utils.torch_utils.fuse_conv_and_bn",
"ultralytics.utils.torch_utils.smart_inference_mode",
"block.DFL",
"block.SAVPE",
"block.BNContrastiveHead",
"block.ContrastiveHead",
"block.Proto",
"block.Residual",
"block.SwiGLUFFN",
"conv.Conv",
"conv.DWConv",
"transformer.MLP",
"transformer.DeformableTransformerDecoder",
"transformer.DeformableTransformerDecoderLayer",
"utils.bias_init_with_prob",
"utils.linear_init",
"ultralytics.models.utils.ops.get_cdn_group",
"Detect"
],
"chunk_id": "class_v10Detect_6bfc5509"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_b656c83b"
},
{
"content": "from typing import List, Optional",
"chunk_type": "import",
"name": "List, Optional",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional_bceca4d3"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_2c3e4f94"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_15d9547d"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_6458c49b"
},
{
"content": "from torch.nn.init import constant_, xavier_uniform_",
"chunk_type": "import",
"name": "constant_, xavier_uniform_",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_constant_, xavier_uniform__cf516011"
},
{
"content": "from .conv import Conv",
"chunk_type": "import",
"name": "Conv",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Conv_28de563c"
},
{
"content": "from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch",
"chunk_type": "import",
"name": "_get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 13,
"end_line": 13,
"start_col": 0,
"end_col": 84,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import__get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch_60657b4f"
},
{
"content": "__all__ = (\n \"TransformerEncoderLayer\",\n \"TransformerLayer\",\n \"TransformerBlock\",\n \"MLPBlock\",\n \"LayerNorm2d\",\n \"AIFI\",\n \"DeformableTransformerDecoder\",\n \"DeformableTransformerDecoderLayer\",\n \"MSDeformAttn\",\n \"MLP\",\n)",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 15,
"end_line": 26,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___515bdabe"
},
{
"content": "class TransformerEncoderLayer(nn.Module):\n \"\"\"\n A single layer of the transformer encoder.\n\n This class implements a standard transformer encoder layer with multi-head attention and feedforward network,\n supporting both pre-normalization and post-normalization configurations.\n\n Attributes:\n ma (nn.MultiheadAttention): Multi-head attention module.\n fc1 (nn.Linear): First linear layer in the feedforward network.\n fc2 (nn.Linear): Second linear layer in the feedforward network.\n norm1 (nn.LayerNorm): Layer normalization after attention.\n norm2 (nn.LayerNorm): Layer normalization after feedforward network.\n dropout (nn.Dropout): Dropout layer for the feedforward network.\n dropout1 (nn.Dropout): Dropout layer after attention.\n dropout2 (nn.Dropout): Dropout layer after feedforward network.\n act (nn.Module): Activation function.\n normalize_before (bool): Whether to apply normalization before attention and feedforward.\n \"\"\"\n\n def __init__(\n self,\n c1: int,\n cm: int = 2048,\n num_heads: int = 8,\n dropout: float = 0.0,\n act: nn.Module = nn.GELU(),\n normalize_before: bool = False,\n ):\n \"\"\"\n Initialize the TransformerEncoderLayer with specified parameters.\n\n Args:\n c1 (int): Input dimension.\n cm (int): Hidden dimension in the feedforward network.\n num_heads (int): Number of attention heads.\n dropout (float): Dropout probability.\n act (nn.Module): Activation function.\n normalize_before (bool): Whether to apply normalization before attention and feedforward.\n \"\"\"\n super().__init__()\n from ...utils.torch_utils import TORCH_1_9\n\n if not TORCH_1_9:\n raise ModuleNotFoundError(\n \"TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).\"\n )\n self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)\n # Implementation of Feedforward model\n self.fc1 = nn.Linear(c1, cm)\n self.fc2 = nn.Linear(cm, c1)\n\n self.norm1 = nn.LayerNorm(c1)\n self.norm2 = nn.LayerNorm(c1)\n self.dropout = nn.Dropout(dropout)\n self.dropout1 = nn.Dropout(dropout)\n self.dropout2 = nn.Dropout(dropout)\n\n self.act = act\n self.normalize_before = normalize_before\n\n @staticmethod\n def with_pos_embed(tensor: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor:\n \"\"\"Add position embeddings to the tensor if provided.\"\"\"\n return tensor if pos is None else tensor + pos\n\n def forward_post(\n self,\n src: torch.Tensor,\n src_mask: Optional[torch.Tensor] = None,\n src_key_padding_mask: Optional[torch.Tensor] = None,\n pos: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Perform forward pass with post-normalization.\n\n Args:\n src (torch.Tensor): Input tensor.\n src_mask (torch.Tensor, optional): Mask for the src sequence.\n src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.\n pos (torch.Tensor, optional): Positional encoding.\n\n Returns:\n (torch.Tensor): Output tensor after attention and feedforward.\n \"\"\"\n q = k = self.with_pos_embed(src, pos)\n src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]\n src = src + self.dropout1(src2)\n src = self.norm1(src)\n src2 = self.fc2(self.dropout(self.act(self.fc1(src))))\n src = src + self.dropout2(src2)\n return self.norm2(src)\n\n def forward_pre(\n self,\n src: torch.Tensor,\n src_mask: Optional[torch.Tensor] = None,\n src_key_padding_mask: Optional[torch.Tensor] = None,\n pos: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Perform forward pass with pre-normalization.\n\n Args:\n src (torch.Tensor): Input tensor.\n src_mask (torch.Tensor, optional): Mask for the src sequence.\n src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.\n pos (torch.Tensor, optional): Positional encoding.\n\n Returns:\n (torch.Tensor): Output tensor after attention and feedforward.\n \"\"\"\n src2 = self.norm1(src)\n q = k = self.with_pos_embed(src2, pos)\n src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]\n src = src + self.dropout1(src2)\n src2 = self.norm2(src)\n src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))\n return src + self.dropout2(src2)\n\n def forward(\n self,\n src: torch.Tensor,\n src_mask: Optional[torch.Tensor] = None,\n src_key_padding_mask: Optional[torch.Tensor] = None,\n pos: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Forward propagate the input through the encoder module.\n\n Args:\n src (torch.Tensor): Input tensor.\n src_mask (torch.Tensor, optional): Mask for the src sequence.\n src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.\n pos (torch.Tensor, optional): Positional encoding.\n\n Returns:\n (torch.Tensor): Output tensor after transformer encoder layer.\n \"\"\"\n if self.normalize_before:\n return self.forward_pre(src, src_mask, src_key_padding_mask, pos)\n return self.forward_post(src, src_mask, src_key_padding_mask, pos)",
"chunk_type": "class",
"name": "TransformerEncoderLayer",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 29,
"end_line": 170,
"start_col": 0,
"end_col": 74,
"parent_name": null,
"docstring": "A single layer of the transformer encoder.\n\nThis class implements a standard transformer encoder layer with multi-head attention and feedforward network,\nsupporting both pre-normalization and post-normalization configurations.\n\nAttributes:\n ma (nn.MultiheadAttention): Multi-head attention module.\n fc1 (nn.Linear): First linear layer in the feedforward network.\n fc2 (nn.Linear): Second linear layer in the feedforward network.\n norm1 (nn.LayerNorm): Layer normalization after attention.\n norm2 (nn.LayerNorm): Layer normalization after feedforward network.\n dropout (nn.Dropout): Dropout layer for the feedforward network.\n dropout1 (nn.Dropout): Dropout layer after attention.\n dropout2 (nn.Dropout): Dropout layer after feedforward network.\n act (nn.Module): Activation function.\n normalize_before (bool): Whether to apply normalization before attention and feedforward.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_TransformerEncoderLayer_836709d4"
},
{
"content": "class AIFI(TransformerEncoderLayer):\n \"\"\"\n AIFI transformer layer for 2D data with positional embeddings.\n\n This class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional\n embeddings and handling the spatial dimensions appropriately.\n \"\"\"\n\n def __init__(\n self,\n c1: int,\n cm: int = 2048,\n num_heads: int = 8,\n dropout: float = 0,\n act: nn.Module = nn.GELU(),\n normalize_before: bool = False,\n ):\n \"\"\"\n Initialize the AIFI instance with specified parameters.\n\n Args:\n c1 (int): Input dimension.\n cm (int): Hidden dimension in the feedforward network.\n num_heads (int): Number of attention heads.\n dropout (float): Dropout probability.\n act (nn.Module): Activation function.\n normalize_before (bool): Whether to apply normalization before attention and feedforward.\n \"\"\"\n super().__init__(c1, cm, num_heads, dropout, act, normalize_before)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass for the AIFI transformer layer.\n\n Args:\n x (torch.Tensor): Input tensor with shape [B, C, H, W].\n\n Returns:\n (torch.Tensor): Output tensor with shape [B, C, H, W].\n \"\"\"\n c, h, w = x.shape[1:]\n pos_embed = self.build_2d_sincos_position_embedding(w, h, c)\n # Flatten [B, C, H, W] to [B, HxW, C]\n x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))\n return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()\n\n @staticmethod\n def build_2d_sincos_position_embedding(\n w: int, h: int, embed_dim: int = 256, temperature: float = 10000.0\n ) -> torch.Tensor:\n \"\"\"\n Build 2D sine-cosine position embedding.\n\n Args:\n w (int): Width of the feature map.\n h (int): Height of the feature map.\n embed_dim (int): Embedding dimension.\n temperature (float): Temperature for the sine/cosine functions.\n\n Returns:\n (torch.Tensor): Position embedding with shape [1, embed_dim, h*w].\n \"\"\"\n assert embed_dim % 4 == 0, \"Embed dimension must be divisible by 4 for 2D sin-cos position embedding\"\n grid_w = torch.arange(w, dtype=torch.float32)\n grid_h = torch.arange(h, dtype=torch.float32)\n grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing=\"ij\")\n pos_dim = embed_dim // 4\n omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n omega = 1.0 / (temperature**omega)\n\n out_w = grid_w.flatten()[..., None] @ omega[None]\n out_h = grid_h.flatten()[..., None] @ omega[None]\n\n return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]",
"chunk_type": "class",
"name": "AIFI",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 173,
"end_line": 246,
"start_col": 0,
"end_col": 107,
"parent_name": null,
"docstring": "AIFI transformer layer for 2D data with positional embeddings.\n\nThis class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional\nembeddings and handling the spatial dimensions appropriately.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"TransformerEncoderLayer"
],
"chunk_id": "class_AIFI_8f6b96f3"
},
{
"content": "class TransformerLayer(nn.Module):\n \"\"\"Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance).\"\"\"\n\n def __init__(self, c: int, num_heads: int):\n \"\"\"\n Initialize a self-attention mechanism using linear transformations and multi-head attention.\n\n Args:\n c (int): Input and output channel dimension.\n num_heads (int): Number of attention heads.\n \"\"\"\n super().__init__()\n self.q = nn.Linear(c, c, bias=False)\n self.k = nn.Linear(c, c, bias=False)\n self.v = nn.Linear(c, c, bias=False)\n self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)\n self.fc1 = nn.Linear(c, c, bias=False)\n self.fc2 = nn.Linear(c, c, bias=False)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Apply a transformer block to the input x and return the output.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after transformer layer.\n \"\"\"\n x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x\n return self.fc2(self.fc1(x)) + x",
"chunk_type": "class",
"name": "TransformerLayer",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 249,
"end_line": 279,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance).",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_TransformerLayer_f04f01ea"
},
{
"content": "class TransformerBlock(nn.Module):\n \"\"\"\n Vision Transformer block based on https://arxiv.org/abs/2010.11929.\n\n This class implements a complete transformer block with optional convolution layer for channel adjustment,\n learnable position embedding, and multiple transformer layers.\n\n Attributes:\n conv (Conv, optional): Convolution layer if input and output channels differ.\n linear (nn.Linear): Learnable position embedding.\n tr (nn.Sequential): Sequential container of transformer layers.\n c2 (int): Output channel dimension.\n \"\"\"\n\n def __init__(self, c1: int, c2: int, num_heads: int, num_layers: int):\n \"\"\"\n Initialize a Transformer module with position embedding and specified number of heads and layers.\n\n Args:\n c1 (int): Input channel dimension.\n c2 (int): Output channel dimension.\n num_heads (int): Number of attention heads.\n num_layers (int): Number of transformer layers.\n \"\"\"\n super().__init__()\n self.conv = None\n if c1 != c2:\n self.conv = Conv(c1, c2)\n self.linear = nn.Linear(c2, c2) # learnable position embedding\n self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))\n self.c2 = c2\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward propagate the input through the transformer block.\n\n Args:\n x (torch.Tensor): Input tensor with shape [b, c1, w, h].\n\n Returns:\n (torch.Tensor): Output tensor with shape [b, c2, w, h].\n \"\"\"\n if self.conv is not None:\n x = self.conv(x)\n b, _, w, h = x.shape\n p = x.flatten(2).permute(2, 0, 1)\n return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)",
"chunk_type": "class",
"name": "TransformerBlock",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 282,
"end_line": 328,
"start_col": 0,
"end_col": 85,
"parent_name": null,
"docstring": "Vision Transformer block based on https://arxiv.org/abs/2010.11929.\n\nThis class implements a complete transformer block with optional convolution layer for channel adjustment,\nlearnable position embedding, and multiple transformer layers.\n\nAttributes:\n conv (Conv, optional): Convolution layer if input and output channels differ.\n linear (nn.Linear): Learnable position embedding.\n tr (nn.Sequential): Sequential container of transformer layers.\n c2 (int): Output channel dimension.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_TransformerBlock_def174e0"
},
{
"content": "class MLPBlock(nn.Module):\n \"\"\"A single block of a multi-layer perceptron.\"\"\"\n\n def __init__(self, embedding_dim: int, mlp_dim: int, act=nn.GELU):\n \"\"\"\n Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.\n\n Args:\n embedding_dim (int): Input and output dimension.\n mlp_dim (int): Hidden dimension.\n act (nn.Module): Activation function.\n \"\"\"\n super().__init__()\n self.lin1 = nn.Linear(embedding_dim, mlp_dim)\n self.lin2 = nn.Linear(mlp_dim, embedding_dim)\n self.act = act()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass for the MLPBlock.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after MLP block.\n \"\"\"\n return self.lin2(self.act(self.lin1(x)))",
"chunk_type": "class",
"name": "MLPBlock",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 331,
"end_line": 358,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": "A single block of a multi-layer perceptron.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_MLPBlock_cfb56b63"
},
{
"content": "class MLP(nn.Module):\n \"\"\"\n A simple multi-layer perceptron (also called FFN).\n\n This class implements a configurable MLP with multiple linear layers, activation functions, and optional\n sigmoid output activation.\n\n Attributes:\n num_layers (int): Number of layers in the MLP.\n layers (nn.ModuleList): List of linear layers.\n sigmoid (bool): Whether to apply sigmoid to the output.\n act (nn.Module): Activation function.\n \"\"\"\n\n def __init__(\n self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act=nn.ReLU, sigmoid: bool = False\n ):\n \"\"\"\n Initialize the MLP with specified input, hidden, output dimensions and number of layers.\n\n Args:\n input_dim (int): Input dimension.\n hidden_dim (int): Hidden dimension.\n output_dim (int): Output dimension.\n num_layers (int): Number of layers.\n act (nn.Module): Activation function.\n sigmoid (bool): Whether to apply sigmoid to the output.\n \"\"\"\n super().__init__()\n self.num_layers = num_layers\n h = [hidden_dim] * (num_layers - 1)\n self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))\n self.sigmoid = sigmoid\n self.act = act()\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass for the entire MLP.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after MLP.\n \"\"\"\n for i, layer in enumerate(self.layers):\n x = getattr(self, \"act\", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)\n return x.sigmoid() if getattr(self, \"sigmoid\", False) else x",
"chunk_type": "class",
"name": "MLP",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 361,
"end_line": 408,
"start_col": 0,
"end_col": 68,
"parent_name": null,
"docstring": "A simple multi-layer perceptron (also called FFN).\n\nThis class implements a configurable MLP with multiple linear layers, activation functions, and optional\nsigmoid output activation.\n\nAttributes:\n num_layers (int): Number of layers in the MLP.\n layers (nn.ModuleList): List of linear layers.\n sigmoid (bool): Whether to apply sigmoid to the output.\n act (nn.Module): Activation function.",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_MLP_7a99ec10"
},
{
"content": "class LayerNorm2d(nn.Module):\n \"\"\"\n 2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.\n\n This class implements layer normalization for 2D feature maps, normalizing across the channel dimension\n while preserving spatial dimensions.\n\n Attributes:\n weight (nn.Parameter): Learnable scale parameter.\n bias (nn.Parameter): Learnable bias parameter.\n eps (float): Small constant for numerical stability.\n\n References:\n https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py\n https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py\n \"\"\"\n\n def __init__(self, num_channels: int, eps: float = 1e-6):\n \"\"\"\n Initialize LayerNorm2d with the given parameters.\n\n Args:\n num_channels (int): Number of channels in the input.\n eps (float): Small constant for numerical stability.\n \"\"\"\n super().__init__()\n self.weight = nn.Parameter(torch.ones(num_channels))\n self.bias = nn.Parameter(torch.zeros(num_channels))\n self.eps = eps\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform forward pass for 2D layer normalization.\n\n Args:\n x (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Normalized output tensor.\n \"\"\"\n u = x.mean(1, keepdim=True)\n s = (x - u).pow(2).mean(1, keepdim=True)\n x = (x - u) / torch.sqrt(s + self.eps)\n return self.weight[:, None, None] * x + self.bias[:, None, None]",
"chunk_type": "class",
"name": "LayerNorm2d",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 411,
"end_line": 454,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": "2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.\n\nThis class implements layer normalization for 2D feature maps, normalizing across the channel dimension\nwhile preserving spatial dimensions.\n\nAttributes:\n weight (nn.Parameter): Learnable scale parameter.\n bias (nn.Parameter): Learnable bias parameter.\n eps (float): Small constant for numerical stability.\n\nReferences:\n https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py\n https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_LayerNorm2d_933f0d90"
},
{
"content": "class MSDeformAttn(nn.Module):\n \"\"\"\n Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.\n\n This module implements multiscale deformable attention that can attend to features at multiple scales\n with learnable sampling locations and attention weights.\n\n Attributes:\n im2col_step (int): Step size for im2col operations.\n d_model (int): Model dimension.\n n_levels (int): Number of feature levels.\n n_heads (int): Number of attention heads.\n n_points (int): Number of sampling points per attention head per feature level.\n sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.\n attention_weights (nn.Linear): Linear layer for generating attention weights.\n value_proj (nn.Linear): Linear layer for projecting values.\n output_proj (nn.Linear): Linear layer for projecting output.\n\n References:\n https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py\n \"\"\"\n\n def __init__(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4):\n \"\"\"\n Initialize MSDeformAttn with the given parameters.\n\n Args:\n d_model (int): Model dimension.\n n_levels (int): Number of feature levels.\n n_heads (int): Number of attention heads.\n n_points (int): Number of sampling points per attention head per feature level.\n \"\"\"\n super().__init__()\n if d_model % n_heads != 0:\n raise ValueError(f\"d_model must be divisible by n_heads, but got {d_model} and {n_heads}\")\n _d_per_head = d_model // n_heads\n # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation\n assert _d_per_head * n_heads == d_model, \"`d_model` must be divisible by `n_heads`\"\n\n self.im2col_step = 64\n\n self.d_model = d_model\n self.n_levels = n_levels\n self.n_heads = n_heads\n self.n_points = n_points\n\n self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)\n self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)\n self.value_proj = nn.Linear(d_model, d_model)\n self.output_proj = nn.Linear(d_model, d_model)\n\n self._reset_parameters()\n\n def _reset_parameters(self):\n \"\"\"Reset module parameters.\"\"\"\n constant_(self.sampling_offsets.weight.data, 0.0)\n thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)\n grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)\n grid_init = (\n (grid_init / grid_init.abs().max(-1, keepdim=True)[0])\n .view(self.n_heads, 1, 1, 2)\n .repeat(1, self.n_levels, self.n_points, 1)\n )\n for i in range(self.n_points):\n grid_init[:, :, i, :] *= i + 1\n with torch.no_grad():\n self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))\n constant_(self.attention_weights.weight.data, 0.0)\n constant_(self.attention_weights.bias.data, 0.0)\n xavier_uniform_(self.value_proj.weight.data)\n constant_(self.value_proj.bias.data, 0.0)\n xavier_uniform_(self.output_proj.weight.data)\n constant_(self.output_proj.bias.data, 0.0)\n\n def forward(\n self,\n query: torch.Tensor,\n refer_bbox: torch.Tensor,\n value: torch.Tensor,\n value_shapes: List,\n value_mask: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Perform forward pass for multiscale deformable attention.\n\n Args:\n query (torch.Tensor): Query tensor with shape [bs, query_length, C].\n refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, n_levels, 2],\n range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area.\n value (torch.Tensor): Value tensor with shape [bs, value_length, C].\n value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].\n value_mask (torch.Tensor, optional): Mask tensor with shape [bs, value_length], True for non-padding\n elements, False for padding elements.\n\n Returns:\n (torch.Tensor): Output tensor with shape [bs, Length_{query}, C].\n\n References:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py\n \"\"\"\n bs, len_q = query.shape[:2]\n len_v = value.shape[1]\n assert sum(s[0] * s[1] for s in value_shapes) == len_v\n\n value = self.value_proj(value)\n if value_mask is not None:\n value = value.masked_fill(value_mask[..., None], float(0))\n value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)\n sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)\n attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)\n attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)\n # N, Len_q, n_heads, n_levels, n_points, 2\n num_points = refer_bbox.shape[-1]\n if num_points == 2:\n offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)\n add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]\n sampling_locations = refer_bbox[:, :, None, :, None, :] + add\n elif num_points == 4:\n add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5\n sampling_locations = refer_bbox[:, :, None, :, None, :2] + add\n else:\n raise ValueError(f\"Last dim of reference_points must be 2 or 4, but got {num_points}.\")\n output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)\n return self.output_proj(output)",
"chunk_type": "class",
"name": "MSDeformAttn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 457,
"end_line": 580,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": "Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.\n\nThis module implements multiscale deformable attention that can attend to features at multiple scales\nwith learnable sampling locations and attention weights.\n\nAttributes:\n im2col_step (int): Step size for im2col operations.\n d_model (int): Model dimension.\n n_levels (int): Number of feature levels.\n n_heads (int): Number of attention heads.\n n_points (int): Number of sampling points per attention head per feature level.\n sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.\n attention_weights (nn.Linear): Linear layer for generating attention weights.\n value_proj (nn.Linear): Linear layer for projecting values.\n output_proj (nn.Linear): Linear layer for projecting output.\n\nReferences:\n https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_MSDeformAttn_b3c584f0"
},
{
"content": "class DeformableTransformerDecoderLayer(nn.Module):\n \"\"\"\n Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.\n\n This class implements a single decoder layer with self-attention, cross-attention using multiscale deformable\n attention, and a feedforward network.\n\n Attributes:\n self_attn (nn.MultiheadAttention): Self-attention module.\n dropout1 (nn.Dropout): Dropout after self-attention.\n norm1 (nn.LayerNorm): Layer normalization after self-attention.\n cross_attn (MSDeformAttn): Cross-attention module.\n dropout2 (nn.Dropout): Dropout after cross-attention.\n norm2 (nn.LayerNorm): Layer normalization after cross-attention.\n linear1 (nn.Linear): First linear layer in the feedforward network.\n act (nn.Module): Activation function.\n dropout3 (nn.Dropout): Dropout in the feedforward network.\n linear2 (nn.Linear): Second linear layer in the feedforward network.\n dropout4 (nn.Dropout): Dropout after the feedforward network.\n norm3 (nn.LayerNorm): Layer normalization after the feedforward network.\n\n References:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py\n https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py\n \"\"\"\n\n def __init__(\n self,\n d_model: int = 256,\n n_heads: int = 8,\n d_ffn: int = 1024,\n dropout: float = 0.0,\n act: nn.Module = nn.ReLU(),\n n_levels: int = 4,\n n_points: int = 4,\n ):\n \"\"\"\n Initialize the DeformableTransformerDecoderLayer with the given parameters.\n\n Args:\n d_model (int): Model dimension.\n n_heads (int): Number of attention heads.\n d_ffn (int): Dimension of the feedforward network.\n dropout (float): Dropout probability.\n act (nn.Module): Activation function.\n n_levels (int): Number of feature levels.\n n_points (int): Number of sampling points.\n \"\"\"\n super().__init__()\n\n # Self attention\n self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)\n self.dropout1 = nn.Dropout(dropout)\n self.norm1 = nn.LayerNorm(d_model)\n\n # Cross attention\n self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)\n self.dropout2 = nn.Dropout(dropout)\n self.norm2 = nn.LayerNorm(d_model)\n\n # FFN\n self.linear1 = nn.Linear(d_model, d_ffn)\n self.act = act\n self.dropout3 = nn.Dropout(dropout)\n self.linear2 = nn.Linear(d_ffn, d_model)\n self.dropout4 = nn.Dropout(dropout)\n self.norm3 = nn.LayerNorm(d_model)\n\n @staticmethod\n def with_pos_embed(tensor: torch.Tensor, pos: Optional[torch.Tensor]) -> torch.Tensor:\n \"\"\"Add positional embeddings to the input tensor, if provided.\"\"\"\n return tensor if pos is None else tensor + pos\n\n def forward_ffn(self, tgt: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Perform forward pass through the Feed-Forward Network part of the layer.\n\n Args:\n tgt (torch.Tensor): Input tensor.\n\n Returns:\n (torch.Tensor): Output tensor after FFN.\n \"\"\"\n tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))\n tgt = tgt + self.dropout4(tgt2)\n return self.norm3(tgt)\n\n def forward(\n self,\n embed: torch.Tensor,\n refer_bbox: torch.Tensor,\n feats: torch.Tensor,\n shapes: List,\n padding_mask: Optional[torch.Tensor] = None,\n attn_mask: Optional[torch.Tensor] = None,\n query_pos: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Perform the forward pass through the entire decoder layer.\n\n Args:\n embed (torch.Tensor): Input embeddings.\n refer_bbox (torch.Tensor): Reference bounding boxes.\n feats (torch.Tensor): Feature maps.\n shapes (list): Feature shapes.\n padding_mask (torch.Tensor, optional): Padding mask.\n attn_mask (torch.Tensor, optional): Attention mask.\n query_pos (torch.Tensor, optional): Query position embeddings.\n\n Returns:\n (torch.Tensor): Output tensor after decoder layer.\n \"\"\"\n # Self attention\n q = k = self.with_pos_embed(embed, query_pos)\n tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[\n 0\n ].transpose(0, 1)\n embed = embed + self.dropout1(tgt)\n embed = self.norm1(embed)\n\n # Cross attention\n tgt = self.cross_attn(\n self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask\n )\n embed = embed + self.dropout2(tgt)\n embed = self.norm2(embed)\n\n # FFN\n return self.forward_ffn(embed)",
"chunk_type": "class",
"name": "DeformableTransformerDecoderLayer",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 583,
"end_line": 711,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": "Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.\n\nThis class implements a single decoder layer with self-attention, cross-attention using multiscale deformable\nattention, and a feedforward network.\n\nAttributes:\n self_attn (nn.MultiheadAttention): Self-attention module.\n dropout1 (nn.Dropout): Dropout after self-attention.\n norm1 (nn.LayerNorm): Layer normalization after self-attention.\n cross_attn (MSDeformAttn): Cross-attention module.\n dropout2 (nn.Dropout): Dropout after cross-attention.\n norm2 (nn.LayerNorm): Layer normalization after cross-attention.\n linear1 (nn.Linear): First linear layer in the feedforward network.\n act (nn.Module): Activation function.\n dropout3 (nn.Dropout): Dropout in the feedforward network.\n linear2 (nn.Linear): Second linear layer in the feedforward network.\n dropout4 (nn.Dropout): Dropout after the feedforward network.\n norm3 (nn.LayerNorm): Layer normalization after the feedforward network.\n\nReferences:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py\n https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_DeformableTransformerDecoderLayer_be16c3bf"
},
{
"content": "class DeformableTransformerDecoder(nn.Module):\n \"\"\"\n Deformable Transformer Decoder based on PaddleDetection implementation.\n\n This class implements a complete deformable transformer decoder with multiple decoder layers and prediction\n heads for bounding box regression and classification.\n\n Attributes:\n layers (nn.ModuleList): List of decoder layers.\n num_layers (int): Number of decoder layers.\n hidden_dim (int): Hidden dimension.\n eval_idx (int): Index of the layer to use during evaluation.\n\n References:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py\n \"\"\"\n\n def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1):\n \"\"\"\n Initialize the DeformableTransformerDecoder with the given parameters.\n\n Args:\n hidden_dim (int): Hidden dimension.\n decoder_layer (nn.Module): Decoder layer module.\n num_layers (int): Number of decoder layers.\n eval_idx (int): Index of the layer to use during evaluation.\n \"\"\"\n super().__init__()\n self.layers = _get_clones(decoder_layer, num_layers)\n self.num_layers = num_layers\n self.hidden_dim = hidden_dim\n self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx\n\n def forward(\n self,\n embed: torch.Tensor, # decoder embeddings\n refer_bbox: torch.Tensor, # anchor\n feats: torch.Tensor, # image features\n shapes: List, # feature shapes\n bbox_head: nn.Module,\n score_head: nn.Module,\n pos_mlp: nn.Module,\n attn_mask: Optional[torch.Tensor] = None,\n padding_mask: Optional[torch.Tensor] = None,\n ):\n \"\"\"\n Perform the forward pass through the entire decoder.\n\n Args:\n embed (torch.Tensor): Decoder embeddings.\n refer_bbox (torch.Tensor): Reference bounding boxes.\n feats (torch.Tensor): Image features.\n shapes (list): Feature shapes.\n bbox_head (nn.Module): Bounding box prediction head.\n score_head (nn.Module): Score prediction head.\n pos_mlp (nn.Module): Position MLP.\n attn_mask (torch.Tensor, optional): Attention mask.\n padding_mask (torch.Tensor, optional): Padding mask.\n\n Returns:\n dec_bboxes (torch.Tensor): Decoded bounding boxes.\n dec_cls (torch.Tensor): Decoded classification scores.\n \"\"\"\n output = embed\n dec_bboxes = []\n dec_cls = []\n last_refined_bbox = None\n refer_bbox = refer_bbox.sigmoid()\n for i, layer in enumerate(self.layers):\n output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))\n\n bbox = bbox_head[i](output)\n refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))\n\n if self.training:\n dec_cls.append(score_head[i](output))\n if i == 0:\n dec_bboxes.append(refined_bbox)\n else:\n dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))\n elif i == self.eval_idx:\n dec_cls.append(score_head[i](output))\n dec_bboxes.append(refined_bbox)\n break\n\n last_refined_bbox = refined_bbox\n refer_bbox = refined_bbox.detach() if self.training else refined_bbox\n\n return torch.stack(dec_bboxes), torch.stack(dec_cls)",
"chunk_type": "class",
"name": "DeformableTransformerDecoder",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\transformer.py",
"start_line": 714,
"end_line": 802,
"start_col": 0,
"end_col": 60,
"parent_name": null,
"docstring": "Deformable Transformer Decoder based on PaddleDetection implementation.\n\nThis class implements a complete deformable transformer decoder with multiple decoder layers and prediction\nheads for bounding box regression and classification.\n\nAttributes:\n layers (nn.ModuleList): List of decoder layers.\n num_layers (int): Number of decoder layers.\n hidden_dim (int): Hidden dimension.\n eval_idx (int): Index of the layer to use during evaluation.\n\nReferences:\n https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"math",
"typing.List",
"typing.Optional",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.constant_",
"torch.nn.init.xavier_uniform_",
"conv.Conv",
"utils._get_clones",
"utils.inverse_sigmoid",
"utils.multi_scale_deformable_attn_pytorch",
"utils.torch_utils.TORCH_1_9",
"nn.Module"
],
"chunk_id": "class_DeformableTransformerDecoder_ac5c8729"
},
{
"content": "import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_6a0eaf33"
},
{
"content": "import math",
"chunk_type": "import",
"name": "math",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_math_a6e310e4"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_c5fb0e69"
},
{
"content": "import torch",
"chunk_type": "import",
"name": "torch",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch_8b75e7ed"
},
{
"content": "import torch.nn as nn",
"chunk_type": "import",
"name": "torch.nn",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn_b05f2489"
},
{
"content": "import torch.nn.functional as F",
"chunk_type": "import",
"name": "torch.nn.functional",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_torch.nn.functional_5e58c264"
},
{
"content": "from torch.nn.init import uniform_",
"chunk_type": "import",
"name": "uniform_",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 34,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_uniform__81c9763b"
},
{
"content": "__all__ = \"multi_scale_deformable_attn_pytorch\", \"inverse_sigmoid\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 12,
"end_line": 12,
"start_col": 0,
"end_col": 66,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___85de65a2"
},
{
"content": "def _get_clones(module, n):\n \"\"\"\n Create a list of cloned modules from the given module.\n\n Args:\n module (nn.Module): The module to be cloned.\n n (int): Number of clones to create.\n\n Returns:\n (nn.ModuleList): A ModuleList containing n clones of the input module.\n\n Examples:\n >>> import torch.nn as nn\n >>> layer = nn.Linear(10, 10)\n >>> clones = _get_clones(layer, 3)\n >>> len(clones)\n 3\n \"\"\"\n return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])",
"chunk_type": "function",
"name": "_get_clones",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 15,
"end_line": 33,
"start_col": 0,
"end_col": 67,
"parent_name": null,
"docstring": "Create a list of cloned modules from the given module.\n\nArgs:\n module (nn.Module): The module to be cloned.\n n (int): Number of clones to create.\n\nReturns:\n (nn.ModuleList): A ModuleList containing n clones of the input module.\n\nExamples:\n >>> import torch.nn as nn\n >>> layer = nn.Linear(10, 10)\n >>> clones = _get_clones(layer, 3)\n >>> len(clones)\n 3",
"parameters": [
"module",
"n"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"copy",
"math",
"numpy",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.uniform_"
],
"chunk_id": "function__get_clones_c4c6091d"
},
{
"content": "def bias_init_with_prob(prior_prob=0.01):\n \"\"\"\n Initialize conv/fc bias value according to a given probability value.\n\n This function calculates the bias initialization value based on a prior probability using the inverse error function.\n It's commonly used in object detection models to initialize classification layers with a specific positive prediction\n probability.\n\n Args:\n prior_prob (float, optional): Prior probability for bias initialization.\n\n Returns:\n (float): Bias initialization value calculated from the prior probability.\n\n Examples:\n >>> bias = bias_init_with_prob(0.01)\n >>> print(f\"Bias initialization value: {bias:.4f}\")\n Bias initialization value: -4.5951\n \"\"\"\n return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init",
"chunk_type": "function",
"name": "bias_init_with_prob",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 36,
"end_line": 55,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": "Initialize conv/fc bias value according to a given probability value.\n\nThis function calculates the bias initialization value based on a prior probability using the inverse error function.\nIt's commonly used in object detection models to initialize classification layers with a specific positive prediction\nprobability.\n\nArgs:\n prior_prob (float, optional): Prior probability for bias initialization.\n\nReturns:\n (float): Bias initialization value calculated from the prior probability.\n\nExamples:\n >>> bias = bias_init_with_prob(0.01)\n >>> print(f\"Bias initialization value: {bias:.4f}\")\n Bias initialization value: -4.5951",
"parameters": [
"prior_prob"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"copy",
"math",
"numpy",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.uniform_"
],
"chunk_id": "function_bias_init_with_prob_e7a9d08b"
},
{
"content": "def linear_init(module):\n \"\"\"\n Initialize the weights and biases of a linear module.\n\n This function initializes the weights of a linear module using a uniform distribution within bounds calculated\n from the input dimension. If the module has a bias, it is also initialized.\n\n Args:\n module (nn.Module): Linear module to initialize.\n\n Returns:\n (nn.Module): The initialized module.\n\n Examples:\n >>> import torch.nn as nn\n >>> linear = nn.Linear(10, 5)\n >>> initialized_linear = linear_init(linear)\n \"\"\"\n bound = 1 / math.sqrt(module.weight.shape[0])\n uniform_(module.weight, -bound, bound)\n if hasattr(module, \"bias\") and module.bias is not None:\n uniform_(module.bias, -bound, bound)",
"chunk_type": "function",
"name": "linear_init",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 58,
"end_line": 79,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Initialize the weights and biases of a linear module.\n\nThis function initializes the weights of a linear module using a uniform distribution within bounds calculated\nfrom the input dimension. If the module has a bias, it is also initialized.\n\nArgs:\n module (nn.Module): Linear module to initialize.\n\nReturns:\n (nn.Module): The initialized module.\n\nExamples:\n >>> import torch.nn as nn\n >>> linear = nn.Linear(10, 5)\n >>> initialized_linear = linear_init(linear)",
"parameters": [
"module"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"copy",
"math",
"numpy",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.uniform_"
],
"chunk_id": "function_linear_init_1a00c166"
},
{
"content": "def inverse_sigmoid(x, eps=1e-5):\n \"\"\"\n Calculate the inverse sigmoid function for a tensor.\n\n This function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network\n operations, particularly in attention mechanisms and coordinate transformations.\n\n Args:\n x (torch.Tensor): Input tensor with values in range [0, 1].\n eps (float, optional): Small epsilon value to prevent numerical instability.\n\n Returns:\n (torch.Tensor): Tensor after applying the inverse sigmoid function.\n\n Examples:\n >>> x = torch.tensor([0.2, 0.5, 0.8])\n >>> inverse_sigmoid(x)\n tensor([-1.3863, 0.0000, 1.3863])\n \"\"\"\n x = x.clamp(min=0, max=1)\n x1 = x.clamp(min=eps)\n x2 = (1 - x).clamp(min=eps)\n return torch.log(x1 / x2)",
"chunk_type": "function",
"name": "inverse_sigmoid",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 82,
"end_line": 104,
"start_col": 0,
"end_col": 29,
"parent_name": null,
"docstring": "Calculate the inverse sigmoid function for a tensor.\n\nThis function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network\noperations, particularly in attention mechanisms and coordinate transformations.\n\nArgs:\n x (torch.Tensor): Input tensor with values in range [0, 1].\n eps (float, optional): Small epsilon value to prevent numerical instability.\n\nReturns:\n (torch.Tensor): Tensor after applying the inverse sigmoid function.\n\nExamples:\n >>> x = torch.tensor([0.2, 0.5, 0.8])\n >>> inverse_sigmoid(x)\n tensor([-1.3863, 0.0000, 1.3863])",
"parameters": [
"x",
"eps"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"copy",
"math",
"numpy",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.uniform_"
],
"chunk_id": "function_inverse_sigmoid_9e224e18"
},
{
"content": "def multi_scale_deformable_attn_pytorch(\n value: torch.Tensor,\n value_spatial_shapes: torch.Tensor,\n sampling_locations: torch.Tensor,\n attention_weights: torch.Tensor,\n) -> torch.Tensor:\n \"\"\"\n Implement multi-scale deformable attention in PyTorch.\n\n This function performs deformable attention across multiple feature map scales, allowing the model to attend to\n different spatial locations with learned offsets.\n\n Args:\n value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims).\n value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2).\n sampling_locations (torch.Tensor): The sampling locations with shape\n (bs, num_queries, num_heads, num_levels, num_points, 2).\n attention_weights (torch.Tensor): The attention weights with shape\n (bs, num_queries, num_heads, num_levels, num_points).\n\n Returns:\n (torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims).\n\n References:\n https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py\n \"\"\"\n bs, _, num_heads, embed_dims = value.shape\n _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape\n value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)\n sampling_grids = 2 * sampling_locations - 1\n sampling_value_list = []\n for level, (H_, W_) in enumerate(value_spatial_shapes):\n # bs, H_*W_, num_heads, embed_dims ->\n # bs, H_*W_, num_heads*embed_dims ->\n # bs, num_heads*embed_dims, H_*W_ ->\n # bs*num_heads, embed_dims, H_, W_\n value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)\n # bs, num_queries, num_heads, num_points, 2 ->\n # bs, num_heads, num_queries, num_points, 2 ->\n # bs*num_heads, num_queries, num_points, 2\n sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)\n # bs*num_heads, embed_dims, num_queries, num_points\n sampling_value_l_ = F.grid_sample(\n value_l_, sampling_grid_l_, mode=\"bilinear\", padding_mode=\"zeros\", align_corners=False\n )\n sampling_value_list.append(sampling_value_l_)\n # (bs, num_queries, num_heads, num_levels, num_points) ->\n # (bs, num_heads, num_queries, num_levels, num_points) ->\n # (bs, num_heads, 1, num_queries, num_levels*num_points)\n attention_weights = attention_weights.transpose(1, 2).reshape(\n bs * num_heads, 1, num_queries, num_levels * num_points\n )\n output = (\n (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)\n .sum(-1)\n .view(bs, num_heads * embed_dims, num_queries)\n )\n return output.transpose(1, 2).contiguous()",
"chunk_type": "function",
"name": "multi_scale_deformable_attn_pytorch",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\utils.py",
"start_line": 107,
"end_line": 164,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Implement multi-scale deformable attention in PyTorch.\n\nThis function performs deformable attention across multiple feature map scales, allowing the model to attend to\ndifferent spatial locations with learned offsets.\n\nArgs:\n value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims).\n value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2).\n sampling_locations (torch.Tensor): The sampling locations with shape\n (bs, num_queries, num_heads, num_levels, num_points, 2).\n attention_weights (torch.Tensor): The attention weights with shape\n (bs, num_queries, num_heads, num_levels, num_points).\n\nReturns:\n (torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims).\n\nReferences:\n https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py",
"parameters": [
"value: torch.Tensor",
"value_spatial_shapes: torch.Tensor",
"sampling_locations: torch.Tensor",
"attention_weights: torch.Tensor"
],
"return_type": "torch.Tensor",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"copy",
"math",
"numpy",
"torch",
"torch.nn",
"torch.nn.functional",
"torch.nn.init.uniform_"
],
"chunk_id": "function_multi_scale_deformable_attn_pytorch_46591a4c"
},
{
"content": "from .block import (\n C1,\n C2,\n C2PSA,\n C3,\n C3TR,\n CIB,\n DFL,\n ELAN1,\n PSA,\n SPP,\n SPPELAN,\n SPPF,\n A2C2f,\n AConv,\n ADown,\n Attention,\n BNContrastiveHead,\n Bottleneck,\n BottleneckCSP,\n C2f,\n C2fAttn,\n C2fCIB,\n C2fPSA,\n C3Ghost,\n C3k2,\n C3x,\n CBFuse,\n CBLinear,\n ContrastiveHead,\n GhostBottleneck,\n HGBlock,\n HGStem,\n ImagePoolingAttn,\n MaxSigmoidAttnBlock,\n Proto,\n RepC3,\n RepNCSPELAN4,\n RepVGGDW,\n ResNetLayer,\n SCDown,\n TorchVision,\n)",
"chunk_type": "import",
"name": "C1, C2, C2PSA, C3, C3TR, CIB, DFL, ELAN1, PSA, SPP, SPPELAN, SPPF, A2C2f, AConv, ADown, Attention, BNContrastiveHead, Bottleneck, BottleneckCSP, C2f, C2fAttn, C2fCIB, C2fPSA, C3Ghost, C3k2, C3x, CBFuse, CBLinear, ContrastiveHead, GhostBottleneck, HGBlock, HGStem, ImagePoolingAttn, MaxSigmoidAttnBlock, Proto, RepC3, RepNCSPELAN4, RepVGGDW, ResNetLayer, SCDown, TorchVision",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py",
"start_line": 20,
"end_line": 62,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_C1, C2, C2PSA, C3, C3TR, CIB, DFL, ELAN1, PSA, SPP, SPPELAN, SPPF, A2C2f, AConv, ADown, Attention, BNContrastiveHead, Bottleneck, BottleneckCSP, C2f, C2fAttn, C2fCIB, C2fPSA, C3Ghost, C3k2, C3x, CBFuse, CBLinear, ContrastiveHead, GhostBottleneck, HGBlock, HGStem, ImagePoolingAttn, MaxSigmoidAttnBlock, Proto, RepC3, RepNCSPELAN4, RepVGGDW, ResNetLayer, SCDown, TorchVision_65ef2def"
},
{
"content": "from .conv import (\n CBAM,\n ChannelAttention,\n Concat,\n Conv,\n Conv2,\n ConvTranspose,\n DWConv,\n DWConvTranspose2d,\n Focus,\n GhostConv,\n Index,\n LightConv,\n RepConv,\n SpatialAttention,\n)",
"chunk_type": "import",
"name": "CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, GhostConv, Index, LightConv, RepConv, SpatialAttention",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py",
"start_line": 63,
"end_line": 78,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus, GhostConv, Index, LightConv, RepConv, SpatialAttention_a70cef6e"
},
{
"content": "from .head import (\n OBB,\n Classify,\n Detect,\n LRPCHead,\n Pose,\n RTDETRDecoder,\n Segment,\n WorldDetect,\n YOLOEDetect,\n YOLOESegment,\n v10Detect,\n)",
"chunk_type": "import",
"name": "OBB, Classify, Detect, LRPCHead, Pose, RTDETRDecoder, Segment, WorldDetect, YOLOEDetect, YOLOESegment, v10Detect",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py",
"start_line": 79,
"end_line": 91,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_OBB, Classify, Detect, LRPCHead, Pose, RTDETRDecoder, Segment, WorldDetect, YOLOEDetect, YOLOESegment, v10Detect_35091f91"
},
{
"content": "from .transformer import (\n AIFI,\n MLP,\n DeformableTransformerDecoder,\n DeformableTransformerDecoderLayer,\n LayerNorm2d,\n MLPBlock,\n MSDeformAttn,\n TransformerBlock,\n TransformerEncoderLayer,\n TransformerLayer,\n)",
"chunk_type": "import",
"name": "AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py",
"start_line": 92,
"end_line": 103,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer_c27e4da3"
},
{
"content": "__all__ = (\n \"Conv\",\n \"Conv2\",\n \"LightConv\",\n \"RepConv\",\n \"DWConv\",\n \"DWConvTranspose2d\",\n \"ConvTranspose\",\n \"Focus\",\n \"GhostConv\",\n \"ChannelAttention\",\n \"SpatialAttention\",\n \"CBAM\",\n \"Concat\",\n \"TransformerLayer\",\n \"TransformerBlock\",\n \"MLPBlock\",\n \"LayerNorm2d\",\n \"DFL\",\n \"HGBlock\",\n \"HGStem\",\n \"SPP\",\n \"SPPF\",\n \"C1\",\n \"C2\",\n \"C3\",\n \"C2f\",\n \"C3k2\",\n \"SCDown\",\n \"C2fPSA\",\n \"C2PSA\",\n \"C2fAttn\",\n \"C3x\",\n \"C3TR\",\n \"C3Ghost\",\n \"GhostBottleneck\",\n \"Bottleneck\",\n \"BottleneckCSP\",\n \"Proto\",\n \"Detect\",\n \"Segment\",\n \"Pose\",\n \"Classify\",\n \"TransformerEncoderLayer\",\n \"RepC3\",\n \"RTDETRDecoder\",\n \"AIFI\",\n \"DeformableTransformerDecoder\",\n \"DeformableTransformerDecoderLayer\",\n \"MSDeformAttn\",\n \"MLP\",\n \"ResNetLayer\",\n \"OBB\",\n \"WorldDetect\",\n \"YOLOEDetect\",\n \"YOLOESegment\",\n \"v10Detect\",\n \"LRPCHead\",\n \"ImagePoolingAttn\",\n \"MaxSigmoidAttnBlock\",\n \"ContrastiveHead\",\n \"BNContrastiveHead\",\n \"RepNCSPELAN4\",\n \"ADown\",\n \"SPPELAN\",\n \"CBFuse\",\n \"CBLinear\",\n \"AConv\",\n \"ELAN1\",\n \"RepVGGDW\",\n \"CIB\",\n \"C2fCIB\",\n \"Attention\",\n \"PSA\",\n \"TorchVision\",\n \"Index\",\n \"A2C2f\",\n)",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\nn\\modules\\__init__.py",
"start_line": 105,
"end_line": 182,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___b3080452"
},
{
"content": "import copy",
"chunk_type": "import",
"name": "copy",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_copy_6bb5f9c9"
},
{
"content": "from typing import List, Optional",
"chunk_type": "import",
"name": "List, Optional",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_List, Optional_845e34bc"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_a3836a3f"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_c9f60795"
},
{
"content": "from ultralytics.utils import LOGGER",
"chunk_type": "import",
"name": "LOGGER",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py",
"start_line": 9,
"end_line": 9,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER_40dda891"
},
{
"content": "class GMC:\n \"\"\"\n Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.\n\n This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,\n SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.\n\n Attributes:\n method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.\n downscale (int): Factor by which to downscale the frames for processing.\n prevFrame (np.ndarray): Previous frame for tracking.\n prevKeyPoints (List): Keypoints from the previous frame.\n prevDescriptors (np.ndarray): Descriptors from the previous frame.\n initializedFirstFrame (bool): Flag indicating if the first frame has been processed.\n\n Methods:\n apply: Apply the chosen method to a raw frame and optionally use provided detections.\n apply_ecc: Apply the ECC algorithm to a raw frame.\n apply_features: Apply feature-based methods like ORB or SIFT to a raw frame.\n apply_sparseoptflow: Apply the Sparse Optical Flow method to a raw frame.\n reset_params: Reset the internal parameters of the GMC object.\n\n Examples:\n Create a GMC object and apply it to a frame\n >>> gmc = GMC(method=\"sparseOptFlow\", downscale=2)\n >>> frame = np.array([[1, 2, 3], [4, 5, 6]])\n >>> processed_frame = gmc.apply(frame)\n >>> print(processed_frame)\n array([[1, 2, 3],\n [4, 5, 6]])\n \"\"\"\n\n def __init__(self, method: str = \"sparseOptFlow\", downscale: int = 2) -> None:\n \"\"\"\n Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.\n\n Args:\n method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.\n downscale (int): Downscale factor for processing frames.\n\n Examples:\n Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2\n >>> gmc = GMC(method=\"sparseOptFlow\", downscale=2)\n \"\"\"\n super().__init__()\n\n self.method = method\n self.downscale = max(1, downscale)\n\n if self.method == \"orb\":\n self.detector = cv2.FastFeatureDetector_create(20)\n self.extractor = cv2.ORB_create()\n self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)\n\n elif self.method == \"sift\":\n self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)\n self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)\n self.matcher = cv2.BFMatcher(cv2.NORM_L2)\n\n elif self.method == \"ecc\":\n number_of_iterations = 5000\n termination_eps = 1e-6\n self.warp_mode = cv2.MOTION_EUCLIDEAN\n self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)\n\n elif self.method == \"sparseOptFlow\":\n self.feature_params = dict(\n maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04\n )\n\n elif self.method in {\"none\", \"None\", None}:\n self.method = None\n else:\n raise ValueError(f\"Unknown GMC method: {method}\")\n\n self.prevFrame = None\n self.prevKeyPoints = None\n self.prevDescriptors = None\n self.initializedFirstFrame = False\n\n def apply(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:\n \"\"\"\n Apply object detection on a raw frame using the specified method.\n\n Args:\n raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).\n detections (List, optional): List of detections to be used in the processing.\n\n Returns:\n (np.ndarray): Transformation matrix with shape (2, 3).\n\n Examples:\n >>> gmc = GMC(method=\"sparseOptFlow\")\n >>> raw_frame = np.random.rand(480, 640, 3)\n >>> transformation_matrix = gmc.apply(raw_frame)\n >>> print(transformation_matrix.shape)\n (2, 3)\n \"\"\"\n if self.method in {\"orb\", \"sift\"}:\n return self.apply_features(raw_frame, detections)\n elif self.method == \"ecc\":\n return self.apply_ecc(raw_frame)\n elif self.method == \"sparseOptFlow\":\n return self.apply_sparseoptflow(raw_frame)\n else:\n return np.eye(2, 3)\n\n def apply_ecc(self, raw_frame: np.ndarray) -> np.ndarray:\n \"\"\"\n Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.\n\n Args:\n raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).\n\n Returns:\n (np.ndarray): Transformation matrix with shape (2, 3).\n\n Examples:\n >>> gmc = GMC(method=\"ecc\")\n >>> processed_frame = gmc.apply_ecc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))\n >>> print(processed_frame)\n [[1. 0. 0.]\n [0. 1. 0.]]\n \"\"\"\n height, width, c = raw_frame.shape\n frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame\n H = np.eye(2, 3, dtype=np.float32)\n\n # Downscale image for computational efficiency\n if self.downscale > 1.0:\n frame = cv2.GaussianBlur(frame, (3, 3), 1.5)\n frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))\n\n # Handle first frame initialization\n if not self.initializedFirstFrame:\n self.prevFrame = frame.copy()\n self.initializedFirstFrame = True\n return H\n\n # Run the ECC algorithm to find transformation matrix\n try:\n (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)\n except Exception as e:\n LOGGER.warning(f\"find transform failed. Set warp as identity {e}\")\n\n return H\n\n def apply_features(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray:\n \"\"\"\n Apply feature-based methods like ORB or SIFT to a raw frame.\n\n Args:\n raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).\n detections (List, optional): List of detections to be used in the processing.\n\n Returns:\n (np.ndarray): Transformation matrix with shape (2, 3).\n\n Examples:\n >>> gmc = GMC(method=\"orb\")\n >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)\n >>> transformation_matrix = gmc.apply_features(raw_frame)\n >>> print(transformation_matrix.shape)\n (2, 3)\n \"\"\"\n height, width, c = raw_frame.shape\n frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame\n H = np.eye(2, 3)\n\n # Downscale image for computational efficiency\n if self.downscale > 1.0:\n frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))\n width = width // self.downscale\n height = height // self.downscale\n\n # Create mask for keypoint detection, excluding border regions\n mask = np.zeros_like(frame)\n mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255\n\n # Exclude detection regions from mask to avoid tracking detected objects\n if detections is not None:\n for det in detections:\n tlbr = (det[:4] / self.downscale).astype(np.int_)\n mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0\n\n # Find keypoints and compute descriptors\n keypoints = self.detector.detect(frame, mask)\n keypoints, descriptors = self.extractor.compute(frame, keypoints)\n\n # Handle first frame initialization\n if not self.initializedFirstFrame:\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n self.prevDescriptors = copy.copy(descriptors)\n self.initializedFirstFrame = True\n return H\n\n # Match descriptors between previous and current frame\n knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)\n\n # Filter matches based on spatial distance constraints\n matches = []\n spatialDistances = []\n maxSpatialDistance = 0.25 * np.array([width, height])\n\n # Handle empty matches case\n if len(knnMatches) == 0:\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n self.prevDescriptors = copy.copy(descriptors)\n return H\n\n # Apply Lowe's ratio test and spatial distance filtering\n for m, n in knnMatches:\n if m.distance < 0.9 * n.distance:\n prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt\n currKeyPointLocation = keypoints[m.trainIdx].pt\n\n spatialDistance = (\n prevKeyPointLocation[0] - currKeyPointLocation[0],\n prevKeyPointLocation[1] - currKeyPointLocation[1],\n )\n\n if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (\n np.abs(spatialDistance[1]) < maxSpatialDistance[1]\n ):\n spatialDistances.append(spatialDistance)\n matches.append(m)\n\n # Filter outliers using statistical analysis\n meanSpatialDistances = np.mean(spatialDistances, 0)\n stdSpatialDistances = np.std(spatialDistances, 0)\n inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances\n\n # Extract good matches and corresponding points\n goodMatches = []\n prevPoints = []\n currPoints = []\n for i in range(len(matches)):\n if inliers[i, 0] and inliers[i, 1]:\n goodMatches.append(matches[i])\n prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)\n currPoints.append(keypoints[matches[i].trainIdx].pt)\n\n prevPoints = np.array(prevPoints)\n currPoints = np.array(currPoints)\n\n # Estimate transformation matrix using RANSAC\n if prevPoints.shape[0] > 4:\n H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)\n\n # Scale translation components back to original resolution\n if self.downscale > 1.0:\n H[0, 2] *= self.downscale\n H[1, 2] *= self.downscale\n else:\n LOGGER.warning(\"not enough matching points\")\n\n # Store current frame data for next iteration\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n self.prevDescriptors = copy.copy(descriptors)\n\n return H\n\n def apply_sparseoptflow(self, raw_frame: np.ndarray) -> np.ndarray:\n \"\"\"\n Apply Sparse Optical Flow method to a raw frame.\n\n Args:\n raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).\n\n Returns:\n (np.ndarray): Transformation matrix with shape (2, 3).\n\n Examples:\n >>> gmc = GMC()\n >>> result = gmc.apply_sparseoptflow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))\n >>> print(result)\n [[1. 0. 0.]\n [0. 1. 0.]]\n \"\"\"\n height, width, c = raw_frame.shape\n frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame\n H = np.eye(2, 3)\n\n # Downscale image for computational efficiency\n if self.downscale > 1.0:\n frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))\n\n # Find good features to track\n keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)\n\n # Handle first frame initialization\n if not self.initializedFirstFrame or self.prevKeyPoints is None:\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n self.initializedFirstFrame = True\n return H\n\n # Calculate optical flow using Lucas-Kanade method\n matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)\n\n # Extract successfully tracked points\n prevPoints = []\n currPoints = []\n\n for i in range(len(status)):\n if status[i]:\n prevPoints.append(self.prevKeyPoints[i])\n currPoints.append(matchedKeypoints[i])\n\n prevPoints = np.array(prevPoints)\n currPoints = np.array(currPoints)\n\n # Estimate transformation matrix using RANSAC\n if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]):\n H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)\n\n # Scale translation components back to original resolution\n if self.downscale > 1.0:\n H[0, 2] *= self.downscale\n H[1, 2] *= self.downscale\n else:\n LOGGER.warning(\"not enough matching points\")\n\n # Store current frame data for next iteration\n self.prevFrame = frame.copy()\n self.prevKeyPoints = copy.copy(keypoints)\n\n return H\n\n def reset_params(self) -> None:\n \"\"\"Reset the internal parameters including previous frame, keypoints, and descriptors.\"\"\"\n self.prevFrame = None\n self.prevKeyPoints = None\n self.prevDescriptors = None\n self.initializedFirstFrame = False",
"chunk_type": "class",
"name": "GMC",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\gmc.py",
"start_line": 12,
"end_line": 349,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": "Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.\n\nThis class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,\nSIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.\n\nAttributes:\n method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.\n downscale (int): Factor by which to downscale the frames for processing.\n prevFrame (np.ndarray): Previous frame for tracking.\n prevKeyPoints (List): Keypoints from the previous frame.\n prevDescriptors (np.ndarray): Descriptors from the previous frame.\n initializedFirstFrame (bool): Flag indicating if the first frame has been processed.\n\nMethods:\n apply: Apply the chosen method to a raw frame and optionally use provided detections.\n apply_ecc: Apply the ECC algorithm to a raw frame.\n apply_features: Apply feature-based methods like ORB or SIFT to a raw frame.\n apply_sparseoptflow: Apply the Sparse Optical Flow method to a raw frame.\n reset_params: Reset the internal parameters of the GMC object.\n\nExamples:\n Create a GMC object and apply it to a frame\n >>> gmc = GMC(method=\"sparseOptFlow\", downscale=2)\n >>> frame = np.array([[1, 2, 3], [4, 5, 6]])\n >>> processed_frame = gmc.apply(frame)\n >>> print(processed_frame)\n array([[1, 2, 3],\n [4, 5, 6]])",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"copy",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER"
],
"chunk_id": "class_GMC_ca52d141"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\kalman_filter.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_ea19bbd7"
},
{
"content": "import scipy.linalg",
"chunk_type": "import",
"name": "scipy.linalg",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\kalman_filter.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_scipy.linalg_f50fd77e"
},
{
"content": "class KalmanFilterXYAH:\n \"\"\"\n A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.\n\n Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space\n (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their\n respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is\n taken as a direct observation of the state space (linear observation model).\n\n Attributes:\n _motion_mat (np.ndarray): The motion matrix for the Kalman filter.\n _update_mat (np.ndarray): The update matrix for the Kalman filter.\n _std_weight_position (float): Standard deviation weight for position.\n _std_weight_velocity (float): Standard deviation weight for velocity.\n\n Methods:\n initiate: Create a track from an unassociated measurement.\n predict: Run the Kalman filter prediction step.\n project: Project the state distribution to measurement space.\n multi_predict: Run the Kalman filter prediction step (vectorized version).\n update: Run the Kalman filter correction step.\n gating_distance: Compute the gating distance between state distribution and measurements.\n\n Examples:\n Initialize the Kalman filter and create a track from a measurement\n >>> kf = KalmanFilterXYAH()\n >>> measurement = np.array([100, 200, 1.5, 50])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n >>> print(covariance)\n \"\"\"\n\n def __init__(self):\n \"\"\"\n Initialize Kalman filter model matrices with motion and observation uncertainty weights.\n\n The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)\n represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective\n velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear\n observation model for bounding box location.\n\n Examples:\n Initialize a Kalman filter for tracking:\n >>> kf = KalmanFilterXYAH()\n \"\"\"\n ndim, dt = 4, 1.0\n\n # Create Kalman filter model matrices\n self._motion_mat = np.eye(2 * ndim, 2 * ndim)\n for i in range(ndim):\n self._motion_mat[i, ndim + i] = dt\n self._update_mat = np.eye(ndim, 2 * ndim)\n\n # Motion and observation uncertainty are chosen relative to the current state estimate\n self._std_weight_position = 1.0 / 20\n self._std_weight_velocity = 1.0 / 160\n\n def initiate(self, measurement: np.ndarray):\n \"\"\"\n Create a track from an unassociated measurement.\n\n Args:\n measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,\n and height h.\n\n Returns:\n mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.\n covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.\n\n Examples:\n >>> kf = KalmanFilterXYAH()\n >>> measurement = np.array([100, 50, 1.5, 200])\n >>> mean, covariance = kf.initiate(measurement)\n \"\"\"\n mean_pos = measurement\n mean_vel = np.zeros_like(mean_pos)\n mean = np.r_[mean_pos, mean_vel]\n\n std = [\n 2 * self._std_weight_position * measurement[3],\n 2 * self._std_weight_position * measurement[3],\n 1e-2,\n 2 * self._std_weight_position * measurement[3],\n 10 * self._std_weight_velocity * measurement[3],\n 10 * self._std_weight_velocity * measurement[3],\n 1e-5,\n 10 * self._std_weight_velocity * measurement[3],\n ]\n covariance = np.diag(np.square(std))\n return mean, covariance\n\n def predict(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Run Kalman filter prediction step.\n\n Args:\n mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.\n covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.\n\n Returns:\n mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.\n covariance (np.ndarray): Covariance matrix of the predicted state.\n\n Examples:\n >>> kf = KalmanFilterXYAH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)\n \"\"\"\n std_pos = [\n self._std_weight_position * mean[3],\n self._std_weight_position * mean[3],\n 1e-2,\n self._std_weight_position * mean[3],\n ]\n std_vel = [\n self._std_weight_velocity * mean[3],\n self._std_weight_velocity * mean[3],\n 1e-5,\n self._std_weight_velocity * mean[3],\n ]\n motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))\n\n mean = np.dot(mean, self._motion_mat.T)\n covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov\n\n return mean, covariance\n\n def project(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Project state distribution to measurement space.\n\n Args:\n mean (np.ndarray): The state's mean vector (8 dimensional array).\n covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).\n\n Returns:\n mean (np.ndarray): Projected mean of the given state estimate.\n covariance (np.ndarray): Projected covariance matrix of the given state estimate.\n\n Examples:\n >>> kf = KalmanFilterXYAH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> projected_mean, projected_covariance = kf.project(mean, covariance)\n \"\"\"\n std = [\n self._std_weight_position * mean[3],\n self._std_weight_position * mean[3],\n 1e-1,\n self._std_weight_position * mean[3],\n ]\n innovation_cov = np.diag(np.square(std))\n\n mean = np.dot(self._update_mat, mean)\n covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))\n return mean, covariance + innovation_cov\n\n def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Run Kalman filter prediction step for multiple object states (Vectorized version).\n\n Args:\n mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.\n covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.\n\n Returns:\n mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).\n covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).\n\n Examples:\n >>> mean = np.random.rand(10, 8) # 10 object states\n >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states\n >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance)\n \"\"\"\n std_pos = [\n self._std_weight_position * mean[:, 3],\n self._std_weight_position * mean[:, 3],\n 1e-2 * np.ones_like(mean[:, 3]),\n self._std_weight_position * mean[:, 3],\n ]\n std_vel = [\n self._std_weight_velocity * mean[:, 3],\n self._std_weight_velocity * mean[:, 3],\n 1e-5 * np.ones_like(mean[:, 3]),\n self._std_weight_velocity * mean[:, 3],\n ]\n sqr = np.square(np.r_[std_pos, std_vel]).T\n\n motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]\n motion_cov = np.asarray(motion_cov)\n\n mean = np.dot(mean, self._motion_mat.T)\n left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))\n covariance = np.dot(left, self._motion_mat.T) + motion_cov\n\n return mean, covariance\n\n def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):\n \"\"\"\n Run Kalman filter correction step.\n\n Args:\n mean (np.ndarray): The predicted state's mean vector (8 dimensional).\n covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).\n measurement (np.ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center\n position, a the aspect ratio, and h the height of the bounding box.\n\n Returns:\n new_mean (np.ndarray): Measurement-corrected state mean.\n new_covariance (np.ndarray): Measurement-corrected state covariance.\n\n Examples:\n >>> kf = KalmanFilterXYAH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> measurement = np.array([1, 1, 1, 1])\n >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)\n \"\"\"\n projected_mean, projected_cov = self.project(mean, covariance)\n\n chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)\n kalman_gain = scipy.linalg.cho_solve(\n (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False\n ).T\n innovation = measurement - projected_mean\n\n new_mean = mean + np.dot(innovation, kalman_gain.T)\n new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))\n return new_mean, new_covariance\n\n def gating_distance(\n self,\n mean: np.ndarray,\n covariance: np.ndarray,\n measurements: np.ndarray,\n only_position: bool = False,\n metric: str = \"maha\",\n ) -> np.ndarray:\n \"\"\"\n Compute gating distance between state distribution and measurements.\n\n A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square\n distribution has 4 degrees of freedom, otherwise 2.\n\n Args:\n mean (np.ndarray): Mean vector over the state distribution (8 dimensional).\n covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).\n measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the\n bounding box center position, a the aspect ratio, and h the height.\n only_position (bool, optional): If True, distance computation is done with respect to box center position only.\n metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared\n Euclidean distance and 'maha' for the squared Mahalanobis distance.\n\n Returns:\n (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between\n (mean, covariance) and `measurements[i]`.\n\n Examples:\n Compute gating distance using Mahalanobis metric:\n >>> kf = KalmanFilterXYAH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])\n >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric=\"maha\")\n \"\"\"\n mean, covariance = self.project(mean, covariance)\n if only_position:\n mean, covariance = mean[:2], covariance[:2, :2]\n measurements = measurements[:, :2]\n\n d = measurements - mean\n if metric == \"gaussian\":\n return np.sum(d * d, axis=1)\n elif metric == \"maha\":\n cholesky_factor = np.linalg.cholesky(covariance)\n z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)\n return np.sum(z * z, axis=0) # square maha\n else:\n raise ValueError(\"Invalid distance metric\")",
"chunk_type": "class",
"name": "KalmanFilterXYAH",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\kalman_filter.py",
"start_line": 7,
"end_line": 286,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": "A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.\n\nImplements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space\n(x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their\nrespective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is\ntaken as a direct observation of the state space (linear observation model).\n\nAttributes:\n _motion_mat (np.ndarray): The motion matrix for the Kalman filter.\n _update_mat (np.ndarray): The update matrix for the Kalman filter.\n _std_weight_position (float): Standard deviation weight for position.\n _std_weight_velocity (float): Standard deviation weight for velocity.\n\nMethods:\n initiate: Create a track from an unassociated measurement.\n predict: Run the Kalman filter prediction step.\n project: Project the state distribution to measurement space.\n multi_predict: Run the Kalman filter prediction step (vectorized version).\n update: Run the Kalman filter correction step.\n gating_distance: Compute the gating distance between state distribution and measurements.\n\nExamples:\n Initialize the Kalman filter and create a track from a measurement\n >>> kf = KalmanFilterXYAH()\n >>> measurement = np.array([100, 200, 1.5, 50])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n >>> print(covariance)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"numpy",
"scipy.linalg"
],
"chunk_id": "class_KalmanFilterXYAH_03d35260"
},
{
"content": "class KalmanFilterXYWH(KalmanFilterXYAH):\n \"\"\"\n A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.\n\n Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where\n (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.\n The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct\n observation of the state space (linear observation model).\n\n Attributes:\n _motion_mat (np.ndarray): The motion matrix for the Kalman filter.\n _update_mat (np.ndarray): The update matrix for the Kalman filter.\n _std_weight_position (float): Standard deviation weight for position.\n _std_weight_velocity (float): Standard deviation weight for velocity.\n\n Methods:\n initiate: Create a track from an unassociated measurement.\n predict: Run the Kalman filter prediction step.\n project: Project the state distribution to measurement space.\n multi_predict: Run the Kalman filter prediction step in a vectorized manner.\n update: Run the Kalman filter correction step.\n\n Examples:\n Create a Kalman filter and initialize a track\n >>> kf = KalmanFilterXYWH()\n >>> measurement = np.array([100, 50, 20, 40])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n >>> print(covariance)\n \"\"\"\n\n def initiate(self, measurement: np.ndarray):\n \"\"\"\n Create track from unassociated measurement.\n\n Args:\n measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.\n\n Returns:\n mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.\n covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.\n\n Examples:\n >>> kf = KalmanFilterXYWH()\n >>> measurement = np.array([100, 50, 20, 40])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n [100. 50. 20. 40. 0. 0. 0. 0.]\n >>> print(covariance)\n [[ 4. 0. 0. 0. 0. 0. 0. 0.]\n [ 0. 4. 0. 0. 0. 0. 0. 0.]\n [ 0. 0. 4. 0. 0. 0. 0. 0.]\n [ 0. 0. 0. 4. 0. 0. 0. 0.]\n [ 0. 0. 0. 0. 0.25 0. 0. 0.]\n [ 0. 0. 0. 0. 0. 0.25 0. 0.]\n [ 0. 0. 0. 0. 0. 0. 0.25 0.]\n [ 0. 0. 0. 0. 0. 0. 0. 0.25]]\n \"\"\"\n mean_pos = measurement\n mean_vel = np.zeros_like(mean_pos)\n mean = np.r_[mean_pos, mean_vel]\n\n std = [\n 2 * self._std_weight_position * measurement[2],\n 2 * self._std_weight_position * measurement[3],\n 2 * self._std_weight_position * measurement[2],\n 2 * self._std_weight_position * measurement[3],\n 10 * self._std_weight_velocity * measurement[2],\n 10 * self._std_weight_velocity * measurement[3],\n 10 * self._std_weight_velocity * measurement[2],\n 10 * self._std_weight_velocity * measurement[3],\n ]\n covariance = np.diag(np.square(std))\n return mean, covariance\n\n def predict(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Run Kalman filter prediction step.\n\n Args:\n mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.\n covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.\n\n Returns:\n mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.\n covariance (np.ndarray): Covariance matrix of the predicted state.\n\n Examples:\n >>> kf = KalmanFilterXYWH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)\n \"\"\"\n std_pos = [\n self._std_weight_position * mean[2],\n self._std_weight_position * mean[3],\n self._std_weight_position * mean[2],\n self._std_weight_position * mean[3],\n ]\n std_vel = [\n self._std_weight_velocity * mean[2],\n self._std_weight_velocity * mean[3],\n self._std_weight_velocity * mean[2],\n self._std_weight_velocity * mean[3],\n ]\n motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))\n\n mean = np.dot(mean, self._motion_mat.T)\n covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov\n\n return mean, covariance\n\n def project(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Project state distribution to measurement space.\n\n Args:\n mean (np.ndarray): The state's mean vector (8 dimensional array).\n covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).\n\n Returns:\n mean (np.ndarray): Projected mean of the given state estimate.\n covariance (np.ndarray): Projected covariance matrix of the given state estimate.\n\n Examples:\n >>> kf = KalmanFilterXYWH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> projected_mean, projected_cov = kf.project(mean, covariance)\n \"\"\"\n std = [\n self._std_weight_position * mean[2],\n self._std_weight_position * mean[3],\n self._std_weight_position * mean[2],\n self._std_weight_position * mean[3],\n ]\n innovation_cov = np.diag(np.square(std))\n\n mean = np.dot(self._update_mat, mean)\n covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))\n return mean, covariance + innovation_cov\n\n def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):\n \"\"\"\n Run Kalman filter prediction step (Vectorized version).\n\n Args:\n mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.\n covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.\n\n Returns:\n mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).\n covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).\n\n Examples:\n >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors\n >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices\n >>> kf = KalmanFilterXYWH()\n >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)\n \"\"\"\n std_pos = [\n self._std_weight_position * mean[:, 2],\n self._std_weight_position * mean[:, 3],\n self._std_weight_position * mean[:, 2],\n self._std_weight_position * mean[:, 3],\n ]\n std_vel = [\n self._std_weight_velocity * mean[:, 2],\n self._std_weight_velocity * mean[:, 3],\n self._std_weight_velocity * mean[:, 2],\n self._std_weight_velocity * mean[:, 3],\n ]\n sqr = np.square(np.r_[std_pos, std_vel]).T\n\n motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]\n motion_cov = np.asarray(motion_cov)\n\n mean = np.dot(mean, self._motion_mat.T)\n left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))\n covariance = np.dot(left, self._motion_mat.T) + motion_cov\n\n return mean, covariance\n\n def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):\n \"\"\"\n Run Kalman filter correction step.\n\n Args:\n mean (np.ndarray): The predicted state's mean vector (8 dimensional).\n covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).\n measurement (np.ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center\n position, w the width, and h the height of the bounding box.\n\n Returns:\n new_mean (np.ndarray): Measurement-corrected state mean.\n new_covariance (np.ndarray): Measurement-corrected state covariance.\n\n Examples:\n >>> kf = KalmanFilterXYWH()\n >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])\n >>> covariance = np.eye(8)\n >>> measurement = np.array([0.5, 0.5, 1.2, 1.2])\n >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)\n \"\"\"\n return super().update(mean, covariance, measurement)",
"chunk_type": "class",
"name": "KalmanFilterXYWH",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\kalman_filter.py",
"start_line": 289,
"end_line": 493,
"start_col": 0,
"end_col": 60,
"parent_name": null,
"docstring": "A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.\n\nImplements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where\n(x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.\nThe object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct\nobservation of the state space (linear observation model).\n\nAttributes:\n _motion_mat (np.ndarray): The motion matrix for the Kalman filter.\n _update_mat (np.ndarray): The update matrix for the Kalman filter.\n _std_weight_position (float): Standard deviation weight for position.\n _std_weight_velocity (float): Standard deviation weight for velocity.\n\nMethods:\n initiate: Create a track from an unassociated measurement.\n predict: Run the Kalman filter prediction step.\n project: Project the state distribution to measurement space.\n multi_predict: Run the Kalman filter prediction step in a vectorized manner.\n update: Run the Kalman filter correction step.\n\nExamples:\n Create a Kalman filter and initialize a track\n >>> kf = KalmanFilterXYWH()\n >>> measurement = np.array([100, 50, 20, 40])\n >>> mean, covariance = kf.initiate(measurement)\n >>> print(mean)\n >>> print(covariance)",
"parameters": null,
"return_type": null,
"decorators": [],
"complexity_score": null,
"dependencies": [
"numpy",
"scipy.linalg",
"KalmanFilterXYAH"
],
"chunk_id": "class_KalmanFilterXYWH_509215f6"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_46358698"
},
{
"content": "import scipy",
"chunk_type": "import",
"name": "scipy",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 12,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_scipy_72d84f80"
},
{
"content": "from scipy.spatial.distance import cdist",
"chunk_type": "import",
"name": "cdist",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cdist_40636338"
},
{
"content": "from ultralytics.utils.metrics import batch_probiou, bbox_ioa",
"chunk_type": "import",
"name": "batch_probiou, bbox_ioa",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_batch_probiou, bbox_ioa_cc293f11"
},
{
"content": "def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):\n \"\"\"\n Perform linear assignment using either the scipy or lap.lapjv method.\n\n Args:\n cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).\n thresh (float): Threshold for considering an assignment valid.\n use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.\n\n Returns:\n matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.\n unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).\n unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).\n\n Examples:\n >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n >>> thresh = 5.0\n >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)\n \"\"\"\n if cost_matrix.size == 0:\n return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))\n\n if use_lap:\n # Use lap.lapjv\n # https://github.com/gatagat/lap\n _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)\n matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]\n unmatched_a = np.where(x < 0)[0]\n unmatched_b = np.where(y < 0)[0]\n else:\n # Use scipy.optimize.linear_sum_assignment\n # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html\n x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y\n matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])\n if len(matches) == 0:\n unmatched_a = list(np.arange(cost_matrix.shape[0]))\n unmatched_b = list(np.arange(cost_matrix.shape[1]))\n else:\n unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))\n unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))\n\n return matches, unmatched_a, unmatched_b",
"chunk_type": "function",
"name": "linear_assignment",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py",
"start_line": 20,
"end_line": 61,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Perform linear assignment using either the scipy or lap.lapjv method.\n\nArgs:\n cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).\n thresh (float): Threshold for considering an assignment valid.\n use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.\n\nReturns:\n matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.\n unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).\n unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).\n\nExamples:\n >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n >>> thresh = 5.0\n >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)",
"parameters": [
"cost_matrix: np.ndarray",
"thresh: float",
"use_lap: bool"
],
"return_type": null,
"decorators": [],
"complexity_score": 6,
"dependencies": [
"numpy",
"scipy",
"scipy.spatial.distance.cdist",
"ultralytics.utils.metrics.batch_probiou",
"ultralytics.utils.metrics.bbox_ioa",
"lap",
"ultralytics.utils.checks.check_requirements",
"lap"
],
"chunk_id": "function_linear_assignment_c451a178"
},
{
"content": "def iou_distance(atracks: list, btracks: list) -> np.ndarray:\n \"\"\"\n Compute cost based on Intersection over Union (IoU) between tracks.\n\n Args:\n atracks (List[STrack] | List[np.ndarray]): List of tracks 'a' or bounding boxes.\n btracks (List[STrack] | List[np.ndarray]): List of tracks 'b' or bounding boxes.\n\n Returns:\n (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)).\n\n Examples:\n Compute IoU distance between two sets of tracks\n >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]\n >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]\n >>> cost_matrix = iou_distance(atracks, btracks)\n \"\"\"\n if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):\n atlbrs = atracks\n btlbrs = btracks\n else:\n atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]\n btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]\n\n ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)\n if len(atlbrs) and len(btlbrs):\n if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:\n ious = batch_probiou(\n np.ascontiguousarray(atlbrs, dtype=np.float32),\n np.ascontiguousarray(btlbrs, dtype=np.float32),\n ).numpy()\n else:\n ious = bbox_ioa(\n np.ascontiguousarray(atlbrs, dtype=np.float32),\n np.ascontiguousarray(btlbrs, dtype=np.float32),\n iou=True,\n )\n return 1 - ious # cost matrix",
"chunk_type": "function",
"name": "iou_distance",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py",
"start_line": 64,
"end_line": 101,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Compute cost based on Intersection over Union (IoU) between tracks.\n\nArgs:\n atracks (List[STrack] | List[np.ndarray]): List of tracks 'a' or bounding boxes.\n btracks (List[STrack] | List[np.ndarray]): List of tracks 'b' or bounding boxes.\n\nReturns:\n (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)).\n\nExamples:\n Compute IoU distance between two sets of tracks\n >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]\n >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]\n >>> cost_matrix = iou_distance(atracks, btracks)",
"parameters": [
"atracks: list",
"btracks: list"
],
"return_type": "np.ndarray",
"decorators": [],
"complexity_score": 6,
"dependencies": [
"numpy",
"scipy",
"scipy.spatial.distance.cdist",
"ultralytics.utils.metrics.batch_probiou",
"ultralytics.utils.metrics.bbox_ioa",
"lap",
"ultralytics.utils.checks.check_requirements",
"lap"
],
"chunk_id": "function_iou_distance_046b5443"
},
{
"content": "def embedding_distance(tracks: list, detections: list, metric: str = \"cosine\") -> np.ndarray:\n \"\"\"\n Compute distance between tracks and detections based on embeddings.\n\n Args:\n tracks (List[STrack]): List of tracks, where each track contains embedding features.\n detections (List[BaseTrack]): List of detections, where each detection contains embedding features.\n metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.\n\n Returns:\n (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks\n and M is the number of detections.\n\n Examples:\n Compute the embedding distance between tracks and detections using cosine metric\n >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features\n >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features\n >>> cost_matrix = embedding_distance(tracks, detections, metric=\"cosine\")\n \"\"\"\n cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)\n if cost_matrix.size == 0:\n return cost_matrix\n det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)\n # for i, track in enumerate(tracks):\n # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))\n track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)\n cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features\n return cost_matrix",
"chunk_type": "function",
"name": "embedding_distance",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py",
"start_line": 104,
"end_line": 131,
"start_col": 0,
"end_col": 22,
"parent_name": null,
"docstring": "Compute distance between tracks and detections based on embeddings.\n\nArgs:\n tracks (List[STrack]): List of tracks, where each track contains embedding features.\n detections (List[BaseTrack]): List of detections, where each detection contains embedding features.\n metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.\n\nReturns:\n (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks\n and M is the number of detections.\n\nExamples:\n Compute the embedding distance between tracks and detections using cosine metric\n >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features\n >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features\n >>> cost_matrix = embedding_distance(tracks, detections, metric=\"cosine\")",
"parameters": [
"tracks: list",
"detections: list",
"metric: str"
],
"return_type": "np.ndarray",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"numpy",
"scipy",
"scipy.spatial.distance.cdist",
"ultralytics.utils.metrics.batch_probiou",
"ultralytics.utils.metrics.bbox_ioa",
"lap",
"ultralytics.utils.checks.check_requirements",
"lap"
],
"chunk_id": "function_embedding_distance_5fb9ec69"
},
{
"content": "def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:\n \"\"\"\n Fuse cost matrix with detection scores to produce a single similarity matrix.\n\n Args:\n cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).\n detections (List[BaseTrack]): List of detections, each containing a score attribute.\n\n Returns:\n (np.ndarray): Fused similarity matrix with shape (N, M).\n\n Examples:\n Fuse a cost matrix with detection scores\n >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections\n >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]\n >>> fused_matrix = fuse_score(cost_matrix, detections)\n \"\"\"\n if cost_matrix.size == 0:\n return cost_matrix\n iou_sim = 1 - cost_matrix\n det_scores = np.array([det.score for det in detections])\n det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)\n fuse_sim = iou_sim * det_scores\n return 1 - fuse_sim # fuse_cost",
"chunk_type": "function",
"name": "fuse_score",
"file_path": "ultralytics\\ultralytics\\trackers\\utils\\matching.py",
"start_line": 134,
"end_line": 157,
"start_col": 0,
"end_col": 23,
"parent_name": null,
"docstring": "Fuse cost matrix with detection scores to produce a single similarity matrix.\n\nArgs:\n cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).\n detections (List[BaseTrack]): List of detections, each containing a score attribute.\n\nReturns:\n (np.ndarray): Fused similarity matrix with shape (N, M).\n\nExamples:\n Fuse a cost matrix with detection scores\n >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections\n >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]\n >>> fused_matrix = fuse_score(cost_matrix, detections)",
"parameters": [
"cost_matrix: np.ndarray",
"detections: list"
],
"return_type": "np.ndarray",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"numpy",
"scipy",
"scipy.spatial.distance.cdist",
"ultralytics.utils.metrics.batch_probiou",
"ultralytics.utils.metrics.bbox_ioa",
"lap",
"ultralytics.utils.checks.check_requirements",
"lap"
],
"chunk_id": "function_fuse_score_2d29eb18"
},
{
"content": "from collections import defaultdict",
"chunk_type": "import",
"name": "defaultdict",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 35,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_defaultdict_32412c51"
},
{
"content": "from copy import deepcopy",
"chunk_type": "import",
"name": "deepcopy",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 25,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_deepcopy_fb9df1a7"
},
{
"content": "def on_pretrain_routine_start(trainer):\n \"\"\"Called before the pretraining routine starts.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_pretrain_routine_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 10,
"end_line": 12,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called before the pretraining routine starts.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_pretrain_routine_start_f6e22321"
},
{
"content": "def on_pretrain_routine_end(trainer):\n \"\"\"Called after the pretraining routine ends.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_pretrain_routine_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 15,
"end_line": 17,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called after the pretraining routine ends.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_pretrain_routine_end_82d57cc4"
},
{
"content": "def on_train_start(trainer):\n \"\"\"Called when the training starts.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_train_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 20,
"end_line": 22,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the training starts.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_train_start_678ff1f2"
},
{
"content": "def on_train_epoch_start(trainer):\n \"\"\"Called at the start of each training epoch.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_train_epoch_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 25,
"end_line": 27,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the start of each training epoch.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_train_epoch_start_c5ea4d1e"
},
{
"content": "def on_train_batch_start(trainer):\n \"\"\"Called at the start of each training batch.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_train_batch_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 30,
"end_line": 32,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the start of each training batch.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_train_batch_start_1b5f49c6"
},
{
"content": "def optimizer_step(trainer):\n \"\"\"Called when the optimizer takes a step.\"\"\"\n pass",
"chunk_type": "function",
"name": "optimizer_step",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 35,
"end_line": 37,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the optimizer takes a step.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_optimizer_step_26375cc8"
},
{
"content": "def on_before_zero_grad(trainer):\n \"\"\"Called before the gradients are set to zero.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_before_zero_grad",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 40,
"end_line": 42,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called before the gradients are set to zero.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_before_zero_grad_460f4a5b"
},
{
"content": "def on_train_batch_end(trainer):\n \"\"\"Called at the end of each training batch.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_train_batch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 45,
"end_line": 47,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the end of each training batch.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_train_batch_end_00f24bdc"
},
{
"content": "def on_train_epoch_end(trainer):\n \"\"\"Called at the end of each training epoch.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_train_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 50,
"end_line": 52,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the end of each training epoch.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_train_epoch_end_604d33e9"
},
{
"content": "def on_fit_epoch_end(trainer):\n \"\"\"Called at the end of each fit epoch (train + val).\"\"\"\n pass",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 55,
"end_line": 57,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the end of each fit epoch (train + val).",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_fit_epoch_end_0025f6b1"
},
{
"content": "def on_model_save(trainer):\n \"\"\"Called when the model is saved.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_model_save",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 60,
"end_line": 62,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the model is saved.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_model_save_14d6ee39"
},
{
"content": "def on_train_end(trainer):\n \"\"\"Called when the training ends.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_train_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 65,
"end_line": 67,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the training ends.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_train_end_019e91c8"
},
{
"content": "def on_params_update(trainer):\n \"\"\"Called when the model parameters are updated.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_params_update",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 70,
"end_line": 72,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the model parameters are updated.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_params_update_aae0e6c4"
},
{
"content": "def teardown(trainer):\n \"\"\"Called during the teardown of the training process.\"\"\"\n pass",
"chunk_type": "function",
"name": "teardown",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 75,
"end_line": 77,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called during the teardown of the training process.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_teardown_2d541efb"
},
{
"content": "def on_val_start(validator):\n \"\"\"Called when the validation starts.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_val_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 83,
"end_line": 85,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the validation starts.",
"parameters": [
"validator"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_val_start_8e1a914d"
},
{
"content": "def on_val_batch_start(validator):\n \"\"\"Called at the start of each validation batch.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_val_batch_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 88,
"end_line": 90,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the start of each validation batch.",
"parameters": [
"validator"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_val_batch_start_5dd04932"
},
{
"content": "def on_val_batch_end(validator):\n \"\"\"Called at the end of each validation batch.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_val_batch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 93,
"end_line": 95,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the end of each validation batch.",
"parameters": [
"validator"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_val_batch_end_9c0eee9c"
},
{
"content": "def on_val_end(validator):\n \"\"\"Called when the validation ends.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_val_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 98,
"end_line": 100,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the validation ends.",
"parameters": [
"validator"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_val_end_bd55536d"
},
{
"content": "def on_predict_start(predictor):\n \"\"\"Called when the prediction starts.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_predict_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 106,
"end_line": 108,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the prediction starts.",
"parameters": [
"predictor"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_predict_start_7f9e42dc"
},
{
"content": "def on_predict_batch_start(predictor):\n \"\"\"Called at the start of each prediction batch.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_predict_batch_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 111,
"end_line": 113,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the start of each prediction batch.",
"parameters": [
"predictor"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_predict_batch_start_f9aaf040"
},
{
"content": "def on_predict_batch_end(predictor):\n \"\"\"Called at the end of each prediction batch.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_predict_batch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 116,
"end_line": 118,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called at the end of each prediction batch.",
"parameters": [
"predictor"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_predict_batch_end_0ae684b0"
},
{
"content": "def on_predict_postprocess_end(predictor):\n \"\"\"Called after the post-processing of the prediction ends.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_predict_postprocess_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 121,
"end_line": 123,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called after the post-processing of the prediction ends.",
"parameters": [
"predictor"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_predict_postprocess_end_357c0141"
},
{
"content": "def on_predict_end(predictor):\n \"\"\"Called when the prediction ends.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_predict_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 126,
"end_line": 128,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the prediction ends.",
"parameters": [
"predictor"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_predict_end_9942070b"
},
{
"content": "def on_export_start(exporter):\n \"\"\"Called when the model export starts.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_export_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 134,
"end_line": 136,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the model export starts.",
"parameters": [
"exporter"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_export_start_d6541e30"
},
{
"content": "def on_export_end(exporter):\n \"\"\"Called when the model export ends.\"\"\"\n pass",
"chunk_type": "function",
"name": "on_export_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 139,
"end_line": 141,
"start_col": 0,
"end_col": 8,
"parent_name": null,
"docstring": "Called when the model export ends.",
"parameters": [
"exporter"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_on_export_end_6fd84381"
},
{
"content": "default_callbacks = {\n # Run in trainer\n \"on_pretrain_routine_start\": [on_pretrain_routine_start],\n \"on_pretrain_routine_end\": [on_pretrain_routine_end],\n \"on_train_start\": [on_train_start],\n \"on_train_epoch_start\": [on_train_epoch_start],\n \"on_train_batch_start\": [on_train_batch_start],\n \"optimizer_step\": [optimizer_step],\n \"on_before_zero_grad\": [on_before_zero_grad],\n \"on_train_batch_end\": [on_train_batch_end],\n \"on_train_epoch_end\": [on_train_epoch_end],\n \"on_fit_epoch_end\": [on_fit_epoch_end], # fit = train + val\n \"on_model_save\": [on_model_save],\n \"on_train_end\": [on_train_end],\n \"on_params_update\": [on_params_update],\n \"teardown\": [teardown],\n # Run in validator\n \"on_val_start\": [on_val_start],\n \"on_val_batch_start\": [on_val_batch_start],\n \"on_val_batch_end\": [on_val_batch_end],\n \"on_val_end\": [on_val_end],\n # Run in predictor\n \"on_predict_start\": [on_predict_start],\n \"on_predict_batch_start\": [on_predict_batch_start],\n \"on_predict_postprocess_end\": [on_predict_postprocess_end],\n \"on_predict_batch_end\": [on_predict_batch_end],\n \"on_predict_end\": [on_predict_end],\n # Run in exporter\n \"on_export_start\": [on_export_start],\n \"on_export_end\": [on_export_end],\n}",
"chunk_type": "variable",
"name": "default_callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 144,
"end_line": 174,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_default_callbacks_53b1012c"
},
{
"content": "def get_default_callbacks():\n \"\"\"\n Get the default callbacks for Ultralytics training, validation, prediction, and export processes.\n\n Returns:\n (dict): Dictionary of default callbacks for various training events. Each key represents an event during the\n training process, and the corresponding value is a list of callback functions executed when that event\n occurs.\n\n Examples:\n >>> callbacks = get_default_callbacks()\n >>> print(list(callbacks.keys())) # show all available callback events\n ['on_pretrain_routine_start', 'on_pretrain_routine_end', ...]\n \"\"\"\n return defaultdict(list, deepcopy(default_callbacks))",
"chunk_type": "function",
"name": "get_default_callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 177,
"end_line": 191,
"start_col": 0,
"end_col": 57,
"parent_name": null,
"docstring": "Get the default callbacks for Ultralytics training, validation, prediction, and export processes.\n\nReturns:\n (dict): Dictionary of default callbacks for various training events. Each key represents an event during the\n training process, and the corresponding value is a list of callback functions executed when that event\n occurs.\n\nExamples:\n >>> callbacks = get_default_callbacks()\n >>> print(list(callbacks.keys())) # show all available callback events\n ['on_pretrain_routine_start', 'on_pretrain_routine_end', ...]",
"parameters": [],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_get_default_callbacks_eea37bf9"
},
{
"content": "def add_integration_callbacks(instance):\n \"\"\"\n Add integration callbacks to the instance's callbacks dictionary.\n\n This function loads and adds various integration callbacks to the provided instance. The specific callbacks added\n depend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive\n additional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard,\n and Weights & Biases.\n\n Args:\n instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added.\n The type of instance determines which callbacks are loaded.\n\n Examples:\n >>> from ultralytics.engine.trainer import BaseTrainer\n >>> trainer = BaseTrainer()\n >>> add_integration_callbacks(trainer)\n \"\"\"\n # Load HUB callbacks\n from .hub import callbacks as hub_cb\n\n callbacks_list = [hub_cb]\n\n # Load training callbacks\n if \"Trainer\" in instance.__class__.__name__:\n from .clearml import callbacks as clear_cb\n from .comet import callbacks as comet_cb\n from .dvc import callbacks as dvc_cb\n from .mlflow import callbacks as mlflow_cb\n from .neptune import callbacks as neptune_cb\n from .raytune import callbacks as tune_cb\n from .tensorboard import callbacks as tb_cb\n from .wb import callbacks as wb_cb\n\n callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])\n\n # Add the callbacks to the callbacks dictionary\n for callbacks in callbacks_list:\n for k, v in callbacks.items():\n if v not in instance.callbacks[k]:\n instance.callbacks[k].append(v)",
"chunk_type": "function",
"name": "add_integration_callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\base.py",
"start_line": 194,
"end_line": 234,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Add integration callbacks to the instance's callbacks dictionary.\n\nThis function loads and adds various integration callbacks to the provided instance. The specific callbacks added\ndepend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive\nadditional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard,\nand Weights & Biases.\n\nArgs:\n instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added.\n The type of instance determines which callbacks are loaded.\n\nExamples:\n >>> from ultralytics.engine.trainer import BaseTrainer\n >>> trainer = BaseTrainer()\n >>> add_integration_callbacks(trainer)",
"parameters": [
"instance"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"collections.defaultdict",
"copy.deepcopy",
"hub.callbacks",
"clearml.callbacks",
"comet.callbacks",
"dvc.callbacks",
"mlflow.callbacks",
"neptune.callbacks",
"raytune.callbacks",
"tensorboard.callbacks",
"wb.callbacks"
],
"chunk_id": "function_add_integration_callbacks_5049c2f7"
},
{
"content": "from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING",
"chunk_type": "import",
"name": "LOGGER, SETTINGS, TESTS_RUNNING",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, SETTINGS, TESTS_RUNNING_f460cad7"
},
{
"content": "def _log_debug_samples(files, title: str = \"Debug Samples\") -> None:\n \"\"\"\n Log files (images) as debug samples in the ClearML task.\n\n Args:\n files (List[Path]): A list of file paths in PosixPath format.\n title (str): A title that groups together images with the same values.\n \"\"\"\n import re\n\n if task := Task.current_task():\n for f in files:\n if f.exists():\n it = re.search(r\"_batch(\\d+)\", f.name)\n iteration = int(it.groups()[0]) if it else 0\n task.get_logger().report_image(\n title=title, series=f.name.replace(it.group(), \"\"), local_path=str(f), iteration=iteration\n )",
"chunk_type": "function",
"name": "_log_debug_samples",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 17,
"end_line": 34,
"start_col": 0,
"end_col": 17,
"parent_name": null,
"docstring": "Log files (images) as debug samples in the ClearML task.\n\nArgs:\n files (List[Path]): A list of file paths in PosixPath format.\n title (str): A title that groups together images with the same values.",
"parameters": [
"files",
"title: str"
],
"return_type": "None",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"clearml",
"clearml.Task",
"re",
"matplotlib.image",
"matplotlib.pyplot",
"clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO",
"clearml.binding.matplotlib_bind.PatchedMatplotlib",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_debug_samples_f955ae87"
},
{
"content": "def _log_plot(title: str, plot_path: str) -> None:\n \"\"\"\n Log an image as a plot in the plot section of ClearML.\n\n Args:\n title (str): The title of the plot.\n plot_path (str): The path to the saved image file.\n \"\"\"\n import matplotlib.image as mpimg\n import matplotlib.pyplot as plt\n\n img = mpimg.imread(plot_path)\n fig = plt.figure()\n ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=\"auto\", xticks=[], yticks=[]) # no ticks\n ax.imshow(img)\n\n Task.current_task().get_logger().report_matplotlib_figure(\n title=title, series=\"\", figure=fig, report_interactive=False\n )",
"chunk_type": "function",
"name": "_log_plot",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 37,
"end_line": 55,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Log an image as a plot in the plot section of ClearML.\n\nArgs:\n title (str): The title of the plot.\n plot_path (str): The path to the saved image file.",
"parameters": [
"title: str",
"plot_path: str"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"clearml",
"clearml.Task",
"re",
"matplotlib.image",
"matplotlib.pyplot",
"clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO",
"clearml.binding.matplotlib_bind.PatchedMatplotlib",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_plot_0104fe1a"
},
{
"content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Initialize and connect ClearML task at the start of pretraining routine.\"\"\"\n try:\n if task := Task.current_task():\n # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!\n # We are logging these plots and model files manually in the integration\n from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO\n from clearml.binding.matplotlib_bind import PatchedMatplotlib\n\n PatchPyTorchModelIO.update_current_task(None)\n PatchedMatplotlib.update_current_task(None)\n else:\n task = Task.init(\n project_name=trainer.args.project or \"Ultralytics\",\n task_name=trainer.args.name,\n tags=[\"Ultralytics\"],\n output_uri=True,\n reuse_last_task_id=False,\n auto_connect_frameworks={\"pytorch\": False, \"matplotlib\": False},\n )\n LOGGER.warning(\n \"ClearML Initialized a new task. If you want to run remotely, \"\n \"please add clearml-init and connect your arguments before initializing YOLO.\"\n )\n task.connect(vars(trainer.args), name=\"General\")\n except Exception as e:\n LOGGER.warning(f\"ClearML installed but not initialized correctly, not logging this run. {e}\")",
"chunk_type": "function",
"name": "on_pretrain_routine_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 58,
"end_line": 84,
"start_col": 0,
"end_col": 101,
"parent_name": null,
"docstring": "Initialize and connect ClearML task at the start of pretraining routine.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"clearml",
"clearml.Task",
"re",
"matplotlib.image",
"matplotlib.pyplot",
"clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO",
"clearml.binding.matplotlib_bind.PatchedMatplotlib",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_pretrain_routine_start_5ce93a7c"
},
{
"content": "def on_train_epoch_end(trainer) -> None:\n \"\"\"Log debug samples for the first epoch and report current training progress.\"\"\"\n if task := Task.current_task():\n # Log debug samples for first epoch only\n if trainer.epoch == 1:\n _log_debug_samples(sorted(trainer.save_dir.glob(\"train_batch*.jpg\")), \"Mosaic\")\n # Report the current training progress\n for k, v in trainer.label_loss_items(trainer.tloss, prefix=\"train\").items():\n task.get_logger().report_scalar(\"train\", k, v, iteration=trainer.epoch)\n for k, v in trainer.lr.items():\n task.get_logger().report_scalar(\"lr\", k, v, iteration=trainer.epoch)",
"chunk_type": "function",
"name": "on_train_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 87,
"end_line": 97,
"start_col": 0,
"end_col": 80,
"parent_name": null,
"docstring": "Log debug samples for the first epoch and report current training progress.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"clearml",
"clearml.Task",
"re",
"matplotlib.image",
"matplotlib.pyplot",
"clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO",
"clearml.binding.matplotlib_bind.PatchedMatplotlib",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_epoch_end_152ca3b7"
},
{
"content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"Report model information and metrics to logger at the end of an epoch.\"\"\"\n if task := Task.current_task():\n # Report epoch time and validation metrics\n task.get_logger().report_scalar(\n title=\"Epoch Time\", series=\"Epoch Time\", value=trainer.epoch_time, iteration=trainer.epoch\n )\n for k, v in trainer.metrics.items():\n task.get_logger().report_scalar(\"val\", k, v, iteration=trainer.epoch)\n if trainer.epoch == 0:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n for k, v in model_info_for_loggers(trainer).items():\n task.get_logger().report_single_value(k, v)",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 100,
"end_line": 113,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": "Report model information and metrics to logger at the end of an epoch.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"clearml",
"clearml.Task",
"re",
"matplotlib.image",
"matplotlib.pyplot",
"clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO",
"clearml.binding.matplotlib_bind.PatchedMatplotlib",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_fit_epoch_end_3d64d955"
},
{
"content": "def on_val_end(validator) -> None:\n \"\"\"Log validation results including labels and predictions.\"\"\"\n if Task.current_task():\n # Log validation labels and predictions\n _log_debug_samples(sorted(validator.save_dir.glob(\"val*.jpg\")), \"Validation\")",
"chunk_type": "function",
"name": "on_val_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 116,
"end_line": 120,
"start_col": 0,
"end_col": 85,
"parent_name": null,
"docstring": "Log validation results including labels and predictions.",
"parameters": [
"validator"
],
"return_type": "None",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"clearml",
"clearml.Task",
"re",
"matplotlib.image",
"matplotlib.pyplot",
"clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO",
"clearml.binding.matplotlib_bind.PatchedMatplotlib",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_val_end_9a3d3651"
},
{
"content": "def on_train_end(trainer) -> None:\n \"\"\"Log final model and training results on training completion.\"\"\"\n if task := Task.current_task():\n # Log final results, confusion matrix and PR plots\n files = [\n \"results.png\",\n \"confusion_matrix.png\",\n \"confusion_matrix_normalized.png\",\n *(f\"{x}_curve.png\" for x in (\"F1\", \"PR\", \"P\", \"R\")),\n ]\n files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter existing files\n for f in files:\n _log_plot(title=f.stem, plot_path=f)\n # Report final metrics\n for k, v in trainer.validator.metrics.results_dict.items():\n task.get_logger().report_single_value(k, v)\n # Log the final model\n task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)",
"chunk_type": "function",
"name": "on_train_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 123,
"end_line": 140,
"start_col": 0,
"end_col": 116,
"parent_name": null,
"docstring": "Log final model and training results on training completion.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 6,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"clearml",
"clearml.Task",
"re",
"matplotlib.image",
"matplotlib.pyplot",
"clearml.binding.frameworks.pytorch_bind.PatchPyTorchModelIO",
"clearml.binding.matplotlib_bind.PatchedMatplotlib",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_end_9539f420"
},
{
"content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_val_end\": on_val_end,\n \"on_train_end\": on_train_end,\n }\n if clearml\n else {}\n)",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\clearml.py",
"start_line": 143,
"end_line": 153,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_03e51358"
},
{
"content": "from collections.abc import Callable",
"chunk_type": "import",
"name": "Callable",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 36,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Callable_c4d6c13a"
},
{
"content": "from types import SimpleNamespace",
"chunk_type": "import",
"name": "SimpleNamespace",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 33,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SimpleNamespace_dda040c9"
},
{
"content": "from typing import Any, List, Optional",
"chunk_type": "import",
"name": "Any, List, Optional",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Any, List, Optional_2b5fd40b"
},
{
"content": "import cv2",
"chunk_type": "import",
"name": "cv2",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 10,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_cv2_93c76c4c"
},
{
"content": "import numpy as np",
"chunk_type": "import",
"name": "numpy",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 8,
"end_line": 8,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_numpy_0d19a4cc"
},
{
"content": "from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops",
"chunk_type": "import",
"name": "LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 10,
"end_line": 10,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops_07aeb2f7"
},
{
"content": "from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics",
"chunk_type": "import",
"name": "ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 11,
"end_line": 11,
"start_col": 0,
"end_col": 106,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics_11aefd35"
},
{
"content": "def _get_comet_mode() -> str:\n \"\"\"Return the Comet mode from environment variables, defaulting to 'online'.\"\"\"\n comet_mode = os.getenv(\"COMET_MODE\")\n if comet_mode is not None:\n LOGGER.warning(\n \"The COMET_MODE environment variable is deprecated. \"\n \"Please use COMET_START_ONLINE to set the Comet experiment mode. \"\n \"To start an offline Comet experiment, use 'export COMET_START_ONLINE=0'. \"\n \"If COMET_START_ONLINE is not set or is set to '1', an online Comet experiment will be created.\"\n )\n return comet_mode\n\n return \"online\"",
"chunk_type": "function",
"name": "_get_comet_mode",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 39,
"end_line": 51,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Return the Comet mode from environment variables, defaulting to 'online'.",
"parameters": [],
"return_type": "str",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__get_comet_mode_98d487f2"
},
{
"content": "def _get_comet_model_name() -> str:\n \"\"\"Return the Comet model name from environment variable or default to 'Ultralytics'.\"\"\"\n return os.getenv(\"COMET_MODEL_NAME\", \"Ultralytics\")",
"chunk_type": "function",
"name": "_get_comet_model_name",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 54,
"end_line": 56,
"start_col": 0,
"end_col": 55,
"parent_name": null,
"docstring": "Return the Comet model name from environment variable or default to 'Ultralytics'.",
"parameters": [],
"return_type": "str",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__get_comet_model_name_1b12a748"
},
{
"content": "def _get_eval_batch_logging_interval() -> int:\n \"\"\"Get the evaluation batch logging interval from environment variable or use default value 1.\"\"\"\n return int(os.getenv(\"COMET_EVAL_BATCH_LOGGING_INTERVAL\", 1))",
"chunk_type": "function",
"name": "_get_eval_batch_logging_interval",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 59,
"end_line": 61,
"start_col": 0,
"end_col": 65,
"parent_name": null,
"docstring": "Get the evaluation batch logging interval from environment variable or use default value 1.",
"parameters": [],
"return_type": "int",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__get_eval_batch_logging_interval_67f6c443"
},
{
"content": "def _get_max_image_predictions_to_log() -> int:\n \"\"\"Get the maximum number of image predictions to log from environment variables.\"\"\"\n return int(os.getenv(\"COMET_MAX_IMAGE_PREDICTIONS\", 100))",
"chunk_type": "function",
"name": "_get_max_image_predictions_to_log",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 64,
"end_line": 66,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": "Get the maximum number of image predictions to log from environment variables.",
"parameters": [],
"return_type": "int",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__get_max_image_predictions_to_log_ded60425"
},
{
"content": "def _scale_confidence_score(score: float) -> float:\n \"\"\"Scale the confidence score by a factor specified in environment variable.\"\"\"\n scale = float(os.getenv(\"COMET_MAX_CONFIDENCE_SCORE\", 100.0))\n return score * scale",
"chunk_type": "function",
"name": "_scale_confidence_score",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 69,
"end_line": 72,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": "Scale the confidence score by a factor specified in environment variable.",
"parameters": [
"score: float"
],
"return_type": "float",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__scale_confidence_score_bee04db0"
},
{
"content": "def _should_log_confusion_matrix() -> bool:\n \"\"\"Determine if the confusion matrix should be logged based on environment variable settings.\"\"\"\n return os.getenv(\"COMET_EVAL_LOG_CONFUSION_MATRIX\", \"false\").lower() == \"true\"",
"chunk_type": "function",
"name": "_should_log_confusion_matrix",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 75,
"end_line": 77,
"start_col": 0,
"end_col": 82,
"parent_name": null,
"docstring": "Determine if the confusion matrix should be logged based on environment variable settings.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__should_log_confusion_matrix_88937404"
},
{
"content": "def _should_log_image_predictions() -> bool:\n \"\"\"Determine whether to log image predictions based on environment variable.\"\"\"\n return os.getenv(\"COMET_EVAL_LOG_IMAGE_PREDICTIONS\", \"true\").lower() == \"true\"",
"chunk_type": "function",
"name": "_should_log_image_predictions",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 80,
"end_line": 82,
"start_col": 0,
"end_col": 82,
"parent_name": null,
"docstring": "Determine whether to log image predictions based on environment variable.",
"parameters": [],
"return_type": "bool",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__should_log_image_predictions_2ec21984"
},
{
"content": "def _resume_or_create_experiment(args: SimpleNamespace) -> None:\n \"\"\"\n Resume CometML experiment or create a new experiment based on args.\n\n Ensures that the experiment object is only created in a single process during distributed training.\n\n Args:\n args (SimpleNamespace): Training arguments containing project configuration and other parameters.\n \"\"\"\n if RANK not in {-1, 0}:\n return\n\n # Set environment variable (if not set by the user) to configure the Comet experiment's online mode under the hood.\n # IF COMET_START_ONLINE is set by the user it will override COMET_MODE value.\n if os.getenv(\"COMET_START_ONLINE\") is None:\n comet_mode = _get_comet_mode()\n os.environ[\"COMET_START_ONLINE\"] = \"1\" if comet_mode != \"offline\" else \"0\"\n\n try:\n _project_name = os.getenv(\"COMET_PROJECT_NAME\", args.project)\n experiment = comet_ml.start(project_name=_project_name)\n experiment.log_parameters(vars(args))\n experiment.log_others(\n {\n \"eval_batch_logging_interval\": _get_eval_batch_logging_interval(),\n \"log_confusion_matrix_on_eval\": _should_log_confusion_matrix(),\n \"log_image_predictions\": _should_log_image_predictions(),\n \"max_image_predictions\": _get_max_image_predictions_to_log(),\n }\n )\n experiment.log_other(\"Created from\", \"ultralytics\")\n\n except Exception as e:\n LOGGER.warning(f\"Comet installed but not initialized correctly, not logging this run. {e}\")",
"chunk_type": "function",
"name": "_resume_or_create_experiment",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 85,
"end_line": 118,
"start_col": 0,
"end_col": 99,
"parent_name": null,
"docstring": "Resume CometML experiment or create a new experiment based on args.\n\nEnsures that the experiment object is only created in a single process during distributed training.\n\nArgs:\n args (SimpleNamespace): Training arguments containing project configuration and other parameters.",
"parameters": [
"args: SimpleNamespace"
],
"return_type": "None",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__resume_or_create_experiment_bf13f87e"
},
{
"content": "def _fetch_trainer_metadata(trainer) -> dict:\n \"\"\"\n Return metadata for YOLO training including epoch and asset saving status.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config.\n\n Returns:\n (dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag.\n \"\"\"\n curr_epoch = trainer.epoch + 1\n\n train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size\n curr_step = curr_epoch * train_num_steps_per_epoch\n final_epoch = curr_epoch == trainer.epochs\n\n save = trainer.args.save\n save_period = trainer.args.save_period\n save_interval = curr_epoch % save_period == 0\n save_assets = save and save_period > 0 and save_interval and not final_epoch\n\n return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)",
"chunk_type": "function",
"name": "_fetch_trainer_metadata",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 121,
"end_line": 142,
"start_col": 0,
"end_col": 109,
"parent_name": null,
"docstring": "Return metadata for YOLO training including epoch and asset saving status.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config.\n\nReturns:\n (dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag.",
"parameters": [
"trainer"
],
"return_type": "dict",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__fetch_trainer_metadata_010b2d25"
},
{
"content": "def _scale_bounding_box_to_original_image_shape(\n box, resized_image_shape, original_image_shape, ratio_pad\n) -> List[float]:\n \"\"\"\n Scale bounding box from resized image coordinates to original image coordinates.\n\n YOLO resizes images during training and the label values are normalized based on this resized shape.\n This function rescales the bounding box labels to the original image shape.\n\n Args:\n box (torch.Tensor): Bounding box in normalized xywh format.\n resized_image_shape (tuple): Shape of the resized image (height, width).\n original_image_shape (tuple): Shape of the original image (height, width).\n ratio_pad (tuple): Ratio and padding information for scaling.\n\n Returns:\n (List[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment.\n \"\"\"\n resized_image_height, resized_image_width = resized_image_shape\n\n # Convert normalized xywh format predictions to xyxy in resized scale format\n box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)\n # Scale box predictions from resized image scale back to original image scale\n box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)\n # Convert bounding box format from xyxy to xywh for Comet logging\n box = ops.xyxy2xywh(box)\n # Adjust xy center to correspond top-left corner\n box[:2] -= box[2:] / 2\n box = box.tolist()\n\n return box",
"chunk_type": "function",
"name": "_scale_bounding_box_to_original_image_shape",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 145,
"end_line": 175,
"start_col": 0,
"end_col": 14,
"parent_name": null,
"docstring": "Scale bounding box from resized image coordinates to original image coordinates.\n\nYOLO resizes images during training and the label values are normalized based on this resized shape.\nThis function rescales the bounding box labels to the original image shape.\n\nArgs:\n box (torch.Tensor): Bounding box in normalized xywh format.\n resized_image_shape (tuple): Shape of the resized image (height, width).\n original_image_shape (tuple): Shape of the original image (height, width).\n ratio_pad (tuple): Ratio and padding information for scaling.\n\nReturns:\n (List[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment.",
"parameters": [
"box",
"resized_image_shape",
"original_image_shape",
"ratio_pad"
],
"return_type": "List[float]",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__scale_bounding_box_to_original_image_shape_c86c2d8f"
},
{
"content": "def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> Optional[dict]:\n \"\"\"\n Format ground truth annotations for object detection.\n\n This function processes ground truth annotations from a batch of images for object detection tasks. It extracts\n bounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for\n visualization or evaluation.\n\n Args:\n img_idx (int): Index of the image in the batch to process.\n image_path (str | Path): Path to the image file.\n batch (dict): Batch dictionary containing detection data with keys:\n - 'batch_idx': Tensor of batch indices\n - 'bboxes': Tensor of bounding boxes in normalized xywh format\n - 'cls': Tensor of class labels\n - 'ori_shape': Original image shapes\n - 'resized_shape': Resized image shapes\n - 'ratio_pad': Ratio and padding information\n class_name_map (dict, optional): Mapping from class indices to class names.\n\n Returns:\n (dict | None): Formatted ground truth annotations with the following structure:\n - 'boxes': List of box coordinates [x, y, width, height]\n - 'label': Label string with format \"gt_{class_name}\"\n - 'score': Confidence score (always 1.0, scaled by _scale_confidence_score)\n Returns None if no bounding boxes are found for the image.\n \"\"\"\n indices = batch[\"batch_idx\"] == img_idx\n bboxes = batch[\"bboxes\"][indices]\n if len(bboxes) == 0:\n LOGGER.debug(f\"Comet Image: {image_path} has no bounding boxes labels\")\n return None\n\n cls_labels = batch[\"cls\"][indices].squeeze(1).tolist()\n if class_name_map:\n cls_labels = [str(class_name_map[label]) for label in cls_labels]\n\n original_image_shape = batch[\"ori_shape\"][img_idx]\n resized_image_shape = batch[\"resized_shape\"][img_idx]\n ratio_pad = batch[\"ratio_pad\"][img_idx]\n\n data = []\n for box, label in zip(bboxes, cls_labels):\n box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)\n data.append(\n {\n \"boxes\": [box],\n \"label\": f\"gt_{label}\",\n \"score\": _scale_confidence_score(1.0),\n }\n )\n\n return {\"name\": \"ground_truth\", \"data\": data}",
"chunk_type": "function",
"name": "_format_ground_truth_annotations_for_detection",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 178,
"end_line": 230,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": "Format ground truth annotations for object detection.\n\nThis function processes ground truth annotations from a batch of images for object detection tasks. It extracts\nbounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for\nvisualization or evaluation.\n\nArgs:\n img_idx (int): Index of the image in the batch to process.\n image_path (str | Path): Path to the image file.\n batch (dict): Batch dictionary containing detection data with keys:\n - 'batch_idx': Tensor of batch indices\n - 'bboxes': Tensor of bounding boxes in normalized xywh format\n - 'cls': Tensor of class labels\n - 'ori_shape': Original image shapes\n - 'resized_shape': Resized image shapes\n - 'ratio_pad': Ratio and padding information\n class_name_map (dict, optional): Mapping from class indices to class names.\n\nReturns:\n (dict | None): Formatted ground truth annotations with the following structure:\n - 'boxes': List of box coordinates [x, y, width, height]\n - 'label': Label string with format \"gt_{class_name}\"\n - 'score': Confidence score (always 1.0, scaled by _scale_confidence_score)\n Returns None if no bounding boxes are found for the image.",
"parameters": [
"img_idx",
"image_path",
"batch",
"class_name_map"
],
"return_type": "Optional[dict]",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__format_ground_truth_annotations_for_detection_ea29e256"
},
{
"content": "def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> Optional[dict]:\n \"\"\"\n Format YOLO predictions for object detection visualization.\n\n Args:\n image_path (Path): Path to the image file.\n metadata (dict): Prediction metadata containing bounding boxes and class information.\n class_label_map (dict, optional): Mapping from class indices to class names.\n class_map (dict, optional): Additional class mapping for label conversion.\n\n Returns:\n (dict | None): Formatted prediction annotations or None if no predictions exist.\n \"\"\"\n stem = image_path.stem\n image_id = int(stem) if stem.isnumeric() else stem\n\n predictions = metadata.get(image_id)\n if not predictions:\n LOGGER.debug(f\"Comet Image: {image_path} has no bounding boxes predictions\")\n return None\n\n # apply the mapping that was used to map the predicted classes when the JSON was created\n if class_label_map and class_map:\n class_label_map = {class_map[k]: v for k, v in class_label_map.items()}\n try:\n # import pycotools utilities to decompress annotations for various tasks, e.g. segmentation\n from faster_coco_eval.core.mask import decode # noqa\n except ImportError:\n decode = None\n\n data = []\n for prediction in predictions:\n boxes = prediction[\"bbox\"]\n score = _scale_confidence_score(prediction[\"score\"])\n cls_label = prediction[\"category_id\"]\n if class_label_map:\n cls_label = str(class_label_map[cls_label])\n\n annotation_data = {\"boxes\": [boxes], \"label\": cls_label, \"score\": score}\n\n if decode is not None:\n # do segmentation processing only if we are able to decode it\n segments = prediction.get(\"segmentation\", None)\n if segments is not None:\n segments = _extract_segmentation_annotation(segments, decode)\n if segments is not None:\n annotation_data[\"points\"] = segments\n\n data.append(annotation_data)\n\n return {\"name\": \"prediction\", \"data\": data}",
"chunk_type": "function",
"name": "_format_prediction_annotations",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 233,
"end_line": 283,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Format YOLO predictions for object detection visualization.\n\nArgs:\n image_path (Path): Path to the image file.\n metadata (dict): Prediction metadata containing bounding boxes and class information.\n class_label_map (dict, optional): Mapping from class indices to class names.\n class_map (dict, optional): Additional class mapping for label conversion.\n\nReturns:\n (dict | None): Formatted prediction annotations or None if no predictions exist.",
"parameters": [
"image_path",
"metadata",
"class_label_map",
"class_map"
],
"return_type": "Optional[dict]",
"decorators": [],
"complexity_score": 10,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__format_prediction_annotations_f01ea5f2"
},
{
"content": "def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> Optional[List[List[Any]]]:\n \"\"\"\n Extract segmentation annotation from compressed segmentations as list of polygons.\n\n Args:\n segmentation_raw (str): Raw segmentation data in compressed format.\n decode (Callable): Function to decode the compressed segmentation data.\n\n Returns:\n (List[List[Any]] | None): List of polygon points or None if extraction fails.\n \"\"\"\n try:\n mask = decode(segmentation_raw)\n contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)\n annotations = [np.array(polygon).squeeze() for polygon in contours if len(polygon) >= 3]\n return [annotation.ravel().tolist() for annotation in annotations]\n except Exception as e:\n LOGGER.warning(f\"Comet Failed to extract segmentation annotation: {e}\")\n return None",
"chunk_type": "function",
"name": "_extract_segmentation_annotation",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 286,
"end_line": 304,
"start_col": 0,
"end_col": 15,
"parent_name": null,
"docstring": "Extract segmentation annotation from compressed segmentations as list of polygons.\n\nArgs:\n segmentation_raw (str): Raw segmentation data in compressed format.\n decode (Callable): Function to decode the compressed segmentation data.\n\nReturns:\n (List[List[Any]] | None): List of polygon points or None if extraction fails.",
"parameters": [
"segmentation_raw: str",
"decode: Callable"
],
"return_type": "Optional[List[List[Any]]]",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__extract_segmentation_annotation_62de6e1a"
},
{
"content": "def _fetch_annotations(\n img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map\n) -> Optional[List]:\n \"\"\"\n Join the ground truth and prediction annotations if they exist.\n\n Args:\n img_idx (int): Index of the image in the batch.\n image_path (Path): Path to the image file.\n batch (dict): Batch data containing ground truth annotations.\n prediction_metadata_map (dict): Map of prediction metadata by image ID.\n class_label_map (dict): Mapping from class indices to class names.\n class_map (dict): Additional class mapping for label conversion.\n\n Returns:\n (List | None): List of annotation dictionaries or None if no annotations exist.\n \"\"\"\n ground_truth_annotations = _format_ground_truth_annotations_for_detection(\n img_idx, image_path, batch, class_label_map\n )\n prediction_annotations = _format_prediction_annotations(\n image_path, prediction_metadata_map, class_label_map, class_map\n )\n\n annotations = [\n annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None\n ]\n return [annotations] if annotations else None",
"chunk_type": "function",
"name": "_fetch_annotations",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 307,
"end_line": 334,
"start_col": 0,
"end_col": 49,
"parent_name": null,
"docstring": "Join the ground truth and prediction annotations if they exist.\n\nArgs:\n img_idx (int): Index of the image in the batch.\n image_path (Path): Path to the image file.\n batch (dict): Batch data containing ground truth annotations.\n prediction_metadata_map (dict): Map of prediction metadata by image ID.\n class_label_map (dict): Mapping from class indices to class names.\n class_map (dict): Additional class mapping for label conversion.\n\nReturns:\n (List | None): List of annotation dictionaries or None if no annotations exist.",
"parameters": [
"img_idx",
"image_path",
"batch",
"prediction_metadata_map",
"class_label_map",
"class_map"
],
"return_type": "Optional[List]",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__fetch_annotations_979a6a1f"
},
{
"content": "def _create_prediction_metadata_map(model_predictions) -> dict:\n \"\"\"Create metadata map for model predictions by grouping them based on image ID.\"\"\"\n pred_metadata_map = {}\n for prediction in model_predictions:\n pred_metadata_map.setdefault(prediction[\"image_id\"], [])\n pred_metadata_map[prediction[\"image_id\"]].append(prediction)\n\n return pred_metadata_map",
"chunk_type": "function",
"name": "_create_prediction_metadata_map",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 337,
"end_line": 344,
"start_col": 0,
"end_col": 28,
"parent_name": null,
"docstring": "Create metadata map for model predictions by grouping them based on image ID.",
"parameters": [
"model_predictions"
],
"return_type": "dict",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__create_prediction_metadata_map_287d87d4"
},
{
"content": "def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None:\n \"\"\"Log the confusion matrix to Comet experiment.\"\"\"\n conf_mat = trainer.validator.confusion_matrix.matrix\n names = list(trainer.data[\"names\"].values()) + [\"background\"]\n experiment.log_confusion_matrix(\n matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step\n )",
"chunk_type": "function",
"name": "_log_confusion_matrix",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 347,
"end_line": 353,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Log the confusion matrix to Comet experiment.",
"parameters": [
"experiment",
"trainer",
"curr_step",
"curr_epoch"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_confusion_matrix_af670c19"
},
{
"content": "def _log_images(experiment, image_paths, curr_step, annotations=None) -> None:\n \"\"\"\n Log images to the experiment with optional annotations.\n\n This function logs images to a Comet ML experiment, optionally including annotation data for visualization\n such as bounding boxes or segmentation masks.\n\n Args:\n experiment (comet_ml.Experiment): The Comet ML experiment to log images to.\n image_paths (List[Path]): List of paths to images that will be logged.\n curr_step (int): Current training step/iteration for tracking in the experiment timeline.\n annotations (List[List[dict]], optional): Nested list of annotation dictionaries for each image. Each\n annotation contains visualization data like bounding boxes, labels, and confidence scores.\n \"\"\"\n if annotations:\n for image_path, annotation in zip(image_paths, annotations):\n experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)\n\n else:\n for image_path in image_paths:\n experiment.log_image(image_path, name=image_path.stem, step=curr_step)",
"chunk_type": "function",
"name": "_log_images",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 356,
"end_line": 376,
"start_col": 0,
"end_col": 82,
"parent_name": null,
"docstring": "Log images to the experiment with optional annotations.\n\nThis function logs images to a Comet ML experiment, optionally including annotation data for visualization\nsuch as bounding boxes or segmentation masks.\n\nArgs:\n experiment (comet_ml.Experiment): The Comet ML experiment to log images to.\n image_paths (List[Path]): List of paths to images that will be logged.\n curr_step (int): Current training step/iteration for tracking in the experiment timeline.\n annotations (List[List[dict]], optional): Nested list of annotation dictionaries for each image. Each\n annotation contains visualization data like bounding boxes, labels, and confidence scores.",
"parameters": [
"experiment",
"image_paths",
"curr_step",
"annotations"
],
"return_type": "None",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_images_630cee76"
},
{
"content": "def _log_image_predictions(experiment, validator, curr_step) -> None:\n \"\"\"\n Log predicted boxes for a single image during training.\n\n This function logs image predictions to a Comet ML experiment during model validation. It processes\n validation data and formats both ground truth and prediction annotations for visualization in the Comet\n dashboard. The function respects configured limits on the number of images to log.\n\n Args:\n experiment (comet_ml.Experiment): The Comet ML experiment to log to.\n validator (BaseValidator): The validator instance containing validation data and predictions.\n curr_step (int): The current training step for logging timeline.\n\n Notes:\n This function uses global state to track the number of logged predictions across calls.\n It only logs predictions for supported tasks defined in COMET_SUPPORTED_TASKS.\n The number of logged images is limited by the COMET_MAX_IMAGE_PREDICTIONS environment variable.\n \"\"\"\n global _comet_image_prediction_count\n\n task = validator.args.task\n if task not in COMET_SUPPORTED_TASKS:\n return\n\n jdict = validator.jdict\n if not jdict:\n return\n\n predictions_metadata_map = _create_prediction_metadata_map(jdict)\n dataloader = validator.dataloader\n class_label_map = validator.names\n class_map = getattr(validator, \"class_map\", None)\n\n batch_logging_interval = _get_eval_batch_logging_interval()\n max_image_predictions = _get_max_image_predictions_to_log()\n\n for batch_idx, batch in enumerate(dataloader):\n if (batch_idx + 1) % batch_logging_interval != 0:\n continue\n\n image_paths = batch[\"im_file\"]\n for img_idx, image_path in enumerate(image_paths):\n if _comet_image_prediction_count >= max_image_predictions:\n return\n\n image_path = Path(image_path)\n annotations = _fetch_annotations(\n img_idx,\n image_path,\n batch,\n predictions_metadata_map,\n class_label_map,\n class_map=class_map,\n )\n _log_images(\n experiment,\n [image_path],\n curr_step,\n annotations=annotations,\n )\n _comet_image_prediction_count += 1",
"chunk_type": "function",
"name": "_log_image_predictions",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 379,
"end_line": 439,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Log predicted boxes for a single image during training.\n\nThis function logs image predictions to a Comet ML experiment during model validation. It processes\nvalidation data and formats both ground truth and prediction annotations for visualization in the Comet\ndashboard. The function respects configured limits on the number of images to log.\n\nArgs:\n experiment (comet_ml.Experiment): The Comet ML experiment to log to.\n validator (BaseValidator): The validator instance containing validation data and predictions.\n curr_step (int): The current training step for logging timeline.\n\nNotes:\n This function uses global state to track the number of logged predictions across calls.\n It only logs predictions for supported tasks defined in COMET_SUPPORTED_TASKS.\n The number of logged images is limited by the COMET_MAX_IMAGE_PREDICTIONS environment variable.",
"parameters": [
"experiment",
"validator",
"curr_step"
],
"return_type": "None",
"decorators": [],
"complexity_score": 7,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_image_predictions_0eafc348"
},
{
"content": "def _log_plots(experiment, trainer) -> None:\n \"\"\"\n Log evaluation plots and label plots for the experiment.\n\n This function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles\n different types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots\n for each type.\n\n Args:\n experiment (comet_ml.Experiment): The Comet ML experiment to log plots to.\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save\n directory information.\n\n Examples:\n >>> from ultralytics.utils.callbacks.comet import _log_plots\n >>> _log_plots(experiment, trainer)\n \"\"\"\n plot_filenames = None\n if isinstance(trainer.validator.metrics, SegmentMetrics):\n plot_filenames = [\n trainer.save_dir / f\"{prefix}{plots}.png\"\n for plots in EVALUATION_PLOT_NAMES\n for prefix in SEGMENT_METRICS_PLOT_PREFIX\n ]\n elif isinstance(trainer.validator.metrics, PoseMetrics):\n plot_filenames = [\n trainer.save_dir / f\"{prefix}{plots}.png\"\n for plots in EVALUATION_PLOT_NAMES\n for prefix in POSE_METRICS_PLOT_PREFIX\n ]\n elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)):\n plot_filenames = [trainer.save_dir / f\"{plots}.png\" for plots in EVALUATION_PLOT_NAMES]\n\n if plot_filenames is not None:\n _log_images(experiment, plot_filenames, None)\n\n confusion_matrix_filenames = [trainer.save_dir / f\"{plots}.png\" for plots in CONFUSION_MATRIX_PLOT_NAMES]\n _log_images(experiment, confusion_matrix_filenames, None)\n\n if not isinstance(trainer.validator.metrics, ClassifyMetrics):\n label_plot_filenames = [trainer.save_dir / f\"{labels}.jpg\" for labels in LABEL_PLOT_NAMES]\n _log_images(experiment, label_plot_filenames, None)",
"chunk_type": "function",
"name": "_log_plots",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 442,
"end_line": 483,
"start_col": 0,
"end_col": 59,
"parent_name": null,
"docstring": "Log evaluation plots and label plots for the experiment.\n\nThis function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles\ndifferent types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots\nfor each type.\n\nArgs:\n experiment (comet_ml.Experiment): The Comet ML experiment to log plots to.\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save\n directory information.\n\nExamples:\n >>> from ultralytics.utils.callbacks.comet import _log_plots\n >>> _log_plots(experiment, trainer)",
"parameters": [
"experiment",
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 11,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_plots_05fa3549"
},
{
"content": "def _log_model(experiment, trainer) -> None:\n \"\"\"Log the best-trained model to Comet.ml.\"\"\"\n model_name = _get_comet_model_name()\n experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name=\"best.pt\", overwrite=True)",
"chunk_type": "function",
"name": "_log_model",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 486,
"end_line": 489,
"start_col": 0,
"end_col": 107,
"parent_name": null,
"docstring": "Log the best-trained model to Comet.ml.",
"parameters": [
"experiment",
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_model_f0ac61d6"
},
{
"content": "def _log_image_batches(experiment, trainer, curr_step: int) -> None:\n \"\"\"Log samples of image batches for train, validation, and test.\"\"\"\n _log_images(experiment, trainer.save_dir.glob(\"train_batch*.jpg\"), curr_step)\n _log_images(experiment, trainer.save_dir.glob(\"val_batch*.jpg\"), curr_step)",
"chunk_type": "function",
"name": "_log_image_batches",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 492,
"end_line": 495,
"start_col": 0,
"end_col": 79,
"parent_name": null,
"docstring": "Log samples of image batches for train, validation, and test.",
"parameters": [
"experiment",
"trainer",
"curr_step: int"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_image_batches_061ac9f1"
},
{
"content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Create or resume a CometML experiment at the start of a YOLO pre-training routine.\"\"\"\n _resume_or_create_experiment(trainer.args)",
"chunk_type": "function",
"name": "on_pretrain_routine_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 498,
"end_line": 500,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Create or resume a CometML experiment at the start of a YOLO pre-training routine.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_pretrain_routine_start_3223cb38"
},
{
"content": "def on_train_epoch_end(trainer) -> None:\n \"\"\"Log metrics and save batch images at the end of training epochs.\"\"\"\n experiment = comet_ml.get_running_experiment()\n if not experiment:\n return\n\n metadata = _fetch_trainer_metadata(trainer)\n curr_epoch = metadata[\"curr_epoch\"]\n curr_step = metadata[\"curr_step\"]\n\n experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix=\"train\"), step=curr_step, epoch=curr_epoch)",
"chunk_type": "function",
"name": "on_train_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 503,
"end_line": 513,
"start_col": 0,
"end_col": 117,
"parent_name": null,
"docstring": "Log metrics and save batch images at the end of training epochs.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_epoch_end_b5a2fb54"
},
{
"content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"\n Log model assets at the end of each epoch during training.\n\n This function is called at the end of each training epoch to log metrics, learning rates, and model information\n to a Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on\n configuration settings.\n\n The function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch,\n it also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled),\n and image predictions (if enabled).\n\n Args:\n trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration.\n\n Examples:\n >>> # Inside a training loop\n >>> on_fit_epoch_end(trainer) # Log metrics and assets to Comet ML\n \"\"\"\n experiment = comet_ml.get_running_experiment()\n if not experiment:\n return\n\n metadata = _fetch_trainer_metadata(trainer)\n curr_epoch = metadata[\"curr_epoch\"]\n curr_step = metadata[\"curr_step\"]\n save_assets = metadata[\"save_assets\"]\n\n experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)\n experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)\n if curr_epoch == 1:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)\n\n if not save_assets:\n return\n\n _log_model(experiment, trainer)\n if _should_log_confusion_matrix():\n _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)\n if _should_log_image_predictions():\n _log_image_predictions(experiment, trainer.validator, curr_step)",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 516,
"end_line": 558,
"start_col": 0,
"end_col": 72,
"parent_name": null,
"docstring": "Log model assets at the end of each epoch during training.\n\nThis function is called at the end of each training epoch to log metrics, learning rates, and model information\nto a Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on\nconfiguration settings.\n\nThe function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch,\nit also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled),\nand image predictions (if enabled).\n\nArgs:\n trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration.\n\nExamples:\n >>> # Inside a training loop\n >>> on_fit_epoch_end(trainer) # Log metrics and assets to Comet ML",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 6,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_fit_epoch_end_9ea290e4"
},
{
"content": "def on_train_end(trainer) -> None:\n \"\"\"Perform operations at the end of training.\"\"\"\n experiment = comet_ml.get_running_experiment()\n if not experiment:\n return\n\n metadata = _fetch_trainer_metadata(trainer)\n curr_epoch = metadata[\"curr_epoch\"]\n curr_step = metadata[\"curr_step\"]\n plots = trainer.args.plots\n\n _log_model(experiment, trainer)\n if plots:\n _log_plots(experiment, trainer)\n\n _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)\n _log_image_predictions(experiment, trainer.validator, curr_step)\n _log_image_batches(experiment, trainer, curr_step)\n experiment.end()\n\n global _comet_image_prediction_count\n _comet_image_prediction_count = 0",
"chunk_type": "function",
"name": "on_train_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 561,
"end_line": 582,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": "Perform operations at the end of training.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"collections.abc.Callable",
"types.SimpleNamespace",
"typing.Any",
"typing.List",
"typing.Optional",
"cv2",
"numpy",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.ops",
"ultralytics.utils.metrics.ClassifyMetrics",
"ultralytics.utils.metrics.DetMetrics",
"ultralytics.utils.metrics.OBBMetrics",
"ultralytics.utils.metrics.PoseMetrics",
"ultralytics.utils.metrics.SegmentMetrics",
"comet_ml",
"os",
"pathlib.Path",
"faster_coco_eval.core.mask.decode",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_end_49d0684a"
},
{
"content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_end\": on_train_end,\n }\n if comet_ml\n else {}\n)",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\comet.py",
"start_line": 585,
"end_line": 594,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_65d87ac9"
},
{
"content": "from pathlib import Path",
"chunk_type": "import",
"name": "Path",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 24,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_Path_35b05885"
},
{
"content": "from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks",
"chunk_type": "import",
"name": "LOGGER, SETTINGS, TESTS_RUNNING, checks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 69,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, SETTINGS, TESTS_RUNNING, checks_e208c880"
},
{
"content": "def _log_images(path: Path, prefix: str = \"\") -> None:\n \"\"\"\n Log images at specified path with an optional prefix using DVCLive.\n\n This function logs images found at the given path to DVCLive, organizing them by batch to enable slider\n functionality in the UI. It processes image filenames to extract batch information and restructures the path\n accordingly.\n\n Args:\n path (Path): Path to the image file to be logged.\n prefix (str, optional): Optional prefix to add to the image name when logging.\n\n Examples:\n >>> from pathlib import Path\n >>> _log_images(Path(\"runs/train/exp/val_batch0_pred.jpg\"), prefix=\"validation\")\n \"\"\"\n if live:\n name = path.name\n\n # Group images by batch to enable sliders in UI\n if m := re.search(r\"_batch(\\d+)\", name):\n ni = m[1]\n new_stem = re.sub(r\"_batch(\\d+)\", \"_batch\", path.stem)\n name = (Path(new_stem) / ni).with_suffix(path.suffix)\n\n live.log_image(os.path.join(prefix, name), path)",
"chunk_type": "function",
"name": "_log_images",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 29,
"end_line": 54,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": "Log images at specified path with an optional prefix using DVCLive.\n\nThis function logs images found at the given path to DVCLive, organizing them by batch to enable slider\nfunctionality in the UI. It processes image filenames to extract batch information and restructures the path\naccordingly.\n\nArgs:\n path (Path): Path to the image file to be logged.\n prefix (str, optional): Optional prefix to add to the image name when logging.\n\nExamples:\n >>> from pathlib import Path\n >>> _log_images(Path(\"runs/train/exp/val_batch0_pred.jpg\"), prefix=\"validation\")",
"parameters": [
"path: Path",
"prefix: str"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_images_76e8deb2"
},
{
"content": "def _log_plots(plots: dict, prefix: str = \"\") -> None:\n \"\"\"\n Log plot images for training progress if they have not been previously processed.\n\n Args:\n plots (dict): Dictionary containing plot information with timestamps.\n prefix (str, optional): Optional prefix to add to the logged image paths.\n \"\"\"\n for name, params in plots.items():\n timestamp = params[\"timestamp\"]\n if _processed_plots.get(name) != timestamp:\n _log_images(name, prefix)\n _processed_plots[name] = timestamp",
"chunk_type": "function",
"name": "_log_plots",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 57,
"end_line": 69,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Log plot images for training progress if they have not been previously processed.\n\nArgs:\n plots (dict): Dictionary containing plot information with timestamps.\n prefix (str, optional): Optional prefix to add to the logged image paths.",
"parameters": [
"plots: dict",
"prefix: str"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_plots_55416a93"
},
{
"content": "def _log_confusion_matrix(validator) -> None:\n \"\"\"\n Log confusion matrix for a validator using DVCLive.\n\n This function processes the confusion matrix from a validator object and logs it to DVCLive by converting\n the matrix into lists of target and prediction labels.\n\n Args:\n validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have\n attributes: confusion_matrix.matrix, confusion_matrix.task, and names.\n \"\"\"\n targets = []\n preds = []\n matrix = validator.confusion_matrix.matrix\n names = list(validator.names.values())\n if validator.confusion_matrix.task == \"detect\":\n names += [\"background\"]\n\n for ti, pred in enumerate(matrix.T.astype(int)):\n for pi, num in enumerate(pred):\n targets.extend([names[ti]] * num)\n preds.extend([names[pi]] * num)\n\n live.log_sklearn_plot(\"confusion_matrix\", targets, preds, name=\"cf.json\", normalized=True)",
"chunk_type": "function",
"name": "_log_confusion_matrix",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 72,
"end_line": 95,
"start_col": 0,
"end_col": 94,
"parent_name": null,
"docstring": "Log confusion matrix for a validator using DVCLive.\n\nThis function processes the confusion matrix from a validator object and logs it to DVCLive by converting\nthe matrix into lists of target and prediction labels.\n\nArgs:\n validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have\n attributes: confusion_matrix.matrix, confusion_matrix.task, and names.",
"parameters": [
"validator"
],
"return_type": "None",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_confusion_matrix_9a6afa08"
},
{
"content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Initialize DVCLive logger for training metadata during pre-training routine.\"\"\"\n try:\n global live\n live = dvclive.Live(save_dvc_exp=True, cache_images=True)\n LOGGER.info(\"DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).\")\n except Exception as e:\n LOGGER.warning(f\"DVCLive installed but not initialized correctly, not logging this run. {e}\")",
"chunk_type": "function",
"name": "on_pretrain_routine_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 98,
"end_line": 105,
"start_col": 0,
"end_col": 101,
"parent_name": null,
"docstring": "Initialize DVCLive logger for training metadata during pre-training routine.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_pretrain_routine_start_78643cf6"
},
{
"content": "def on_pretrain_routine_end(trainer) -> None:\n \"\"\"Log plots related to the training process at the end of the pretraining routine.\"\"\"\n _log_plots(trainer.plots, \"train\")",
"chunk_type": "function",
"name": "on_pretrain_routine_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 108,
"end_line": 110,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": "Log plots related to the training process at the end of the pretraining routine.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_pretrain_routine_end_3e9575ea"
},
{
"content": "def on_train_start(trainer) -> None:\n \"\"\"Log the training parameters if DVCLive logging is active.\"\"\"\n if live:\n live.log_params(trainer.args)",
"chunk_type": "function",
"name": "on_train_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 113,
"end_line": 116,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": "Log the training parameters if DVCLive logging is active.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_start_af29dda1"
},
{
"content": "def on_train_epoch_start(trainer) -> None:\n \"\"\"Set the global variable _training_epoch value to True at the start of training each epoch.\"\"\"\n global _training_epoch\n _training_epoch = True",
"chunk_type": "function",
"name": "on_train_epoch_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 119,
"end_line": 122,
"start_col": 0,
"end_col": 26,
"parent_name": null,
"docstring": "Set the global variable _training_epoch value to True at the start of training each epoch.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_epoch_start_02ab2515"
},
{
"content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"\n Log training metrics, model info, and advance to next step at the end of each fit epoch.\n\n This function is called at the end of each fit epoch during training. It logs various metrics including\n training loss items, validation metrics, and learning rates. On the first epoch, it also logs model\n information. Additionally, it logs training and validation plots and advances the DVCLive step counter.\n\n Args:\n trainer (BaseTrainer): The trainer object containing training state, metrics, and plots.\n\n Notes:\n This function only performs logging operations when DVCLive logging is active and during a training epoch.\n The global variable _training_epoch is used to track whether the current epoch is a training epoch.\n \"\"\"\n global _training_epoch\n if live and _training_epoch:\n all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix=\"train\"), **trainer.metrics, **trainer.lr}\n for metric, value in all_metrics.items():\n live.log_metric(metric, value)\n\n if trainer.epoch == 0:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n for metric, value in model_info_for_loggers(trainer).items():\n live.log_metric(metric, value, plot=False)\n\n _log_plots(trainer.plots, \"train\")\n _log_plots(trainer.validator.plots, \"val\")\n\n live.next_step()\n _training_epoch = False",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 125,
"end_line": 156,
"start_col": 0,
"end_col": 31,
"parent_name": null,
"docstring": "Log training metrics, model info, and advance to next step at the end of each fit epoch.\n\nThis function is called at the end of each fit epoch during training. It logs various metrics including\ntraining loss items, validation metrics, and learning rates. On the first epoch, it also logs model\ninformation. Additionally, it logs training and validation plots and advances the DVCLive step counter.\n\nArgs:\n trainer (BaseTrainer): The trainer object containing training state, metrics, and plots.\n\nNotes:\n This function only performs logging operations when DVCLive logging is active and during a training epoch.\n The global variable _training_epoch is used to track whether the current epoch is a training epoch.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_fit_epoch_end_88ab55a6"
},
{
"content": "def on_train_end(trainer) -> None:\n \"\"\"\n Log best metrics, plots, and confusion matrix at the end of training.\n\n This function is called at the conclusion of the training process to log final metrics, visualizations, and\n model artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots,\n validation plots, and confusion matrix for later analysis.\n\n Args:\n trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results.\n\n Examples:\n >>> # Inside a custom training loop\n >>> from ultralytics.utils.callbacks.dvc import on_train_end\n >>> on_train_end(trainer) # Log final metrics and artifacts\n \"\"\"\n if live:\n # At the end log the best metrics. It runs validator on the best model internally.\n all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix=\"train\"), **trainer.metrics, **trainer.lr}\n for metric, value in all_metrics.items():\n live.log_metric(metric, value, plot=False)\n\n _log_plots(trainer.plots, \"val\")\n _log_plots(trainer.validator.plots, \"val\")\n _log_confusion_matrix(trainer.validator)\n\n if trainer.best.exists():\n live.log_artifact(trainer.best, copy=True, type=\"model\")\n\n live.end()",
"chunk_type": "function",
"name": "on_train_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 159,
"end_line": 188,
"start_col": 0,
"end_col": 18,
"parent_name": null,
"docstring": "Log best metrics, plots, and confusion matrix at the end of training.\n\nThis function is called at the conclusion of the training process to log final metrics, visualizations, and\nmodel artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots,\nvalidation plots, and confusion matrix for later analysis.\n\nArgs:\n trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results.\n\nExamples:\n >>> # Inside a custom training loop\n >>> from ultralytics.utils.callbacks.dvc import on_train_end\n >>> on_train_end(trainer) # Log final metrics and artifacts",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 4,
"dependencies": [
"pathlib.Path",
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.checks",
"dvclive",
"os",
"re",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_end_10251516"
},
{
"content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_pretrain_routine_end\": on_pretrain_routine_end,\n \"on_train_start\": on_train_start,\n \"on_train_epoch_start\": on_train_epoch_start,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_end\": on_train_end,\n }\n if dvclive\n else {}\n)",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\dvc.py",
"start_line": 191,
"end_line": 202,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_063bbfb2"
},
{
"content": "import json",
"chunk_type": "import",
"name": "json",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 11,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_json_cbf8f6e7"
},
{
"content": "from time import time",
"chunk_type": "import",
"name": "time",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 21,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_time_661e5670"
},
{
"content": "from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events",
"chunk_type": "import",
"name": "HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 6,
"end_line": 6,
"start_col": 0,
"end_col": 76,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events_fee7c9ca"
},
{
"content": "from ultralytics.utils import LOGGER, RANK, SETTINGS",
"chunk_type": "import",
"name": "LOGGER, RANK, SETTINGS",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 7,
"end_line": 7,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, RANK, SETTINGS_f6dce9b9"
},
{
"content": "def on_pretrain_routine_start(trainer):\n \"\"\"Create a remote Ultralytics HUB session to log local model training.\"\"\"\n if RANK in {-1, 0} and SETTINGS[\"hub\"] is True and SETTINGS[\"api_key\"] and trainer.hub_session is None:\n trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)",
"chunk_type": "function",
"name": "on_pretrain_routine_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 10,
"end_line": 13,
"start_col": 0,
"end_col": 97,
"parent_name": null,
"docstring": "Create a remote Ultralytics HUB session to log local model training.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_pretrain_routine_start_d36e6ce3"
},
{
"content": "def on_pretrain_routine_end(trainer):\n \"\"\"Initialize timers for upload rate limiting before training begins.\"\"\"\n if session := getattr(trainer, \"hub_session\", None):\n # Start timer for upload rate limit\n session.timers = {\"metrics\": time(), \"ckpt\": time()} # start timer for session rate limiting",
"chunk_type": "function",
"name": "on_pretrain_routine_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 16,
"end_line": 20,
"start_col": 0,
"end_col": 60,
"parent_name": null,
"docstring": "Initialize timers for upload rate limiting before training begins.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_pretrain_routine_end_8df60b02"
},
{
"content": "def on_fit_epoch_end(trainer):\n \"\"\"Upload training progress metrics to Ultralytics HUB at the end of each epoch.\"\"\"\n if session := getattr(trainer, \"hub_session\", None):\n # Upload metrics after validation ends\n all_plots = {\n **trainer.label_loss_items(trainer.tloss, prefix=\"train\"),\n **trainer.metrics,\n }\n if trainer.epoch == 0:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n all_plots = {**all_plots, **model_info_for_loggers(trainer)}\n\n session.metrics_queue[trainer.epoch] = json.dumps(all_plots)\n\n # If any metrics failed to upload previously, add them to the queue to attempt uploading again\n if session.metrics_upload_failed_queue:\n session.metrics_queue.update(session.metrics_upload_failed_queue)\n\n if time() - session.timers[\"metrics\"] > session.rate_limits[\"metrics\"]:\n session.upload_metrics()\n session.timers[\"metrics\"] = time() # reset timer\n session.metrics_queue = {} # reset queue",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 23,
"end_line": 45,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": "Upload training progress metrics to Ultralytics HUB at the end of each epoch.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_fit_epoch_end_03530e00"
},
{
"content": "def on_model_save(trainer):\n \"\"\"Upload model checkpoints to Ultralytics HUB with rate limiting.\"\"\"\n if session := getattr(trainer, \"hub_session\", None):\n # Upload checkpoints with rate limiting\n is_best = trainer.best_fitness == trainer.fitness\n if time() - session.timers[\"ckpt\"] > session.rate_limits[\"ckpt\"]:\n LOGGER.info(f\"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}\")\n session.upload_model(trainer.epoch, trainer.last, is_best)\n session.timers[\"ckpt\"] = time() # reset timer",
"chunk_type": "function",
"name": "on_model_save",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 48,
"end_line": 56,
"start_col": 0,
"end_col": 43,
"parent_name": null,
"docstring": "Upload model checkpoints to Ultralytics HUB with rate limiting.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_model_save_76486db7"
},
{
"content": "def on_train_end(trainer):\n \"\"\"Upload final model and metrics to Ultralytics HUB at the end of training.\"\"\"\n if session := getattr(trainer, \"hub_session\", None):\n # Upload final model and metrics with exponential standoff\n LOGGER.info(f\"{PREFIX}Syncing final model...\")\n session.upload_model(\n trainer.epoch,\n trainer.best,\n map=trainer.metrics.get(\"metrics/mAP50-95(B)\", 0),\n final=True,\n )\n session.alive = False # stop heartbeats\n LOGGER.info(f\"{PREFIX}Done ✅\\n{PREFIX}View model at {session.model_url} 🚀\")",
"chunk_type": "function",
"name": "on_train_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 59,
"end_line": 71,
"start_col": 0,
"end_col": 88,
"parent_name": null,
"docstring": "Upload final model and metrics to Ultralytics HUB at the end of training.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_end_82e21297"
},
{
"content": "def on_train_start(trainer):\n \"\"\"Run events on train start.\"\"\"\n events(trainer.args, trainer.device)",
"chunk_type": "function",
"name": "on_train_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 74,
"end_line": 76,
"start_col": 0,
"end_col": 40,
"parent_name": null,
"docstring": "Run events on train start.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_start_91192112"
},
{
"content": "def on_val_start(validator):\n \"\"\"Run events on validation start.\"\"\"\n if not validator.training:\n events(validator.args, validator.device)",
"chunk_type": "function",
"name": "on_val_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 79,
"end_line": 82,
"start_col": 0,
"end_col": 48,
"parent_name": null,
"docstring": "Run events on validation start.",
"parameters": [
"validator"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_val_start_16876252"
},
{
"content": "def on_predict_start(predictor):\n \"\"\"Run events on predict start.\"\"\"\n events(predictor.args, predictor.device)",
"chunk_type": "function",
"name": "on_predict_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 85,
"end_line": 87,
"start_col": 0,
"end_col": 44,
"parent_name": null,
"docstring": "Run events on predict start.",
"parameters": [
"predictor"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_predict_start_6613bf11"
},
{
"content": "def on_export_start(exporter):\n \"\"\"Run events on export start.\"\"\"\n events(exporter.args, exporter.device)",
"chunk_type": "function",
"name": "on_export_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 90,
"end_line": 92,
"start_col": 0,
"end_col": 42,
"parent_name": null,
"docstring": "Run events on export start.",
"parameters": [
"exporter"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"json",
"time.time",
"ultralytics.hub.HUB_WEB_ROOT",
"ultralytics.hub.PREFIX",
"ultralytics.hub.HUBTrainingSession",
"ultralytics.hub.events",
"ultralytics.utils.LOGGER",
"ultralytics.utils.RANK",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_export_start_c3b5094d"
},
{
"content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_pretrain_routine_end\": on_pretrain_routine_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_model_save\": on_model_save,\n \"on_train_end\": on_train_end,\n \"on_train_start\": on_train_start,\n \"on_val_start\": on_val_start,\n \"on_predict_start\": on_predict_start,\n \"on_export_start\": on_export_start,\n }\n if SETTINGS[\"hub\"] is True\n else {}\n) # verify hub is enabled before registering callbacks",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\hub.py",
"start_line": 95,
"end_line": 109,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_88371912"
},
{
"content": "from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr",
"chunk_type": "import",
"name": "LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py",
"start_line": 24,
"end_line": 24,
"start_col": 0,
"end_col": 81,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr_d1eb9590"
},
{
"content": "def sanitize_dict(x: dict) -> dict:\n \"\"\"Sanitize dictionary keys by removing parentheses and converting values to floats.\"\"\"\n return {k.replace(\"(\", \"\").replace(\")\", \"\"): float(v) for k, v in x.items()}",
"chunk_type": "function",
"name": "sanitize_dict",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py",
"start_line": 42,
"end_line": 44,
"start_col": 0,
"end_col": 80,
"parent_name": null,
"docstring": "Sanitize dictionary keys by removing parentheses and converting values to floats.",
"parameters": [
"x: dict"
],
"return_type": "dict",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.RUNS_DIR",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"os",
"mlflow",
"pathlib.Path"
],
"chunk_id": "function_sanitize_dict_3cc1ce93"
},
{
"content": "def on_pretrain_routine_end(trainer):\n \"\"\"\n Log training parameters to MLflow at the end of the pretraining routine.\n\n This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,\n experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters\n from the trainer.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.\n\n Environment Variables:\n MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.\n MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.\n MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.\n MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.\n \"\"\"\n global mlflow\n\n uri = os.environ.get(\"MLFLOW_TRACKING_URI\") or str(RUNS_DIR / \"mlflow\")\n LOGGER.debug(f\"{PREFIX} tracking uri: {uri}\")\n mlflow.set_tracking_uri(uri)\n\n # Set experiment and run names\n experiment_name = os.environ.get(\"MLFLOW_EXPERIMENT_NAME\") or trainer.args.project or \"/Shared/Ultralytics\"\n run_name = os.environ.get(\"MLFLOW_RUN\") or trainer.args.name\n mlflow.set_experiment(experiment_name)\n\n mlflow.autolog()\n try:\n active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)\n LOGGER.info(f\"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}\")\n if Path(uri).is_dir():\n LOGGER.info(f\"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'\")\n LOGGER.info(f\"{PREFIX}disable with 'yolo settings mlflow=False'\")\n mlflow.log_params(dict(trainer.args))\n except Exception as e:\n LOGGER.warning(f\"{PREFIX}Failed to initialize: {e}\")\n LOGGER.warning(f\"{PREFIX}Not tracking this run\")",
"chunk_type": "function",
"name": "on_pretrain_routine_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py",
"start_line": 47,
"end_line": 85,
"start_col": 0,
"end_col": 56,
"parent_name": null,
"docstring": "Log training parameters to MLflow at the end of the pretraining routine.\n\nThis function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,\nexperiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters\nfrom the trainer.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.\n\nEnvironment Variables:\n MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.\n MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.\n MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.\n MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.RUNS_DIR",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"os",
"mlflow",
"pathlib.Path"
],
"chunk_id": "function_on_pretrain_routine_end_aabdf4e2"
},
{
"content": "def on_train_epoch_end(trainer):\n \"\"\"Log training metrics at the end of each train epoch to MLflow.\"\"\"\n if mlflow:\n mlflow.log_metrics(\n metrics={\n **sanitize_dict(trainer.lr),\n **sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix=\"train\")),\n },\n step=trainer.epoch,\n )",
"chunk_type": "function",
"name": "on_train_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py",
"start_line": 88,
"end_line": 97,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Log training metrics at the end of each train epoch to MLflow.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.RUNS_DIR",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"os",
"mlflow",
"pathlib.Path"
],
"chunk_id": "function_on_train_epoch_end_96d2c478"
},
{
"content": "def on_fit_epoch_end(trainer):\n \"\"\"Log training metrics at the end of each fit epoch to MLflow.\"\"\"\n if mlflow:\n mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch)",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py",
"start_line": 100,
"end_line": 103,
"start_col": 0,
"end_col": 86,
"parent_name": null,
"docstring": "Log training metrics at the end of each fit epoch to MLflow.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.RUNS_DIR",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"os",
"mlflow",
"pathlib.Path"
],
"chunk_id": "function_on_fit_epoch_end_c86aae8a"
},
{
"content": "def on_train_end(trainer):\n \"\"\"Log model artifacts at the end of training.\"\"\"\n if not mlflow:\n return\n mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt\n for f in trainer.save_dir.glob(\"*\"): # log all other files in save_dir\n if f.suffix in {\".png\", \".jpg\", \".csv\", \".pt\", \".yaml\"}:\n mlflow.log_artifact(str(f))\n keep_run_active = os.environ.get(\"MLFLOW_KEEP_RUN_ACTIVE\", \"False\").lower() == \"true\"\n if keep_run_active:\n LOGGER.info(f\"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()\")\n else:\n mlflow.end_run()\n LOGGER.debug(f\"{PREFIX}mlflow run ended\")\n\n LOGGER.info(\n f\"{PREFIX}results logged to {mlflow.get_tracking_uri()}\\n{PREFIX}disable with 'yolo settings mlflow=False'\"\n )",
"chunk_type": "function",
"name": "on_train_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py",
"start_line": 106,
"end_line": 123,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Log model artifacts at the end of training.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 5,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.RUNS_DIR",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"os",
"mlflow",
"pathlib.Path"
],
"chunk_id": "function_on_train_end_d16b581b"
},
{
"content": "callbacks = (\n {\n \"on_pretrain_routine_end\": on_pretrain_routine_end,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_end\": on_train_end,\n }\n if mlflow\n else {}\n)",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\mlflow.py",
"start_line": 126,
"end_line": 135,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_dd07e529"
},
{
"content": "from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING",
"chunk_type": "import",
"name": "LOGGER, SETTINGS, TESTS_RUNNING",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 61,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, SETTINGS, TESTS_RUNNING_e1381ec7"
},
{
"content": "def _log_scalars(scalars: dict, step: int = 0) -> None:\n \"\"\"\n Log scalars to the NeptuneAI experiment logger.\n\n Args:\n scalars (dict): Dictionary of scalar values to log to NeptuneAI.\n step (int, optional): The current step or iteration number for logging.\n\n Examples:\n >>> metrics = {\"mAP\": 0.85, \"loss\": 0.32}\n >>> _log_scalars(metrics, step=100)\n \"\"\"\n if run:\n for k, v in scalars.items():\n run[k].append(value=v, step=step)",
"chunk_type": "function",
"name": "_log_scalars",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 20,
"end_line": 34,
"start_col": 0,
"end_col": 45,
"parent_name": null,
"docstring": "Log scalars to the NeptuneAI experiment logger.\n\nArgs:\n scalars (dict): Dictionary of scalar values to log to NeptuneAI.\n step (int, optional): The current step or iteration number for logging.\n\nExamples:\n >>> metrics = {\"mAP\": 0.85, \"loss\": 0.32}\n >>> _log_scalars(metrics, step=100)",
"parameters": [
"scalars: dict",
"step: int"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"neptune",
"neptune.types.File",
"matplotlib.image",
"matplotlib.pyplot",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_scalars_775f6960"
},
{
"content": "def _log_images(imgs_dict: dict, group: str = \"\") -> None:\n \"\"\"\n Log images to the NeptuneAI experiment logger.\n\n This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized\n under the specified group name.\n\n Args:\n imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.\n group (str, optional): Group name to organize images under in the Neptune UI.\n\n Examples:\n >>> # Log validation images\n >>> _log_images({\"val_batch\": img_tensor}, group=\"validation\")\n \"\"\"\n if run:\n for k, v in imgs_dict.items():\n run[f\"{group}/{k}\"].upload(File(v))",
"chunk_type": "function",
"name": "_log_images",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 37,
"end_line": 54,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Log images to the NeptuneAI experiment logger.\n\nThis function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized\nunder the specified group name.\n\nArgs:\n imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.\n group (str, optional): Group name to organize images under in the Neptune UI.\n\nExamples:\n >>> # Log validation images\n >>> _log_images({\"val_batch\": img_tensor}, group=\"validation\")",
"parameters": [
"imgs_dict: dict",
"group: str"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"neptune",
"neptune.types.File",
"matplotlib.image",
"matplotlib.pyplot",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_images_a7b4daf0"
},
{
"content": "def _log_plot(title: str, plot_path: str) -> None:\n \"\"\"Log plots to the NeptuneAI experiment logger.\"\"\"\n import matplotlib.image as mpimg\n import matplotlib.pyplot as plt\n\n img = mpimg.imread(plot_path)\n fig = plt.figure()\n ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=\"auto\", xticks=[], yticks=[]) # no ticks\n ax.imshow(img)\n run[f\"Plots/{title}\"].upload(fig)",
"chunk_type": "function",
"name": "_log_plot",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 57,
"end_line": 66,
"start_col": 0,
"end_col": 37,
"parent_name": null,
"docstring": "Log plots to the NeptuneAI experiment logger.",
"parameters": [
"title: str",
"plot_path: str"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"neptune",
"neptune.types.File",
"matplotlib.image",
"matplotlib.pyplot",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function__log_plot_80d7545a"
},
{
"content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Initialize NeptuneAI run and log hyperparameters before training starts.\"\"\"\n try:\n global run\n run = neptune.init_run(\n project=trainer.args.project or \"Ultralytics\",\n name=trainer.args.name,\n tags=[\"Ultralytics\"],\n )\n run[\"Configuration/Hyperparameters\"] = {k: \"\" if v is None else v for k, v in vars(trainer.args).items()}\n except Exception as e:\n LOGGER.warning(f\"NeptuneAI installed but not initialized correctly, not logging this run. {e}\")",
"chunk_type": "function",
"name": "on_pretrain_routine_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 69,
"end_line": 80,
"start_col": 0,
"end_col": 103,
"parent_name": null,
"docstring": "Initialize NeptuneAI run and log hyperparameters before training starts.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"neptune",
"neptune.types.File",
"matplotlib.image",
"matplotlib.pyplot",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_pretrain_routine_start_3ee0c2b7"
},
{
"content": "def on_train_epoch_end(trainer) -> None:\n \"\"\"Log training metrics and learning rate at the end of each training epoch.\"\"\"\n _log_scalars(trainer.label_loss_items(trainer.tloss, prefix=\"train\"), trainer.epoch + 1)\n _log_scalars(trainer.lr, trainer.epoch + 1)\n if trainer.epoch == 1:\n _log_images({f.stem: str(f) for f in trainer.save_dir.glob(\"train_batch*.jpg\")}, \"Mosaic\")",
"chunk_type": "function",
"name": "on_train_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 83,
"end_line": 88,
"start_col": 0,
"end_col": 98,
"parent_name": null,
"docstring": "Log training metrics and learning rate at the end of each training epoch.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"neptune",
"neptune.types.File",
"matplotlib.image",
"matplotlib.pyplot",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_epoch_end_6ce6fa22"
},
{
"content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"Log model info and validation metrics at the end of each fit epoch.\"\"\"\n if run and trainer.epoch == 0:\n from ultralytics.utils.torch_utils import model_info_for_loggers\n\n run[\"Configuration/Model\"] = model_info_for_loggers(trainer)\n _log_scalars(trainer.metrics, trainer.epoch + 1)",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 91,
"end_line": 97,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": "Log model info and validation metrics at the end of each fit epoch.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"neptune",
"neptune.types.File",
"matplotlib.image",
"matplotlib.pyplot",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_fit_epoch_end_366f762e"
},
{
"content": "def on_val_end(validator) -> None:\n \"\"\"Log validation images at the end of validation.\"\"\"\n if run:\n # Log val_labels and val_pred\n _log_images({f.stem: str(f) for f in validator.save_dir.glob(\"val*.jpg\")}, \"Validation\")",
"chunk_type": "function",
"name": "on_val_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 100,
"end_line": 104,
"start_col": 0,
"end_col": 96,
"parent_name": null,
"docstring": "Log validation images at the end of validation.",
"parameters": [
"validator"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"neptune",
"neptune.types.File",
"matplotlib.image",
"matplotlib.pyplot",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_val_end_80ca1e49"
},
{
"content": "def on_train_end(trainer) -> None:\n \"\"\"Log final results, plots, and model weights at the end of training.\"\"\"\n if run:\n # Log final results, CM matrix + PR plots\n files = [\n \"results.png\",\n \"confusion_matrix.png\",\n \"confusion_matrix_normalized.png\",\n *(f\"{x}_curve.png\" for x in (\"F1\", \"PR\", \"P\", \"R\")),\n ]\n files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter\n for f in files:\n _log_plot(title=f.stem, plot_path=f)\n # Log the final model\n run[f\"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}\"].upload(File(str(trainer.best)))",
"chunk_type": "function",
"name": "on_train_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 107,
"end_line": 121,
"start_col": 0,
"end_col": 116,
"parent_name": null,
"docstring": "Log final results, plots, and model weights at the end of training.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"neptune",
"neptune.types.File",
"matplotlib.image",
"matplotlib.pyplot",
"ultralytics.utils.torch_utils.model_info_for_loggers"
],
"chunk_id": "function_on_train_end_4d68251b"
},
{
"content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_val_end\": on_val_end,\n \"on_train_end\": on_train_end,\n }\n if neptune\n else {}\n)",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\neptune.py",
"start_line": 124,
"end_line": 134,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_e915f40c"
},
{
"content": "from ultralytics.utils import SETTINGS",
"chunk_type": "import",
"name": "SETTINGS",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\raytune.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 38,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SETTINGS_4a5b48e5"
},
{
"content": "def on_fit_epoch_end(trainer):\n \"\"\"\n Report training metrics to Ray Tune at epoch end when a Ray session is active.\n\n Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number,\n enabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.\n\n Examples:\n >>> # Called automatically by the Ultralytics training loop\n >>> on_fit_epoch_end(trainer)\n\n References:\n Ray Tune docs: https://docs.ray.io/en/latest/tune/index.html\n \"\"\"\n if ray.train._internal.session.get_session(): # check if Ray Tune session is active\n metrics = trainer.metrics\n session.report({**metrics, **{\"epoch\": trainer.epoch + 1}})",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\raytune.py",
"start_line": 15,
"end_line": 34,
"start_col": 0,
"end_col": 67,
"parent_name": null,
"docstring": "Report training metrics to Ray Tune at epoch end when a Ray session is active.\n\nCaptures metrics from the trainer object and sends them to Ray Tune with the current epoch number,\nenabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.\n\nExamples:\n >>> # Called automatically by the Ultralytics training loop\n >>> on_fit_epoch_end(trainer)\n\nReferences:\n Ray Tune docs: https://docs.ray.io/en/latest/tune/index.html",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.SETTINGS",
"ray",
"ray.tune",
"ray.air.session"
],
"chunk_id": "function_on_fit_epoch_end_2399644c"
},
{
"content": "callbacks = (\n {\n \"on_fit_epoch_end\": on_fit_epoch_end,\n }\n if tune\n else {}\n)",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\raytune.py",
"start_line": 37,
"end_line": 43,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_ceb3edc8"
},
{
"content": "from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils",
"chunk_type": "import",
"name": "LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 84,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils_2bd75001"
},
{
"content": "def _log_scalars(scalars: dict, step: int = 0) -> None:\n \"\"\"\n Log scalar values to TensorBoard.\n\n Args:\n scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the\n corresponding scalar values.\n step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.\n\n Examples:\n Log training metrics\n >>> metrics = {\"loss\": 0.5, \"accuracy\": 0.95}\n >>> _log_scalars(metrics, step=100)\n \"\"\"\n if WRITER:\n for k, v in scalars.items():\n WRITER.add_scalar(k, v, step)",
"chunk_type": "function",
"name": "_log_scalars",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py",
"start_line": 24,
"end_line": 40,
"start_col": 0,
"end_col": 41,
"parent_name": null,
"docstring": "Log scalar values to TensorBoard.\n\nArgs:\n scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the\n corresponding scalar values.\n step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.\n\nExamples:\n Log training metrics\n >>> metrics = {\"loss\": 0.5, \"accuracy\": 0.95}\n >>> _log_scalars(metrics, step=100)",
"parameters": [
"scalars: dict",
"step: int"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"ultralytics.utils.torch_utils",
"warnings",
"copy.deepcopy",
"torch",
"torch.utils.tensorboard.SummaryWriter"
],
"chunk_id": "function__log_scalars_1970cfe8"
},
{
"content": "def _log_tensorboard_graph(trainer) -> None:\n \"\"\"\n Log model graph to TensorBoard.\n\n This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input\n tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex\n approach for models like RTDETR that may require special handling.\n\n Args:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize.\n Must have attributes model and args with imgsz.\n\n Notes:\n This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.\n It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different\n model architectures.\n \"\"\"\n # Input image\n imgsz = trainer.args.imgsz\n imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz\n p = next(trainer.model.parameters()) # for device, type\n im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)\n\n with warnings.catch_warnings():\n warnings.simplefilter(\"ignore\", category=UserWarning) # suppress jit trace warning\n warnings.simplefilter(\"ignore\", category=torch.jit.TracerWarning) # suppress jit trace warning\n\n # Try simple method first (YOLO)\n try:\n trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes\n WRITER.add_graph(torch.jit.trace(torch_utils.de_parallel(trainer.model), im, strict=False), [])\n LOGGER.info(f\"{PREFIX}model graph visualization added ✅\")\n return\n\n except Exception:\n # Fallback to TorchScript export steps (RTDETR)\n try:\n model = deepcopy(torch_utils.de_parallel(trainer.model))\n model.eval()\n model = model.fuse(verbose=False)\n for m in model.modules():\n if hasattr(m, \"export\"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)\n m.export = True\n m.format = \"torchscript\"\n model(im) # dry run\n WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])\n LOGGER.info(f\"{PREFIX}model graph visualization added ✅\")\n except Exception as e:\n LOGGER.warning(f\"{PREFIX}TensorBoard graph visualization failure {e}\")",
"chunk_type": "function",
"name": "_log_tensorboard_graph",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py",
"start_line": 43,
"end_line": 91,
"start_col": 0,
"end_col": 86,
"parent_name": null,
"docstring": "Log model graph to TensorBoard.\n\nThis function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input\ntensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex\napproach for models like RTDETR that may require special handling.\n\nArgs:\n trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize.\n Must have attributes model and args with imgsz.\n\nNotes:\n This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.\n It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different\n model architectures.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 5,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"ultralytics.utils.torch_utils",
"warnings",
"copy.deepcopy",
"torch",
"torch.utils.tensorboard.SummaryWriter"
],
"chunk_id": "function__log_tensorboard_graph_c2b5db29"
},
{
"content": "def on_pretrain_routine_start(trainer) -> None:\n \"\"\"Initialize TensorBoard logging with SummaryWriter.\"\"\"\n if SummaryWriter:\n try:\n global WRITER\n WRITER = SummaryWriter(str(trainer.save_dir))\n LOGGER.info(f\"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/\")\n except Exception as e:\n LOGGER.warning(f\"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}\")",
"chunk_type": "function",
"name": "on_pretrain_routine_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py",
"start_line": 94,
"end_line": 102,
"start_col": 0,
"end_col": 103,
"parent_name": null,
"docstring": "Initialize TensorBoard logging with SummaryWriter.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"ultralytics.utils.torch_utils",
"warnings",
"copy.deepcopy",
"torch",
"torch.utils.tensorboard.SummaryWriter"
],
"chunk_id": "function_on_pretrain_routine_start_f2c2a491"
},
{
"content": "def on_train_start(trainer) -> None:\n \"\"\"Log TensorBoard graph.\"\"\"\n if WRITER:\n _log_tensorboard_graph(trainer)",
"chunk_type": "function",
"name": "on_train_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py",
"start_line": 105,
"end_line": 108,
"start_col": 0,
"end_col": 39,
"parent_name": null,
"docstring": "Log TensorBoard graph.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"ultralytics.utils.torch_utils",
"warnings",
"copy.deepcopy",
"torch",
"torch.utils.tensorboard.SummaryWriter"
],
"chunk_id": "function_on_train_start_4bde3c8d"
},
{
"content": "def on_train_epoch_end(trainer) -> None:\n \"\"\"Log scalar statistics at the end of a training epoch.\"\"\"\n _log_scalars(trainer.label_loss_items(trainer.tloss, prefix=\"train\"), trainer.epoch + 1)\n _log_scalars(trainer.lr, trainer.epoch + 1)",
"chunk_type": "function",
"name": "on_train_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py",
"start_line": 111,
"end_line": 114,
"start_col": 0,
"end_col": 47,
"parent_name": null,
"docstring": "Log scalar statistics at the end of a training epoch.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"ultralytics.utils.torch_utils",
"warnings",
"copy.deepcopy",
"torch",
"torch.utils.tensorboard.SummaryWriter"
],
"chunk_id": "function_on_train_epoch_end_9d19a031"
},
{
"content": "def on_fit_epoch_end(trainer) -> None:\n \"\"\"Log epoch metrics at end of training epoch.\"\"\"\n _log_scalars(trainer.metrics, trainer.epoch + 1)",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py",
"start_line": 117,
"end_line": 119,
"start_col": 0,
"end_col": 52,
"parent_name": null,
"docstring": "Log epoch metrics at end of training epoch.",
"parameters": [
"trainer"
],
"return_type": "None",
"decorators": [],
"complexity_score": 1,
"dependencies": [
"ultralytics.utils.LOGGER",
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.colorstr",
"ultralytics.utils.torch_utils",
"warnings",
"copy.deepcopy",
"torch",
"torch.utils.tensorboard.SummaryWriter"
],
"chunk_id": "function_on_fit_epoch_end_1d966ecf"
},
{
"content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_start\": on_train_start,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_epoch_end\": on_train_epoch_end,\n }\n if SummaryWriter\n else {}\n)",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\tensorboard.py",
"start_line": 122,
"end_line": 131,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_df9c659b"
},
{
"content": "from ultralytics.utils import SETTINGS, TESTS_RUNNING",
"chunk_type": "import",
"name": "SETTINGS, TESTS_RUNNING",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 53,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_SETTINGS, TESTS_RUNNING_d2d8e81b"
},
{
"content": "from ultralytics.utils.torch_utils import model_info_for_loggers",
"chunk_type": "import",
"name": "model_info_for_loggers",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 4,
"end_line": 4,
"start_col": 0,
"end_col": 64,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_model_info_for_loggers_b3f062da"
},
{
"content": "def _custom_table(x, y, classes, title=\"Precision Recall Curve\", x_title=\"Recall\", y_title=\"Precision\"):\n \"\"\"\n Create and log a custom metric visualization to wandb.plot.pr_curve.\n\n This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall\n curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across\n different classes.\n\n Args:\n x (list): Values for the x-axis; expected to have length N.\n y (list): Corresponding values for the y-axis; also expected to have length N.\n classes (list): Labels identifying the class of each point; length N.\n title (str, optional): Title for the plot.\n x_title (str, optional): Label for the x-axis.\n y_title (str, optional): Label for the y-axis.\n\n Returns:\n (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.\n \"\"\"\n import pandas # scope for faster 'import ultralytics'\n\n df = pandas.DataFrame({\"class\": classes, \"y\": y, \"x\": x}).round(3)\n fields = {\"x\": \"x\", \"y\": \"y\", \"class\": \"class\"}\n string_fields = {\"title\": title, \"x-axis-title\": x_title, \"y-axis-title\": y_title}\n return wb.plot_table(\n \"wandb/area-under-curve/v0\", wb.Table(dataframe=df), fields=fields, string_fields=string_fields\n )",
"chunk_type": "function",
"name": "_custom_table",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 18,
"end_line": 44,
"start_col": 0,
"end_col": 5,
"parent_name": null,
"docstring": "Create and log a custom metric visualization to wandb.plot.pr_curve.\n\nThis function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall\ncurve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across\ndifferent classes.\n\nArgs:\n x (list): Values for the x-axis; expected to have length N.\n y (list): Corresponding values for the y-axis; also expected to have length N.\n classes (list): Labels identifying the class of each point; length N.\n title (str, optional): Title for the plot.\n x_title (str, optional): Label for the x-axis.\n y_title (str, optional): Label for the y-axis.\n\nReturns:\n (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.",
"parameters": [
"x",
"y",
"classes",
"title",
"x_title",
"y_title"
],
"return_type": null,
"decorators": [],
"complexity_score": 1,
"dependencies": [
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.torch_utils.model_info_for_loggers",
"wandb",
"pandas",
"numpy"
],
"chunk_id": "function__custom_table_535745f6"
},
{
"content": "def _plot_curve(\n x,\n y,\n names=None,\n id=\"precision-recall\",\n title=\"Precision Recall Curve\",\n x_title=\"Recall\",\n y_title=\"Precision\",\n num_x=100,\n only_mean=False,\n):\n \"\"\"\n Log a metric curve visualization.\n\n This function generates a metric curve based on input data and logs the visualization to wandb.\n The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.\n\n Args:\n x (np.ndarray): Data points for the x-axis with length N.\n y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.\n names (list, optional): Names of the classes corresponding to the y-axis data; length C.\n id (str, optional): Unique identifier for the logged data in wandb.\n title (str, optional): Title for the visualization plot.\n x_title (str, optional): Label for the x-axis.\n y_title (str, optional): Label for the y-axis.\n num_x (int, optional): Number of interpolated data points for visualization.\n only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.\n\n Notes:\n The function leverages the '_custom_table' function to generate the actual visualization.\n \"\"\"\n import numpy as np\n\n # Create new x\n if names is None:\n names = []\n x_new = np.linspace(x[0], x[-1], num_x).round(5)\n\n # Create arrays for logging\n x_log = x_new.tolist()\n y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()\n\n if only_mean:\n table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])\n wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})\n else:\n classes = [\"mean\"] * len(x_log)\n for i, yi in enumerate(y):\n x_log.extend(x_new) # add new x\n y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x\n classes.extend([names[i]] * len(x_new)) # add class names\n wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)",
"chunk_type": "function",
"name": "_plot_curve",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 47,
"end_line": 98,
"start_col": 0,
"end_col": 97,
"parent_name": null,
"docstring": "Log a metric curve visualization.\n\nThis function generates a metric curve based on input data and logs the visualization to wandb.\nThe curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.\n\nArgs:\n x (np.ndarray): Data points for the x-axis with length N.\n y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.\n names (list, optional): Names of the classes corresponding to the y-axis data; length C.\n id (str, optional): Unique identifier for the logged data in wandb.\n title (str, optional): Title for the visualization plot.\n x_title (str, optional): Label for the x-axis.\n y_title (str, optional): Label for the y-axis.\n num_x (int, optional): Number of interpolated data points for visualization.\n only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.\n\nNotes:\n The function leverages the '_custom_table' function to generate the actual visualization.",
"parameters": [
"x",
"y",
"names",
"id",
"title",
"x_title",
"y_title",
"num_x",
"only_mean"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.torch_utils.model_info_for_loggers",
"wandb",
"pandas",
"numpy"
],
"chunk_id": "function__plot_curve_7bb47a16"
},
{
"content": "def _log_plots(plots, step):\n \"\"\"\n Log plots to WandB at a specific step if they haven't been logged already.\n\n This function checks each plot in the input dictionary against previously processed plots and logs\n new or updated plots to WandB at the specified step.\n\n Args:\n plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries\n containing plot metadata including timestamps.\n step (int): The step/epoch at which to log the plots in the WandB run.\n\n Notes:\n The function uses a shallow copy of the plots dictionary to prevent modification during iteration.\n Plots are identified by their stem name (filename without extension).\n Each plot is logged as a WandB Image object.\n \"\"\"\n for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration\n timestamp = params[\"timestamp\"]\n if _processed_plots.get(name) != timestamp:\n wb.run.log({name.stem: wb.Image(str(name))}, step=step)\n _processed_plots[name] = timestamp",
"chunk_type": "function",
"name": "_log_plots",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 101,
"end_line": 122,
"start_col": 0,
"end_col": 46,
"parent_name": null,
"docstring": "Log plots to WandB at a specific step if they haven't been logged already.\n\nThis function checks each plot in the input dictionary against previously processed plots and logs\nnew or updated plots to WandB at the specified step.\n\nArgs:\n plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries\n containing plot metadata including timestamps.\n step (int): The step/epoch at which to log the plots in the WandB run.\n\nNotes:\n The function uses a shallow copy of the plots dictionary to prevent modification during iteration.\n Plots are identified by their stem name (filename without extension).\n Each plot is logged as a WandB Image object.",
"parameters": [
"plots",
"step"
],
"return_type": null,
"decorators": [],
"complexity_score": 3,
"dependencies": [
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.torch_utils.model_info_for_loggers",
"wandb",
"pandas",
"numpy"
],
"chunk_id": "function__log_plots_6342a3c7"
},
{
"content": "def on_pretrain_routine_start(trainer):\n \"\"\"Initialize and start wandb project if module is present.\"\"\"\n if not wb.run:\n wb.init(\n project=str(trainer.args.project).replace(\"/\", \"-\") if trainer.args.project else \"Ultralytics\",\n name=str(trainer.args.name).replace(\"/\", \"-\"),\n config=vars(trainer.args),\n )",
"chunk_type": "function",
"name": "on_pretrain_routine_start",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 125,
"end_line": 132,
"start_col": 0,
"end_col": 9,
"parent_name": null,
"docstring": "Initialize and start wandb project if module is present.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.torch_utils.model_info_for_loggers",
"wandb",
"pandas",
"numpy"
],
"chunk_id": "function_on_pretrain_routine_start_e7123d05"
},
{
"content": "def on_fit_epoch_end(trainer):\n \"\"\"Log training metrics and model information at the end of an epoch.\"\"\"\n wb.run.log(trainer.metrics, step=trainer.epoch + 1)\n _log_plots(trainer.plots, step=trainer.epoch + 1)\n _log_plots(trainer.validator.plots, step=trainer.epoch + 1)\n if trainer.epoch == 0:\n wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)",
"chunk_type": "function",
"name": "on_fit_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 135,
"end_line": 141,
"start_col": 0,
"end_col": 75,
"parent_name": null,
"docstring": "Log training metrics and model information at the end of an epoch.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.torch_utils.model_info_for_loggers",
"wandb",
"pandas",
"numpy"
],
"chunk_id": "function_on_fit_epoch_end_9fd9944a"
},
{
"content": "def on_train_epoch_end(trainer):\n \"\"\"Log metrics and save images at the end of each training epoch.\"\"\"\n wb.run.log(trainer.label_loss_items(trainer.tloss, prefix=\"train\"), step=trainer.epoch + 1)\n wb.run.log(trainer.lr, step=trainer.epoch + 1)\n if trainer.epoch == 1:\n _log_plots(trainer.plots, step=trainer.epoch + 1)",
"chunk_type": "function",
"name": "on_train_epoch_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 144,
"end_line": 149,
"start_col": 0,
"end_col": 57,
"parent_name": null,
"docstring": "Log metrics and save images at the end of each training epoch.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 2,
"dependencies": [
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.torch_utils.model_info_for_loggers",
"wandb",
"pandas",
"numpy"
],
"chunk_id": "function_on_train_epoch_end_61965413"
},
{
"content": "def on_train_end(trainer):\n \"\"\"Save the best model as an artifact and log final plots at the end of training.\"\"\"\n _log_plots(trainer.validator.plots, step=trainer.epoch + 1)\n _log_plots(trainer.plots, step=trainer.epoch + 1)\n art = wb.Artifact(type=\"model\", name=f\"run_{wb.run.id}_model\")\n if trainer.best.exists():\n art.add_file(trainer.best)\n wb.run.log_artifact(art, aliases=[\"best\"])\n # Check if we actually have plots to save\n if trainer.args.plots and hasattr(trainer.validator.metrics, \"curves_results\"):\n for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):\n x, y, x_title, y_title = curve_values\n _plot_curve(\n x,\n y,\n names=list(trainer.validator.metrics.names.values()),\n id=f\"curves/{curve_name}\",\n title=curve_name,\n x_title=x_title,\n y_title=y_title,\n )\n wb.run.finish() # required or run continues on dashboard",
"chunk_type": "function",
"name": "on_train_end",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 152,
"end_line": 173,
"start_col": 0,
"end_col": 19,
"parent_name": null,
"docstring": "Save the best model as an artifact and log final plots at the end of training.",
"parameters": [
"trainer"
],
"return_type": null,
"decorators": [],
"complexity_score": 4,
"dependencies": [
"ultralytics.utils.SETTINGS",
"ultralytics.utils.TESTS_RUNNING",
"ultralytics.utils.torch_utils.model_info_for_loggers",
"wandb",
"pandas",
"numpy"
],
"chunk_id": "function_on_train_end_2b067edb"
},
{
"content": "callbacks = (\n {\n \"on_pretrain_routine_start\": on_pretrain_routine_start,\n \"on_train_epoch_end\": on_train_epoch_end,\n \"on_fit_epoch_end\": on_fit_epoch_end,\n \"on_train_end\": on_train_end,\n }\n if wb\n else {}\n)",
"chunk_type": "variable",
"name": "callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\wb.py",
"start_line": 176,
"end_line": 185,
"start_col": 0,
"end_col": 1,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable_callbacks_0344b950"
},
{
"content": "from .base import add_integration_callbacks, default_callbacks, get_default_callbacks",
"chunk_type": "import",
"name": "add_integration_callbacks, default_callbacks, get_default_callbacks",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\__init__.py",
"start_line": 3,
"end_line": 3,
"start_col": 0,
"end_col": 85,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "import_add_integration_callbacks, default_callbacks, get_default_callbacks_ba9532b6"
},
{
"content": "__all__ = \"add_integration_callbacks\", \"default_callbacks\", \"get_default_callbacks\"",
"chunk_type": "variable",
"name": "__all__",
"file_path": "ultralytics\\ultralytics\\utils\\callbacks\\__init__.py",
"start_line": 5,
"end_line": 5,
"start_col": 0,
"end_col": 83,
"parent_name": null,
"docstring": null,
"parameters": null,
"return_type": null,
"decorators": null,
"complexity_score": null,
"dependencies": null,
"chunk_id": "variable___all___c2d15a05"
}
]