Skip to content

[WIP] Attempt to make model training environment more reproducible#4

Open
jchodera wants to merge 2 commits into
mainfrom
streamline-reproduction
Open

[WIP] Attempt to make model training environment more reproducible#4
jchodera wants to merge 2 commits into
mainfrom
streamline-reproduction

Conversation

@jchodera

@jchodera jchodera commented Sep 17, 2022

Copy link
Copy Markdown
Member

This adds an environment.yml file that creates a spice-models environment in a single line that should install torchmd-net 0.2.2, its dependencies, and the dependencies for converting the dataset.

This also modifies the conversion script to automatically download the SPICE 1.1 dataset if it does not already exist locally.

The train.py script is also imported from torchmd-net 0.2.2 for ease of reproducibility.

Finally, the number of epochs is specified to match what was used in the paper.

Finally, the README.md is updated.

@jchodera jchodera requested a review from peastman September 17, 2022 03:58
@jchodera

Copy link
Copy Markdown
Member Author

@peastman : Even with torchmd-net 0.2.2, I am seeing this error:

(spice-models) [chodera@lilac:five-et]$ python train.py --conf hparams.yaml
Traceback (most recent call last):
  File "/lila/home/chodera/github/torchmd/spice-models/five-et/train.py", line 170, in <module>
    main()
  File "/lila/home/chodera/github/torchmd/spice-models/five-et/train.py", line 113, in main
    args = get_args()
  File "/lila/home/chodera/github/torchmd/spice-models/five-et/train.py", line 95, in get_args
    args = parser.parse_args()
  File "/lila/home/chodera/miniconda/envs/spice-models/lib/python3.9/argparse.py", line 1825, in parse_args
    args, argv = self.parse_known_args(args, namespace)
  File "/lila/home/chodera/miniconda/envs/spice-models/lib/python3.9/argparse.py", line 1858, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
  File "/lila/home/chodera/miniconda/envs/spice-models/lib/python3.9/argparse.py", line 2067, in _parse_known_args
    start_index = consume_optional(start_index)
  File "/lila/home/chodera/miniconda/envs/spice-models/lib/python3.9/argparse.py", line 2007, in consume_optional
    take_action(action, args, option_string)
  File "/lila/home/chodera/miniconda/envs/spice-models/lib/python3.9/argparse.py", line 1935, in take_action
    action(self, namespace, argument_values, option_string)
  File "/lila/home/chodera/miniconda/envs/spice-models/lib/python3.9/site-packages/torchmdnet/utils.py", line 105, in __call__
    raise ValueError(f"Unknown argument in config file: {key}")
ValueError: Unknown argument in config file: distributed_backend

I'm running on a machine with an A100.

@jchodera

Copy link
Copy Markdown
Member Author

Removing the distributed_backend option yields a new error:

(spice-models) [chodera@lilac:five-et]$ python train.py --conf hparams.yaml
Traceback (most recent call last):
  File "/lila/home/chodera/github/torchmd/spice-models/five-et/train.py", line 170, in <module>
    main()
  File "/lila/home/chodera/github/torchmd/spice-models/five-et/train.py", line 113, in main
    args = get_args()
  File "/lila/home/chodera/github/torchmd/spice-models/five-et/train.py", line 98, in get_args
    sys.stdout = open(os.path.join(args.log_dir, "log"), "w")
FileNotFoundError: [Errno 2] No such file or directory: 'model1b/log'

@jchodera

Copy link
Copy Markdown
Member Author

It looks like it was expecting the model1b directory to exist so it could put the log files here.

We actually want to tell the user to run something like something:

MODELNAME='model1'; mkdir $MODELNAME ; python train.py --conf hparams.yaml --log-dir $MODELNAME

@jchodera

Copy link
Copy Markdown
Member Author

Training is running now, and using most of the GPU!

(base) [chodera@lilac:chodera]$ nvidia-smi
Sat Sep 17 14:46:06 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.43.04    Driver Version: 515.43.04    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
...
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100 80G...  On   | 00000000:CA:00.0 Off |                    0 |
| N/A   69C    P0   301W / 300W |  77055MiB / 81920MiB |     84%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

I've updated the README.md and hparams.yaml file to reflect the working environment.

I can update this to the latest release of torchmd-net once @raimis cuts a new release.

@raimis : Are there any other version numbers I should pin for pytorch or pytorch-lightning?

@jchodera

Copy link
Copy Markdown
Member Author

