Spaces:
Running
Running
""" | |
π MODEL DESIGNATION: | |
Figure2CNN is validated ONLY for RAMAN spectra input. | |
Any use for FTIR modeling is invalid and deprecated. | |
See milestone: @figure2cnn-raman-only-milestone | |
""" | |
import torch | |
import torch.nn as nn | |
class ResidualBlock1D(nn.Module): | |
""" | |
Basic 1-D residual block: | |
Conv1d -> ReLU -> Conv1d (+ skip connection). | |
If channel count changes, a 1x1 Conv aligns the skip path. | |
""" | |
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3): | |
super().__init__() | |
padding = kernel_size // 2 | |
self.conv1 = nn.Conv1d(in_channels, out_channels, | |
kernel_size, padding=padding) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = nn.Conv1d(out_channels, out_channels, | |
kernel_size, padding=padding) | |
self.skip = ( | |
nn.Identity() | |
if in_channels == out_channels | |
else nn.Conv1d(in_channels, out_channels, kernel_size=1) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
identity = self.skip(x) | |
out = self.relu(self.conv1(x)) | |
out = self.conv2(out) | |
return self.relu(out + identity) | |
def describe_model(self): | |
"""Print architecture and flattened size (for debug). """ | |
print(r"\n Model Summary:") | |
print(r" - Conv Block: 4 Layers") | |
print(f" - Input length: {self.flattened_size} after conv/pool") | |
print(f" - Classifier: {self.classifier}\n") | |
class ResNet1D(nn.Module): | |
""" | |
Lightweight 1-D ResNet for Raman spectra (length 500, single channel). | |
""" | |
def __init__(self, input_length: int = 500, num_classes: int = 2): | |
super().__init__() | |
# Three residual stages | |
self.stage1 = ResidualBlock1D(1, 16) | |
self.stage2 = ResidualBlock1D(16, 32) | |
self.stage3 = ResidualBlock1D(32, 64) | |
# Global aggregation + classifier | |
self.global_pool = nn.AdaptiveAvgPool1d(1) # -> [B, 64, 1] | |
self.fc = nn.Linear(64, num_classes) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.stage1(x) | |
x = self.stage2(x) | |
x = self.stage3(x) | |
x = self.global_pool(x).squeeze(-1) # -> [B, 64] | |
return self.fc(x) | |