|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
metrics: |
|
|
- accuracy 0.9860 |
|
|
- roc_auc 0.9979 |
|
|
pipeline_tag: image-feature-extraction |
|
|
tags: |
|
|
- gnn |
|
|
- link-prediction |
|
|
- pytorch |
|
|
- pytorch-geometric |
|
|
- graph-neural-networks |
|
|
- biology |
|
|
- microscopy |
|
|
- cell-tracking |
|
|
- lineage-tracking |
|
|
|
|
|
--- |
|
|
# DuMM Bacteria Tracker Model |
|
|
## Overview |
|
|
|
|
|
This repository contains the trained weights and documentation for the **DuMM Bacteria Tracker Model**, a **Graph Neural Network (GNN)** designed for **cell lineage link prediction** in time-lapse microscopy data. |
|
|
|
|
|
The model uses a custom implementation of a **Parameter Decoupled Network (PDN)** variant within an **Edge-Propagation Message Passing Neural Network (EP-MPNN)** architecture, inspired by recent advancements in dynamic graph representation learning. |
|
|
|
|
|
## Model Architecture and Implementation Details |
|
|
|
|
|
* **Framework:** PyTorch and PyTorch Geometric (PyG). |
|
|
* **Architecture:** `LineageLinkPredictionGNN` (custom `nn.Module`). |
|
|
* **Core Layers:** Utilizes custom `EP_MPNN_Block` which incorporates **Distance & Similarity (DS)** features and a **Jumping Knowledge (JK)** network for aggregating features across multiple layers. |
|
|
* **Task:** Binary classification (Link Prediction) on candidate edges between adjacent time frames (continuation or division links). |
|
|
* **Loss Function:** `BCEWithLogitsLoss`. |
|
|
* **Input Features:** Node features are scaled using a `StandardScalerTransform` and dynamically used to compute edge attributes (Absolute Difference + Cosine Similarity) via the `DS_block`. |
|
|
|
|
|
### Input Features |
|
|
|
|
|
The model uses a **10-dimensional feature vector** for each node (cell) derived from image analysis. These features capture morphological and intensity properties of the bacterial cells across two channels (Phase Contrast and Fluorescence). |
|
|
|
|
|
| Feature Name | Type | Description | |
|
|
| :--- | :--- | :--- | |
|
|
| `area` | Morphological | Area of the cell segment. | |
|
|
| `centroid_y` | Positional | Y-coordinate of the cell's centroid (critical for 1D growth systems). | |
|
|
| `axis_major_length` | Morphological | Length of the cell's major axis. | |
|
|
| `axis_minor_length` | Morphological | Length of the cell's minor axis. | |
|
|
| `intensity_mean/max/min_phase` | Intensity | Mean, max, and min pixel intensity in the **Phase Contrast** channel. | |
|
|
| `intensity_mean/max/min_fluor` | Intensity | Mean, max, and min pixel intensity in the **Fluorescence** channel. | |
|
|
|
|
|
|
|
|
The model was trained on microscopy images of the duplex mother machine developed by the Jun lab (https://jun.ucsd.edu/mother_machine.php) |
|
|
|
|
|
### Data Preprocessing and Splitting |
|
|
|
|
|
#### Splitting Strategy |
|
|
To ensure the model generalizes to future, unseen data, a **time-based temporal split** was employed: |
|
|
1. **Training Set:** First 60% of unique time frames (`sorted_time_frames[:train_split_idx]`). |
|
|
2. **Validation Set:** Next 20% of unique time frames (`sorted_time_frames[train_split_idx:val_split_idx]`). |
|
|
3. **Test Set:** Final 20% of unique time frames (`sorted_time_frames[val_split_idx:]`). |
|
|
|
|
|
#### Normalization |
|
|
* **Method:** Node features were normalized using **Standard Scaling (`sklearn.preprocessing.StandardScaler`)**. |
|
|
* **Fit:** The scaler was **fitted *only* on the training set features** (`all_train_node_features_df`). |
|
|
* **Application:** The fitted scaler was then applied to transform the features in the Training, Validation, and Test sets via the `StandardScalerTransform`. This avoids data leakage. |
|
|
|
|
|
#### Graph Creation (Candidate Generation) |
|
|
Candidate edges (links between cells in adjacent time frames) were generated based on a custom set of geometric and morphological heuristics: |
|
|
* **Distance Constraint:** Max distance between centroids is limited by `max_dist_link` (default 50.0). |
|
|
* **Area Ratio Constraints:** |
|
|
* **Continuation (1-to-1):** `min_area_ratio_continuation` (0.8) to `max_area_ratio_continuation` (1.2). |
|
|
* **Division (1-to-2):** `min_area_ratio_division` (1.8) to `max_area_ratio_division` (2.2). |
|
|
|
|
|
### Training Protocol |
|
|
|
|
|
| Hyperparameter | Value | Description | |
|
|
| :--- | :--- | :--- | |
|
|
| **GNN Layers (`num_blocks`)** | 2 | Number of sequential EP-MPNN blocks. | |
|
|
| **Hidden Channels** | 128 | Dimension for node and edge embeddings. | |
|
|
| **Optimizer** | Adam | Standard optimization algorithm. | |
|
|
| **Learning Rate** | 0.001 | Base learning rate. | |
|
|
| **Weight Decay** | 0.0005 | L2 regularization applied to prevent overfitting. | |
|
|
| **Batch Size** | 32 | Number of graphs processed per iteration. | |
|
|
| **Evaluation Metric** | Validation Accuracy (`val_acc`) | Used for saving the `best_link_prediction_model.pt`. | |
|
|
| **Early Stopping** | Yes | Monitors Validation Loss (`val_loss`) with a **patience of 10 epochs**. | |
|
|
| **Max Epochs** | 500 | Maximum number of training epochs. | |