@peastman: This is only doing ~1 epoch/hour on an A100. Is there something else needed to make this run decently quickly? I'm only giving it one CPU thread and 1 GPU---does it need more CPU threads? Is it necessary to specify an alternative to distributed_backend?

@peastman

Copy link
Copy Markdown
Member

Four CPU threads per GPU is a good rule of thumb. But if it's already using most of the GPU, that will only have a small impact. I trained on four A100s and it completed 118 epochs in 24 hours, or 4.9 epochs per hour. So your speed sounds about right.

@jchodera

Copy link
Copy Markdown
Member Author

@peastman: I made it to iteration 30 in ~24 hours before it terminated. Is there a way to resume, perhaps on more GPUs?

(base) [chodera@lilac:five-et]$ cat model1b/metrics.csv 
epoch,lr,train_loss,val_loss,train_loss_y,train_loss_dy,val_loss_y,val_loss_dy,step
0.0,0.0005000000237487257,415951.59375,43753.390625,408118.65625,7832.9248046875,40474.96875,3278.423095703125,8239
1.0,0.0005000000237487257,9220.044921875,4110.337890625,7084.13916015625,2135.90576171875,2577.056396484375,1533.2818603515625,16479
2.0,0.0005000000237487257,5993.22900390625,6763.59619140625,4465.41357421875,1527.814697265625,5661.66064453125,1101.9351806640625,24719
3.0,0.0005000000237487257,3752.89990234375,3246.296630859375,2761.4912109375,991.4083862304688,1363.682373046875,1882.6143798828125,32959
4.0,0.0005000000237487257,2082.106201171875,7021.27783203125,1245.3057861328125,836.8003540039062,4397.51318359375,2623.7646484375,41199
5.0,0.0005000000237487257,1662.6856689453125,1590.697021484375,973.4273071289062,689.2582397460938,872.2218017578125,718.4752197265625,49439
6.0,0.0005000000237487257,1400.7320556640625,1132.1680908203125,841.2183227539062,559.513671875,476.9352722167969,655.2328491210938,57679
7.0,0.0005000000237487257,1208.6385498046875,955.9788208007812,716.5443725585938,492.09405517578125,375.7664794921875,580.21240234375,65919
8.0,0.0005000000237487257,1373.7999267578125,1653.5787353515625,751.698486328125,622.101318359375,879.0201416015625,774.55859375,74159
9.0,0.0002500000118743628,740.4225463867188,858.6006469726562,308.6726989746094,431.7498474121094,346.9961853027344,511.60455322265625,82399
10.0,0.0002500000118743628,691.2056884765625,777.0108032226562,309.3612060546875,381.8445129394531,284.3540954589844,492.6567077636719,90639
11.0,0.0002500000118743628,649.3666381835938,694.3455200195312,287.1024475097656,362.2643127441406,221.6646728515625,472.6808776855469,98879
12.0,0.0002500000118743628,620.4614868164062,755.1241455078125,270.81304931640625,349.6484375,294.8038330078125,460.3202819824219,107119
13.0,0.0002500000118743628,614.3432006835938,928.512939453125,274.3605041503906,339.982666015625,455.21722412109375,473.2958068847656,115359
14.0,0.0002500000118743628,597.1185913085938,645.1339111328125,265.7243347167969,331.3943176269531,201.91043090820312,443.2235107421875,123599
15.0,0.0002500000118743628,577.0795288085938,670.9802856445312,252.25747680664062,324.82208251953125,231.13743591308594,439.8428955078125,131839
16.0,0.0002500000118743628,570.9285278320312,1005.7330932617188,251.0818328857422,319.8466796875,568.791259765625,436.9417419433594,140079
17.0,0.0002500000118743628,557.9736328125,705.8413696289062,242.2953338623047,315.67828369140625,275.7255859375,430.11578369140625,148319
18.0,0.0002500000118743628,546.3439331054688,631.070068359375,234.75704956054688,311.5869140625,197.27853393554688,433.7914733886719,156559
19.0,0.0002500000118743628,530.5957641601562,1010.63134765625,223.96826171875,306.6274719238281,584.8029174804688,425.828369140625,164799
20.0,0.0002500000118743628,529.35546875,755.9890747070312,225.8995819091797,303.45587158203125,330.8993835449219,425.0897216796875,173039
21.0,0.0002500000118743628,522.5003662109375,970.5034790039062,222.3240966796875,300.17626953125,522.2012939453125,448.3022155761719,181279
22.0,0.0002500000118743628,511.585693359375,713.2323608398438,213.98883056640625,297.59686279296875,295.4289245605469,417.8034362792969,189519
23.0,0.0002500000118743628,510.05242919921875,672.0818481445312,215.08908081054688,294.9633483886719,259.2350769042969,412.8467712402344,197759
24.0,0.0002500000118743628,501.7392883300781,595.1604614257812,210.0992431640625,291.6400451660156,183.89414978027344,411.26629638671875,205999
25.0,0.0002500000118743628,498.0483093261719,648.014404296875,207.76629638671875,290.28204345703125,237.3606719970703,410.6536865234375,214239
26.0,0.0002500000118743628,490.3211364746094,586.6843872070312,202.3558349609375,287.9653015136719,179.64987182617188,407.0345764160156,222479
27.0,0.0002500000118743628,493.4549865722656,583.164794921875,207.41969299316406,286.0353088378906,176.25076293945312,406.9140319824219,230719
28.0,0.0001250000059371814,409.5025939941406,616.0693969726562,135.19017028808594,274.3123779296875,213.1492156982422,402.9201354980469,238959
29.0,0.0001250000059371814,403.29412841796875,593.2405395507812,132.00363159179688,271.2904968261719,192.54238891601562,400.6981201171875,247199
30.0,0.0001250000059371814,402.2795104980469,559.7166748046875,132.01614379882812,270.26336669921875,164.86419677734375,394.8524475097656,255439

