DuMM Bacteria Tracker Model
Overview
This repository contains the trained weights and documentation for the DuMM Bacteria Tracker Model, a Graph Neural Network (GNN) designed for cell lineage link prediction in time-lapse microscopy data.
The model uses a custom implementation of a Parameter Decoupled Network (PDN) variant within an Edge-Propagation Message Passing Neural Network (EP-MPNN) architecture, inspired by recent advancements in dynamic graph representation learning.
Model Architecture and Implementation Details
- Framework: PyTorch and PyTorch Geometric (PyG).
- Architecture:
LineageLinkPredictionGNN(customnn.Module). - Core Layers: Utilizes custom
EP_MPNN_Blockwhich incorporates Distance & Similarity (DS) features and a Jumping Knowledge (JK) network for aggregating features across multiple layers. - Task: Binary classification (Link Prediction) on candidate edges between adjacent time frames (continuation or division links).
- Loss Function:
BCEWithLogitsLoss. - Input Features: Node features are scaled using a
StandardScalerTransformand dynamically used to compute edge attributes (Absolute Difference + Cosine Similarity) via theDS_block.
Input Features
The model uses a 10-dimensional feature vector for each node (cell) derived from image analysis. These features capture morphological and intensity properties of the bacterial cells across two channels (Phase Contrast and Fluorescence).
| Feature Name | Type | Description |
|---|---|---|
area |
Morphological | Area of the cell segment. |
centroid_y |
Positional | Y-coordinate of the cell's centroid (critical for 1D growth systems). |
axis_major_length |
Morphological | Length of the cell's major axis. |
axis_minor_length |
Morphological | Length of the cell's minor axis. |
intensity_mean/max/min_phase |
Intensity | Mean, max, and min pixel intensity in the Phase Contrast channel. |
intensity_mean/max/min_fluor |
Intensity | Mean, max, and min pixel intensity in the Fluorescence channel. |
The model was trained on microscopy images of the duplex mother machine developed by the Jun lab (https://jun.ucsd.edu/mother_machine.php)
Data Preprocessing and Splitting
Splitting Strategy
To ensure the model generalizes to future, unseen data, a time-based temporal split was employed:
- Training Set: First 60% of unique time frames (
sorted_time_frames[:train_split_idx]). - Validation Set: Next 20% of unique time frames (
sorted_time_frames[train_split_idx:val_split_idx]). - Test Set: Final 20% of unique time frames (
sorted_time_frames[val_split_idx:]).
Normalization
- Method: Node features were normalized using Standard Scaling (
sklearn.preprocessing.StandardScaler). - Fit: The scaler was fitted only on the training set features (
all_train_node_features_df). - Application: The fitted scaler was then applied to transform the features in the Training, Validation, and Test sets via the
StandardScalerTransform. This avoids data leakage.
Graph Creation (Candidate Generation)
Candidate edges (links between cells in adjacent time frames) were generated based on a custom set of geometric and morphological heuristics:
- Distance Constraint: Max distance between centroids is limited by
max_dist_link(default 50.0). - Area Ratio Constraints:
- Continuation (1-to-1):
min_area_ratio_continuation(0.8) tomax_area_ratio_continuation(1.2). - Division (1-to-2):
min_area_ratio_division(1.8) tomax_area_ratio_division(2.2).
- Continuation (1-to-1):
Training Protocol
| Hyperparameter | Value | Description |
|---|---|---|
GNN Layers (num_blocks) |
2 | Number of sequential EP-MPNN blocks. |
| Hidden Channels | 128 | Dimension for node and edge embeddings. |
| Optimizer | Adam | Standard optimization algorithm. |
| Learning Rate | 0.001 | Base learning rate. |
| Weight Decay | 0.0005 | L2 regularization applied to prevent overfitting. |
| Batch Size | 32 | Number of graphs processed per iteration. |
| Evaluation Metric | Validation Accuracy (val_acc) |
Used for saving the best_link_prediction_model.pt. |
| Early Stopping | Yes | Monitors Validation Loss (val_loss) with a patience of 10 epochs. |
| Max Epochs | 500 | Maximum number of training epochs. |