From a6bc64046bd230fe706d190063df411063a390d9 Mon Sep 17 00:00:00 2001 From: Vaibhav Singh Date: Tue, 27 Oct 2020 18:51:37 +0000 Subject: [PATCH 1/2] initial changes to support training on tpus initial changes to support training on tpus changed tpu configuration to use training.device replace parallelLoader with mpLoader to solved loader exhaust issue. removed debug message. updated the comment added comments for drop_last change. removed pdb lines removed redundant device config added comments for pending changes default init not applicable for xla device type moved wrapping of dataloader to build added line-debug function metsumm removed some .item calls from reporting xla equivalents in the distributed module, earlier eval was failing at the metrics all reduce step implemented broadcast in terms of all_to_all changes for checkpoint saving change to make execution even across cores corrected the is_master logic one more fix for is_master clean up of debug messages --- mmf/common/meter.py | 4 +- mmf/common/sample.py | 6 ++- mmf/datasets/multi_dataset_loader.py | 10 ++++- mmf/trainers/callbacks/logistics.py | 4 +- mmf/trainers/core/device.py | 18 ++++++--- mmf/trainers/core/evaluation_loop.py | 4 ++ mmf/trainers/core/reporting.py | 4 +- mmf/trainers/core/training_loop.py | 16 +++++--- mmf/utils/build.py | 24 ++++++++++- mmf/utils/checkpoint.py | 36 +++++++++++++---- mmf/utils/configuration.py | 3 ++ mmf/utils/distributed.py | 60 ++++++++++++++++++++++------ mmf/utils/early_stopping.py | 4 +- mmf/utils/metsumm.py | 8 ++++ mmf_cli/run.py | 32 ++++++++++----- 15 files changed, 181 insertions(+), 52 deletions(-) create mode 100644 mmf/utils/metsumm.py diff --git a/mmf/common/meter.py b/mmf/common/meter.py index e4b26f2c6..403cbdf2c 100644 --- a/mmf/common/meter.py +++ b/mmf/common/meter.py @@ -60,8 +60,8 @@ def update(self, update_dict, batch_size): if isinstance(v, torch.Tensor): if v.dim() != 0: v = v.mean() - v = v.item() - assert isinstance(v, (float, int)) + #v = v.item() + #assert isinstance(v, (float, int)) self.meters[k].update(v, batch_size) def update_from_meter(self, meter): diff --git a/mmf/common/sample.py b/mmf/common/sample.py index 40c535918..45089e87b 100644 --- a/mmf/common/sample.py +++ b/mmf/common/sample.py @@ -396,7 +396,11 @@ def to_device( if isinstance(device, str): device = torch.device(device) - if not torch.cuda.is_available(): + # default valude of device_type is cuda + # since other device types such as xla can be passed + # falling back to cpu should only happen when device_type + # is set to cude but cuda is not available. + if not torch.cuda.is_available() and device == "cuda": device = torch.device("cpu") # to_device is specifically for SampleList # if user is passing something custom built diff --git a/mmf/datasets/multi_dataset_loader.py b/mmf/datasets/multi_dataset_loader.py index 37a78add6..21b0518fe 100644 --- a/mmf/datasets/multi_dataset_loader.py +++ b/mmf/datasets/multi_dataset_loader.py @@ -8,7 +8,7 @@ import numpy as np from mmf.common.registry import registry from mmf.utils.build import build_dataloader_and_sampler, build_dataset -from mmf.utils.distributed import broadcast_scalar, is_dist_initialized, is_master +from mmf.utils.distributed import broadcast_scalar, is_dist_initialized, is_master, is_xla from mmf.utils.general import get_batch_size @@ -186,9 +186,15 @@ def _infer_dataset_probabilities(self): def __len__(self): # Since, this is iterator, we need to return total length == number of batches batch_size = get_batch_size() + # Changed the length to accomadate drop_last == True + # drop_last is required if the batch is split intor multiple cores + # some of the cores may not have enough examples. + if is_xla(): + return (self._total_length) // batch_size + else: # This assumes drop_last=False for all loaders. See also # build_dataloader_and_sampler(). - return (self._total_length + batch_size - 1) // batch_size + return (self._total_length + batch_size - 1) // batch_size def __iter__(self): if self._num_datasets == 1: diff --git a/mmf/trainers/callbacks/logistics.py b/mmf/trainers/callbacks/logistics.py index 9211e269e..27666302c 100644 --- a/mmf/trainers/callbacks/logistics.py +++ b/mmf/trainers/callbacks/logistics.py @@ -5,7 +5,7 @@ import torch from mmf.trainers.callbacks.base import Callback from mmf.utils.configuration import get_mmf_env -from mmf.utils.distributed import is_master +from mmf.utils.distributed import is_master, is_xla from mmf.utils.logger import TensorboardLogger, log_progress, setup_output_folder from mmf.utils.timer import Timer @@ -105,7 +105,7 @@ def on_test_end(self, **kwargs): def _summarize_report(self, meter, should_print=True, extra=None): if extra is None: extra = {} - if not is_master(): + if not is_master() and not is_xla(): return if self.training_config.tensorboard: diff --git a/mmf/trainers/core/device.py b/mmf/trainers/core/device.py index a89bfccb6..37864435e 100644 --- a/mmf/trainers/core/device.py +++ b/mmf/trainers/core/device.py @@ -20,9 +20,17 @@ def configure_seed(self) -> None: torch.backends.cudnn.benchmark = False def configure_device(self) -> None: - self.local_rank = self.config.device_id - self.device = self.local_rank - self.distributed = False + if getattr(self.config.training, 'device', 'cuda') == 'xla': + import torch_xla.core.xla_model as xm + self.device = xm.xla_device() + self.distributed = True + self.local_rank = xm.get_local_ordinal() + self.tpu = True + else: + self.tpu = False + self.local_rank = self.config.device_id + self.device = self.local_rank + self.distributed = False # Will be updated later based on distributed setup registry.register("global_device", self.device) @@ -30,9 +38,9 @@ def configure_device(self) -> None: if self.config.distributed.init_method is not None: self.distributed = True self.device = torch.device("cuda", self.local_rank) - elif torch.cuda.is_available(): + elif torch.cuda.is_available() and not self.tpu: self.device = torch.device("cuda") - else: + elif not self.tpu: self.device = torch.device("cpu") registry.register("current_device", self.device) diff --git a/mmf/trainers/core/evaluation_loop.py b/mmf/trainers/core/evaluation_loop.py index ca088a66a..9ca258b89 100644 --- a/mmf/trainers/core/evaluation_loop.py +++ b/mmf/trainers/core/evaluation_loop.py @@ -10,6 +10,7 @@ from mmf.common.report import Report from mmf.common.sample import to_device from mmf.utils.distributed import is_master +from mmf.utils.metsumm import metsumm logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ def evaluation_loop( self.model.eval() disable_tqdm = not use_tqdm or not is_master() combined_report = None + metsumm("Before Validation Start:") for batch in tqdm.tqdm(loader, disable=disable_tqdm): report = self._forward(batch) @@ -44,6 +46,8 @@ def evaluation_loop( combined_report.metrics = self.metrics(combined_report, combined_report) self.update_meter(combined_report, meter, eval_mode=True) + logger.info("Validation Done") + metsumm("After Validation Complete") # enable train mode again self.model.train() diff --git a/mmf/trainers/core/reporting.py b/mmf/trainers/core/reporting.py index 8f6458399..41ed794c9 100644 --- a/mmf/trainers/core/reporting.py +++ b/mmf/trainers/core/reporting.py @@ -51,8 +51,8 @@ def update_dict(self, meter_update_dict, values_dict): if val.dim() == 1: val = val.mean() - if hasattr(val, "item"): - val = val.item() + #if hasattr(val, "item"): + # val = val.item() meter_update_dict.update({key: val}) total_val += val diff --git a/mmf/trainers/core/training_loop.py b/mmf/trainers/core/training_loop.py index 0d779a7e2..17af8bc05 100644 --- a/mmf/trainers/core/training_loop.py +++ b/mmf/trainers/core/training_loop.py @@ -12,7 +12,7 @@ from mmf.common.sample import to_device from mmf.utils.general import clip_gradients from torch import Tensor - +from mmf.utils.metsumm import metsumm logger = logging.getLogger(__name__) @@ -72,8 +72,8 @@ def run_training_epoch(self) -> None: combined_report = None num_batches_for_this_update = 1 - for idx, batch in enumerate(self.train_loader): + for idx, batch in enumerate(self.train_loader): if (idx + 1) % self.training_config.update_frequency == 0: combined_report = None num_batches_for_this_update = min( @@ -84,7 +84,6 @@ def run_training_epoch(self) -> None: # batch execution starts here self.on_batch_start() - self.profile("Batch load time") report = self.run_training_batch(batch, num_batches_for_this_update) @@ -129,7 +128,6 @@ def run_training_epoch(self) -> None: # Validation begin callbacks self.on_validation_start() - logger.info("Evaluation time. Running on full validation set...") # Validation and Early stopping # Create a new meter for this case report, meter = self.evaluation_loop(self.val_loader) @@ -146,7 +144,6 @@ def run_training_epoch(self) -> None: torch.cuda.empty_cache() if stop is True: - logger.info("Early stopping activated") should_break = True if self.num_updates >= self.max_updates: should_break = True @@ -168,7 +165,7 @@ def run_training_batch(self, batch: Tensor, loss_divisor: int) -> None: def _forward(self, batch: Tensor) -> Dict[str, Any]: prepared_batch = self.dataset_loader.prepare_batch(batch) # Move the sample list to device if it isn't as of now. - prepared_batch = to_device(prepared_batch, torch.device("cuda")) + prepared_batch = to_device(prepared_batch, self.device) self.profile("Batch prepare time") # Arguments should be a dict at this point @@ -188,6 +185,7 @@ def _start_update(self): def _backward(self, loss: Tensor) -> None: self.scaler.scale(loss).backward() + self.profile("Backward time") def _finish_update(self): @@ -199,6 +197,12 @@ def _finish_update(self): self.config, scale=self.scaler.get_scale(), ) + if getattr(self.config.training, 'device', 'cuda') == 'xla' and self.config.distributed.world_size > 1: + import torch_xla.core.xla_model as xm + #gradients = xm._fetch_gradients(self.optimizer) + # Assumes no model parallel + #xm.all_reduce('sum', gradients, scale=1.0 / self.config.distributed.world_size) + xm.reduce_gradients(self.optimizer) self.scaler.step(self.optimizer) self.scaler.update() diff --git a/mmf/utils/build.py b/mmf/utils/build.py index a7470eea9..6d651279d 100644 --- a/mmf/utils/build.py +++ b/mmf/utils/build.py @@ -10,10 +10,16 @@ from mmf.common.registry import registry from mmf.datasets.processors.processors import Processor from mmf.utils.configuration import Configuration -from mmf.utils.distributed import is_dist_initialized +from mmf.utils.distributed import is_dist_initialized, is_xla from mmf.utils.general import get_optimizer_parameters from omegaconf import DictConfig, OmegaConf +try: + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl +except ImportError: + xm = None + pl = None ProcessorType = Type[Processor] ProcessorDict = Dict[str, ProcessorType] @@ -152,6 +158,18 @@ def build_dataloader_and_sampler( if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) + if is_xla(): + dataset_type = dataset_instance.dataset_type + shuffle=True + other_args["sampler"] = torch.utils.data.DistributedSampler( + dataset_instance, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=shuffle + ) + other_args.pop("shuffle") + + loader = torch.utils.data.DataLoader( dataset=dataset_instance, pin_memory=pin_memory, @@ -163,6 +181,10 @@ def build_dataloader_and_sampler( **other_args, ) + if is_xla(): + device = xm.xla_device() + loader = pl.MpDeviceLoader(loader, device) + if num_workers >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" diff --git a/mmf/utils/checkpoint.py b/mmf/utils/checkpoint.py index 4cdf88b8c..b6072b027 100644 --- a/mmf/utils/checkpoint.py +++ b/mmf/utils/checkpoint.py @@ -10,7 +10,7 @@ import torch from mmf.common.registry import registry from mmf.utils.configuration import get_mmf_env, load_yaml -from mmf.utils.distributed import is_master, synchronize +from mmf.utils.distributed import is_master, synchronize, is_xla from mmf.utils.download import download_pretrained_model from mmf.utils.file_io import PathManager from mmf.utils.general import updir @@ -22,6 +22,11 @@ except ImportError: git = None +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + logger = logging.getLogger(__name__) @@ -379,15 +384,26 @@ def _get_vcs_fields(self): "git/diff": self.git_repo.git.diff("--no-prefix"), } + + + def save_func(self): + if is_xla(): + return xm.save + else: + return torch.save + def save(self, update, iteration=None, update_best=False): # Only save in main process - if not is_master(): + if not is_master() and not is_xla(): return + logger.info("Checkpoint save operation started!") + if not iteration: iteration = update ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) + best_ckpt_filepath = os.path.join( self.ckpt_foldername, self.ckpt_prefix + "best.ckpt" ) @@ -437,16 +453,20 @@ def save(self, update, iteration=None, update_best=False): git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) + logger.info("Saving checkpoint") with PathManager.open(ckpt_filepath, "wb") as f: - torch.save(ckpt, f) + self.save_func()(ckpt, f) if update_best: + logger.info("Saving best checkpoint") with PathManager.open(best_ckpt_filepath, "wb") as f: - torch.save(ckpt, f) + self.save_func()(ckpt, f) # Save current always + + logger.info("Saving Current checkpoint") with PathManager.open(current_ckpt_filepath, "wb") as f: - torch.save(ckpt, f) + self.save_func()(ckpt, f) # Remove old checkpoints if max_to_keep is set if self.max_to_keep > 0: @@ -454,6 +474,8 @@ def save(self, update, iteration=None, update_best=False): self.remove(self.saved_iterations.pop(0)) self.saved_iterations.append(update) + logger.info("Checkpoint save operation finished!") + def remove(self, update): ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) if PathManager.isfile(ckpt_filepath): @@ -468,6 +490,6 @@ def restore(self): self._load(best_path, force=True) def finalize(self): - if is_master(): + if is_master() or is_xla(): with PathManager.open(self.pth_filepath, "wb") as f: - torch.save(self.trainer.model.state_dict(), f) + self.save_func()(self.trainer.model.state_dict(), f) diff --git a/mmf/utils/configuration.py b/mmf/utils/configuration.py index 6a9ba8bec..5d3f88351 100644 --- a/mmf/utils/configuration.py +++ b/mmf/utils/configuration.py @@ -531,6 +531,9 @@ def _update_specific(self, config): lr = config.learning_rate config.optimizer.params.lr = lr + # TODO: Correct the following issue + # This check is triggered before the config override from commandline is effective + # even after setting training.device = 'xla', it gets triggered. if not torch.cuda.is_available() and "cuda" in config.training.device: warnings.warn( "Device specified is 'cuda' but cuda is not present. " diff --git a/mmf/utils/distributed.py b/mmf/utils/distributed.py index 104b52b8f..91b53d5d2 100644 --- a/mmf/utils/distributed.py +++ b/mmf/utils/distributed.py @@ -9,13 +9,17 @@ import torch from torch import distributed as dist +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None +_USE_XLA = False MAX_SIZE_LIMIT = 65533 BYTE_SIZE = 256 logger = logging.getLogger(__name__) - def synchronize(): if not dist.is_available(): return @@ -31,8 +35,13 @@ def synchronize(): dist.barrier() +def is_xla(): + global _USE_XLA + return _USE_XLA def get_rank(): + if is_xla(): + return xm.get_ordinal() if not dist.is_available(): return 0 if not dist.is_nccl_available(): @@ -51,6 +60,8 @@ def is_dist_initialized(): def get_world_size(): + if is_xla(): + return xm.xrt_world_size() if not dist.is_available(): return 1 if not dist.is_nccl_available(): @@ -66,7 +77,14 @@ def broadcast_tensor(tensor, src=0): return tensor with torch.no_grad(): - dist.broadcast(tensor, src=0) + if is_xla(): + tensor = xm.all_to_all( + tensor.repeat([world_size,1]), + split_dimension=0, + concat_dimension=0, + split_count=world_size)[0] + else: + dist.broadcast(tensor, src=0) return tensor @@ -105,7 +123,11 @@ def gather_tensor(tensor): for _ in range(world_size): tensor_list.append(torch.zeros_like(tensor)) - dist.all_gather(tensor_list, tensor) + if is_xla(): + tensor_list = xm.all_gather(tensor) + tensor_list = tensor_list.view(world_size, *tensor.size()) + else: + dist.all_gather(tensor_list, tensor) tensor_list = torch.stack(tensor_list, dim=0) return tensor_list @@ -122,12 +144,17 @@ def reduce_dict(dictionary): keys, values = zip(*sorted(dictionary.items())) values = torch.stack(values, dim=0) - dist.reduce(values, dst=0) - - if dist.get_rank() == 0: - # only main process gets accumulated, so only divide by - # world_size in this case - values /= world_size + if is_xla(): + values = xm.all_reduce('sum', + [values], + scale=1.0/world_size + )[0] + else: + dist.reduce(values, dst=0) + if dist.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size reduced_dict = {k: v for k, v in zip(keys, values)} return reduced_dict @@ -169,6 +196,11 @@ def byte_tensor_to_object(byte_tensor, max_size=4094): def infer_init_method(config): if config.distributed.init_method is not None: return + + if getattr(config.training, 'device', 'cuda') == 'xla': + global _USE_XLA + _USE_XLA = True + # support torch.distributed.launch if all( key in os.environ @@ -221,9 +253,14 @@ def infer_init_method(config): def distributed_init(config): if config.distributed.world_size == 1: raise ValueError("Cannot initialize distributed with distributed_world_size=1") + logger.info("XLA Mode:{}".format(is_xla())) - if dist.is_initialized(): + if is_xla(): + config.device_id = xm.get_local_ordinal() + config.distributed.rank = xm.get_ordinal() + elif dist.is_initialized(): warnings.warn("Distributed is already initialized, cannot initialize twice!") + config.distributed.rank = dist.get_rank() else: logger.info( f"Distributed Init (Rank {config.distributed.rank}): " @@ -244,8 +281,7 @@ def distributed_init(config): dist.all_reduce(torch.zeros(1).cuda()) suppress_output(is_master()) - - config.distributed.rank = dist.get_rank() + config.distributed.rank = dist.get_rank() return config.distributed.rank diff --git a/mmf/utils/early_stopping.py b/mmf/utils/early_stopping.py index 35f80f65c..971934893 100644 --- a/mmf/utils/early_stopping.py +++ b/mmf/utils/early_stopping.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. import numpy as np import torch -from mmf.utils.distributed import is_master +from mmf.utils.distributed import is_master, is_xla class EarlyStopping: @@ -46,7 +46,7 @@ def __call__(self, update, iteration, meter): Returns: bool -- Tells whether early stopping occurred or not """ - if not is_master(): + if not is_master() and not is_xla(): return False value = meter.meters.get(self.early_stop_criteria, None) diff --git a/mmf/utils/metsumm.py b/mmf/utils/metsumm.py new file mode 100644 index 000000000..3dc5c7b4a --- /dev/null +++ b/mmf/utils/metsumm.py @@ -0,0 +1,8 @@ +def metsumm(stepno=''): + import torch_xla.debug.metrics as met + x = met.metrics_report().split('\n') + for i, line in enumerate(x): + if 'CompileTime' in line or 'aten::' in line: + key = line.split()[-1] + value = x[i+1].split()[-1] + print('step {}, key {}, value {}'.format(stepno, key, value)) diff --git a/mmf_cli/run.py b/mmf_cli/run.py index 9ba4b6f12..9f970438a 100644 --- a/mmf_cli/run.py +++ b/mmf_cli/run.py @@ -9,7 +9,7 @@ from mmf.common.registry import registry from mmf.utils.build import build_config, build_trainer from mmf.utils.configuration import Configuration -from mmf.utils.distributed import distributed_init, get_rank, infer_init_method +from mmf.utils.distributed import distributed_init, get_rank, infer_init_method, is_xla from mmf.utils.env import set_seed, setup_imports from mmf.utils.flags import flags from mmf.utils.general import log_device_names @@ -32,6 +32,7 @@ def main(configuration, init_distributed=False, predict=False): if init_distributed: distributed_init(config) + seed = config.training.seed config.training.seed = set_seed(seed if seed == -1 else seed + get_rank()) registry.register("seed", config.training.seed) @@ -96,6 +97,7 @@ def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False): if config.distributed.init_method is None: infer_init_method(config) + if config.distributed.init_method is not None: if torch.cuda.device_count() > 1 and not config.distributed.no_spawn: config.start_rank = config.distributed.rank @@ -108,15 +110,25 @@ def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False): else: distributed_main(0, configuration, predict) elif config.distributed.world_size > 1: - assert config.distributed.world_size <= torch.cuda.device_count() - port = random.randint(10000, 20000) - config.distributed.init_method = f"tcp://localhost:{port}" - config.distributed.rank = None - torch.multiprocessing.spawn( - fn=distributed_main, - args=(configuration, predict), - nprocs=config.distributed.world_size, - ) + if is_xla(): + import torch_xla.distributed.xla_multiprocessing as xmp + torch.multiprocessing.set_sharing_strategy("file_system") + xmp.spawn( + fn=distributed_main, + args=(configuration, predict), + nprocs=8, # use all 8 TPU cores + start_method='fork' + ) + else: + assert config.distributed.world_size <= torch.cuda.device_count() + port = random.randint(10000, 20000) + config.distributed.init_method = f"tcp://localhost:{port}" + config.distributed.rank = None + torch.multiprocessing.spawn( + fn=distributed_main, + args=(configuration, predict), + nprocs=config.distributed.world_size, + ) else: config.device_id = 0 main(configuration, predict=predict) From d06f077588f81a496171a94e8147bb84b581f270 Mon Sep 17 00:00:00 2001 From: Vaibhav Singh Date: Thu, 17 Dec 2020 23:45:49 +0000 Subject: [PATCH 2/2] fix for initial data/checkpoint download --- mmf/utils/build.py | 9 +++++++-- mmf/utils/distributed.py | 6 ++++-- requirements.txt | 2 -- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mmf/utils/build.py b/mmf/utils/build.py index 6d651279d..78a3e1904 100644 --- a/mmf/utils/build.py +++ b/mmf/utils/build.py @@ -10,7 +10,7 @@ from mmf.common.registry import registry from mmf.datasets.processors.processors import Processor from mmf.utils.configuration import Configuration -from mmf.utils.distributed import is_dist_initialized, is_xla +from mmf.utils.distributed import is_dist_initialized, is_xla, is_master, synchronize from mmf.utils.general import get_optimizer_parameters from omegaconf import DictConfig, OmegaConf @@ -82,7 +82,12 @@ def build_model( if hasattr(model, "build"): model.load_requirements() - model.build() + if is_master(): + model.build() + synchronize() + else: + synchronize() + model.build() model.init_losses() return model diff --git a/mmf/utils/distributed.py b/mmf/utils/distributed.py index 91b53d5d2..07f02b55c 100644 --- a/mmf/utils/distributed.py +++ b/mmf/utils/distributed.py @@ -20,8 +20,10 @@ BYTE_SIZE = 256 logger = logging.getLogger(__name__) -def synchronize(): - if not dist.is_available(): +def synchronize(message='sync-workers'): + if is_xla(): + xm.rendezvous(message) + elif not dist.is_available(): return if not dist.is_nccl_available(): return diff --git a/requirements.txt b/requirements.txt index 329a6a556..363dbed34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -torch==1.6.0 -torchvision==0.7.0 numpy>=1.16.6 tqdm>=4.43.0 demjson==2.2.4