DuMM_bacteria_track / README.md
nvivanco's picture
Update README.md
01aa36f verified
---
license: mit
language:
- en
metrics:
- accuracy 0.9860
- roc_auc 0.9979
pipeline_tag: image-feature-extraction
tags:
- gnn
- link-prediction
- pytorch
- pytorch-geometric
- graph-neural-networks
- biology
- microscopy
- cell-tracking
- lineage-tracking
---
# DuMM Bacteria Tracker Model
## Overview
This repository contains the trained weights and documentation for the **DuMM Bacteria Tracker Model**, a **Graph Neural Network (GNN)** designed for **cell lineage link prediction** in time-lapse microscopy data.
The model uses a custom implementation of a **Parameter Decoupled Network (PDN)** variant within an **Edge-Propagation Message Passing Neural Network (EP-MPNN)** architecture, inspired by recent advancements in dynamic graph representation learning.
## Model Architecture and Implementation Details
* **Framework:** PyTorch and PyTorch Geometric (PyG).
* **Architecture:** `LineageLinkPredictionGNN` (custom `nn.Module`).
* **Core Layers:** Utilizes custom `EP_MPNN_Block` which incorporates **Distance & Similarity (DS)** features and a **Jumping Knowledge (JK)** network for aggregating features across multiple layers.
* **Task:** Binary classification (Link Prediction) on candidate edges between adjacent time frames (continuation or division links).
* **Loss Function:** `BCEWithLogitsLoss`.
* **Input Features:** Node features are scaled using a `StandardScalerTransform` and dynamically used to compute edge attributes (Absolute Difference + Cosine Similarity) via the `DS_block`.
### Input Features
The model uses a **10-dimensional feature vector** for each node (cell) derived from image analysis. These features capture morphological and intensity properties of the bacterial cells across two channels (Phase Contrast and Fluorescence).
| Feature Name | Type | Description |
| :--- | :--- | :--- |
| `area` | Morphological | Area of the cell segment. |
| `centroid_y` | Positional | Y-coordinate of the cell's centroid (critical for 1D growth systems). |
| `axis_major_length` | Morphological | Length of the cell's major axis. |
| `axis_minor_length` | Morphological | Length of the cell's minor axis. |
| `intensity_mean/max/min_phase` | Intensity | Mean, max, and min pixel intensity in the **Phase Contrast** channel. |
| `intensity_mean/max/min_fluor` | Intensity | Mean, max, and min pixel intensity in the **Fluorescence** channel. |
The model was trained on microscopy images of the duplex mother machine developed by the Jun lab (https://jun.ucsd.edu/mother_machine.php)
### Data Preprocessing and Splitting
#### Splitting Strategy
To ensure the model generalizes to future, unseen data, a **time-based temporal split** was employed:
1. **Training Set:** First 60% of unique time frames (`sorted_time_frames[:train_split_idx]`).
2. **Validation Set:** Next 20% of unique time frames (`sorted_time_frames[train_split_idx:val_split_idx]`).
3. **Test Set:** Final 20% of unique time frames (`sorted_time_frames[val_split_idx:]`).
#### Normalization
* **Method:** Node features were normalized using **Standard Scaling (`sklearn.preprocessing.StandardScaler`)**.
* **Fit:** The scaler was **fitted *only* on the training set features** (`all_train_node_features_df`).
* **Application:** The fitted scaler was then applied to transform the features in the Training, Validation, and Test sets via the `StandardScalerTransform`. This avoids data leakage.
#### Graph Creation (Candidate Generation)
Candidate edges (links between cells in adjacent time frames) were generated based on a custom set of geometric and morphological heuristics:
* **Distance Constraint:** Max distance between centroids is limited by `max_dist_link` (default 50.0).
* **Area Ratio Constraints:**
* **Continuation (1-to-1):** `min_area_ratio_continuation` (0.8) to `max_area_ratio_continuation` (1.2).
* **Division (1-to-2):** `min_area_ratio_division` (1.8) to `max_area_ratio_division` (2.2).
### Training Protocol
| Hyperparameter | Value | Description |
| :--- | :--- | :--- |
| **GNN Layers (`num_blocks`)** | 2 | Number of sequential EP-MPNN blocks. |
| **Hidden Channels** | 128 | Dimension for node and edge embeddings. |
| **Optimizer** | Adam | Standard optimization algorithm. |
| **Learning Rate** | 0.001 | Base learning rate. |
| **Weight Decay** | 0.0005 | L2 regularization applied to prevent overfitting. |
| **Batch Size** | 32 | Number of graphs processed per iteration. |
| **Evaluation Metric** | Validation Accuracy (`val_acc`) | Used for saving the `best_link_prediction_model.pt`. |
| **Early Stopping** | Yes | Monitors Validation Loss (`val_loss`) with a **patience of 10 epochs**. |
| **Max Epochs** | 500 | Maximum number of training epochs. |