|
--- |
|
license: mit |
|
tags: |
|
- molecular-property-prediction |
|
- graph-neural-network |
|
- chemistry |
|
- pytorch |
|
- molecular-dynamics |
|
- force-fields |
|
datasets: |
|
- qm9 |
|
- spice |
|
- pfas |
|
metrics: |
|
- mse |
|
- mae |
|
pipeline_tag: graph-ml |
|
library_name: moml |
|
--- |
|
|
|
# MoML-CA: Molecular Machine Learning for Coarse-grained Applications |
|
|
|
This repository contains the **DJMGNN** (Dense Jump Multi-Graph Neural Network) models from the MoML-CA project, designed for molecular property prediction and coarse-grained molecular modeling applications. |
|
|
|
## π Models Available |
|
|
|
### 1. Base Model (`base_model/`) |
|
- **Pre-trained DJMGNN** model trained on multiple molecular datasets |
|
- **Datasets**: QM9, SPICE, PFAS |
|
- **Task**: General molecular property prediction |
|
- **Use case**: Starting point for transfer learning or direct molecular property prediction |
|
|
|
### 2. Fine-tuned Model (`finetuned_model/`) |
|
- **PFAS-specialized DJMGNN** model fine-tuned for PFAS molecular properties |
|
- **Base**: Built upon the base model |
|
- **Specialization**: Per- and polyfluoroalkyl substances (PFAS) |
|
- **Use case**: Optimized for PFAS molecular property prediction |
|
|
|
## ποΈ Architecture |
|
|
|
**DJMGNN** (Dense Jump Multi-Graph Neural Network) features: |
|
- **Multi-task learning**: Simultaneous node-level and graph-level predictions |
|
- **Jump connections**: Enhanced information flow between layers |
|
- **Dense blocks**: Improved gradient flow and feature reuse |
|
- **Supernode aggregation**: Global graph representation |
|
- **RBF features**: Radial basis function encoding for distance information |
|
|
|
### Architecture Details |
|
- **Hidden Dimensions**: 128 |
|
- **Number of Blocks**: 3-4 |
|
- **Layers per Block**: 6 |
|
- **Input Node Dimensions**: 11-29 (depending on featurization) |
|
- **Node Output Dimensions**: 3 (forces/properties per atom) |
|
- **Graph Output Dimensions**: 19 (molecular descriptors) |
|
- **Energy Output Dimensions**: 1 (total energy) |
|
|
|
## π Training Details |
|
|
|
### Datasets |
|
- **QM9**: ~130k small organic molecules with quantum mechanical properties |
|
- **SPICE**: Molecular dynamics trajectories with forces and energies |
|
- **PFAS**: Per- and polyfluoroalkyl substances dataset with specialized descriptors |
|
|
|
### Training Configuration |
|
- **Optimizer**: Adam |
|
- **Learning Rate**: 3e-5 (fine-tuning), 1e-3 (base training) |
|
- **Batch Size**: 4-8 (node tasks), 8-32 (graph tasks) |
|
- **Loss Functions**: MSE for regression, weighted multi-task loss |
|
- **Regularization**: Dropout (0.2), gradient clipping |
|
|
|
## π§ Usage |
|
|
|
### Loading the Base Model |
|
|
|
```python |
|
import torch |
|
from moml.models.mgnn.djmgnn import DJMGNN |
|
|
|
# Initialize model architecture |
|
model = DJMGNN( |
|
in_node_dim=29, # Adjust based on your featurization |
|
in_edge_dim=0, |
|
hidden_dim=128, |
|
n_blocks=4, |
|
layers_per_block=6, |
|
node_output_dims=3, |
|
graph_output_dims=19, |
|
energy_output_dims=1, |
|
jk_mode="attention", |
|
dropout=0.2, |
|
use_supernode=True, |
|
use_rbf=True, |
|
rbf_K=32 |
|
) |
|
|
|
# Load base model checkpoint |
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
"https://huggingface.co/saketh11/MoML-CA/resolve/main/base_model/pytorch_model.pt" |
|
) |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
model.eval() |
|
``` |
|
|
|
### Loading the Fine-tuned Model |
|
|
|
```python |
|
# Same architecture setup as above, then: |
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
"https://huggingface.co/saketh11/MoML-CA/resolve/main/finetuned_model/pytorch_model.pt" |
|
) |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
model.eval() |
|
``` |
|
|
|
### Making Predictions |
|
|
|
```python |
|
# Assuming you have a molecular graph 'data' (torch_geometric.data.Data) |
|
with torch.no_grad(): |
|
output = model( |
|
x=data.x, |
|
edge_index=data.edge_index, |
|
edge_attr=data.edge_attr, |
|
batch=data.batch |
|
) |
|
|
|
# Extract predictions |
|
node_predictions = output["node_pred"] # Per-atom properties/forces |
|
graph_predictions = output["graph_pred"] # Molecular descriptors |
|
energy_predictions = output["energy_pred"] # Total energy |
|
``` |
|
|
|
## π Performance |
|
|
|
### Base Model |
|
- Trained on diverse molecular datasets for robust generalization |
|
- Multi-task learning across node and graph-level properties |
|
- Suitable for transfer learning to specialized domains |
|
|
|
### Fine-tuned Model |
|
- Specialized for PFAS molecular properties |
|
- Improved accuracy on fluorinated compounds |
|
- Optimized for environmental and toxicological applications |
|
|
|
## π¬ Applications |
|
|
|
- **Molecular Property Prediction**: HOMO/LUMO, dipole moments, polarizability |
|
- **Force Field Development**: Atomic forces and energies for MD simulations |
|
- **Environmental Chemistry**: PFAS behavior and properties |
|
- **Drug Discovery**: Molecular screening and optimization |
|
- **Materials Science**: Polymer and surface properties |
|
|
|
## π Links |
|
|
|
- **GitHub Repository**: [SAKETH11111/MoML-CA](https://github.com/SAKETH11111/MoML-CA) |
|
- **Documentation**: See repository README and docs/ |
|
- **Issues**: Report bugs and request features on GitHub |
|
|
|
## π License |
|
|
|
This project is licensed under the MIT License. See the LICENSE file for details. |
|
|
|
## π₯ Contributing |
|
|
|
Contributions are welcome! Please see the contributing guidelines in the GitHub repository. |
|
|
|
--- |
|
|
|
*For questions or support, please open an issue in the [GitHub repository](https://github.com/SAKETH11111/MoML-CA).* |