Spaces:
Running
Running
import sys | |
import os | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from pathlib import Path | |
import argparse | |
import warnings | |
import logging | |
import numpy as np | |
import torch | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
from scripts.preprocess_dataset import resample_spectrum, label_file | |
from models.registry import choices as model_choices, build as build_model | |
# ============================================= | |
# β Raman-Only Inference Script | |
# This script supports prediction on a single Raman spectrum (.txt file). | |
# FTIR inference has been deprecated and removed for scientific integrity. | |
# See: @raman-pipeline-focus-milestone | |
# ============================================= | |
warnings.filterwarnings( | |
"ignore", | |
message=".*weights_only=False.*", | |
category=FutureWarning | |
) | |
def load_raman_spectrum(filepath): | |
"""Load a 2-column Raman spectrum from a .txt file""" | |
x_vals, y_vals = [], [] | |
with open(filepath, 'r', encoding='utf-8') as f: | |
for line in f: | |
parts = line.strip().split() | |
if len(parts) == 2: | |
try: | |
x, y = float(parts[0]), float(parts[1]) | |
x_vals.append(x) | |
y_vals.append(y) | |
except ValueError: | |
continue | |
return np.array(x_vals), np.array(y_vals) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Run inference on a single Raman spectrum (.txt file)." | |
) | |
parser.add_argument("--arch", type=str, default="figure2", choices=model_choices(), | |
help="Model architecture (must match the provided weights).") # NEW | |
parser.add_argument( | |
"--target-len", type=int, required=True, | |
help="Target length to match model input" | |
) | |
parser.add_argument( | |
"--input", required=True, | |
help="Path to Raman .txt file." | |
) | |
parser.add_argument( | |
"--model", default="random", | |
help="Path to .pth model file, or specify 'random' to use untrained weights." | |
) | |
parser.add_argument( | |
"--output", default=None, | |
help="Where to write prediction result. If omitted, prints to stdout." | |
) | |
verbosity = parser.add_mutually_exclusive_group() | |
verbosity.add_argument( | |
"--quiet", action="store_true", | |
help="Show only warnings and errors" | |
) | |
verbosity.add_argument( | |
"--verbose", action="store_true", | |
help="Show debug-level logging" | |
) | |
args = parser.parse_args() | |
# configure logging | |
level = logging.INFO | |
if args.verbose: | |
level = logging.DEBUG | |
elif args.quiet: | |
level = logging.WARNING | |
logging.basicConfig(level=level, format="%(levelname)s: %(message)s") | |
try: | |
# Load & preprocess Raman spectrum | |
if os.path.isdir(args.input): | |
parser.error(f"Input must be a single Raman .txt file, got a directory: {args.input}") | |
x_raw, y_raw = load_raman_spectrum(args.input) | |
if len(x_raw) < 10: | |
parser.error("Spectrum too short for inference.") | |
data = resample_spectrum(x_raw, y_raw, target_len=args.target_len) | |
# Shape = (1, 1, target_len) β valid input for Raman inference | |
input_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE) | |
# 2. Load Model (via shared model registry) | |
model = build_model(args.arch, args.target_len).to(DEVICE) | |
if args.model != "random": | |
state = torch.load(args.model, map_location="cpu") # broad compatibility | |
model.load_state_dict(state) | |
model.eval() | |
# 3. Inference | |
with torch.no_grad(): | |
logits = model(input_tensor) | |
pred = torch.argmax(logits, dim=1).item() | |
# 4. True Label | |
try: | |
true_label = label_file(args.input) | |
label_str = f"True Label: {true_label}" | |
except FileNotFoundError: | |
label_str = "True Label: Unknown" | |
result = f"Predicted Label: {pred} {label_str}\nRaw Logits: {logits.tolist()}" | |
logging.info(result) | |
# 5. Save or stdout | |
if args.output: | |
# ensure parent dir exists (e.g., outputs/inference/) | |
Path(args.output).parent.mkdir(parents=True, exist_ok=True) | |
with open(args.output, "w", encoding="utf-8") as fout: | |
fout.write(result) | |
logging.info("Result saved to %s", args.output) | |
sys.exit(0) | |
except Exception as e: | |
logging.error(e) | |
sys.exit(1) | |