From 70d1a25c7b2f179015e76aeb47421c52ed3eb12d Mon Sep 17 00:00:00 2001 From: John Chodera Date: Fri, 16 Sep 2022 23:56:07 -0400 Subject: [PATCH 1/2] Attempt to make model production more reproducible --- five-et/README.md | 45 +++++---- five-et/createSpiceDataset.py | 7 ++ five-et/environment.yml | 28 ++++++ five-et/hparams.yaml | 2 +- five-et/train.py | 170 ++++++++++++++++++++++++++++++++++ 5 files changed, 228 insertions(+), 24 deletions(-) create mode 100644 five-et/environment.yml create mode 100644 five-et/train.py diff --git a/five-et/README.md b/five-et/README.md index 502281a..a2937d2 100644 --- a/five-et/README.md +++ b/five-et/README.md @@ -1,11 +1,15 @@ # Using the Models This directory contains the five equivariant transformer models described in (insert reference when available). -They were created with TorchMD-Net 0.2.2. They might work with later versions as well, but that is not guaranteed. +They were created with [TorchMD-Net](https://github.com/openmm/spice-dataset/releases/download/1.1/SPICE.hdf5) 0.2.2. +They might work with later versions as well, but that is not guaranteed. -To use them, first install TorchMD-Net by following the instructions at https://github.com/torchmd/torchmd-net. -That involves checking out the source code, creating a conda environment using the provided environment file, -and running `pip` to install it into the environment. +To use them, first install TorchMD-Net with the provided `environment.yml`: + +```bash +mamba env create -f environment.yml +conda activate spice-models +``` The files ending in `.ckpt` are checkpoint files containing the trained models. They can be loaded like this: @@ -43,27 +47,22 @@ energy, forces = model.forward(types, pos) # Training New Models -If you want to train new models on the same data, follow these steps. - -1. Install OpenFF-Toolkit into the conda environment by executing the command +If you want to train new models on the same data, follow these steps: +1. Create an environment containing torchmd-net 0.2.2, its dependencies, and the openff-toolkit if you haven't done so already: +```bash +mamba env create -f environment.yml +conda activate spice-models ``` -conda install -c conda-forge openff-toolkit=0.10.6 +2. Run the `createSpiceDataset.py` script, which will download and convert the SPICE dataset to the format used by TorchMD-Net. + It generates a new file `SPICE-processed.hdf5` to use for training. +4. Run the `train.py` script: +```bash +python train.py --conf hparams.yaml ``` +The file `hparams.yaml` contains the configuration used for training the models. All models here used identical settings +except that `seed` was set to a different value for each one (the numbers 1 through 5). -2. Download the `SPICE.hdf5` file from https://github.com/openmm/spice-dataset/releases/tag/1.1 and place it - in this directory. -3. Run the `createSpiceDataset.py` script, which converts the dataset to the format used by TorchMD-Net. It - generates a new file `SPICE-processed.hdf5` to use for training. -4. Run the `train.py` script provided by TorchMD-Net. The command will be something like - -``` -python /scripts/train.py --conf hparams.yaml -``` +Note that this script uses TorchMD-Net 0.2.2, since later versions made incompatible changes to some of the parameter definitions. -The file `hparams.yaml` contains the configuration used for training the models. All models here used identical settings -except that `seed` was set to a different value for each one (the numbers 1 through 5). Be sure to use TorchMD-Net 0.2.2, -since later versions made incompatible changes to some of the parameter definitions. Note that although the file -specifies `num_epochs: 1000`, training was halted after 24 hours (when the training job reached the end of its allocated -time). This corresponded to 118 epochs. You can edit the file to try different hyperparameters, or override them with -command line arguments to `train.py`. +You can edit the file to try different hyperparameters, or override them with command line arguments to `train.py`. diff --git a/five-et/createSpiceDataset.py b/five-et/createSpiceDataset.py index 8e9b606..50f55a7 100644 --- a/five-et/createSpiceDataset.py +++ b/five-et/createSpiceDataset.py @@ -9,6 +9,13 @@ ('Li', 1): 14, ('Mg', 2): 15, ('N', -1): 16, ('N', 0): 17, ('N', 1): 18, ('Na', 1): 19, ('O', -1): 20, ('O', 0): 21, ('O', 1): 22, ('P', 0): 23, ('P', 1): 24, ('S', -1): 25, ('S', 0): 26, ('S', 1): 27} +# Download the SPICE dataset if it is not already available +import os +if not os.path.exists('SPICE.hdf5'): + print('Downloading SPICE dataset...') + import urllib.request + urllib.request.urlretrieve("https://github.com/openmm/spice-dataset/releases/download/1.1/SPICE.hdf5", "SPICE.hdf5") + infile = h5py.File('SPICE.hdf5') # First pass: group the samples by total number of atoms. diff --git a/five-et/environment.yml b/five-et/environment.yml new file mode 100644 index 0000000..3cd3727 --- /dev/null +++ b/five-et/environment.yml @@ -0,0 +1,28 @@ +name: spice-models +channels: + - conda-forge +dependencies: + # torchmd-net dependencies + - ase + - h5py + - matplotlib + - nnpops==0.2 + - pip + - pytorch==1.11.0 + - pytorch_cluster==1.5.9 + - pytorch_geometric==2.0.3 + - pytorch_scatter==2.0.8 + - pytorch_sparse==0.6.10 + - pytorch-lightning==1.6.3 + - torchmetrics==0.8.2 + - tqdm + # dev tools + - flake8 + - pytest + - psutil + # spice data reformatting + - openff-toolkit<0.11.0 + # torchmd-net + - pip: + - git+https://github.com/torchmd/torchmd-net.git@0.2.2 + diff --git a/five-et/hparams.yaml b/five-et/hparams.yaml index 10f703d..51c79a4 100644 --- a/five-et/hparams.yaml +++ b/five-et/hparams.yaml @@ -37,7 +37,7 @@ max_z: 28 model: equivariant-transformer neighbor_embedding: true ngpus: -1 -num_epochs: 1000 +num_epochs: 118 num_heads: 8 num_layers: 6 num_nodes: 1 diff --git a/five-et/train.py b/five-et/train.py new file mode 100644 index 0000000..73b3227 --- /dev/null +++ b/five-et/train.py @@ -0,0 +1,170 @@ +import sys +import os +import argparse +import logging +import pytorch_lightning as pl +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.loggers import CSVLogger +from pytorch_lightning.strategies.ddp import DDPStrategy +from torchmdnet.module import LNNP +from torchmdnet import datasets, priors, models +from torchmdnet.data import DataModule +from torchmdnet.models import output_modules +from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping +from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number + + +def get_args(): + # fmt: off + parser = argparse.ArgumentParser(description='Training') + parser.add_argument('--load-model', action=LoadFromCheckpoint, help='Restart training using a model checkpoint') # keep first + parser.add_argument('--conf', '-c', type=open, action=LoadFromFile, help='Configuration yaml file') # keep second + parser.add_argument('--num-epochs', default=300, type=int, help='number of epochs') + parser.add_argument('--batch-size', default=32, type=int, help='batch size') + parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.') + parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') + parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation') + parser.add_argument('--lr-metric', type=str, default='val_loss', choices=['train_loss', 'val_loss'], help='Metric to monitor when deciding whether to reduce learning rate') + parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop') + parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving') + parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up') + parser.add_argument('--early-stopping-patience', type=int, default=30, help='Stop training after this many epochs without improvement') + parser.add_argument('--reset-trainer', type=bool, default=False, help='Reset training metrics (e.g. early stopping, lr) when loading a model checkpoint') + parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay strength') + parser.add_argument('--ema-alpha-y', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of y') + parser.add_argument('--ema-alpha-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy') + parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus') + parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes') + parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision') + parser.add_argument('--log-dir', '-l', default='/tmp/logs', help='log file') + parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test') + parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)') + parser.add_argument('--val-size', type=number, default=0.05, help='Percentage/number of samples in validation set (None to use all remaining samples)') + parser.add_argument('--test-size', type=number, default=0.1, help='Percentage/number of samples in test set (None to use all remaining samples)') + parser.add_argument('--test-interval', type=int, default=10, help='Test interval, one test per n epochs (default: 10)') + parser.add_argument('--save-interval', type=int, default=10, help='Save interval, one save per n epochs (default: 10)') + parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') + parser.add_argument('--num-workers', type=int, default=4, help='Number of workers for data prefetch') + parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log') + + # dataset specific + parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') + parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') + parser.add_argument('--dataset-arg', default=None, type=str, help='Additional dataset argument, e.g. target property for QM9 or molecule for MD17') + parser.add_argument('--coord-files', default=None, type=str, help='Custom coordinate files glob') + parser.add_argument('--embed-files', default=None, type=str, help='Custom embedding files glob') + parser.add_argument('--energy-files', default=None, type=str, help='Custom energy files glob') + parser.add_argument('--force-files', default=None, type=str, help='Custom force files glob') + parser.add_argument('--energy-weight', default=1.0, type=float, help='Weighting factor for energies in the loss function') + parser.add_argument('--force-weight', default=1.0, type=float, help='Weighting factor for forces in the loss function') + + # model architecture + parser.add_argument('--model', type=str, default='graph-network', choices=models.__all__, help='Which model to train') + parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model') + parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') + + # architectural args + parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') + parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state') + parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') + parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') + parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model') + parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function') + parser.add_argument('--rbf-type', type=str, default='expnorm', choices=list(rbf_class_mapping.keys()), help='Type of distance expansion') + parser.add_argument('--trainable-rbf', type=bool, default=False, help='If distance expansion functions should be trainable') + parser.add_argument('--neighbor-embedding', type=bool, default=False, help='If a neighbor embedding should be applied before interactions') + parser.add_argument('--aggr', type=str, default='add', help='Aggregation operation for CFConv filter output. Must be one of \'add\', \'mean\', or \'max\'') + + # Transformer specific + parser.add_argument('--distance-influence', type=str, default='both', choices=['keys', 'values', 'both', 'none'], help='Where distance information is included inside the attention') + parser.add_argument('--attn-activation', default='silu', choices=list(act_class_mapping.keys()), help='Attention activation function') + parser.add_argument('--num-heads', type=int, default=8, help='Number of attention heads') + + # other args + parser.add_argument('--derivative', default=False, type=bool, help='If true, take the derivative of the prediction w.r.t coordinates') + parser.add_argument('--cutoff-lower', type=float, default=0.0, help='Lower cutoff in model') + parser.add_argument('--cutoff-upper', type=float, default=5.0, help='Upper cutoff in model') + parser.add_argument('--atom-filter', type=int, default=-1, help='Only sum over atoms with Z > atom_filter') + parser.add_argument('--max-z', type=int, default=100, help='Maximum atomic number that fits in the embedding matrix') + parser.add_argument('--max-num-neighbors', type=int, default=32, help='Maximum number of neighbors to consider in the network') + parser.add_argument('--standardize', type=bool, default=False, help='If true, multiply prediction by dataset std and add mean') + parser.add_argument('--reduce-op', type=str, default='add', choices=['add', 'mean'], help='Reduce operation to apply to atomic predictions') + # fmt: on + + args = parser.parse_args() + + if args.redirect: + sys.stdout = open(os.path.join(args.log_dir, "log"), "w") + sys.stderr = sys.stdout + logging.getLogger("pytorch_lightning").addHandler( + logging.StreamHandler(sys.stdout) + ) + + if args.inference_batch_size is None: + args.inference_batch_size = args.batch_size + + save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"]) + + return args + + +def main(): + args = get_args() + pl.seed_everything(args.seed, workers=True) + + # initialize data module + data = DataModule(args) + data.prepare_data() + data.setup("fit") + + prior = None + if args.prior_model: + assert hasattr(priors, args.prior_model), ( + f"Unknown prior model {args['prior_model']}. " + f"Available models are {', '.join(priors.__all__)}" + ) + # initialize the prior model + prior = getattr(priors, args.prior_model)(dataset=data.dataset) + args.prior_args = prior.get_init_args() + + # initialize lightning module + model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std) + + checkpoint_callback = ModelCheckpoint( + dirpath=args.log_dir, + monitor="val_loss", + save_top_k=10, # -1 to save all + every_n_epochs=args.save_interval, + filename="{epoch}-{val_loss:.4f}-{test_loss:.4f}", + ) + early_stopping = EarlyStopping("val_loss", patience=args.early_stopping_patience) + + tb_logger = pl.loggers.TensorBoardLogger( + args.log_dir, name="tensorbord", version="", default_hp_metric=False + ) + csv_logger = CSVLogger(args.log_dir, name="", version="") + + trainer = pl.Trainer( + strategy=DDPStrategy(find_unused_parameters=False), + max_epochs=args.num_epochs, + gpus=args.ngpus, + num_nodes=args.num_nodes, + default_root_dir=args.log_dir, + auto_lr_find=False, + resume_from_checkpoint=None if args.reset_trainer else args.load_model, + callbacks=[early_stopping, checkpoint_callback], + logger=[tb_logger, csv_logger], + precision=args.precision, + ) + + trainer.fit(model, data) + + # run test set after completing the fit + model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + trainer = pl.Trainer(logger=[tb_logger, csv_logger]) + trainer.test(model, data) + + +if __name__ == "__main__": + main() From c0025ac3a39e59361723ba3d7c6c2208bb9e5117 Mon Sep 17 00:00:00 2001 From: John Chodera Date: Sat, 17 Sep 2022 14:49:11 -0400 Subject: [PATCH 2/2] Update README to eliminate errors; remove deprecated parameter from hparams.yaml --- five-et/README.md | 9 ++++++--- five-et/hparams.yaml | 1 - 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/five-et/README.md b/five-et/README.md index a2937d2..9668eae 100644 --- a/five-et/README.md +++ b/five-et/README.md @@ -54,11 +54,14 @@ If you want to train new models on the same data, follow these steps: mamba env create -f environment.yml conda activate spice-models ``` -2. Run the `createSpiceDataset.py` script, which will download and convert the SPICE dataset to the format used by TorchMD-Net. +2. Run the `createSpiceDataset.py` script, which will download and convert the SPICE dataset to the format used by TorchMD-Net: +```bash +python createSpiceDataset.py +``` It generates a new file `SPICE-processed.hdf5` to use for training. -4. Run the `train.py` script: +4. Run the `train.py` script to create a model: ```bash -python train.py --conf hparams.yaml +MODELNAME='model1'; mkdir $MODELNAME ; python train.py --conf hparams.yaml --log-dir $MODELNAME ``` The file `hparams.yaml` contains the configuration used for training the models. All models here used identical settings except that `seed` was set to a different value for each one (the numbers 1 through 5). diff --git a/five-et/hparams.yaml b/five-et/hparams.yaml index 51c79a4..5763a0b 100644 --- a/five-et/hparams.yaml +++ b/five-et/hparams.yaml @@ -13,7 +13,6 @@ dataset_arg: null dataset_root: SPICE-processed.hdf5 derivative: true distance_influence: both -distributed_backend: ddp early_stopping_patience: 20 ema_alpha_dy: 1.0 ema_alpha_y: 1.0