Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions five-et/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Using the Models

This directory contains the five equivariant transformer models described in (insert reference when available).
They were created with TorchMD-Net 0.2.2. They might work with later versions as well, but that is not guaranteed.
They were created with [TorchMD-Net](https://github.com/openmm/spice-dataset/releases/download/1.1/SPICE.hdf5) 0.2.2.
They might work with later versions as well, but that is not guaranteed.

To use them, first install TorchMD-Net by following the instructions at https://github.com/torchmd/torchmd-net.
That involves checking out the source code, creating a conda environment using the provided environment file,
and running `pip` to install it into the environment.
To use them, first install TorchMD-Net with the provided `environment.yml`:

```bash
mamba env create -f environment.yml
conda activate spice-models
```

The files ending in `.ckpt` are checkpoint files containing the trained models. They can be loaded like this:

Expand Down Expand Up @@ -43,27 +47,25 @@ energy, forces = model.forward(types, pos)

# Training New Models

If you want to train new models on the same data, follow these steps.

1. Install OpenFF-Toolkit into the conda environment by executing the command
If you want to train new models on the same data, follow these steps:

1. Create an environment containing torchmd-net 0.2.2, its dependencies, and the openff-toolkit if you haven't done so already:
```bash
mamba env create -f environment.yml
conda activate spice-models
```
conda install -c conda-forge openff-toolkit=0.10.6
2. Run the `createSpiceDataset.py` script, which will download and convert the SPICE dataset to the format used by TorchMD-Net:
```bash
python createSpiceDataset.py
```

2. Download the `SPICE.hdf5` file from https://github.com/openmm/spice-dataset/releases/tag/1.1 and place it
in this directory.
3. Run the `createSpiceDataset.py` script, which converts the dataset to the format used by TorchMD-Net. It
generates a new file `SPICE-processed.hdf5` to use for training.
4. Run the `train.py` script provided by TorchMD-Net. The command will be something like

It generates a new file `SPICE-processed.hdf5` to use for training.
4. Run the `train.py` script to create a model:
```bash
MODELNAME='model1'; mkdir $MODELNAME ; python train.py --conf hparams.yaml --log-dir $MODELNAME
```
python <path to torchmd-net>/scripts/train.py --conf hparams.yaml
```

The file `hparams.yaml` contains the configuration used for training the models. All models here used identical settings
except that `seed` was set to a different value for each one (the numbers 1 through 5). Be sure to use TorchMD-Net 0.2.2,
since later versions made incompatible changes to some of the parameter definitions. Note that although the file
specifies `num_epochs: 1000`, training was halted after 24 hours (when the training job reached the end of its allocated
time). This corresponded to 118 epochs. You can edit the file to try different hyperparameters, or override them with
command line arguments to `train.py`.
except that `seed` was set to a different value for each one (the numbers 1 through 5).

Note that this script uses TorchMD-Net 0.2.2, since later versions made incompatible changes to some of the parameter definitions.

You can edit the file to try different hyperparameters, or override them with command line arguments to `train.py`.
7 changes: 7 additions & 0 deletions five-et/createSpiceDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
('Li', 1): 14, ('Mg', 2): 15, ('N', -1): 16, ('N', 0): 17, ('N', 1): 18, ('Na', 1): 19, ('O', -1): 20,
('O', 0): 21, ('O', 1): 22, ('P', 0): 23, ('P', 1): 24, ('S', -1): 25, ('S', 0): 26, ('S', 1): 27}

# Download the SPICE dataset if it is not already available
import os
if not os.path.exists('SPICE.hdf5'):
print('Downloading SPICE dataset...')
import urllib.request
urllib.request.urlretrieve("https://github.com/openmm/spice-dataset/releases/download/1.1/SPICE.hdf5", "SPICE.hdf5")

infile = h5py.File('SPICE.hdf5')

# First pass: group the samples by total number of atoms.
Expand Down
28 changes: 28 additions & 0 deletions five-et/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: spice-models
channels:
- conda-forge
dependencies:
# torchmd-net dependencies
- ase
- h5py
- matplotlib
- nnpops==0.2
- pip
- pytorch==1.11.0
- pytorch_cluster==1.5.9
- pytorch_geometric==2.0.3
- pytorch_scatter==2.0.8
- pytorch_sparse==0.6.10
- pytorch-lightning==1.6.3
- torchmetrics==0.8.2
- tqdm
# dev tools
- flake8
- pytest
- psutil
# spice data reformatting
- openff-toolkit<0.11.0
# torchmd-net
- pip:
- git+https://github.com/torchmd/torchmd-net.git@0.2.2

3 changes: 1 addition & 2 deletions five-et/hparams.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ dataset_arg: null
dataset_root: SPICE-processed.hdf5
derivative: true
distance_influence: both
distributed_backend: ddp
early_stopping_patience: 20
ema_alpha_dy: 1.0
ema_alpha_y: 1.0
Expand All @@ -37,7 +36,7 @@ max_z: 28
model: equivariant-transformer
neighbor_embedding: true
ngpus: -1
num_epochs: 1000
num_epochs: 118
num_heads: 8
num_layers: 6
num_nodes: 1
Expand Down
170 changes: 170 additions & 0 deletions five-et/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import sys
import os
import argparse
import logging
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.strategies.ddp import DDPStrategy
from torchmdnet.module import LNNP
from torchmdnet import datasets, priors, models
from torchmdnet.data import DataModule
from torchmdnet.models import output_modules
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number


def get_args():
# fmt: off
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--load-model', action=LoadFromCheckpoint, help='Restart training using a model checkpoint') # keep first
parser.add_argument('--conf', '-c', type=open, action=LoadFromFile, help='Configuration yaml file') # keep second
parser.add_argument('--num-epochs', default=300, type=int, help='number of epochs')
parser.add_argument('--batch-size', default=32, type=int, help='batch size')
parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation')
parser.add_argument('--lr-metric', type=str, default='val_loss', choices=['train_loss', 'val_loss'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop')
parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving')
parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up')
parser.add_argument('--early-stopping-patience', type=int, default=30, help='Stop training after this many epochs without improvement')
parser.add_argument('--reset-trainer', type=bool, default=False, help='Reset training metrics (e.g. early stopping, lr) when loading a model checkpoint')
parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay strength')
parser.add_argument('--ema-alpha-y', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of y')
parser.add_argument('--ema-alpha-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy')
parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus')
parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes')
parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision')
parser.add_argument('--log-dir', '-l', default='/tmp/logs', help='log file')
parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test')
parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)')
parser.add_argument('--val-size', type=number, default=0.05, help='Percentage/number of samples in validation set (None to use all remaining samples)')
parser.add_argument('--test-size', type=number, default=0.1, help='Percentage/number of samples in test set (None to use all remaining samples)')
parser.add_argument('--test-interval', type=int, default=10, help='Test interval, one test per n epochs (default: 10)')
parser.add_argument('--save-interval', type=int, default=10, help='Save interval, one save per n epochs (default: 10)')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--num-workers', type=int, default=4, help='Number of workers for data prefetch')
parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log')

# dataset specific
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')
parser.add_argument('--dataset-arg', default=None, type=str, help='Additional dataset argument, e.g. target property for QM9 or molecule for MD17')
parser.add_argument('--coord-files', default=None, type=str, help='Custom coordinate files glob')
parser.add_argument('--embed-files', default=None, type=str, help='Custom embedding files glob')
parser.add_argument('--energy-files', default=None, type=str, help='Custom energy files glob')
parser.add_argument('--force-files', default=None, type=str, help='Custom force files glob')
parser.add_argument('--energy-weight', default=1.0, type=float, help='Weighting factor for energies in the loss function')
parser.add_argument('--force-weight', default=1.0, type=float, help='Weighting factor for forces in the loss function')

# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all__, help='Which model to train')
parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')

# architectural args
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge')
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state')
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model')
parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model')
parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function')
parser.add_argument('--rbf-type', type=str, default='expnorm', choices=list(rbf_class_mapping.keys()), help='Type of distance expansion')
parser.add_argument('--trainable-rbf', type=bool, default=False, help='If distance expansion functions should be trainable')
parser.add_argument('--neighbor-embedding', type=bool, default=False, help='If a neighbor embedding should be applied before interactions')
parser.add_argument('--aggr', type=str, default='add', help='Aggregation operation for CFConv filter output. Must be one of \'add\', \'mean\', or \'max\'')

# Transformer specific
parser.add_argument('--distance-influence', type=str, default='both', choices=['keys', 'values', 'both', 'none'], help='Where distance information is included inside the attention')
parser.add_argument('--attn-activation', default='silu', choices=list(act_class_mapping.keys()), help='Attention activation function')
parser.add_argument('--num-heads', type=int, default=8, help='Number of attention heads')

# other args
parser.add_argument('--derivative', default=False, type=bool, help='If true, take the derivative of the prediction w.r.t coordinates')
parser.add_argument('--cutoff-lower', type=float, default=0.0, help='Lower cutoff in model')
parser.add_argument('--cutoff-upper', type=float, default=5.0, help='Upper cutoff in model')
parser.add_argument('--atom-filter', type=int, default=-1, help='Only sum over atoms with Z > atom_filter')
parser.add_argument('--max-z', type=int, default=100, help='Maximum atomic number that fits in the embedding matrix')
parser.add_argument('--max-num-neighbors', type=int, default=32, help='Maximum number of neighbors to consider in the network')
parser.add_argument('--standardize', type=bool, default=False, help='If true, multiply prediction by dataset std and add mean')
parser.add_argument('--reduce-op', type=str, default='add', choices=['add', 'mean'], help='Reduce operation to apply to atomic predictions')
# fmt: on

args = parser.parse_args()

if args.redirect:
sys.stdout = open(os.path.join(args.log_dir, "log"), "w")
sys.stderr = sys.stdout
logging.getLogger("pytorch_lightning").addHandler(
logging.StreamHandler(sys.stdout)
)

if args.inference_batch_size is None:
args.inference_batch_size = args.batch_size

save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"])

return args


def main():
args = get_args()
pl.seed_everything(args.seed, workers=True)

# initialize data module
data = DataModule(args)
data.prepare_data()
data.setup("fit")

prior = None
if args.prior_model:
assert hasattr(priors, args.prior_model), (
f"Unknown prior model {args['prior_model']}. "
f"Available models are {', '.join(priors.__all__)}"
)
# initialize the prior model
prior = getattr(priors, args.prior_model)(dataset=data.dataset)
args.prior_args = prior.get_init_args()

# initialize lightning module
model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std)

checkpoint_callback = ModelCheckpoint(
dirpath=args.log_dir,
monitor="val_loss",
save_top_k=10, # -1 to save all
every_n_epochs=args.save_interval,
filename="{epoch}-{val_loss:.4f}-{test_loss:.4f}",
)
early_stopping = EarlyStopping("val_loss", patience=args.early_stopping_patience)

tb_logger = pl.loggers.TensorBoardLogger(
args.log_dir, name="tensorbord", version="", default_hp_metric=False
)
csv_logger = CSVLogger(args.log_dir, name="", version="")

trainer = pl.Trainer(
strategy=DDPStrategy(find_unused_parameters=False),
max_epochs=args.num_epochs,
gpus=args.ngpus,
num_nodes=args.num_nodes,
default_root_dir=args.log_dir,
auto_lr_find=False,
resume_from_checkpoint=None if args.reset_trainer else args.load_model,
callbacks=[early_stopping, checkpoint_callback],
logger=[tb_logger, csv_logger],
precision=args.precision,
)

trainer.fit(model, data)

# run test set after completing the fit
model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
trainer = pl.Trainer(logger=[tb_logger, csv_logger])
trainer.test(model, data)


if __name__ == "__main__":
main()