--- license: mit language: - en metrics: - accuracy 0.9860 - roc_auc 0.9979 pipeline_tag: image-feature-extraction tags: - gnn - link-prediction - pytorch - pytorch-geometric - graph-neural-networks - biology - microscopy - cell-tracking - lineage-tracking --- # DuMM Bacteria Tracker Model ## Overview This repository contains the trained weights and documentation for the **DuMM Bacteria Tracker Model**, a **Graph Neural Network (GNN)** designed for **cell lineage link prediction** in time-lapse microscopy data. The model uses a custom implementation of a **Parameter Decoupled Network (PDN)** variant within an **Edge-Propagation Message Passing Neural Network (EP-MPNN)** architecture, inspired by recent advancements in dynamic graph representation learning. ## Model Architecture and Implementation Details * **Framework:** PyTorch and PyTorch Geometric (PyG). * **Architecture:** `LineageLinkPredictionGNN` (custom `nn.Module`). * **Core Layers:** Utilizes custom `EP_MPNN_Block` which incorporates **Distance & Similarity (DS)** features and a **Jumping Knowledge (JK)** network for aggregating features across multiple layers. * **Task:** Binary classification (Link Prediction) on candidate edges between adjacent time frames (continuation or division links). * **Loss Function:** `BCEWithLogitsLoss`. * **Input Features:** Node features are scaled using a `StandardScalerTransform` and dynamically used to compute edge attributes (Absolute Difference + Cosine Similarity) via the `DS_block`. ### Input Features The model uses a **10-dimensional feature vector** for each node (cell) derived from image analysis. These features capture morphological and intensity properties of the bacterial cells across two channels (Phase Contrast and Fluorescence). | Feature Name | Type | Description | | :--- | :--- | :--- | | `area` | Morphological | Area of the cell segment. | | `centroid_y` | Positional | Y-coordinate of the cell's centroid (critical for 1D growth systems). | | `axis_major_length` | Morphological | Length of the cell's major axis. | | `axis_minor_length` | Morphological | Length of the cell's minor axis. | | `intensity_mean/max/min_phase` | Intensity | Mean, max, and min pixel intensity in the **Phase Contrast** channel. | | `intensity_mean/max/min_fluor` | Intensity | Mean, max, and min pixel intensity in the **Fluorescence** channel. | The model was trained on microscopy images of the duplex mother machine developed by the Jun lab (https://jun.ucsd.edu/mother_machine.php) ### Data Preprocessing and Splitting #### Splitting Strategy To ensure the model generalizes to future, unseen data, a **time-based temporal split** was employed: 1. **Training Set:** First 60% of unique time frames (`sorted_time_frames[:train_split_idx]`). 2. **Validation Set:** Next 20% of unique time frames (`sorted_time_frames[train_split_idx:val_split_idx]`). 3. **Test Set:** Final 20% of unique time frames (`sorted_time_frames[val_split_idx:]`). #### Normalization * **Method:** Node features were normalized using **Standard Scaling (`sklearn.preprocessing.StandardScaler`)**. * **Fit:** The scaler was **fitted *only* on the training set features** (`all_train_node_features_df`). * **Application:** The fitted scaler was then applied to transform the features in the Training, Validation, and Test sets via the `StandardScalerTransform`. This avoids data leakage. #### Graph Creation (Candidate Generation) Candidate edges (links between cells in adjacent time frames) were generated based on a custom set of geometric and morphological heuristics: * **Distance Constraint:** Max distance between centroids is limited by `max_dist_link` (default 50.0). * **Area Ratio Constraints:** * **Continuation (1-to-1):** `min_area_ratio_continuation` (0.8) to `max_area_ratio_continuation` (1.2). * **Division (1-to-2):** `min_area_ratio_division` (1.8) to `max_area_ratio_division` (2.2). ### Training Protocol | Hyperparameter | Value | Description | | :--- | :--- | :--- | | **GNN Layers (`num_blocks`)** | 2 | Number of sequential EP-MPNN blocks. | | **Hidden Channels** | 128 | Dimension for node and edge embeddings. | | **Optimizer** | Adam | Standard optimization algorithm. | | **Learning Rate** | 0.001 | Base learning rate. | | **Weight Decay** | 0.0005 | L2 regularization applied to prevent overfitting. | | **Batch Size** | 32 | Number of graphs processed per iteration. | | **Evaluation Metric** | Validation Accuracy (`val_acc`) | Used for saving the `best_link_prediction_model.pt`. | | **Early Stopping** | Yes | Monitors Validation Loss (`val_loss`) with a **patience of 10 epochs**. | | **Max Epochs** | 500 | Maximum number of training epochs. |