File size: 4,434 Bytes
ab66b4d 6dd467c b37953c c21668e 96b49a7 ab66b4d 6dd467c c21668e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
---
library_name: project-lighter
tags:
- lighter
- model_hub_mixin
- pytorch_model_hub_mixin
language: en
license: apache-2.0
arxiv: 2501.09001
---
# Whole Body Segmentation
This model is a whole body segmentation model based on the SegResNet architecture. It was fine-tuned on CT-FM
## Running instructions
# Whole Body Segmentation Inference
This notebook demonstrates how to:
1. Load a pre-trained whole body segmentation model from HuggingFace Hub
2. Set up preprocessing and postprocessing pipelines
3. Perform sliding window inference on CT volumes
4. Save the segmentation results
The model segments 118 different anatomical structures from CT scans.
## Setup
Install requirements and import necessary packages
```python
# Install lighter_zoo package
%pip install lighter_zoo -U -qq
```
Note: you may need to restart the kernel to use updated packages.
```python
# Imports
import torch
from lighter_zoo import SegResNet
from monai.transforms import (
Compose, LoadImage, EnsureType, Orientation,
ScaleIntensityRange, CropForeground, Invert,
Activations, AsDiscrete, KeepLargestConnectedComponent,
SaveImage
)
from monai.inferers import SlidingWindowInferer
```
Note: you may need to restart the kernel to use updated packages.
## Load Model
Download and initialize the pre-trained model from HuggingFace Hub
```python
# Load pre-trained model
model = SegResNet.from_pretrained(
"project-lighter/whole_body_segmentation",
force_download=True
)
```
config.json: 0%| | 0.00/162 [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/349M [00:00<?, ?B/s]
## Configure Inference
Set up sliding window inference for processing large volumes
```python
# Configure sliding window inference
inferer = SlidingWindowInferer(
roi_size=[96, 160, 160], # Size of patches to process
sw_batch_size=2, # Number of windows to process in parallel
overlap=0.625, # Overlap between windows (reduces boundary artifacts)
mode="gaussian" # Gaussian weighting for overlap regions
)
```
## Setup Processing Pipelines
Define preprocessing and postprocessing transforms
```python
# Preprocessing pipeline
preprocess = Compose([
LoadImage(ensure_channel_first=True), # Load image and ensure channel dimension
EnsureType(), # Ensure correct data type
Orientation(axcodes="SPL"), # Standardize orientation
# Scale intensity to [0,1] range, clipping outliers
ScaleIntensityRange(
a_min=-1024, # Min HU value
a_max=2048, # Max HU value
b_min=0, # Target min
b_max=1, # Target max
clip=True # Clip values outside range
),
CropForeground() # Remove background to reduce computation
])
# Postprocessing pipeline
postprocess = Compose([
Activations(softmax=True), # Apply softmax to get probabilities
AsDiscrete(argmax=True, dtype=torch.int32), # Convert to class labels
KeepLargestConnectedComponent(), # Remove small disconnected regions
Invert(transform=preprocess), # Restore original space
# Save the result
SaveImage(output_dir="./segmentations")
])
```
/home/suraj/miniconda3/lib/python3.10/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
warn_deprecated(argname, msg, warning_category)
## Run Inference
Process an input CT scan and generate segmentation
```python
# Input path
input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz"
# Preprocess input
input_tensor = preprocess(input_path)
# Run inference
with torch.no_grad():
output = inferer(input_tensor.unsqueeze(dim=0), model)[0]
# Copy metadata from input
output.applied_operations = input_tensor.applied_operations
output.affine = input_tensor.affine
# Postprocess and save result
result = postprocess(output[0])
print("✅ Segmentation completed and saved")
```
2025-01-16 18:41:57,674 INFO image_writer.py:197 - writing: /home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/segmentations/0/0_trans.nii.gz
✅ Segmentation completed and saved
|