File size: 4,633 Bytes
e484a46
 
 
86d081c
e484a46
 
 
 
 
 
 
ba24c6a
e484a46
 
ba24c6a
 
e484a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba24c6a
 
e484a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba24c6a
 
e484a46
ba24c6a
 
e484a46
ba24c6a
 
e484a46
ba24c6a
 
e484a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86d081c
 
e484a46
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from pathlib import Path

import argparse
import warnings
import logging

import numpy as np
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from scripts.preprocess_dataset import resample_spectrum, label_file
from models.registry import choices as model_choices, build as build_model



# =============================================
# ✅ Raman-Only Inference Script
# This script supports prediction on a single Raman spectrum (.txt file).
# FTIR inference has been deprecated and removed for scientific integrity.
# See: @raman-pipeline-focus-milestone
# =============================================


warnings.filterwarnings(
    "ignore",
    message=".*weights_only=False.*",
    category=FutureWarning
)


def load_raman_spectrum(filepath):
    """Load a 2-column Raman spectrum from a .txt file"""
    x_vals, y_vals = [], []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) == 2:
                try:
                    x, y = float(parts[0]), float(parts[1])
                    x_vals.append(x)
                    y_vals.append(y)
                except ValueError:
                    continue
    return np.array(x_vals), np.array(y_vals)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run inference on a single Raman spectrum (.txt file)."
    )
    parser.add_argument("--arch", type=str, default="figure2", choices=model_choices(),
                    help="Model architecture (must match the provided weights).")  # NEW
    parser.add_argument(
        "--target-len", type=int, required=True,
        help="Target length to match model input"
    )
    parser.add_argument(
        "--input", required=True,
        help="Path to Raman .txt file."
    )
    parser.add_argument(
        "--model", default="random",
        help="Path to .pth model file, or specify 'random' to use untrained weights."
    )
    parser.add_argument(
        "--output", default=None,
        help="Where to write prediction result. If omitted, prints to stdout."
    )
    verbosity = parser.add_mutually_exclusive_group()
    verbosity.add_argument(
        "--quiet", action="store_true",
        help="Show only warnings and errors"
    )
    verbosity.add_argument(
        "--verbose", action="store_true",
        help="Show debug-level logging"
    )

    args = parser.parse_args()

    # configure logging
    level = logging.INFO
    if args.verbose:
        level = logging.DEBUG
    elif args.quiet:
        level = logging.WARNING
    logging.basicConfig(level=level, format="%(levelname)s: %(message)s")

    try:
        # Load & preprocess Raman spectrum
        if os.path.isdir(args.input):
            parser.error(f"Input must be a single Raman .txt file, got a directory: {args.input}")

        x_raw, y_raw = load_raman_spectrum(args.input)
        if len(x_raw) < 10:
            parser.error("Spectrum too short for inference.")

        data = resample_spectrum(x_raw, y_raw, target_len=args.target_len)
        # Shape = (1, 1, target_len) — valid input for Raman inference
        input_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)


        # 2. Load Model (via shared model registry)
        model = build_model(args.arch, args.target_len).to(DEVICE)
        if args.model != "random":
            state = torch.load(args.model, map_location="cpu") # broad compatibility
            model.load_state_dict(state)
        model.eval()
        
        

        # 3. Inference
        with torch.no_grad():
            logits = model(input_tensor)
            pred = torch.argmax(logits, dim=1).item()

        # 4. True Label
        try:
            true_label = label_file(args.input)
            label_str = f"True Label: {true_label}"
        except FileNotFoundError:
            label_str = "True Label: Unknown"

        result = f"Predicted Label: {pred} {label_str}\nRaw Logits: {logits.tolist()}"
        logging.info(result)

        # 5. Save or stdout
        if args.output:
            # ensure parent dir exists (e.g., outputs/inference/)
            Path(args.output).parent.mkdir(parents=True, exist_ok=True)
            with open(args.output, "w", encoding="utf-8") as fout:
                fout.write(result)
            logging.info("Result saved to %s", args.output)

        sys.exit(0)

    except Exception as e:
        logging.error(e)
        sys.exit(1)