In any case, I'm fairly certain this PR contains fixes sufficient to get this to easily run with torchmd-net 0.2.2, so we will likely want to merge it or update it to a new release of torchmd-net if that is essential.

@jchodera jchodera marked this pull request as ready for review September 18, 2022 20:53
@peastman

Copy link
Copy Markdown
Member

It saves checkpoint files as it runs. To resume training from a checkpoint file, add the command line argument --load-model <path to checkpoint>.ckpt. I recommend pointing to a new log directory when you resume, since otherwise it will overwrite the existing log and metrics files.

@peastman

Copy link
Copy Markdown
Member

I'm not sure it's a good idea to reproduce the train.py script here. It's part of TorchMD-Net. Putting it here on its own is almost certain to cause problems for people, for example when they upgrade TorchMD-Net but try to use the incompatible training script. @giadefa what's the best way to locate where the installed version of the script is located?

I'm also not sure about modifying hparams.yaml. My goal was to document the exact settings I used. But if we do change it, then probably we should also upgrade to the latest TorchMD-Net. The only reason for requiring an older version is the no longer supported distributed_backend setting, which you removed. And specifying num_epochs: 118 makes it look like I pulled an arbitrary number out of thin air (as you once claimed I had done!). It's better to explain what I actually did.

@jchodera

Copy link
Copy Markdown
Member Author

I'm not sure it's a good idea to reproduce the train.py script here. It's part of TorchMD-Net. Putting it here on its own is almost certain to cause problems for people, for example when they upgrade TorchMD-Net but try to use the incompatible training script. @giadefa what's the best way to locate where the installed version of the script is located?

Optimally, we could cut a new release of TorchMD-Net where we add an entrypoint so that it installs a command line tool onto the path called torchmdnet-train (or perhaps be more sophisticated and have this handle torchmdnet train via click cli).

I'm also not sure about modifying hparams.yaml. My goal was to document the exact settings I used. But if we do change it, then probably we should also upgrade to the latest TorchMD-Net. The only reason for requiring an older version is the no longer supported distributed_backend setting, which you removed. And specifying num_epochs: 118 makes it look like I pulled an arbitrary number out of thin air (as you once claimed I had done!). It's better to explain what I actually did.

I'm happy to restore these (including the distributed_backed) if you want to provide a dump of your conda environment. I could pin the lightning distributed version to work with the fully original hparams.yaml.

@raimis

raimis commented Sep 19, 2022

Copy link
Copy Markdown

@jchodera On Wednesday, I'll have time to look all at this. Meanwhile, I'll try to get new release out.

@peastman

Copy link
Copy Markdown
Member

Adding an entrypoint for training sounds like a good idea to me.

@raimis

raimis commented Sep 29, 2022

Copy link
Copy Markdown

In progress: torchmd/torchmd-net#127

@raimis

raimis commented Oct 3, 2022

Copy link
Copy Markdown

@jchodera torchmd/torchmd-net#127 is done! Training can be run with tmn-train.

@peastman

peastman commented Oct 3, 2022

Copy link
Copy Markdown
Member

A later PR changed the name to torchmd-train.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants