Spaces:
Running
Running
devjas1
(chore): unify model selection via shared registry (train uses 'choices()'/'build()')
218c86b
import os | |
import sys | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from datetime import datetime | |
import argparse, numpy as np, torch | |
from torch.utils.data import TensorDataset, DataLoader | |
from sklearn.model_selection import StratifiedKFold | |
from sklearn.metrics import confusion_matrix | |
import random | |
import json | |
# Reproducibility | |
SEED = 42 | |
random.seed(SEED) | |
np.random.seed(SEED) | |
torch.manual_seed(SEED) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(SEED) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
# Add project-specific imports | |
from scripts.preprocess_dataset import preprocess_dataset | |
from models.registry import choices as model_choices, build as build_model | |
# Argument parser for CLI usage | |
parser = argparse.ArgumentParser( | |
description="Run 10-fold CV on Raman data with optional preprocessing.") | |
parser.add_argument("--target-len", type=int, default=500) | |
parser.add_argument("--baseline", action="store_true") | |
parser.add_argument("--smooth", action="store_true") | |
parser.add_argument("--normalize", action="store_true") | |
parser.add_argument("--batch-size", type=int, default=16) | |
parser.add_argument("--epochs", type=int, default=10) | |
parser.add_argument("--learning-rate", type=float, default=1e-3) | |
parser.add_argument("--model", type=str, default="figure2", choices=model_choices()) | |
args = parser.parse_args() | |
# Constants | |
# Raman-only dataset (RDWP) | |
DATASET_PATH = 'datasets/rdwp' | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu') | |
NUM_FOLDS = 10 | |
# Ensure output dirs exist | |
os.makedirs("outputs", exist_ok=True) | |
os.makedirs("outputs/logs", exist_ok=True) | |
print("Preprocessing Configuration:") | |
print(f" Resample to : {args.target_len}") | |
print(f" Baseline Correct: {'β ' if args.baseline else 'β'}") | |
print(f" Smoothing : {'β ' if args.smooth else 'β'}") | |
print(f" Normalization : {'β ' if args.normalize else 'β'}") | |
# Load + Preprocess data | |
print("π Loading and preprocessing data ...") | |
X, y = preprocess_dataset( | |
DATASET_PATH, | |
target_len=args.target_len, | |
baseline_correction=args.baseline, | |
apply_smoothing=args.smooth, | |
normalize=args.normalize | |
) | |
X, y = np.array(X, np.float32), np.array(y, np.int64) | |
print(f"β Data Loaded: {X.shape[0]} samples, {X.shape[1]} features each.") | |
print(f"π Using model: {args.model}") | |
# CV | |
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42) | |
fold_accuracies = [] | |
all_conf_matrices = [] | |
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1): | |
print(f"\nπ Fold {fold}/{NUM_FOLDS}") | |
X_train, X_val = X[train_idx], X[val_idx] | |
y_train, y_val = y[train_idx], y[val_idx] | |
train_loader = DataLoader( | |
TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long)), | |
batch_size=args.batch_size, shuffle=True) | |
val_loader = DataLoader( | |
TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long))) | |
# Model selection | |
model = build_model(args.model, args.target_len).to(DEVICE) | |
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | |
criterion = torch.nn.CrossEntropyLoss() | |
for epoch in range(args.epochs): | |
model.train() | |
RUNNING_LOSS = 0.0 | |
for inputs, labels in train_loader: | |
inputs = inputs.unsqueeze(1).to(DEVICE) | |
labels = labels.to(DEVICE) | |
optimizer.zero_grad() | |
loss = criterion(model(inputs), labels) | |
loss.backward() | |
optimizer.step() | |
RUNNING_LOSS += loss.item() | |
# After fold loop (outside the epoch loop), print 1 line: | |
print(f"β Fold {fold} done. Final loss: {RUNNING_LOSS:.4f}") | |
# Evaluation | |
model.eval() | |
all_true, all_pred = [], [] | |
with torch.no_grad(): | |
for inputs, labels in val_loader: | |
inputs = inputs.unsqueeze(1).to(DEVICE) | |
labels = labels.to(DEVICE) | |
outputs = model(inputs) | |
_, predicted = torch.max(outputs, 1) | |
all_true.extend(labels.cpu().numpy()) | |
all_pred.extend(predicted.cpu().numpy()) | |
acc = 100 * np.mean(np.array(all_true) == np.array(all_pred)) | |
fold_accuracies.append(acc) | |
all_conf_matrices.append(confusion_matrix(all_true, all_pred)) | |
print(f"β Fold {fold} Accuracy: {acc:.2f}%") | |
# Save model checkpoint **after** final fold | |
model_path = f"outputs/{args.model}_model.pth" | |
torch.save(model.state_dict(), model_path) | |
# Summary | |
mean_acc, std_acc = np.mean(fold_accuracies), np.std(fold_accuracies) | |
print("\nπ Cross-Validation Results:") | |
for i, a in enumerate(fold_accuracies, 1): | |
print(f"Fold {i}: {a:.2f}%") | |
print(f"\nβ Mean Accuracy: {mean_acc:.2f}% Β± {std_acc:.2f}%") | |
print(f"β Model saved to {model_path}") | |
# Save diagnostics | |
def save_diagnostics_log(fold_acc, confs, args_param, output_path): | |
fold_metrics = [ | |
{"fold": i + 1, "accuracy": float(a), "confusion_matrix": c.tolist()} | |
for i, (a, c) in enumerate(zip(fold_acc, confs)) | |
] | |
log = { | |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"preprocessing": { | |
"target_len": args_param.target_len, | |
"baseline": args_param.baseline, | |
"smooth": args_param.smooth, | |
"normalize": args_param.normalize, | |
}, | |
"fold_metrics": fold_metrics, | |
"overall": { | |
"mean_accuracy": float(np.mean(fold_acc)), | |
"std_accuracy": float(np.std(fold_acc)), | |
"num_folds": len(fold_acc), | |
"batch_size": args_param.batch_size, | |
"epochs": args_param.epochs, | |
"learning_rate": args_param.learning_rate, | |
"device": str(DEVICE) | |
} | |
} | |
with open(output_path, "w", encoding="utf-8") as f: | |
json.dump(log, f, indent=2) | |
print(f"π§ Diagnostics written to {output_path}") | |
log_path = f"outputs/logs/raman_{args.model}_diagnostics.json" | |
save_diagnostics_log(fold_accuracies, all_conf_matrices, args, log_path) |