diff --git a/src/fairchem/core/components/train/train_runner.py b/src/fairchem/core/components/train/train_runner.py index 78b49e829d..be6108a476 100644 --- a/src/fairchem/core/components/train/train_runner.py +++ b/src/fairchem/core/components/train/train_runner.py @@ -111,12 +111,11 @@ def on_train_step_start(self, state: State, unit: TTrainUnit) -> None: shutil.rmtree(dir) def on_train_end(self, state: State, unit: TTrainUnit) -> None: - if self.checkpoint_every_n_steps is not None: - # also always checkpoint on train end - assert ( - self.save_callback - ), "Must initialize set_checkpoint_call_backs from Runner!" - self.save_callback(os.path.join(self.checkpoint_dir, "final")) + # always checkpoint on train end + assert ( + self.save_callback + ), "Must initialize set_checkpoint_call_backs from Runner!" + self.save_callback(os.path.join(self.checkpoint_dir, "final")) class TrainEvalRunner(Runner):