diff --git a/mmf/common/meter.py b/mmf/common/meter.py index e4b26f2c6..403cbdf2c 100644 --- a/mmf/common/meter.py +++ b/mmf/common/meter.py @@ -60,8 +60,8 @@ def update(self, update_dict, batch_size): if isinstance(v, torch.Tensor): if v.dim() != 0: v = v.mean() - v = v.item() - assert isinstance(v, (float, int)) + #v = v.item() + #assert isinstance(v, (float, int)) self.meters[k].update(v, batch_size) def update_from_meter(self, meter): diff --git a/mmf/common/sample.py b/mmf/common/sample.py index 40c535918..45089e87b 100644 --- a/mmf/common/sample.py +++ b/mmf/common/sample.py @@ -396,7 +396,11 @@ def to_device( if isinstance(device, str): device = torch.device(device) - if not torch.cuda.is_available(): + # default valude of device_type is cuda + # since other device types such as xla can be passed + # falling back to cpu should only happen when device_type + # is set to cude but cuda is not available. + if not torch.cuda.is_available() and device == "cuda": device = torch.device("cpu") # to_device is specifically for SampleList # if user is passing something custom built diff --git a/mmf/datasets/multi_dataset_loader.py b/mmf/datasets/multi_dataset_loader.py index 37a78add6..21b0518fe 100644 --- a/mmf/datasets/multi_dataset_loader.py +++ b/mmf/datasets/multi_dataset_loader.py @@ -8,7 +8,7 @@ import numpy as np from mmf.common.registry import registry from mmf.utils.build import build_dataloader_and_sampler, build_dataset -from mmf.utils.distributed import broadcast_scalar, is_dist_initialized, is_master +from mmf.utils.distributed import broadcast_scalar, is_dist_initialized, is_master, is_xla from mmf.utils.general import get_batch_size @@ -186,9 +186,15 @@ def _infer_dataset_probabilities(self): def __len__(self): # Since, this is iterator, we need to return total length == number of batches batch_size = get_batch_size() + # Changed the length to accomadate drop_last == True + # drop_last is required if the batch is split intor multiple cores + # some of the cores may not have enough examples. + if is_xla(): + return (self._total_length) // batch_size + else: # This assumes drop_last=False for all loaders. See also # build_dataloader_and_sampler(). - return (self._total_length + batch_size - 1) // batch_size + return (self._total_length + batch_size - 1) // batch_size def __iter__(self): if self._num_datasets == 1: diff --git a/mmf/trainers/callbacks/logistics.py b/mmf/trainers/callbacks/logistics.py index 9211e269e..27666302c 100644 --- a/mmf/trainers/callbacks/logistics.py +++ b/mmf/trainers/callbacks/logistics.py @@ -5,7 +5,7 @@ import torch from mmf.trainers.callbacks.base import Callback from mmf.utils.configuration import get_mmf_env -from mmf.utils.distributed import is_master +from mmf.utils.distributed import is_master, is_xla from mmf.utils.logger import TensorboardLogger, log_progress, setup_output_folder from mmf.utils.timer import Timer @@ -105,7 +105,7 @@ def on_test_end(self, **kwargs): def _summarize_report(self, meter, should_print=True, extra=None): if extra is None: extra = {} - if not is_master(): + if not is_master() and not is_xla(): return if self.training_config.tensorboard: diff --git a/mmf/trainers/core/device.py b/mmf/trainers/core/device.py index a89bfccb6..37864435e 100644 --- a/mmf/trainers/core/device.py +++ b/mmf/trainers/core/device.py @@ -20,9 +20,17 @@ def configure_seed(self) -> None: torch.backends.cudnn.benchmark = False def configure_device(self) -> None: - self.local_rank = self.config.device_id - self.device = self.local_rank - self.distributed = False + if getattr(self.config.training, 'device', 'cuda') == 'xla': + import torch_xla.core.xla_model as xm + self.device = xm.xla_device() + self.distributed = True + self.local_rank = xm.get_local_ordinal() + self.tpu = True + else: + self.tpu = False + self.local_rank = self.config.device_id + self.device = self.local_rank + self.distributed = False # Will be updated later based on distributed setup registry.register("global_device", self.device) @@ -30,9 +38,9 @@ def configure_device(self) -> None: if self.config.distributed.init_method is not None: self.distributed = True self.device = torch.device("cuda", self.local_rank) - elif torch.cuda.is_available(): + elif torch.cuda.is_available() and not self.tpu: self.device = torch.device("cuda") - else: + elif not self.tpu: self.device = torch.device("cpu") registry.register("current_device", self.device) diff --git a/mmf/trainers/core/evaluation_loop.py b/mmf/trainers/core/evaluation_loop.py index ca088a66a..9ca258b89 100644 --- a/mmf/trainers/core/evaluation_loop.py +++ b/mmf/trainers/core/evaluation_loop.py @@ -10,6 +10,7 @@ from mmf.common.report import Report from mmf.common.sample import to_device from mmf.utils.distributed import is_master +from mmf.utils.metsumm import metsumm logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ def evaluation_loop( self.model.eval() disable_tqdm = not use_tqdm or not is_master() combined_report = None + metsumm("Before Validation Start:") for batch in tqdm.tqdm(loader, disable=disable_tqdm): report = self._forward(batch) @@ -44,6 +46,8 @@ def evaluation_loop( combined_report.metrics = self.metrics(combined_report, combined_report) self.update_meter(combined_report, meter, eval_mode=True) + logger.info("Validation Done") + metsumm("After Validation Complete") # enable train mode again self.model.train() diff --git a/mmf/trainers/core/reporting.py b/mmf/trainers/core/reporting.py index 8f6458399..41ed794c9 100644 --- a/mmf/trainers/core/reporting.py +++ b/mmf/trainers/core/reporting.py @@ -51,8 +51,8 @@ def update_dict(self, meter_update_dict, values_dict): if val.dim() == 1: val = val.mean() - if hasattr(val, "item"): - val = val.item() + #if hasattr(val, "item"): + # val = val.item() meter_update_dict.update({key: val}) total_val += val diff --git a/mmf/trainers/core/training_loop.py b/mmf/trainers/core/training_loop.py index 0d779a7e2..17af8bc05 100644 --- a/mmf/trainers/core/training_loop.py +++ b/mmf/trainers/core/training_loop.py @@ -12,7 +12,7 @@ from mmf.common.sample import to_device from mmf.utils.general import clip_gradients from torch import Tensor - +from mmf.utils.metsumm import metsumm logger = logging.getLogger(__name__) @@ -72,8 +72,8 @@ def run_training_epoch(self) -> None: combined_report = None num_batches_for_this_update = 1 - for idx, batch in enumerate(self.train_loader): + for idx, batch in enumerate(self.train_loader): if (idx + 1) % self.training_config.update_frequency == 0: combined_report = None num_batches_for_this_update = min( @@ -84,7 +84,6 @@ def run_training_epoch(self) -> None: # batch execution starts here self.on_batch_start() - self.profile("Batch load time") report = self.run_training_batch(batch, num_batches_for_this_update) @@ -129,7 +128,6 @@ def run_training_epoch(self) -> None: # Validation begin callbacks self.on_validation_start() - logger.info("Evaluation time. Running on full validation set...") # Validation and Early stopping # Create a new meter for this case report, meter = self.evaluation_loop(self.val_loader) @@ -146,7 +144,6 @@ def run_training_epoch(self) -> None: torch.cuda.empty_cache() if stop is True: - logger.info("Early stopping activated") should_break = True if self.num_updates >= self.max_updates: should_break = True @@ -168,7 +165,7 @@ def run_training_batch(self, batch: Tensor, loss_divisor: int) -> None: def _forward(self, batch: Tensor) -> Dict[str, Any]: prepared_batch = self.dataset_loader.prepare_batch(batch) # Move the sample list to device if it isn't as of now. - prepared_batch = to_device(prepared_batch, torch.device("cuda")) + prepared_batch = to_device(prepared_batch, self.device) self.profile("Batch prepare time") # Arguments should be a dict at this point @@ -188,6 +185,7 @@ def _start_update(self): def _backward(self, loss: Tensor) -> None: self.scaler.scale(loss).backward() + self.profile("Backward time") def _finish_update(self): @@ -199,6 +197,12 @@ def _finish_update(self): self.config, scale=self.scaler.get_scale(), ) + if getattr(self.config.training, 'device', 'cuda') == 'xla' and self.config.distributed.world_size > 1: + import torch_xla.core.xla_model as xm + #gradients = xm._fetch_gradients(self.optimizer) + # Assumes no model parallel + #xm.all_reduce('sum', gradients, scale=1.0 / self.config.distributed.world_size) + xm.reduce_gradients(self.optimizer) self.scaler.step(self.optimizer) self.scaler.update() diff --git a/mmf/utils/build.py b/mmf/utils/build.py index a7470eea9..78a3e1904 100644 --- a/mmf/utils/build.py +++ b/mmf/utils/build.py @@ -10,10 +10,16 @@ from mmf.common.registry import registry from mmf.datasets.processors.processors import Processor from mmf.utils.configuration import Configuration -from mmf.utils.distributed import is_dist_initialized +from mmf.utils.distributed import is_dist_initialized, is_xla, is_master, synchronize from mmf.utils.general import get_optimizer_parameters from omegaconf import DictConfig, OmegaConf +try: + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl +except ImportError: + xm = None + pl = None ProcessorType = Type[Processor] ProcessorDict = Dict[str, ProcessorType] @@ -76,7 +82,12 @@ def build_model( if hasattr(model, "build"): model.load_requirements() - model.build() + if is_master(): + model.build() + synchronize() + else: + synchronize() + model.build() model.init_losses() return model @@ -152,6 +163,18 @@ def build_dataloader_and_sampler( if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) + if is_xla(): + dataset_type = dataset_instance.dataset_type + shuffle=True + other_args["sampler"] = torch.utils.data.DistributedSampler( + dataset_instance, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=shuffle + ) + other_args.pop("shuffle") + + loader = torch.utils.data.DataLoader( dataset=dataset_instance, pin_memory=pin_memory, @@ -163,6 +186,10 @@ def build_dataloader_and_sampler( **other_args, ) + if is_xla(): + device = xm.xla_device() + loader = pl.MpDeviceLoader(loader, device) + if num_workers >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" diff --git a/mmf/utils/checkpoint.py b/mmf/utils/checkpoint.py index 4cdf88b8c..b6072b027 100644 --- a/mmf/utils/checkpoint.py +++ b/mmf/utils/checkpoint.py @@ -10,7 +10,7 @@ import torch from mmf.common.registry import registry from mmf.utils.configuration import get_mmf_env, load_yaml -from mmf.utils.distributed import is_master, synchronize +from mmf.utils.distributed import is_master, synchronize, is_xla from mmf.utils.download import download_pretrained_model from mmf.utils.file_io import PathManager from mmf.utils.general import updir @@ -22,6 +22,11 @@ except ImportError: git = None +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + logger = logging.getLogger(__name__) @@ -379,15 +384,26 @@ def _get_vcs_fields(self): "git/diff": self.git_repo.git.diff("--no-prefix"), } + + + def save_func(self): + if is_xla(): + return xm.save + else: + return torch.save + def save(self, update, iteration=None, update_best=False): # Only save in main process - if not is_master(): + if not is_master() and not is_xla(): return + logger.info("Checkpoint save operation started!") + if not iteration: iteration = update ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) + best_ckpt_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + "best.ckpt" ) @@ -437,16 +453,20 @@ def save(self, update, iteration=None, update_best=False): git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) + logger.info("Saving checkpoint") with PathManager.open(ckpt_filepath, "wb") as f: - torch.save(ckpt, f) + self.save_func()(ckpt, f) if update_best: + logger.info("Saving best checkpoint") with PathManager.open(best_ckpt_filepath, "wb") as f: - torch.save(ckpt, f) + self.save_func()(ckpt, f) # Save current always + + logger.info("Saving Current checkpoint") with PathManager.open(current_ckpt_filepath, "wb") as f: - torch.save(ckpt, f) + self.save_func()(ckpt, f) # Remove old checkpoints if max_to_keep is set if self.max_to_keep > 0: @@ -454,6 +474,8 @@ def save(self, update, iteration=None, update_best=False): self.remove(self.saved_iterations.pop(0)) self.saved_iterations.append(update) + logger.info("Checkpoint save operation finished!") + def remove(self, update): ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) if PathManager.isfile(ckpt_filepath): @@ -468,6 +490,6 @@ def restore(self): self._load(best_path, force=True) def finalize(self): - if is_master(): + if is_master() or is_xla(): with PathManager.open(self.pth_filepath, "wb") as f: - torch.save(self.trainer.model.state_dict(), f) + self.save_func()(self.trainer.model.state_dict(), f) diff --git a/mmf/utils/configuration.py b/mmf/utils/configuration.py index 6a9ba8bec..5d3f88351 100644 --- a/mmf/utils/configuration.py +++ b/mmf/utils/configuration.py @@ -531,6 +531,9 @@ def _update_specific(self, config): lr = config.learning_rate config.optimizer.params.lr = lr + # TODO: Correct the following issue + # This check is triggered before the config override from commandline is effective + # even after setting training.device = 'xla', it gets triggered. if not torch.cuda.is_available() and "cuda" in config.training.device: warnings.warn( "Device specified is 'cuda' but cuda is not present. " diff --git a/mmf/utils/distributed.py b/mmf/utils/distributed.py index 104b52b8f..07f02b55c 100644 --- a/mmf/utils/distributed.py +++ b/mmf/utils/distributed.py @@ -9,15 +9,21 @@ import torch from torch import distributed as dist +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None +_USE_XLA = False MAX_SIZE_LIMIT = 65533 BYTE_SIZE = 256 logger = logging.getLogger(__name__) - -def synchronize(): - if not dist.is_available(): +def synchronize(message='sync-workers'): + if is_xla(): + xm.rendezvous(message) + elif not dist.is_available(): return if not dist.is_nccl_available(): return @@ -31,8 +37,13 @@ def synchronize(): dist.barrier() +def is_xla(): + global _USE_XLA + return _USE_XLA def get_rank(): + if is_xla(): + return xm.get_ordinal() if not dist.is_available(): return 0 if not dist.is_nccl_available(): @@ -51,6 +62,8 @@ def is_dist_initialized(): def get_world_size(): + if is_xla(): + return xm.xrt_world_size() if not dist.is_available(): return 1 if not dist.is_nccl_available(): @@ -66,7 +79,14 @@ def broadcast_tensor(tensor, src=0): return tensor with torch.no_grad(): - dist.broadcast(tensor, src=0) + if is_xla(): + tensor = xm.all_to_all( + tensor.repeat([world_size,1]), + split_dimension=0, + concat_dimension=0, + split_count=world_size)[0] + else: + dist.broadcast(tensor, src=0) return tensor @@ -105,7 +125,11 @@ def gather_tensor(tensor): for _ in range(world_size): tensor_list.append(torch.zeros_like(tensor)) - dist.all_gather(tensor_list, tensor) + if is_xla(): + tensor_list = xm.all_gather(tensor) + tensor_list = tensor_list.view(world_size, *tensor.size()) + else: + dist.all_gather(tensor_list, tensor) tensor_list = torch.stack(tensor_list, dim=0) return tensor_list @@ -122,12 +146,17 @@ def reduce_dict(dictionary): keys, values = zip(*sorted(dictionary.items())) values = torch.stack(values, dim=0) - dist.reduce(values, dst=0) - - if dist.get_rank() == 0: - # only main process gets accumulated, so only divide by - # world_size in this case - values /= world_size + if is_xla(): + values = xm.all_reduce('sum', + [values], + scale=1.0/world_size + )[0] + else: + dist.reduce(values, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size reduced_dict = {k: v for k, v in zip(keys, values)} return reduced_dict @@ -169,6 +198,11 @@ def byte_tensor_to_object(byte_tensor, max_size=4094): def infer_init_method(config): if config.distributed.init_method is not None: return + + if getattr(config.training, 'device', 'cuda') == 'xla': + global _USE_XLA + _USE_XLA = True + # support torch.distributed.launch if all( key in os.environ @@ -221,9 +255,14 @@ def infer_init_method(config): def distributed_init(config): if config.distributed.world_size == 1: raise ValueError("Cannot initialize distributed with distributed_world_size=1") + logger.info("XLA Mode:{}".format(is_xla())) - if dist.is_initialized(): + if is_xla(): + config.device_id = xm.get_local_ordinal() + config.distributed.rank = xm.get_ordinal() + elif dist.is_initialized(): warnings.warn("Distributed is already initialized, cannot initialize twice!") + config.distributed.rank = dist.get_rank() else: logger.info( f"Distributed Init (Rank {config.distributed.rank}): " @@ -244,8 +283,7 @@ def distributed_init(config): dist.all_reduce(torch.zeros(1).cuda()) suppress_output(is_master()) - - config.distributed.rank = dist.get_rank() + config.distributed.rank = dist.get_rank() return config.distributed.rank diff --git a/mmf/utils/early_stopping.py b/mmf/utils/early_stopping.py index 35f80f65c..971934893 100644 --- a/mmf/utils/early_stopping.py +++ b/mmf/utils/early_stopping.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. import numpy as np import torch -from mmf.utils.distributed import is_master +from mmf.utils.distributed import is_master, is_xla class EarlyStopping: @@ -46,7 +46,7 @@ def __call__(self, update, iteration, meter): Returns: bool -- Tells whether early stopping occurred or not """ - if not is_master(): + if not is_master() and not is_xla(): return False value = meter.meters.get(self.early_stop_criteria, None) diff --git a/mmf/utils/metsumm.py b/mmf/utils/metsumm.py new file mode 100644 index 000000000..3dc5c7b4a --- /dev/null +++ b/mmf/utils/metsumm.py @@ -0,0 +1,8 @@ +def metsumm(stepno=''): + import torch_xla.debug.metrics as met + x = met.metrics_report().split('\n') + for i, line in enumerate(x): + if 'CompileTime' in line or 'aten::' in line: + key = line.split()[-1] + value = x[i+1].split()[-1] + print('step {}, key {}, value {}'.format(stepno, key, value)) diff --git a/mmf_cli/run.py b/mmf_cli/run.py index 9ba4b6f12..9f970438a 100644 --- a/mmf_cli/run.py +++ b/mmf_cli/run.py @@ -9,7 +9,7 @@ from mmf.common.registry import registry from mmf.utils.build import build_config, build_trainer from mmf.utils.configuration import Configuration -from mmf.utils.distributed import distributed_init, get_rank, infer_init_method +from mmf.utils.distributed import distributed_init, get_rank, infer_init_method, is_xla from mmf.utils.env import set_seed, setup_imports from mmf.utils.flags import flags from mmf.utils.general import log_device_names @@ -32,6 +32,7 @@ def main(configuration, init_distributed=False, predict=False): if init_distributed: distributed_init(config) + seed = config.training.seed config.training.seed = set_seed(seed if seed == -1 else seed + get_rank()) registry.register("seed", config.training.seed) @@ -96,6 +97,7 @@ def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False): if config.distributed.init_method is None: infer_init_method(config) + if config.distributed.init_method is not None: if torch.cuda.device_count() > 1 and not config.distributed.no_spawn: config.start_rank = config.distributed.rank @@ -108,15 +110,25 @@ def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False): else: distributed_main(0, configuration, predict) elif config.distributed.world_size > 1: - assert config.distributed.world_size <= torch.cuda.device_count() - port = random.randint(10000, 20000) - config.distributed.init_method = f"tcp://localhost:{port}" - config.distributed.rank = None - torch.multiprocessing.spawn( - fn=distributed_main, - args=(configuration, predict), - nprocs=config.distributed.world_size, - ) + if is_xla(): + import torch_xla.distributed.xla_multiprocessing as xmp + torch.multiprocessing.set_sharing_strategy("file_system") + xmp.spawn( + fn=distributed_main, + args=(configuration, predict), + nprocs=8, # use all 8 TPU cores + start_method='fork' + ) + else: + assert config.distributed.world_size <= torch.cuda.device_count() + port = random.randint(10000, 20000) + config.distributed.init_method = f"tcp://localhost:{port}" + config.distributed.rank = None + torch.multiprocessing.spawn( + fn=distributed_main, + args=(configuration, predict), + nprocs=config.distributed.world_size, + ) else: config.device_id = 0 main(configuration, predict=predict) diff --git a/requirements.txt b/requirements.txt index 329a6a556..363dbed34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -torch==1.6.0 -torchvision==0.7.0 numpy>=1.16.6 tqdm>=4.43.0 demjson==2.2.4