# train_model.py import torch # Assuming `model` is your trained model torch.save(model.state_dict(), 'mnist_model.pth')