File size: 3,724 Bytes
e484a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
This script preprocesses a dataset of spectra by resampling and labeling the data.

Functions:
- resample_spectrum(x, y, target_len): Resamples a spectrum to a fixed number of points.
- preprocess_dataset(...): Loads, resamples, and applies optional preprocessing steps:
  - baseline correction
  - Savitzky-Golay smoothing
  - min-max normalization

The script expects the dataset directory to contain text files representing spectra.
Each file is:
1. Listed using `list_txt_files()`
2. Labeled using `label_file()`
3. Loaded using `load_spectrum()`
4. Resampled and optionally cleaned
5. Returned as arrays suitable for ML training

Dependencies:
- numpy
- scipy.interpolate, scipy.signal
- sklearn.preprocessing
- list_spectra (custom)
- plot_spectrum (custom)
"""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import numpy as np
from scipy.interpolate import interp1d
from scipy.signal import savgol_filter
from sklearn.preprocessing import minmax_scale
from scripts.discover_raman_files import list_txt_files, label_file
from scripts.plot_spectrum import load_spectrum

# Default resample target
TARGET_LENGTH = 500

# Optional preprocessing steps
def remove_baseline(y):
    """Simple baseline correction using polynomial fitting (order 2)"""
    x = np.arange(len(y))
    coeffs = np.polyfit(x, y, deg=2)
    baseline = np.polyval(coeffs, x)
    return y - baseline

def normalize_spectrum(y):
    """Min-max normalization to [0, 1]"""
    return minmax_scale(y)

def smooth_spectrum(y, window_length=11, polyorder=2):
    """Apply Savitzky-Golay smoothing."""
    return savgol_filter(y, window_length, polyorder)

def resample_spectrum(x, y, target_len=TARGET_LENGTH):
    """Resample a spectrum to a fixed number of points."""
    f_interp = interp1d(x, y, kind='linear', fill_value='extrapolate')
    x_uniform = np.linspace(min(x), max(x), target_len)
    y_uniform = f_interp(x_uniform)
    return y_uniform

def preprocess_dataset(
    dataset_dir,
    target_len=500,
    baseline_correction=False,
    apply_smoothing=False,
    normalize=False
):
    """
    Load, resample, and preprocess all valid spectra in the dataset.
    
    Args:
        dataset_dir (str): Path to the dataset
        target_len (int): Number of points to resample to
        baseline_correction (bool): Whether to apply baseline removal
        apply_smoothing (bool): Whether to apply Savitzky-Golay smoothing
        normalize (bool): Whether to apply min-max normalization

    Returns:
        X (np.ndarray): Preprocessed spectra
        y (np.ndarray): Corresponding labels
    """
    txt_paths = list_txt_files(dataset_dir)
    X, y_labels = [], []

    for path in txt_paths:
        label = label_file(path)
        if label is None:
            continue

        x_raw, y_raw = load_spectrum(path)
        if len(x_raw) < 10:
            continue  # Skip files with too few points

        # Resample
        y_processed = resample_spectrum(x_raw, y_raw, target_len=target_len)

        # Optional preprocessing
        if baseline_correction:
            y_processed = remove_baseline(y_processed)
        if apply_smoothing:
            y_processed = smooth_spectrum(y_processed)
        if normalize:
            y_processed = normalize_spectrum(y_processed)

        X.append(y_processed)
        y_labels.append(label)

    return np.array(X), np.array(y_labels)

# Optional: Run directly for testing
if __name__ == "__main__":
    dataset_dir = os.path.join(
        "datasets", "rdwp"
    )
    X, y = preprocess_dataset(dataset_dir)

    print(f"X shape: {X.shape}")
    print(f"y shape: {y.shape}")
    print(f"Label distribution: {np.bincount(y)}")