saketh11's picture
Upload djmgnn-pfas-finetuned model
571c672 verified
metadata
license: mit
tags:
  - molecular-property-prediction
  - graph-neural-network
  - chemistry
  - pytorch
datasets:
  - qm9
  - spice
  - pfas
metrics:
  - mse
  - mae
pipeline_tag: graph-ml
library_name: moml

djmgnn-pfas-finetuned

Model Description

This is a fine-tuned DJMGNN (Dense Jump Multi-Graph Neural Network) model for molecular property prediction. The model is designed to predict various molecular properties from graph representations of molecules.

Architecture

  • Model Type: Dense Jump Multi-Graph Neural Network (DJMGNN)
  • Framework: PyTorch
  • Library: MoML (Molecular Machine Learning)
  • Task: Molecular Property Prediction

Model Architecture Details

  • Hidden Dimensions: 128
  • Number of Blocks: 3
  • Layers per Block: 6
  • Input Node Dimensions: 11
  • Input Edge Dimensions: 0
  • Node Output Dimensions: 3
  • Graph Output Dimensions: 19
  • Energy Output Dimensions: 1
  • Jumping Knowledge Mode: cat
  • Dropout Rate: 0.2
  • Uses Supernode: True
  • Uses RBF Features: True
  • RBF K: 32

Training Details

Datasets

The model was trained on the following datasets:

  • QM9: Quantum mechanical properties of small molecules
  • SPICE: Molecular dynamics data with forces and energies
  • PFAS: Per- and polyfluoroalkyl substances dataset

Training Configuration

batch_size: 32
early_stopping: true
epochs: 100
learning_rate: 0.001
optimizer: Adam
patience: 10
validation_split: 0.2

Usage

Loading the Model

import torch
from moml.models.mgnn.djmgnn import DJMGNN

# Load the model
model = DJMGNN(
    in_node_dim=11,
    in_edge_dim=0,
    hidden_dim=128,
    n_blocks=3,
    layers_per_block=6,
    node_output_dims=3,
    graph_output_dims=19,
    energy_output_dims=1,
    jk_mode="cat",
    dropout=0.2,
    use_supernode=true,
    use_rbf=true,
    rbf_K=32
)

# Load the checkpoint
checkpoint = torch.load("path/to/pytorch_model.pt", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

Making Predictions

# Assuming you have a molecular graph 'data' (torch_geometric.data.Data object)
with torch.no_grad():
    output = model(
        x=data.x,
        edge_index=data.edge_index,
        edge_attr=data.edge_attr,
        batch=data.batch
    )
    
    # Extract predictions
    node_predictions = output["node_pred"]      # Node-level predictions
    graph_predictions = output["graph_pred"]    # Graph-level predictions
    energy_predictions = output["energy_pred"]  # Energy predictions

Model Performance

This model was fine-tuned from a base DJMGNN model on PFAS-specific data.

Citation

If you use this model in your research, please cite:

@misc{djmgnn_model,
  title={DJMGNN: Dense Jump Multi-Graph Neural Network for Molecular Property Prediction},
  author={Your Name},
  year={2024},
  url={https://github.com/SAKETH11111/MoML-CA}
}

License

This model is released under the MIT License.

Contact

For questions or issues, please contact sakethbaddam10@gmail.com or open an issue in the GitHub repository.