From ee5a5a07b619a1bbf4e8efeadf9ab2651104026f Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Fri, 4 Apr 2025 17:13:21 +0000 Subject: [PATCH 01/12] . --- run_train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/run_train.py b/run_train.py index dac6c39e2..8afe6e583 100644 --- a/run_train.py +++ b/run_train.py @@ -188,7 +188,9 @@ def get_dataloader_from_data_stage( tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) eos_token_id = tokenizer.eos_token_id - assert eos_token_id is not None and data.dataset.return_positions is True, "Tokenizer must have an eos token if return_positions is True" + assert ( + eos_token_id is not None or data.dataset.return_positions is False + ), "Tokenizer must have an eos token if return_positions is True" log_rank( f"[Nanoset] Creating Nanoset with {len(data.dataset.dataset_folder)} dataset folders and {trainer.config.tokens.train_steps * trainer.global_batch_size} train samples", logger=logger, From 4c7d5d67acf62f75bd8cc9b24c59ea17249439ac Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Fri, 4 Apr 2025 17:28:54 +0000 Subject: [PATCH 02/12] logs --- src/nanotron/data/nanoset.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index a8f0599b2..4c41a8553 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -4,6 +4,7 @@ import warnings from typing import Dict, List, Tuple, Union +import numba import numpy as np import torch from datatrove.utils.dataset import DatatroveFolderDataset @@ -14,6 +15,8 @@ from nanotron.data.utils import count_dataset_indexes, normalize from nanotron.logging import log_rank +numba.config.NUMBA_DEBUG_CACHE = 1 + logger = logging.get_logger(__name__) @@ -52,7 +55,9 @@ def __init__( self.sequence_length = sequence_length self.eos_token_id = eos_token_id self.return_positions = return_positions - assert self.return_positions or self.eos_token_id is not None, "If return_positions is True, eos_token_id must be defined" + assert ( + self.return_positions or self.eos_token_id is not None + ), "If return_positions is True, eos_token_id must be defined" # Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise self.token_size = token_size self.train_split_num_samples = train_split_num_samples @@ -69,7 +74,7 @@ def __init__( recursive=False, token_size=self.token_size, shuffle=True, - return_positions=self.return_positions, # if set to True, the position ids are directly build datatrove + return_positions=self.return_positions, # if set to True, the position ids are directly build datatrove eos_token_id=self.eos_token_id, ) ) @@ -186,6 +191,18 @@ def build_nanoset_index(self) -> np.ndarray: # Compute samples per epoch and number of epochs samples_per_epoch = sum(self.dataset_lengths) num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1 + + # Debug Numba cache + logger.info( + f"[Nanoset] Building random access index with {samples_per_epoch} samples per epoch and `dataset_weights` and `dataset_lengths` in config" + ) + if not numba.config.CACHE_DIR: + logger.warning("[Nanoset] Numba cache is disabled") + elif os.path.exists(numba.config.CACHE_DIR): + logger.info( + f"[Nanoset] Cache dir is set to: {numba.config.CACHE_DIR} | Numba cache files: {os.listdir(numba.config.CACHE_DIR)}" + ) + # Build the dataset indexes for 1 epoch dataset_index, dataset_sample_index = build_nanoset_index_helper( n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths From b41155806ef78f00cc2424bac04c93efcbee0aee Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sat, 5 Apr 2025 15:23:37 +0000 Subject: [PATCH 03/12] . --- src/nanotron/serialize/main.py | 44 +++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index b1445b481..27f4fdbe1 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Optional, cast +import psutil import torch from datasets.download.streaming_download_manager import xPath from torch import nn @@ -58,10 +59,20 @@ def save( should_save_model: bool = True, should_save_optimizer: bool = True, should_save_lr_scheduler: bool = True, - sanity_checks: bool = True, + sanity_checks: bool = False, ) -> None: assert isinstance(training_metadata, TrainingMetadata) + process = psutil.Process() + # Only log on rank 0 to avoid flooding logs + if dist.get_rank(parallel_context.world_pg) == 0: + log_rank( + f"SAVE - Initial memory: {process.memory_info().rss / (1024 * 1024):.2f} MB", + logger=logger, + level=logging.INFO, + rank=0, + ) + try: if should_save_config: config.save_as_yaml(root_folder / "config.yaml") @@ -201,6 +212,37 @@ def save( msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}", ) + # TODO: do we need this + # Try more aggressive memory cleanup at the end + if dist.get_rank(parallel_context.world_pg) == 0: + # Perform memory cleanup + # gc.collect() + # torch.cuda.empty_cache() # Clear CUDA cache + + # Try to free memory back to OS on Linux + # try: + # import ctypes + # libc = ctypes.CDLL("libc.so.6") + # libc.malloc_trim(0) + # except Exception as e: + # log_rank(f"Failed to release memory to OS: {e}", logger=logger, level=logging.WARNING, rank=0) + + # log_rank( + # f"SAVE - Final memory after cleanup: {process.memory_info().rss / (1024 * 1024):.2f} MB", + # logger=logger, + # level=logging.INFO, + # rank=0 + # ) + pass + + if dist.get_rank(parallel_context.world_pg) == 0: + log_rank( + f"[Save checkpoint] CPU memory (RSS) after save: {process.memory_info().rss / (1024 * 1024):.2f} MB", + logger=logger, + level=logging.INFO, + rank=0, + ) + dist.barrier(parallel_context.world_pg) From a31694d3588e6d3545e670c1bd2b3835e3caa909 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sat, 5 Apr 2025 15:38:23 +0000 Subject: [PATCH 04/12] . --- slurm_launcher.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/slurm_launcher.py b/slurm_launcher.py index 4cc583b3f..1a17bacd8 100644 --- a/slurm_launcher.py +++ b/slurm_launcher.py @@ -6,7 +6,9 @@ It handles configuration generation, resource allocation, and job submission. Usage: - python slurm_launcher.py --run_name my_experiment --nodes 4 [other options] + python slurm_launcher.py --run my_experiment --nodes 4 [other options] + python slurm_launcher.py --config /fsx/elie_bakouch/smollm3_training/0304-begin-nanotron/smollm3/0304-0B/launcher/configs-dump/isolate-s3.yaml --enable-wandb --run 32k-newindex + python slurm_launcher.py --config /fsx/elie_bakouch/smollm3_training/0304-begin-nanotron/smollm3/0304-0B/launcher/configs-dump/_isolate-fw-edu/fw-edu-3200k.yaml --enable-wandb --run fw-edu-3200k The script will: 1. Generate a Nanotron config based on your parameters @@ -476,6 +478,8 @@ def create_slurm_script( #SBATCH --partition={args.partition} #SBATCH --output={logs_path}/{timestamp}-%x-%j.out #SBATCH --qos={args.qos} +#SBATCH --reservation=smollm +#SBATCH --exclude=ip-26-0-160-[225,242],ip-26-0-161-138,ip-26-0-162-233,ip-26-0-163-[134,147],ip-26-0-164-18,ip-26-0-165-213,ip-26-0-166-[36,68],ip-26-0-167-[51,177,217,245],ip-26-0-168-[95,238],ip-26-0-169-[86,132,207,239,247],ip-26-0-170-[31,132,143,160],ip-26-0-171-[62,88,102,168,230],ip-26-0-172-[73,116,252],ip-26-0-173-7,ip-26-0-174-240,ip-26-0-175-[19,132,165,170,241] #SBATCH --wait-all-nodes=1 # fail if any node is not ready {f"#SBATCH --time={args.time_limit}" if args.time_limit else ""} """ From ba987ed2171b78c5fdab09cdb843a87b3d685330 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sun, 6 Apr 2025 04:07:25 +0000 Subject: [PATCH 05/12] timers --- src/nanotron/data/dataloader.py | 94 ++++---- src/nanotron/data/nanoset.py | 6 +- src/nanotron/logging/__init__.py | 59 +++++ src/nanotron/{logging.py => logging/base.py} | 0 src/nanotron/logging/timers.py | 220 ++++++++++++++++++ .../parallel/pipeline_parallel/engine.py | 5 + src/nanotron/trainer.py | 22 +- 7 files changed, 360 insertions(+), 46 deletions(-) create mode 100644 src/nanotron/logging/__init__.py rename src/nanotron/{logging.py => logging/base.py} (100%) create mode 100644 src/nanotron/logging/timers.py diff --git a/src/nanotron/data/dataloader.py b/src/nanotron/data/dataloader.py index 448830307..0a6185163 100644 --- a/src/nanotron/data/dataloader.py +++ b/src/nanotron/data/dataloader.py @@ -10,6 +10,7 @@ from nanotron.config import Config from nanotron.data.clm_collator import DataCollatorForCLM, DataCollatorForCLMWithPositionIds from nanotron.data.samplers import EmptyInfiniteDataset, get_sampler +from nanotron.logging.timers import nanotron_timer from nanotron.parallel import ParallelContext from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.random import set_random_seed @@ -38,49 +39,62 @@ def sanity_check_dataloader( The same batches after performing sanity checks """ # WARNING: This is called in the middle of the training loop, so make sure it's optimized - for batch in dataloader: - # maybe numpy to torch - batch = {k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v for k, v in batch.items()} - - # non_blocking=True seems to be fine? https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4 - micro_batch = { - k: v - if isinstance(v, TensorPointer) - else v.to("cuda", memory_format=torch.contiguous_format, non_blocking=True) - for k, v in batch.items() - } - - if not config.general.ignore_sanity_checks: - # SANITY CHECK: Check input are not the same across DP - for key, value in sorted(micro_batch.items(), key=lambda x: x[0]): - if isinstance(value, TensorPointer): - continue - - if "mask" in key or "position_ids" in key: - # It's fine if mask is the same across DP - continue - - with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.dp_pg): + dataloader_iter = iter(dataloader) + while True: + try: + # Measure just the time to fetch the next batch + nanotron_timer("dataloader_fetch").start() + batch = next(dataloader_iter) + nanotron_timer("dataloader_fetch").end() + + # Process the batch (this part is not timed as part of the dataloader fetch) + # maybe numpy to torch + batch = {k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v for k, v in batch.items()} + + # non_blocking=True seems to be fine? https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4 + micro_batch = { + k: v + if isinstance(v, TensorPointer) + else v.to("cuda", memory_format=torch.contiguous_format, non_blocking=True) + for k, v in batch.items() + } + + if not config.general.ignore_sanity_checks: + # SANITY CHECK: Check input are not the same across DP + for key, value in sorted(micro_batch.items(), key=lambda x: x[0]): + if isinstance(value, TensorPointer): + continue + + if "mask" in key or "position_ids" in key: + # It's fine if mask is the same across DP + continue + + with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.dp_pg): + assert_tensor_synced_across_pg( + tensor=value, pg=parallel_context.dp_pg, msg=lambda err: f"{key} {err}" + ) + + # SANITY CHECK: Check input are synchronized throughout TP + for key, value in sorted(micro_batch.items(), key=lambda x: x[0]): + if isinstance(value, TensorPointer): + continue assert_tensor_synced_across_pg( - tensor=value, pg=parallel_context.dp_pg, msg=lambda err: f"{key} {err}" + tensor=value, + pg=parallel_context.tp_pg, + msg=lambda err: f"{key} are not synchronized throughout TP {err}", ) - # SANITY CHECK: Check input are synchronized throughout TP - for key, value in sorted(micro_batch.items(), key=lambda x: x[0]): - if isinstance(value, TensorPointer): - continue - assert_tensor_synced_across_pg( - tensor=value, - pg=parallel_context.tp_pg, - msg=lambda err: f"{key} are not synchronized throughout TP {err}", - ) - - # SANITY CHECK: Check that input are synchronized throughout PP - # TODO @thomasw21: That's really hard to test as input gets sharded across the PP, let's assume it works for now. - - # SANITY CHECK: Check that an input only exists on the PP rank responsible for it - # TODO @nouamanetazi: add this test - yield micro_batch + # SANITY CHECK: Check that input are synchronized throughout PP + # TODO @thomasw21: That's really hard to test as input gets sharded across the PP, let's assume it works for now. + + # SANITY CHECK: Check that an input only exists on the PP rank responsible for it + # TODO @nouamanetazi: add this test + + yield micro_batch + + except StopIteration: + # End of dataloader + raise StopIteration def dummy_infinite_data_generator( diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index 4c41a8553..05cfede10 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -15,8 +15,6 @@ from nanotron.data.utils import count_dataset_indexes, normalize from nanotron.logging import log_rank -numba.config.NUMBA_DEBUG_CACHE = 1 - logger = logging.get_logger(__name__) @@ -93,8 +91,8 @@ def __init__( self.dataset_weights ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index - self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() - # self.dataset_index, self.dataset_sample_index = self.new_build_nanoset_index() # TODO: Fix this + # self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + self.dataset_index, self.dataset_sample_index = self.new_build_nanoset_index() # TODO: Fix this self.print_nanoset_info() diff --git a/src/nanotron/logging/__init__.py b/src/nanotron/logging/__init__.py new file mode 100644 index 000000000..5529e9f69 --- /dev/null +++ b/src/nanotron/logging/__init__.py @@ -0,0 +1,59 @@ +# Export logging functionality from base.py +from nanotron.logging.base import ( + # Constants + CRITICAL, + DEBUG, + ERROR, + FATAL, + INFO, + NOTSET, + WARNING, + CategoryFilter, + LoggerWriter, + LogItem, + # Classes + NewLineStreamHandler, + # Functions + get_logger, + get_verbosity, + human_format, + log_libraries_versions, + log_memory, + log_rank, + set_formatter, + set_logger_verbosity_format, + set_ranks_logging_level, + set_verbosity, + warn_once, +) + +# Export timer functionality +from nanotron.logging.timers import TimerRecord, Timers, nanotron_timer + +__all__ = [ + "CRITICAL", + "DEBUG", + "ERROR", + "FATAL", + "INFO", + "NOTSET", + "WARNING", + "CategoryFilter", + "LoggerWriter", + "LogItem", + "NewLineStreamHandler", + "get_logger", + "get_verbosity", + "human_format", + "log_libraries_versions", + "log_memory", + "log_rank", + "set_formatter", + "set_logger_verbosity_format", + "set_ranks_logging_level", + "set_verbosity", + "warn_once", + "TimerRecord", + "Timers", + "nanotron_timer", +] diff --git a/src/nanotron/logging.py b/src/nanotron/logging/base.py similarity index 100% rename from src/nanotron/logging.py rename to src/nanotron/logging/base.py diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py new file mode 100644 index 000000000..5ff10c5a1 --- /dev/null +++ b/src/nanotron/logging/timers.py @@ -0,0 +1,220 @@ +import time +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Optional, Union + +import torch + +from nanotron import distributed as dist +from nanotron import logging + +logger = logging.get_logger(__name__) + + +class TimerType(Enum): + CPU = "cpu" # Regular CPU timer (uses time.time()) + CUDA = "cuda" # CUDA-aware timer (uses CUDA events) + + +@dataclass +class TimerRecord: + """Records timing information for a single timer.""" + + name: str + timer_type: TimerType = TimerType.CPU + start_time: float = 0.0 + end_time: float = 0.0 + running: bool = False + call_count: int = 0 + total_time: float = 0.0 + + # CUDA specific fields + _start_event: Optional[torch.cuda.Event] = None + _end_event: Optional[torch.cuda.Event] = None + _last_elapsed_time: float = 0.0 + + def start(self) -> "TimerRecord": + """Start the timer.""" + if self.running: + logger.warning(f"Timer '{self.name}' already running. Restarting.") + + if self.timer_type == TimerType.CUDA: + if torch.cuda.is_available(): + # Create CUDA events with timing enabled + self._start_event = torch.cuda.Event(enable_timing=True) + self._end_event = torch.cuda.Event(enable_timing=True) + + # Record the start event + self._start_event.record() + else: + logger.warning("CUDA timer requested but CUDA is not available. Falling back to CPU timer.") + self.timer_type = TimerType.CPU + self.start_time = time.time() + else: + self.start_time = time.time() + + self.running = True + return self + + def end(self) -> float: + """End the timer and return elapsed time in seconds.""" + if not self.running: + logger.warning(f"Timer '{self.name}' was not running. Ignoring end call.") + return 0.0 + + elapsed = 0.0 + if self.timer_type == TimerType.CUDA: + if torch.cuda.is_available() and self._start_event is not None and self._end_event is not None: + # Record the end event + self._end_event.record() + + # Waits for all preceding CUDA operations to complete + self._end_event.synchronize() + + # Get the elapsed time in milliseconds and convert to seconds + elapsed = self._start_event.elapsed_time(self._end_event) / 1000.0 + self._last_elapsed_time = elapsed + else: + logger.warning("CUDA timer end called but CUDA events are not available.") + self.timer_type = TimerType.CPU + self.end_time = time.time() + elapsed = self.end_time - self.start_time + else: + self.end_time = time.time() + elapsed = self.end_time - self.start_time + + self.total_time += elapsed + self.call_count += 1 + self.running = False + return elapsed + + def reset(self) -> None: + """Reset the timer.""" + self.start_time = 0.0 + self.end_time = 0.0 + self.running = False + self.call_count = 0 + self.total_time = 0.0 + self._start_event = None + self._end_event = None + self._last_elapsed_time = 0.0 + + @property + def elapsed(self) -> float: + """Get elapsed time in seconds.""" + if not self.running: + if self.timer_type == TimerType.CUDA: + return self._last_elapsed_time + return self.end_time - self.start_time + + # Timer is still running + if self.timer_type == TimerType.CUDA: + if torch.cuda.is_available() and self._start_event is not None: + # Create a temporary end event + tmp_end_event = torch.cuda.Event(enable_timing=True) + tmp_end_event.record() + tmp_end_event.synchronize() + return self._start_event.elapsed_time(tmp_end_event) / 1000.0 + else: + logger.warning("CUDA timer elapsed called but CUDA events are not available.") + return time.time() - self.start_time + else: + return time.time() - self.start_time + + @property + def average_time(self) -> float: + """Get average time per call in seconds.""" + if self.call_count == 0: + return 0.0 + return self.total_time / self.call_count + + +class Timers: + """A collection of timers for tracking execution time in Nanotron.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(Timers, cls).__new__(cls) + cls._instance._timers: Dict[str, TimerRecord] = {} + return cls._instance + + def __call__(self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU) -> TimerRecord: + """Get or create a timer with the given name. + + Args: + name: Name of the timer + timer_type: Type of timer, either TimerType.CPU or TimerType.CUDA + (or 'cpu'/'cuda' strings) + """ + if isinstance(timer_type, str): + timer_type = TimerType(timer_type) + + if name not in self._timers: + self._timers[name] = TimerRecord(name=name, timer_type=timer_type) + elif self._timers[name].timer_type != timer_type: + logger.warning( + f"Timer '{name}' already exists with type {self._timers[name].timer_type}. " + f"Requested type {timer_type} will be ignored." + ) + return self._timers[name] + + def reset_all(self) -> None: + """Reset all timers.""" + for timer in self._timers.values(): + timer.reset() + + def reset(self, name: str) -> None: + """Reset a specific timer.""" + if name in self._timers: + self._timers[name].reset() + + def log(self, name: str, logger=None, rank: Optional[int] = 0, group=None) -> None: + """Log a specific timer on the specified rank.""" + if name not in self._timers: + return + + if logger is None: + logger = logging.get_logger(__name__) + + world_rank = dist.get_rank() if group is None else dist.get_rank(group) + if rank is not None and world_rank != rank: + return + + timer = self._timers[name] + if timer.call_count > 0: + avg_time = timer.average_time * 1000 # Convert to ms + total_time = timer.total_time * 1000 # Convert to ms + logger.info( + f"Timer '{name}' ({timer.timer_type.value}): {total_time:.2f}ms total, " + f"{avg_time:.2f}ms avg, {timer.call_count} calls" + ) + + def log_all(self, logger=None, rank: Optional[int] = 0, group=None) -> None: + """Log all timers on the specified rank.""" + if logger is None: + logger = logging.get_logger(__name__) + + world_rank = dist.get_rank() if group is None else dist.get_rank(group) + if rank is not None and world_rank != rank: + return + + # Sort timers by name for consistent output + sorted_timers = sorted(self._timers.items()) + + if sorted_timers: + logger.info("---- Timing Information ----") + for name, timer in sorted_timers: + if timer.call_count > 0: + avg_time = timer.average_time * 1000 # Convert to ms + total_time = timer.total_time * 1000 # Convert to ms + logger.info( + f"Timer '{name}' ({timer.timer_type.value}): {total_time:.2f}ms total, " + f"{avg_time:.2f}ms avg, {timer.call_count} calls" + ) + logger.info("----------------------------") + + +# Create a singleton instance +nanotron_timer = Timers() diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 9ee523ddc..ad224d9e0 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -9,6 +9,7 @@ from nanotron import logging from nanotron.distributed import ProcessGroup from nanotron.logging import log_rank +from nanotron.logging.timers import nanotron_timer from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model @@ -289,7 +290,9 @@ def train_batch_iter( for micro_batch in batch: context = self._get_fwd_context(model=model) + nanotron_timer("forward", timer_type="cuda").start() output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model) + nanotron_timer("forward", timer_type="cuda").end() # We make `output` a dict if not isinstance(output, dict): @@ -306,7 +309,9 @@ def train_batch_iter( nb_backwards=state.nb_backwards, grad_accumulator=grad_accumulator, ) + nanotron_timer("backward", timer_type="cuda").start() self.backward(context=context, state=state, grad_accumulator=grad_accumulator) + nanotron_timer("backward", timer_type="cuda").end() # Check figure in paper: The remain blocks are all backward and there is only `pg.size() - current_pp_rank - 1` blocks left assert len(state.microbatches_activations_requiring_backward) == pg.size() - current_pp_rank - 1 diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 549f6d9e6..276a2494b 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -56,6 +56,7 @@ log_rank, set_ranks_logging_level, ) +from nanotron.logging.timers import nanotron_timer from nanotron.metrics_logging import MetricsLogger from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad @@ -548,6 +549,7 @@ def training_step( if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger, msg="Before train_batch_iter") + nanotron_timer("train_batch_iter").start() with torch.profiler.record_function("train_batch_iter"): outputs = self.pipeline_engine.train_batch_iter( model=self.model, @@ -556,6 +558,7 @@ def training_step( nb_microbatches=self.n_micro_batches_per_batch, grad_accumulator=self.grad_accumulator, ) + nanotron_timer("train_batch_iter").end() if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger, msg="After train_batch_iter") @@ -569,6 +572,7 @@ def training_step( ), "No fp32_grads_allreduce_handle maybe you're using only a single training process" self.grad_accumulator.fp32_grads_allreduce_handle.wait() + nanotron_timer("sync_gradients").start() # Sync tied weights if not isinstance(self.model, DistributedDataParallel): # Manually sync across DP if it's not handled by DDP @@ -587,8 +591,10 @@ def training_step( parallel_context=self.parallel_context, grad_accumulator=self.grad_accumulator, ) + nanotron_timer("sync_gradients").end() # Clip gradients + nanotron_timer("clip_gradients").start() if self.config.optimizer.clip_grad is not None: # Unwrap DDP named_parameters = [ @@ -602,6 +608,7 @@ def training_step( grad_accumulator=self.grad_accumulator, max_norm=self.config.optimizer.clip_grad, ) + nanotron_timer("clip_gradients").end() # Compute DP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): @@ -635,8 +642,10 @@ def training_step( ) # Apply gradient + nanotron_timer("optimizer_step").start() self.optimizer.step() self.optimizer.zero_grad() + nanotron_timer("optimizer_step").end() # Update the learning rate self.lr_scheduler.step() @@ -668,6 +677,8 @@ def train_step_logs( dist.barrier() torch.cuda.synchronize() elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + # if elapsed_time_per_iteration_ms > 600 and self.iteration_step >10: + # print(f"elapsed_time_per_iteration_ms: {elapsed_time_per_iteration_ms}") tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length @@ -705,13 +716,20 @@ def train_step_logs( LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), # LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), LogItem("eta", str(datetime.timedelta(seconds=eta_seconds))), + LogItem("timers/tbi", nanotron_timer("train_batch_iter").elapsed, ".2f"), + LogItem("timers/forward", nanotron_timer("forward").elapsed, ".2f"), + LogItem("timers/backward", nanotron_timer("backward").elapsed, ".2f"), + LogItem("timers/sync_gradients", nanotron_timer("sync_gradients").elapsed, ".2f"), + LogItem("timers/clip_gradients", nanotron_timer("clip_gradients").elapsed, ".2f"), + LogItem("timers/optimizer_step", nanotron_timer("optimizer_step").elapsed, ".2f"), + LogItem("timers/dataloader_fetch", nanotron_timer("dataloader_fetch").elapsed, ".2f"), ] if z_loss_avg is not None: basic_log_entries.insert(6, LogItem("z_loss", z_loss_avg.item(), "human_format")) # , "1.6E"), if self.config.optimizer.clip_grad is not None: - basic_log_entries.append( - LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format") + basic_log_entries.insert( + 5, LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format") ) # , ".3f")) # Console logging only on logger ranks From 79a748d0b6819a986a58164c09ae301919208dc5 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sun, 6 Apr 2025 04:17:44 +0000 Subject: [PATCH 06/12] . --- src/nanotron/logging/timers.py | 111 +++++++++++++++++++-------------- 1 file changed, 63 insertions(+), 48 deletions(-) diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py index 5ff10c5a1..1827c1b3e 100644 --- a/src/nanotron/logging/timers.py +++ b/src/nanotron/logging/timers.py @@ -1,7 +1,7 @@ import time -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import torch @@ -26,12 +26,13 @@ class TimerRecord: end_time: float = 0.0 running: bool = False call_count: int = 0 - total_time: float = 0.0 + + # For CPU timers we still track total_time + _cpu_total_time: float = 0.0 # CUDA specific fields - _start_event: Optional[torch.cuda.Event] = None - _end_event: Optional[torch.cuda.Event] = None - _last_elapsed_time: float = 0.0 + _cuda_events: List[tuple[torch.cuda.Event, torch.cuda.Event]] = field(default_factory=list) + _current_start_event: Optional[torch.cuda.Event] = None def start(self) -> "TimerRecord": """Start the timer.""" @@ -40,12 +41,9 @@ def start(self) -> "TimerRecord": if self.timer_type == TimerType.CUDA: if torch.cuda.is_available(): - # Create CUDA events with timing enabled - self._start_event = torch.cuda.Event(enable_timing=True) - self._end_event = torch.cuda.Event(enable_timing=True) - - # Record the start event - self._start_event.record() + # Create a new start event - we'll create the end event when end() is called + self._current_start_event = torch.cuda.Event(enable_timing=True) + self._current_start_event.record() else: logger.warning("CUDA timer requested but CUDA is not available. Falling back to CPU timer.") self.timer_type = TimerType.CPU @@ -56,37 +54,32 @@ def start(self) -> "TimerRecord": self.running = True return self - def end(self) -> float: - """End the timer and return elapsed time in seconds.""" + def end(self) -> None: + """End the timer, but don't compute elapsed time yet.""" if not self.running: logger.warning(f"Timer '{self.name}' was not running. Ignoring end call.") - return 0.0 + return - elapsed = 0.0 if self.timer_type == TimerType.CUDA: - if torch.cuda.is_available() and self._start_event is not None and self._end_event is not None: - # Record the end event - self._end_event.record() - - # Waits for all preceding CUDA operations to complete - self._end_event.synchronize() - - # Get the elapsed time in milliseconds and convert to seconds - elapsed = self._start_event.elapsed_time(self._end_event) / 1000.0 - self._last_elapsed_time = elapsed + if torch.cuda.is_available() and self._current_start_event is not None: + # Create and record an end event + end_event = torch.cuda.Event(enable_timing=True) + end_event.record() + + # Store the start/end event pair for later querying + self._cuda_events.append((self._current_start_event, end_event)) + self._current_start_event = None else: logger.warning("CUDA timer end called but CUDA events are not available.") self.timer_type = TimerType.CPU self.end_time = time.time() - elapsed = self.end_time - self.start_time + self._cpu_total_time += self.end_time - self.start_time else: self.end_time = time.time() - elapsed = self.end_time - self.start_time + self._cpu_total_time += self.end_time - self.start_time - self.total_time += elapsed self.call_count += 1 self.running = False - return elapsed def reset(self) -> None: """Reset the timer.""" @@ -94,36 +87,61 @@ def reset(self) -> None: self.end_time = 0.0 self.running = False self.call_count = 0 - self.total_time = 0.0 - self._start_event = None - self._end_event = None - self._last_elapsed_time = 0.0 + self._cpu_total_time = 0.0 + self._cuda_events = [] + self._current_start_event = None @property def elapsed(self) -> float: - """Get elapsed time in seconds.""" + """Get elapsed time in seconds for the current timer.""" if not self.running: - if self.timer_type == TimerType.CUDA: - return self._last_elapsed_time - return self.end_time - self.start_time + if self.timer_type == TimerType.CPU: + return self.end_time - self.start_time + + # For CUDA timers, we need to synchronize to get the last elapsed time + if not self._cuda_events: + return 0.0 + + # Get the last event pair + start_event, end_event = self._cuda_events[-1] + end_event.synchronize() # Make sure the event is complete + return start_event.elapsed_time(end_event) / 1000.0 # Convert ms to sec # Timer is still running if self.timer_type == TimerType.CUDA: - if torch.cuda.is_available() and self._start_event is not None: - # Create a temporary end event + if torch.cuda.is_available() and self._current_start_event is not None: + # Create a temporary end event to measure elapsed time so far tmp_end_event = torch.cuda.Event(enable_timing=True) tmp_end_event.record() tmp_end_event.synchronize() - return self._start_event.elapsed_time(tmp_end_event) / 1000.0 + return self._current_start_event.elapsed_time(tmp_end_event) / 1000.0 else: - logger.warning("CUDA timer elapsed called but CUDA events are not available.") return time.time() - self.start_time else: return time.time() - self.start_time + @property + def total_time(self) -> float: + """ + Get total time in seconds across all calls. + Warning: For CUDA timers, this will synchronize all events! + """ + if self.timer_type == TimerType.CPU: + return self._cpu_total_time + + # For CUDA timers, we need to sum up all the event pairs + total = 0.0 + for start_event, end_event in self._cuda_events: + end_event.synchronize() # Make sure the event is complete + total += start_event.elapsed_time(end_event) / 1000.0 # Convert ms to sec + return total + @property def average_time(self) -> float: - """Get average time per call in seconds.""" + """ + Get average time per call in seconds. + Warning: For CUDA timers, this will synchronize all events! + """ if self.call_count == 0: return 0.0 return self.total_time / self.call_count @@ -153,11 +171,6 @@ def __call__(self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU) if name not in self._timers: self._timers[name] = TimerRecord(name=name, timer_type=timer_type) - elif self._timers[name].timer_type != timer_type: - logger.warning( - f"Timer '{name}' already exists with type {self._timers[name].timer_type}. " - f"Requested type {timer_type} will be ignored." - ) return self._timers[name] def reset_all(self) -> None: @@ -184,6 +197,7 @@ def log(self, name: str, logger=None, rank: Optional[int] = 0, group=None) -> No timer = self._timers[name] if timer.call_count > 0: + # This will trigger synchronization for CUDA timers avg_time = timer.average_time * 1000 # Convert to ms total_time = timer.total_time * 1000 # Convert to ms logger.info( @@ -207,6 +221,7 @@ def log_all(self, logger=None, rank: Optional[int] = 0, group=None) -> None: logger.info("---- Timing Information ----") for name, timer in sorted_timers: if timer.call_count > 0: + # This will trigger synchronization for CUDA timers avg_time = timer.average_time * 1000 # Convert to ms total_time = timer.total_time * 1000 # Convert to ms logger.info( From 64eacf3b046fa55b2dfa1a1b8ffb30e824311ac3 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sun, 6 Apr 2025 11:10:06 +0000 Subject: [PATCH 07/12] . --- src/nanotron/trainer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 276a2494b..553b34f2f 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -350,6 +350,7 @@ def pre_training(self, *args, **kwargs): project=self.config.general.project, name=run_name, config={"nanotron_config": self.config.as_dict()}, + x_disable_meta=True, ) log_rank( f"Initialized wandb run '{run_name}' for TP rank {tp_rank}", @@ -364,6 +365,13 @@ def pre_training(self, *args, **kwargs): project=self.config.general.project, name=run_name, config={"nanotron_config": self.config.as_dict()}, + settings=wandb.Settings( + x_stats_sampling_interval=1.0, # TODO: put back to default 15.0 + x_stats_disk_paths=("/scratch", "/fsx/nouamane/"), + x_stats_open_metrics_endpoints={"dcgm": "http://localhost:9104/metrics"}, + x_stats_open_metrics_filters=["DCGM_FI_"], + x_file_stream_transmit_interval=1.0, + ), ) log_rank( f"Initialized wandb run '{run_name}' for TP rank {tp_rank}", @@ -716,13 +724,6 @@ def train_step_logs( LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), # LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), LogItem("eta", str(datetime.timedelta(seconds=eta_seconds))), - LogItem("timers/tbi", nanotron_timer("train_batch_iter").elapsed, ".2f"), - LogItem("timers/forward", nanotron_timer("forward").elapsed, ".2f"), - LogItem("timers/backward", nanotron_timer("backward").elapsed, ".2f"), - LogItem("timers/sync_gradients", nanotron_timer("sync_gradients").elapsed, ".2f"), - LogItem("timers/clip_gradients", nanotron_timer("clip_gradients").elapsed, ".2f"), - LogItem("timers/optimizer_step", nanotron_timer("optimizer_step").elapsed, ".2f"), - LogItem("timers/dataloader_fetch", nanotron_timer("dataloader_fetch").elapsed, ".2f"), ] if z_loss_avg is not None: basic_log_entries.insert(6, LogItem("z_loss", z_loss_avg.item(), "human_format")) # , "1.6E"), @@ -737,6 +738,9 @@ def train_step_logs( assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" self.loggerwriter.add_scalars_from_list(basic_log_entries, self.iteration_step) + for timer_name, timer in nanotron_timer.items(): + basic_log_entries.append(LogItem(f"timers/{timer_name}", timer.elapsed, ".2f")) + # WandB logging - determine if this rank should log to wandb should_log_to_wandb = wandb is not None and ( (tp_size > 1 and dp_rank == 0 and self.metrics_logging.log_level > 0) From e13130f178d2afa6d78e0c2b5d52a909b87661bf Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Sun, 6 Apr 2025 11:49:33 +0000 Subject: [PATCH 08/12] . --- src/nanotron/data/clm_collator.py | 4 ++ src/nanotron/logging/timers.py | 49 ++++++++++++- src/nanotron/models/qwen.py | 114 +++++++++++++++++++++--------- 3 files changed, 133 insertions(+), 34 deletions(-) diff --git a/src/nanotron/data/clm_collator.py b/src/nanotron/data/clm_collator.py index 67103e0f2..1b59f0b66 100644 --- a/src/nanotron/data/clm_collator.py +++ b/src/nanotron/data/clm_collator.py @@ -282,4 +282,8 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni # assert v.is_contiguous(), f"{k} is not contiguous" # assert not v.is_cuda, f"{k} is in cuda. Bad for pinning memory" + # debug: import tokenizer and print tokenizer.decode(result["input_ids"][0]) + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + # print(tokenizer.decode(result["input_ids"][0])) return result diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py index 1827c1b3e..b4d4a5e14 100644 --- a/src/nanotron/logging/timers.py +++ b/src/nanotron/logging/timers.py @@ -34,6 +34,16 @@ class TimerRecord: _cuda_events: List[tuple[torch.cuda.Event, torch.cuda.Event]] = field(default_factory=list) _current_start_event: Optional[torch.cuda.Event] = None + def __enter__(self): + """Context manager support: Start the timer when entering a context.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager support: End the timer when exiting a context.""" + self.end() + return False # Don't suppress exceptions + def start(self) -> "TimerRecord": """Start the timer.""" if self.running: @@ -161,6 +171,11 @@ def __new__(cls): def __call__(self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU) -> TimerRecord: """Get or create a timer with the given name. + Can be used as a decorator, context manager, or directly: + - @nanotron_timer("name") # As decorator + - with nanotron_timer("name"): ... # As context manager + - nanotron_timer("name").start(); ...; nanotron_timer("name").end() # Direct use + Args: name: Name of the timer timer_type: Type of timer, either TimerType.CPU or TimerType.CUDA @@ -169,9 +184,38 @@ def __call__(self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU) if isinstance(timer_type, str): timer_type = TimerType(timer_type) + if callable(name) and timer_type == TimerType.CPU: + # Being used as a decorator with default settings + func = name + timer_name = func.__name__ + return self._create_timer_decorator(timer_name, TimerType.CPU)(func) + if name not in self._timers: self._timers[name] = TimerRecord(name=name, timer_type=timer_type) - return self._timers[name] + + # Check if we're being called as a decorator + if not callable(name): + timer_record = self._timers[name] + # Return the timer which can be used directly or as a context manager + return timer_record + + # If we get here, we're being called as @nanotron_timer("name", timer_type) + return self._create_timer_decorator(name, timer_type) + + def _create_timer_decorator(self, name, timer_type): + """Create a decorator that times the execution of a function.""" + + def decorator(func): + import functools + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with self(name, timer_type): + return func(*args, **kwargs) + + return wrapper + + return decorator def reset_all(self) -> None: """Reset all timers.""" @@ -230,6 +274,9 @@ def log_all(self, logger=None, rank: Optional[int] = 0, group=None) -> None: ) logger.info("----------------------------") + def items(self): + return self._timers.items() + # Create a singleton instance nanotron_timer = Timers() diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index 4ac163690..fdb8eeb42 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -10,7 +10,7 @@ from nanotron import logging from nanotron.config import Config, ParallelismArgs from nanotron.config.models_config import Qwen2Config, RandomInit, SpectralMupInit -from nanotron.logging import log_rank +from nanotron.logging import log_rank, nanotron_timer from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN from nanotron.nn.attention import ALL_ATTENTION_FUNCTIONS, get_attention_mask @@ -209,13 +209,15 @@ def __init__( self.attention = CoreAttention(config, tp_pg, cp_pg, layer_idx) self.simple_causal_mask = True self._use_qkv_packed = config._use_qkv_packed - + self.layer_idx = layer_idx # TODO: support doc masking / SWA / SFT / inference def forward( self, hidden_states: torch.Tensor, # [batch_size*seq_length, hidden_size] position_ids: torch.Tensor, # [batch_size, seq_length] where -1 is padding + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, ): # [0, 1, 2, 3, 4, 0, 1, 2, -1, -1, -1] # 2 documents with 5 and 3 tokens then padding # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 1 document with 11 tokens @@ -225,9 +227,10 @@ def forward( seq_length = position_ids.shape[1] position_ids = position_ids.view(-1) # [batch_size*seq_length] - qkv = self.qkv_proj(hidden_states) + with nanotron_timer(f"qkv_proj_{self.layer_idx}", "cuda"): + qkv = self.qkv_proj(hidden_states) if self._use_qkv_packed: - attn_output = self._forward_packed(qkv, seq_length, position_ids) + attn_output = self._forward_packed(qkv, seq_length, position_ids, cu_seqlens, max_seqlen) else: q, k, v = qkv.split( [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 @@ -248,31 +251,29 @@ def forward( ) # [b*s, num_kv_heads, head_dim] attn_output = self.attention(q, k, v, position_ids=position_ids, seq_length=seq_length) - output = self.o_proj(attn_output) - return {"hidden_states": output, "position_ids": position_ids.view(-1, seq_length)} - - def _forward_packed(self, qkv, seq_length, position_ids): - q = qkv[..., : self.local_num_heads * self.head_dim] # Not contiguous, similar to flash_attn - kv = qkv[..., self.local_num_heads * self.head_dim :] # Not contiguous, similar to flash_attn - q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) - kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) - q, kv = self.rotary_emb( - q, kv, seqlen_offset=0, max_seqlen=None - ) # TODO: should we use position_ids here? flash_attn doesn't - q = q.view(-1, self.local_num_heads, self.head_dim) - kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) - # Compute cu_seqlens - start_indices = torch.where(position_ids == 0)[0] - cu_seqlens = torch.cat( - [start_indices, torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device)] - ).to(torch.int32) + with nanotron_timer(f"o_proj_{self.layer_idx}", "cuda"): + output = self.o_proj(attn_output) + return {"hidden_states": output} + + def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens, max_seqlen): + with nanotron_timer(f"rotary_{self.layer_idx}", "cuda"): + q = qkv[..., : self.local_num_heads * self.head_dim] # Not contiguous, similar to flash_attn + kv = qkv[..., self.local_num_heads * self.head_dim :] # Not contiguous, similar to flash_attn + q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) + kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) + q, kv = self.rotary_emb( + q, kv, seqlen_offset=0, max_seqlen=None + ) # TODO: should we use position_ids here? flash_attn doesn't + q = q.view(-1, self.local_num_heads, self.head_dim) + kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) - max_seqlen = seq_length # TODO: should this be max position_ids? + # max_seqlen = seq_length # TODO: should this be max position_ids? assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) + nanotron_timer(f"flash_attn_{self.layer_idx}", "cuda").start() attn_output = flash_attn_varlen_kvpacked_func( q, kv, @@ -287,8 +288,12 @@ def _forward_packed(self, qkv, seq_length, position_ids): window_size=(-1, -1), # TODO: fix deterministic=False, ) # Not contiguous, similar to flash_attn + nanotron_timer(f"flash_attn_{self.layer_idx}", "cuda").end() # flash_attn use rearrange instead of reshape https://github.com/Dao-AILab/flash-attention/blob/1a58058a6da83bd7baaf4c512e8a1abe0240bb77/flash_attn/modules/mha.py#L730 - return attn_output.reshape(-1, self.local_num_heads * self.head_dim) # [b*s, num_heads*head_dim] + nanotron_timer(f"reshape_{self.layer_idx}", "cuda").start() + attn_output = attn_output.reshape(-1, self.local_num_heads * self.head_dim) # [b*s, num_heads*head_dim] + nanotron_timer(f"reshape_{self.layer_idx}", "cuda").end() + return attn_output # [b*s, num_heads*head_dim] class Qwen2MLP(nn.Module): @@ -567,46 +572,65 @@ def __init__( ) self.recompute_layer = parallel_config.recompute_layer + self.layer_idx = layer_idx def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], # [batch_size*seq_length, hidden_size] position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, ) -> List[Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - output = self.attn(hidden_states=hidden_states, position_ids=position_ids) + nanotron_timer(f"attn_{self.layer_idx}", "cuda").start() + output = self.attn( + hidden_states=hidden_states, position_ids=position_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + nanotron_timer(f"attn_{self.layer_idx}", "cuda").end() hidden_states = output["hidden_states"] hidden_states = hidden_states + residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) + nanotron_timer(f"mlp_{self.layer_idx}", "cuda").start() hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + nanotron_timer(f"mlp_{self.layer_idx}", "cuda").end() hidden_states = hidden_states + residual - return hidden_states, output["position_ids"] + return hidden_states, position_ids, cu_seqlens, max_seqlen def _checkpointed_forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - return CheckpointFunction.apply(self._core_forward, True, hidden_states, position_ids) + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, True, hidden_states, position_ids, cu_seqlens, max_seqlen) def forward( self, hidden_states: Union[torch.Tensor, TensorPointer], position_ids: Union[torch.Tensor, TensorPointer], + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: if self.recompute_layer and not isinstance(hidden_states, TensorPointer): - hidden_states, position_ids = self._checkpointed_forward(hidden_states, position_ids) + hidden_states, position_ids, cu_seqlens, max_seqlen = self._checkpointed_forward( + hidden_states, position_ids, cu_seqlens, max_seqlen + ) else: - hidden_states, position_ids = self._core_forward(hidden_states, position_ids) + hidden_states, position_ids, cu_seqlens, max_seqlen = self._core_forward( + hidden_states, position_ids, cu_seqlens, max_seqlen + ) return { "hidden_states": hidden_states, "position_ids": position_ids, + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, } @@ -674,8 +698,8 @@ def __init__( "cp_pg": parallel_context.cp_pg, "layer_idx": layer_idx, }, - module_input_keys={"hidden_states", "position_ids"}, - module_output_keys={"hidden_states", "position_ids"}, + module_input_keys={"hidden_states", "position_ids", "cu_seqlens", "max_seqlen"}, + module_output_keys={"hidden_states", "position_ids", "cu_seqlens", "max_seqlen"}, ) for layer_idx in range(config.num_hidden_layers) ] @@ -711,18 +735,38 @@ def forward( input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding ): + nanotron_timer("token_position_embeddings", "cuda").start() output = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) + nanotron_timer("token_position_embeddings", "cuda").end() + + # Compute cu_seqlens + start_indices = (position_ids.view(-1) == 0).nonzero(as_tuple=True)[ + 0 + ] # equivalent to torch.where(position_ids == 0)[0] + cu_seqlens = torch.cat( + [start_indices, torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device)], + dim=0, + ).to(torch.int32) decoder_states = { "hidden_states": output["input_embeds"], "position_ids": output["position_ids"], + "cu_seqlens": cu_seqlens, + # "max_seqlen": position_ids.max().item() + 1, + "max_seqlen": int( + position_ids.shape[1] + ), # TODO @nouamane: check which one flash_attn uses. I found no tput difference } - for decoder_layer in self.decoder: + for i, decoder_layer in enumerate(self.decoder): + nanotron_timer(f"decoder_{i}", "cuda").start() decoder_states = decoder_layer(**decoder_states) + nanotron_timer(f"decoder_{i}", "cuda").end() hidden_states = self.final_layer_norm(input=decoder_states["hidden_states"])["hidden_states"] + nanotron_timer("lm_head", "cuda").start() sharded_logits = self.lm_head(x=hidden_states)["logits"] + nanotron_timer("lm_head", "cuda").end() return sharded_logits @@ -849,15 +893,19 @@ def forward( label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + nanotron_timer("model", "cuda").start() sharded_logits = self.model( input_ids=input_ids, position_ids=position_ids, ) + nanotron_timer("model", "cuda").end() + nanotron_timer("loss", "cuda").start() loss = self.loss( sharded_logits=sharded_logits, label_ids=label_ids, label_mask=label_mask, ) + nanotron_timer("loss", "cuda").end() if self.config.z_loss_enabled: return {"loss": loss["loss"], "z_loss": loss["z_loss"]} else: From d0fa217abd077d1eb332ddbc4a7f1256a2da2717 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 7 Apr 2025 11:24:19 +0000 Subject: [PATCH 09/12] deprecate reduce_scatter_coalesced --- src/nanotron/optim/gradient_accumulator.py | 54 +++++++++++++--------- src/nanotron/parallel/tied_parameters.py | 4 +- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 8107b46e6..9d1b59b6a 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -138,20 +138,23 @@ def sync_gradients_across_dp(self, dp_pg: dist.ProcessGroup, reduce_op: dist.Red assert hasattr(self, "param_name_to_offsets") named_offsets = sorted(self.param_name_to_offsets.items(), key=lambda x: x[0]) flat_grad_buffers = [self.fp32_grad_buffers[name]["fp32_grad"].view(-1) for name, _ in named_offsets] - dist.reduce_scatter_coalesced( - output_tensor_list=[ - flat_grad_buffer[start_offset:end_offset] - for (_, (start_offset, end_offset)), flat_grad_buffer in zip(named_offsets, flat_grad_buffers) - ], - input_tensor_lists=[ - torch.split( - flat_grad_buffer, - split_size_or_sections=len(self.fp32_grad_buffers[name]["fp32_grad"].view(-1)) // dp_pg.size(), + with dist._coalescing_manager(group=dp_pg, async_ops=True): + for output_tensor, input_tensor_list in zip( + flat_grad_buffers, + [ + torch.split( + flat_grad_buffer, + split_size_or_sections=len(flat_grad_buffer) // dp_pg.size(), + ) + for flat_grad_buffer in flat_grad_buffers + ], + ): + dist.reduce_scatter_tensor( + output=output_tensor, + input=torch.cat(input_tensor_list), # Stack the split tensors back + op=reduce_op, + group=dp_pg, ) - for (name, _), flat_grad_buffer in zip(named_offsets, flat_grad_buffers) - ], - group=dp_pg, - ) else: dist.all_reduce(self._contiguous_fp32_grad_buffer, op=reduce_op, group=dp_pg) @@ -367,20 +370,25 @@ def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.f torch.split(grad_buffer, split_size_or_sections=len(grad_buffer) // dp_pg.size()) for grad_buffer in grad_buffer_tensor_list ] - dist.reduce_scatter_coalesced( - output_tensor_list=output_tensor_list, - input_tensor_lists=input_tensor_lists, - op=reduce_op, - group=dp_pg, - async_op=True, - ) + with dist._coalescing_manager(group=dp_pg, async_ops=True): + for output_tensor, input_tensor_list in zip(output_tensor_list, input_tensor_lists): + dist.reduce_scatter_tensor( + output=output_tensor, + input=torch.cat(input_tensor_list), # Stack the split tensors back + op=reduce_op, + group=dp_pg, + ) else: grad_buffer_tensor_list = [ accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters() ] - accumulator.fp32_grads_allreduce_handle = dist.all_reduce_coalesced( - grad_buffer_tensor_list, group=dp_pg, async_op=True, op=reduce_op - ) + with dist._coalescing_manager(group=dp_pg, async_ops=True) as cm: + for tensor in grad_buffer_tensor_list: + dist.all_reduce(tensor, op=reduce_op, group=dp_pg) + + # Store the last work handle which will complete after all previous ones + accumulator.fp32_grads_allreduce_handle = cm.works[-1] if cm.works else None + # we shouldn't wait for this future for the rest of the backward # with torch.cuda.stream(s): diff --git a/src/nanotron/parallel/tied_parameters.py b/src/nanotron/parallel/tied_parameters.py index b2a1f7f52..ee6e22bef 100644 --- a/src/nanotron/parallel/tied_parameters.py +++ b/src/nanotron/parallel/tied_parameters.py @@ -164,4 +164,6 @@ def sync_tied_weights_gradients( group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)] = [tied_grad] for (group_ranks, reduce_op), tensors in group_ranks_and_reduce_op_to_tensors_to_reduce.items(): - dist.all_reduce_coalesced(tensors=tensors, op=reduce_op, group=parallel_context.world_ranks_to_pg[group_ranks]) + with dist._coalescing_manager(group=parallel_context.world_ranks_to_pg[group_ranks], async_ops=False): + for tensor in tensors: + dist.all_reduce(tensor, op=reduce_op, group=parallel_context.world_ranks_to_pg[group_ranks]) From a238b77968938e5787351ba026a8bedf379d9da0 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 7 Apr 2025 11:45:53 +0000 Subject: [PATCH 10/12] deprecate reduce_scatter_coalesced --- src/nanotron/distributed.py | 5 +++++ src/nanotron/optim/gradient_accumulator.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 3aa484aa7..f96eafae4 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -17,6 +17,11 @@ # Note: When debugging communication hangs, try decreasing this timeout. default_pg_timeout = datetime.timedelta(minutes=20) +try: + from torch.distributed.distributed_c10d import _coalescing_manager +except ImportError: + _coalescing_manager = None + def new_group( # pylint: disable=function-redefined ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 9d1b59b6a..1b4123a7d 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -380,14 +380,14 @@ def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.f ) else: grad_buffer_tensor_list = [ - accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters() + accumulator.get_grad_buffer(param_id_to_name[id(param)]) for param in bucket.parameters() ] with dist._coalescing_manager(group=dp_pg, async_ops=True) as cm: for tensor in grad_buffer_tensor_list: dist.all_reduce(tensor, op=reduce_op, group=dp_pg) # Store the last work handle which will complete after all previous ones - accumulator.fp32_grads_allreduce_handle = cm.works[-1] if cm.works else None + accumulator.fp32_grads_allreduce_handle = cm.works # we shouldn't wait for this future for the rest of the backward From 8029ccf2a901e9fd7dfee43bc57b228053283f79 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 7 Apr 2025 14:07:28 +0000 Subject: [PATCH 11/12] . --- examples/config_qwen.yaml | 193 ++-- run_train.py | 138 ++- src/nanotron/config/config.py | 8 +- src/nanotron/data/clm_collator.py | 12 +- src/nanotron/data/dataloader_builder.py | 9 +- src/nanotron/data/nanoset.py | 4 + src/nanotron/data/nemo_dataset/Makefile | 23 + src/nanotron/data/nemo_dataset/__init__.py | 940 ++++++++++++++++++ .../data/nemo_dataset/blendable_dataset.py | 210 ++++ .../data/nemo_dataset/dataset_utils.py | 105 ++ src/nanotron/data/nemo_dataset/helpers.cpp | 730 ++++++++++++++ .../data/nemo_dataset/indexed_dataset.py | 344 +++++++ src/nanotron/data/s3_utils.py | 84 ++ src/nanotron/data/samplers.py | 255 +++++ src/nanotron/data/tokenized_bytes.py | 447 +++++++++ src/nanotron/logging/timers.py | 27 +- 16 files changed, 3384 insertions(+), 145 deletions(-) create mode 100644 src/nanotron/data/nemo_dataset/Makefile create mode 100644 src/nanotron/data/nemo_dataset/__init__.py create mode 100644 src/nanotron/data/nemo_dataset/blendable_dataset.py create mode 100644 src/nanotron/data/nemo_dataset/dataset_utils.py create mode 100644 src/nanotron/data/nemo_dataset/helpers.cpp create mode 100644 src/nanotron/data/nemo_dataset/indexed_dataset.py create mode 100644 src/nanotron/data/s3_utils.py create mode 100644 src/nanotron/data/tokenized_bytes.py diff --git a/examples/config_qwen.yaml b/examples/config_qwen.yaml index b78820afa..bd94110d0 100644 --- a/examples/config_qwen.yaml +++ b/examples/config_qwen.yaml @@ -1,79 +1,112 @@ checkpoints: - checkpoint_interval: 100000 - checkpoints_path: checkpoints + checkpoint_interval: 10 + checkpoints_path: checkpoints/smollm3-test-tps-48nn-elie-config checkpoints_path_is_shared_file_system: false load_lr_scheduler: true load_optimizer: true - resume_checkpoint_path: null - save_final_state: false + resume_checkpoint_path: s3://smollm3/pre-training-final/tests/smollm3-test-tps-48nn-elie-config + save_final_state: true save_initial_state: false data_stages: - data: - # dataset: null dataset: dataset_folder: - # - /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged - # - /fsx/loubna/datasets/llama_tokenized/other_sources/dclm/ - # - /fsx/loubna/datasets/llama_tokenized/pes2o/standard + - /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged + - /fsx/loubna/datasets/llama_tokenized/dclm_merged/ + - /fsx/loubna/datasets/llama_tokenized/pes2o/standard - /fsx/loubna/datasets/llama_tokenized/other_sources/wiki - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-fra_Latn/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-spa_Latn/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-deu_Latn/ - # - /fsx/loubna/datasets/llama_tokenized/fw2-hq-ita_Latn/standard - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-por_Latn/ - # - /fsx/loubna/datasets/llama_tokenized/fw2-hq-cmn_Hani/standard - # - /fsx/loubna/datasets/llama_tokenized/fw2-hq-rus_Cyrl/standard - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-fas_Arab/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-jpn_Jpan/ - # - /fsx/loubna/datasets/llama_tokenized/fw2-kor_Hang/standard - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hin_Deva/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-tha_Thai/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-vie_Latn/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-ell_Grek/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/infiwebmath-3plus/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/finemath-3plus/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Python/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Java/ - # - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-JavaScript/ - # - /fsx/loubna/datasets/llama_tokenized/kaggle/standard - # dataset_weights: - # - 0.307 - # - 0.307 - # - 0.024 - # - 0.002 - # - 0.018 - # - 0.018 - # - 0.018 - # - 0.012 - # - 0.012 - # - 0.013 - # - 0.012 - # - 0.003 - # - 0.0026 - # - 0.0026 - # - 0.0026 - # - 0.0026 - # - 0.0026 - # - 0.0026 - # - 0.02 - # - 0.02 - # - 0.069 - # - 0.069 - # - 0.059 - # - 0.003 + - /fsx/loubna/datasets/llama_tokenized/stackexchange/standard + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-fra_Latn/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-spa_Latn/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-deu_Latn/ + - /fsx/loubna/datasets/llama_tokenized/fw2-hq-ita_Latn/standard + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-por_Latn/ + - /fsx/loubna/datasets/llama_tokenized/fw2-hq-cmn_Hani/standard + - /fsx/loubna/datasets/llama_tokenized/fw2-hq-rus_Cyrl/standard + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-fas_Arab/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-jpn_Jpan/ + - /fsx/loubna/datasets/llama_tokenized/fw2-kor_Hang/standard + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hin_Deva/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-tha_Thai/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-vie_Latn/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/fw2-hq-ell_Grek/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/infiwebmath-3plus/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/finemath-3plus/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Python/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Java/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-JavaScript/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-C/ + - /fsx/loubna/datasets/llama_tokenized/stack-edu-Cpp/standard + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-C-Sharp/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-PHP/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-TypeScript/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Swift/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-SQL/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Ruby/ + - /fsx/loubna/datasets/llama_tokenized/stack-edu-Markdown/standard + - /fsx/loubna/datasets/llama_tokenized/stack-edu-HTML/standard + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Rust/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Go/ + - /fsx/loubna/datasets/llama_tokenized/other_sources/stack-edu-Shell/ + - /fsx/loubna/datasets/llama_tokenized/pull-requests/standard + - /fsx/loubna/datasets/llama_tokenized/kaggle/standard + - /fsx/loubna/datasets/llama_tokenized/jupyter-scripts/standard + - /fsx/loubna/datasets/llama_tokenized/github-issues/standard + dataset_weights: + - 0.333 + - 0.38 + - 0.02 + - 0.001 + - 0.004 + - 0.016 + - 0.02 + - 0.022 + - 0.0105 + - 0.01 + - 0.01 + - 0.01 + - 0.003 + - 0.00325 + - 0.00325 + - 0.00325 + - 0.00325 + - 0.00325 + - 0.00225 + - 0.008 + - 0.014 + - 0.022 + - 0.013 + - 0.013 + - 0.007 + - 0.016 + - 0.006 + - 0.006 + - 0.003 + - 0.001 + - 0.004 + - 0.0008 + - 0.005 + - 0.006 + - 0.0008 + - 0.0005 + - 0.0007 + - 0.006 + - 0.0005 + - 0.0055 + - 0.0032 token_size_in_bytes: 4 tokenizer_name: meta-llama/Llama-3.2-1B vocab_size: 128256 - num_loading_workers: 8 - seed: 42 - name: Training Stage + num_loading_workers: 0 + seed: 6 + name: training stage start_training_step: 1 general: - # benchmark_csv_path: benchmark.csv + benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: smollm3-benchmarks - run: qwen-3B-nn8-mbs3-tp2-not-fused + project: smollm3-training + run: smollm3-test-tps-48nn-elie-config seed: 6 step: null lighteval: null @@ -82,7 +115,7 @@ logging: log_level: info log_level_replica: info model: - ddp_bucket_cap_mb: 25 + ddp_bucket_cap_mb: 128 dtype: bfloat16 init_method: std: 0.02 @@ -105,26 +138,26 @@ model: moe_config: null num_attention_heads: 16 num_hidden_layers: 36 - num_key_value_heads: 2 + num_key_value_heads: 4 pad_token_id: null pretraining_tp: 2 rms_norm_eps: 1.0e-06 rope_interleaved: false rope_scaling: null - rope_theta: 10000.0 + rope_theta: 50000.0 sliding_window_size: null tie_word_embeddings: true use_cache: true vocab_size: 128256 - z_loss_coefficient: 0.0001 + z_loss_coefficient: 1.0e-05 z_loss_enabled: false optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 learning_rate_scheduler: learning_rate: 0.0002 - lr_decay_starting_step: 26000 - lr_decay_steps: 6000 + lr_decay_starting_step: 2600000 + lr_decay_steps: 600000 lr_decay_style: linear lr_warmup_steps: 2000 lr_warmup_style: linear @@ -135,12 +168,13 @@ optimizer: adam_eps: 1.0e-08 name: adamW torch_adam_is_fused: true - weight_decay: 0.01 - weight_decay_exclude_named_params: [] + weight_decay: 0.1 + weight_decay_exclude_named_params: + - .*token_embedding.* zero_stage: 0 parallelism: context_parallel_size: 1 - dp: 4 + dp: 1 expert_parallel_size: 1 moe_layer_recompute: false pp: 1 @@ -150,27 +184,22 @@ parallelism: tp_linear_async_communication: true tp_mode: REDUCE_SCATTER tp_recompute_allgather: true -# profiler: -# active: 1 -# export_chrome_trace: false -# profile_memory: false -# profiler_export_path: tb_logs -# record_shapes: false -# repeat: 1 -# skip_first: 3 -# wait: 1 -# warmup: 1 -# with_stack: true +profiler: null s3_upload: null + # remove_after_upload: true + # s5cmd_concurrency: 5 + # s5cmd_numworkers: 16 + # s5cmd_path: /fsx/elie_bakouch/smollm3_training/0304-begin-nanotron/cu124-0304/bin/s5cmd + # upload_s3_path: tests/smollm3-test-tps-48nn-elie-config tokenizer: - tokenizer_max_length: null + tokenizer_max_length: 4096 tokenizer_name_or_path: meta-llama/Llama-3.2-1B tokenizer_revision: null tokens: - batch_accumulation_per_replica: 8 + batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 0 micro_batch_size: 3 sequence_length: 4096 train_steps: 32000 - val_check_interval: -1 + val_check_interval: 100 diff --git a/run_train.py b/run_train.py index 8afe6e583..dbe5452c2 100644 --- a/run_train.py +++ b/run_train.py @@ -25,7 +25,6 @@ dummy_infinite_data_generator, get_train_dataloader, ) -from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.data.processing import ( clm_process, get_datasets, @@ -181,55 +180,106 @@ def get_dataloader_from_data_stage( # Case 3: Nanosets elif isinstance(data.dataset, NanosetDatasetsArgs): - # Create Nanoset - from nanotron.data.nanoset import Nanoset + log_rank("Using TokenizedBytes Dataloader", logger=logger, level=logging.INFO, rank=0) + from nanotron.data.tokenized_bytes import get_tb_dataloader, get_tb_datasets - with main_rank_first(trainer.parallel_context.world_pg): - tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - eos_token_id = tokenizer.eos_token_id - assert ( - eos_token_id is not None or data.dataset.return_positions is False - ), "Tokenizer must have an eos token if return_positions is True" - log_rank( - f"[Nanoset] Creating Nanoset with {len(data.dataset.dataset_folder)} dataset folders and {trainer.config.tokens.train_steps * trainer.global_batch_size} train samples", - logger=logger, - level=logging.INFO, - rank=0, - ) - start_time = time.time() - train_dataset = Nanoset( - dataset_folders=data.dataset.dataset_folder, - sequence_length=trainer.sequence_length, - token_size=data.dataset.token_size_in_bytes, - train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - dataset_weights=data.dataset.dataset_weights, - random_seed=data.seed, - return_positions=data.dataset.return_positions, - eos_token_id=eos_token_id, - ) - end_time = time.time() - log_rank( - f"[Nanoset] Time taken to create Nanoset: {time.strftime('%M:%S', time.gmtime(end_time - start_time))} (MM:SS)", - logger=logger, - level=logging.INFO, - rank=0, - ) - # Prepare dataloader - train_dataloader = build_nanoset_dataloader( - train_dataset, - trainer.sequence_length, + tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + assert ( + len(tokenizer) == trainer.model_config.vocab_size + ), f"Tokenizer vocab size ({len(tokenizer)}) does not match model config vocab size ({trainer.model_config.vocab_size}). " + log_rank( + f"[TokenizedBytes] Creating TokenizedBytes with {len(data.dataset.dataset_folder)} dataset folders and {trainer.config.tokens.train_steps * trainer.global_batch_size} train samples", + logger=logger, + level=logging.INFO, + rank=0, + ) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + start_time = time.time() + train_dataset, data_log = get_tb_datasets( + config=data.dataset, + global_batch_size=trainer.global_batch_size, + sequence_length=trainer.sequence_length, + train_steps=trainer.config.tokens.train_steps, + parallel_context=trainer.parallel_context, + shuffle=True, + seed=data.seed, + ) + train_dataloader = get_tb_dataloader( + dataset=train_dataset, + sequence_length=trainer.sequence_length, + micro_batch_size=trainer.micro_batch_size, + global_batch_size=trainer.global_batch_size, + num_workers=data.num_loading_workers, + cfg=data.dataset, + consumed_samples=consumed_train_samples, + num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, parallel_context=trainer.parallel_context, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, - micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=consumed_train_samples, - dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, - use_position_ids=isinstance(trainer.model_config, Qwen2Config), + ) + log_rank( + f"[TokenizedBytes] Time taken to create TokenizedBytes: {time.strftime('%M:%S', time.gmtime(time.time() - start_time))} (MM:SS)", + logger=logger, + level=logging.INFO, + rank=0, ) dist.barrier() + # Create Nanoset + # from nanotron.data.nanoset import Nanoset + + # with main_rank_first(trainer.parallel_context.world_pg): + # tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path + # tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + # eos_token_id = tokenizer.eos_token_id + # assert ( + # eos_token_id is not None or data.dataset.return_positions is False + # ), "Tokenizer must have an eos token if return_positions is True" + # log_rank( + # f"[Nanoset] Creating Nanoset with {len(data.dataset.dataset_folder)} dataset folders and {trainer.config.tokens.train_steps * trainer.global_batch_size} train samples", + # logger=logger, + # level=logging.INFO, + # rank=0, + # ) + # start_time = time.time() + # train_dataset = Nanoset( + # dataset_folders=data.dataset.dataset_folder, + # sequence_length=trainer.sequence_length, + # token_size=data.dataset.token_size_in_bytes, + # train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + # dataset_weights=data.dataset.dataset_weights, + # random_seed=data.seed, + # return_positions=data.dataset.return_positions, + # eos_token_id=eos_token_id, + # ) + # end_time = time.time() + # log_rank( + # f"[Nanoset] Time taken to create Nanoset: {time.strftime('%M:%S', time.gmtime(end_time - start_time))} (MM:SS)", + # logger=logger, + # level=logging.INFO, + # rank=0, + # ) + # # Prepare dataloader + # train_dataloader = build_nanoset_dataloader( + # train_dataset, + # trainer.sequence_length, + # parallel_context=trainer.parallel_context, + # input_pp_rank=input_pp_rank, + # output_pp_rank=output_pp_rank, + # micro_batch_size=trainer.micro_batch_size, + # consumed_train_samples=consumed_train_samples, + # dataloader_num_workers=data.num_loading_workers, + # dataloader_drop_last=True, + # use_position_ids=isinstance(trainer.model_config, Qwen2Config), + # use_doc_masking=False, + # dataloader_pin_memory=True, + # ) + # dist.barrier() + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") @@ -293,6 +343,10 @@ def get_args(): args = get_args() config_file = args.config_file + import torch + + print("Using allocator:", torch.cuda.get_allocator_backend()) + # Load trainer and data trainer = DistributedTrainer(config_file) dataloader = get_dataloader(trainer) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 693da2fc1..61e78a92f 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -151,6 +151,12 @@ class NanosetDatasetsArgs: token_size_in_bytes: Optional[int] = None return_positions: Optional[bool] = False + # Tokenized bytes dataset config + skip_in_stream: Optional[bool] = True + pad_samples_to_global_batch_size: Optional[bool] = True + dataloader_type: Optional[str] = "single" # single or cyclic + dataset_max_tokens: Optional[List[int]] = None + def __post_init__(self): if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder self.dataset_folder = [self.dataset_folder] @@ -442,7 +448,7 @@ def __post_init__(self): if self.s3_upload is not None: self.s3_upload.__post_init__() - + # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: total_profiling_steps = self.profiler.skip_first + self.profiler.repeat * ( diff --git a/src/nanotron/data/clm_collator.py b/src/nanotron/data/clm_collator.py index 1b59f0b66..e691447a2 100644 --- a/src/nanotron/data/clm_collator.py +++ b/src/nanotron/data/clm_collator.py @@ -47,9 +47,6 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni "label_mask": TensorPointer(group_rank=self.output_pp_rank), } - # Make sure we load only what's necessary, ie we only load a `input_ids` column. - assert all(list(example.keys()) == ["input_ids"] for example in examples) - # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? input_ids = vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) batch_size, expanded_input_length = input_ids.shape @@ -150,10 +147,7 @@ class DataCollatorForCLMWithPositionIds: input_pp_rank: int output_pp_rank: int parallel_context: ParallelContext - sequence_sep_tokens: List[ - int - ] # [tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token, tokenizer.unk_token] - # cumul_doc_lens: List[int] # Cumulative length of each datatrove_dataset in the Nanoset + use_doc_masking: bool = True def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Process the case when current rank doesn't require data @@ -207,7 +201,7 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni if current_pp_rank == self.input_pp_rank: result["input_ids"] = input_ids[:, :-1] - if "positions" in examples[0]: + if "positions" in examples[0] and self.use_doc_masking: # Use provided position_ids if available position_ids = np.vstack([examples[i]["positions"] for i in range(len(examples))]) # Simply drop the last position ID for each example @@ -230,7 +224,7 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni result["label_ids"] = input_ids[:, 1:] # Create label mask based on position_ids - if "positions" in examples[0]: + if "positions" in examples[0] and self.use_doc_masking: # Get position_ids for the labels (shifted right by 1 to align with label_ids) position_ids = np.vstack([examples[i]["positions"] for i in range(len(examples))]) position_ids = position_ids[:, 1:] # Shift right to align with labels diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 2e84adaab..b4ab15254 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -31,7 +31,7 @@ def build_nanoset_dataloader( dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, use_position_ids: bool = True, - sequence_sep_tokens: List[int] = None, + use_doc_masking: List[int] = True, ) -> DataLoader: # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job @@ -47,8 +47,7 @@ def build_nanoset_dataloader( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, parallel_context=parallel_context, - sequence_sep_tokens=sequence_sep_tokens, - # cumul_doc_lens=dataset.cumul_doc_lens, + use_doc_masking=use_doc_masking, ) else: data_collator = DataCollatorForCLM( @@ -80,7 +79,7 @@ def build_nanoset_dataloader( num_workers=dataloader_num_workers, pin_memory=dataloader_pin_memory, worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank), - pin_memory_device="cuda", + # pin_memory_device="cuda", persistent_workers=True if dataloader_num_workers > 0 else False, - prefetch_factor=dataloader_num_workers if dataloader_num_workers > 0 else None, + prefetch_factor=micro_batch_size * 2, ) diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index 05cfede10..658714d52 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -251,6 +251,10 @@ def build_nanoset_index_helper( # Initialize buffer for number of samples used for each dataset current_samples = np.zeros((len(weights),), dtype="long") + # TODO: Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network + # Iterate over all samples for sample_idx in range(n_samples): # Convert sample index to float for comparison against weights diff --git a/src/nanotron/data/nemo_dataset/Makefile b/src/nanotron/data/nemo_dataset/Makefile new file mode 100644 index 000000000..150939026 --- /dev/null +++ b/src/nanotron/data/nemo_dataset/Makefile @@ -0,0 +1,23 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color +CPPFLAGS += $(shell python3 -m pybind11 --includes) +LIBNAME = helpers +LIBEXT = $(shell python3-config --extension-suffix) + +default: $(LIBNAME)$(LIBEXT) + +%$(LIBEXT): %.cpp + $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/src/nanotron/data/nemo_dataset/__init__.py b/src/nanotron/data/nemo_dataset/__init__.py new file mode 100644 index 000000000..0d063c896 --- /dev/null +++ b/src/nanotron/data/nemo_dataset/__init__.py @@ -0,0 +1,940 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT style dataset.""" + +import os +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizerBase + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.logging import log_rank +from nanotron.parallel import ParallelContext +from nanotron.utils import main_rank_first + +from .blendable_dataset import BlendableDataset +from .dataset_utils import ( + get_datasets_weights_and_num_samples, + get_train_valid_test_split_, +) +from .indexed_dataset import MMapIndexedDataset, make_indexed_dataset + +logger = logging.get_logger(__name__) + + +def build_dataset( + cfg: Any, + tokenizer: PreTrainedTokenizerBase, + data_prefix: List[str], + num_samples: int, + seq_length: int, + seed: Any, + skip_warmup: bool, + name: str, + parallel_context: ParallelContext, +) -> Union["GPTDataset", BlendableDataset]: + def _build_dataset(current_data_prefix: str, current_num_samples: int) -> "GPTDataset": + indexed_dataset = get_indexed_dataset(current_data_prefix, skip_warmup) + total_num_of_documents = indexed_dataset.sizes.shape[0] + # Print stats about the splits. + + log_rank(" > dataset split:", logger=logger, level=logging.INFO, rank=0) + log_rank( + " Total {} documents is : {} ".format(name, total_num_of_documents), + logger=logger, + level=logging.INFO, + rank=0, + ) + drop_last = True + if name == "valid": + drop_last = cfg.validation_drop_last + dataset = GPTDataset( + cfg, + tokenizer, + name, + current_data_prefix, + np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32), + indexed_dataset, + current_num_samples, + seq_length, + seed, + parallel_context, + drop_last=drop_last, + ) + return dataset + + if len(data_prefix) == 1: + return _build_dataset(data_prefix[0], num_samples) + + else: + output = get_datasets_weights_and_num_samples(data_prefix, num_samples) + prefixes, weights, datasets_num_samples = output + datasets = [] + for i in range(len(prefixes)): + dataset = _build_dataset(prefixes[i], datasets_num_samples[i]) + datasets.append(dataset) + return BlendableDataset(datasets, weights, num_samples, parallel_context=parallel_context) + + +def build_train_valid_test_datasets( + cfg: Any, + tokenizer: PreTrainedTokenizerBase, + data_prefix: Union[Dict, List], + splits_string: str, + train_valid_test_num_samples: Tuple[int, int, int], + seq_length: int, + seed: Any, + parallel_context: ParallelContext, + skip_warmup: bool, +) -> Tuple[ + Union["GPTDataset", "BlendableDataset", None], + Union["GPTDataset", "BlendableDataset", None], + Union["GPTDataset", "BlendableDataset", None], +]: + if isinstance(data_prefix, dict): + assert ( + data_prefix.get("train") is not None + and data_prefix.get("test") is not None + and data_prefix.get("validation") is not None + ), f"Data prefix dictionary should have train, test and validation keys. data_prefix currently has only {data_prefix.keys()}" + if cfg.splits_string is not None: + log_rank( + cfg.splits_string + " ignored since data prefix is of type dictionary.", + logger=logger, + level=logging.WARNING, + rank=0, + ) + train_ds = build_dataset( + cfg, + tokenizer, + data_prefix["train"], + int(train_valid_test_num_samples[0]), + seq_length, + seed, + skip_warmup, + "train", + parallel_context, + ) + validation_ds = build_dataset( + cfg, + tokenizer, + data_prefix["validation"], + int(train_valid_test_num_samples[1]), + seq_length, + seed, + skip_warmup, + "valid", + parallel_context, + ) + test_ds = build_dataset( + cfg, + tokenizer, + data_prefix["test"], + int(train_valid_test_num_samples[2]), + seq_length, + seed, + skip_warmup, + "test", + parallel_context, + ) + return train_ds, validation_ds, test_ds + + else: + # No data + if len(data_prefix) == 0: + return (None, None, None), [] + # Single dataset. + if len(data_prefix) == 1: + return _build_train_valid_test_datasets( + cfg, + tokenizer, + data_prefix[0], + splits_string, + train_valid_test_num_samples, + seq_length, + seed, + parallel_context, + skip_warmup, + ) + + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets: List["GPTDataset"] = [] + valid_datasets: List["GPTDataset"] = [] + test_datasets: List["GPTDataset"] = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + cfg, + tokenizer, + prefixes[i], + splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, + parallel_context, + skip_warmup, + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + train_n, valid_n, test_n = map(sum, zip(*datasets_train_valid_test_num_samples)) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights, train_n, parallel_context) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_n, parallel_context) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights, test_n, parallel_context) + + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + + +def _build_train_valid_test_datasets( + cfg: Any, + tokenizer: PreTrainedTokenizerBase, + data_prefix: str, + splits_string: str, + train_valid_test_num_samples: int, + seq_length: int, + seed: Any, + parallel_context: ParallelContext, + skip_warmup: bool, +) -> Tuple["GPTDataset", Optional["GPTDataset"], Optional["GPTDataset"]]: + """Build train, valid, and test datasets.""" + + # Indexed dataset. + indexed_dataset = get_indexed_dataset(data_prefix, skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + log_rank(" > dataset split:", logger=logger, level=logging.INFO, rank=0) + + def print_split_stats(name, index): + log_rank(" {}:".format(name), logger=logger, level=logging.INFO, rank=0) + log_rank( + " document indices in [{}, {}) total of {} " + "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index]), + logger=logger, + level=logging.INFO, + rank=0, + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) + + def build_dataset(index: int, name: str) -> GPTDataset: + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) + drop_last = True + if name == "valid": + drop_last = cfg.validation_drop_last + dataset = GPTDataset( + cfg, + tokenizer, + name, + data_prefix, + documents, + indexed_dataset, + train_valid_test_num_samples[index], + seq_length, + seed, + parallel_context, + drop_last=drop_last, + ) + return dataset + + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset(data_prefix: str, skip_warmup: bool) -> MMapIndexedDataset: + """Build indexed dataset.""" + log_rank(" > building dataset index ...", logger=logger, level=logging.INFO, rank=0) + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, skip_warmup) + log_rank( + " > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time), + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( + " number of documents: {}".format(indexed_dataset.sizes.shape[0]), logger=logger, level=logging.INFO, rank=0 + ) + + return indexed_dataset + + +FIM_PREFIX = "" +FIM_MIDDLE = "" +FIM_SUFFIX = "" +FIM_PAD = "" +EOD = "<|endoftext|>" + + +class GPTDataset(Dataset): + def __init__( + self, + cfg: Any, + tokenizer: PreTrainedTokenizerBase, + name: str, + data_prefix: str, + documents: np.ndarray, + indexed_dataset: MMapIndexedDataset, + num_samples: int, + seq_length: int, + seed: int, + parallel_context: ParallelContext, + drop_last: bool = True, + ): + super().__init__() + self.name = name + self.indexed_dataset = indexed_dataset + self.drop_last = drop_last + self.seq_length = seq_length + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + self.eod_mask_loss = cfg.eod_mask_loss + self.no_seqlen_plus_one_input_tokens = cfg.no_seqlen_plus_one_input_tokens + self.add_extra_token = 1 + if self.no_seqlen_plus_one_input_tokens: + self.add_extra_token = 0 + + # save index mappings to a configurable dir + self.index_mapping_dir = cfg.index_mapping_dir + + # For FIM + self.fim_rate = cfg.fim_rate + self.fim_spm_rate = cfg.fim_spm_rate + if self.fim_rate > 1 or self.fim_rate < 0: + raise ValueError("FIM rate must be a probability 0 <= rate <= 1") + if self.fim_spm_rate > 1 or self.fim_spm_rate < 0: + raise ValueError("SPM rate must be a probability 0 <= rate <= 1") + self.tokenizer = tokenizer + self.suffix_tok_id, self.prefix_tok_id, self.middle_tok_id, self.pad_tok_id, self.eod_tok_id = ( + self.tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD, EOD] + ) + self.fim_split_sample = ( + self.tokenizer.vocab[cfg.fim_split_sample] if cfg.fim_split_sample is not None else None + ) + self.fragment_fim_rate = cfg.fragment_fim_rate + self.no_fim_prefix = cfg.no_fim_prefix + self.np_rng = np.random.RandomState(seed=seed) # rng state for FIM + + # create index_mapping_dir on rank 0 + if dist.is_available() and dist.is_initialized(): + with main_rank_first(parallel_context.world_pg): + if self.index_mapping_dir is not None and not os.path.isdir(self.index_mapping_dir): + os.makedirs(self.index_mapping_dir) + + # Build index mappings. + arrays, subset_log = _build_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + num_samples, + seq_length, + seed, + parallel_context, + index_mapping_dir=self.index_mapping_dir, + drop_last=drop_last, + add_extra_token=self.add_extra_token, + ) + self.doc_idx, self.sample_idx, self.shuffle_idx = arrays + self.indexed_dataset.deallocate_indexed_dataset_memory() + + self.subset_log = subset_log + + def __len__(self): + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def _get_text(self, idx: int) -> np.ndarray: + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f, offset_f = self.sample_idx[idx] + doc_index_l, offset_l = self.sample_idx[idx + 1] + # offset_f = self.sample_idx[idx][1] + # offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + if doc_index_f == doc_index_l: + sample = self.indexed_dataset.get( + self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + self.add_extra_token + ) + else: + # Otherwise, get the rest of the initial document. + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + sample_list.append( + self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + self.add_extra_token) + ) + sample = np.concatenate(sample_list) + if len(sample) != (self.seq_length + self.add_extra_token): + log_rank( + f" > WARNING: Got sample of length: {len(sample)} for sequence length={self.seq_length+self.add_extra_token}, padding the sample to match sequence length", + logger=logger, + level=logging.WARNING, + rank=0, + ) + + sample = np.array(sample, dtype=np.int64) + sample = np.pad( + sample, (0, self.seq_length + self.add_extra_token - len(sample)), mode="constant", constant_values=-1 + ) + + if self.fim_rate == 0: + return sample.astype(np.int64) + + # Code from: https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py#L109 + # TODO(Hailey): can merge the code below this line with code above this line. + # TODO(Hailey), cont: above already iterates through loop, so just add the permuting in there? + sample = np.array(sample, dtype=np.int64) + sample_len = sample.shape[0] + # # print(sample, sample.shape) + # # do FIM here, if enabled + # TODO: Do we handle the following point from FIM paper? + # To transform data in the character space for context-level FIM, the tokenized documents have to be decoded back into strings before FIM augmentation. Depending on the vocabulary, some care has to be given to ensure decoding does not introduce any spurious characters into training. For example, utf-8 characters are encoded as multiple tokens with a BPE vocabulary; they can result in fragments from chunking and fail to decode. To prevent unforeseen errors midway through training, we encourage checking for these fragments at the beginning or end of a context and removing them. + + segment_breaks = np.argwhere(sample == self.eod_tok_id) # split sample by document + + def fim_permute_sequence(sequence, rate): + return permute( + sequence, + self.np_rng, + rate, + self.fim_spm_rate, + self.tokenizer, + truncate_or_pad=False, + suffix_tok_id=self.suffix_tok_id, + prefix_tok_id=self.prefix_tok_id, + middle_tok_id=self.middle_tok_id, + pad_tok_id=self.pad_tok_id, + no_fim_prefix=self.no_fim_prefix, + ) + + def fim_split_and_permute_sequence(sequence): + """ + If self.fim_split_sample is not None, split the sequence. + Then apply FIM on the fragments, or the whole sequence if self.fim_split_sample is None. + """ + if self.fim_split_sample is None: + return fim_permute_sequence(sequence, self.fim_rate) + # fim_split_sample is set: split the sample on this token and permute each fragment separately. + # Typically, if each sample is a repository, then we split again on the file level. + # Each fragment is a file, and we permute the files. + fragment_breaks = np.argwhere(sequence == self.fim_split_sample) + if fragment_breaks.shape == (0, 1): + # no split token in this sample + return fim_permute_sequence(sequence, self.fim_rate) + if not self.np_rng.binomial(1, self.fim_rate): + # don't do FIM preproc + return sequence + # Do FIM on each fragment + curr_start_position = 0 + new_samples = [] + for loc in np.nditer(fragment_breaks): + if loc - curr_start_position > 0: + permuted = fim_permute_sequence(sequence[curr_start_position:loc], self.fragment_fim_rate) + new_samples += [permuted, [self.fim_split_sample]] + curr_start_position = loc + 1 # Jump over the split token + # Permute the segment after the last split token + permuted = fim_permute_sequence(sequence[curr_start_position:], self.fragment_fim_rate) + new_samples.append(permuted) + return np.concatenate(new_samples) + + if segment_breaks.shape != (0, 1): # then there is an EOD token in this example + curr_start_position = 0 + new_samples = [] + for loc in np.nditer(segment_breaks): + # Only permute non-empty segments. + if loc - curr_start_position > 0: + # permute {prefix, suffix, middle} or {suffix, prefix, middle} + permuted = fim_split_and_permute_sequence(sample[curr_start_position:loc]) + new_samples += [permuted, [self.eod_tok_id]] + + curr_start_position = loc + 1 # jump over the EOD token + # Permute the segment after the last EOD + permuted = fim_split_and_permute_sequence(sample[curr_start_position:]) + new_samples.append(permuted) + + sample = np.concatenate(new_samples) + else: + sample = fim_split_and_permute_sequence(sample) + + # Truncate or pad sequence to max-length + diff = sample.shape[0] - sample_len + if diff > 0: # too long + sample = sample[:sample_len] + elif diff < 0: # too short + sample = np.concatenate([sample, np.full((-1 * diff), self.pad_tok_id)]) + + assert sample.shape[0] == sample_len + # end FIM-specific code + return sample.astype(np.int64) + + def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: + text = self._get_text(idx) + return {"input_ids": text} + + +@dataclass +class SubsetSplitLog: + name: str + data_prefix: str + doc_idx_filename: str + sample_idx_filename: str + shuffle_idx_filename: str + tokens_per_epoch: int + num_epochs: int + num_samples: int + seq_length: int + + +def _build_index_mappings( + name: str, + data_prefix: str, + documents: np.ndarray, + sizes: np.ndarray, + num_samples: int, + seq_length: int, + seed: Any, + parallel_context: ParallelContext, + index_mapping_dir: str = None, + drop_last: bool = True, + add_extra_token: int = 1, +) -> Tuple[Tuple[np.ndarray, np.ndarray, np.ndarray], SubsetSplitLog]: + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples, add_extra_token) + + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + if index_mapping_dir is not None: + _filename = os.path.join(index_mapping_dir, os.path.basename(data_prefix)) + else: + _filename = data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" + + # Build the indexed mapping if not exist. + with main_rank_first(parallel_context.world_pg): + if ( + (not os.path.isfile(doc_idx_filename)) + or (not os.path.isfile(sample_idx_filename)) + or (not os.path.isfile(shuffle_idx_filename)) + ): + log_rank( + " > WARNING: could not find index map files, building " "the indices on rank 0 ...", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + log_rank( + " > only one epoch required, setting " "separate_last_epoch to False", + logger=logger, + level=logging.INFO, + rank=0, + ) + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ( + (num_epochs - 1) * tokens_per_epoch - add_extra_token + ) // seq_length + last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, "last epoch number of samples should be non-negative." + num_samples_per_epoch = (tokens_per_epoch - add_extra_token) // seq_length + # For very small datasets, `last_epoch_num_samples` can be equal to + # (num_samples_per_epoch + 1). + # TODO: check that this is not problematic indeed + # https://github.com/bigcode-project/Megatron-LM/commit/3a6286ba11181899cccfb11d2e508eca9fd15bea + assert last_epoch_num_samples <= ( + num_samples_per_epoch + 1 + ), "last epoch number of samples exceeded max value." + # If we have less than 80% of the samples for the last epoch, + # separate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = last_epoch_num_samples < int(0.80 * num_samples_per_epoch) + if separate_last_epoch: + string = ( + " > last epoch number of samples ({}) is smaller " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to True" + ) + else: + string = ( + " > last epoch number of samples ({}) is larger " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to False" + ) + log_rank( + string.format(last_epoch_num_samples, num_samples_per_epoch), + logger=logger, + level=logging.INFO, + rank=0, + ) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + log_rank( + " > elasped time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time), + logger=logger, + level=logging.INFO, + rank=0, + ) + + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + + try: + from . import helpers + except ImportError: + try: + from .dataset_utils import compile_helper + + compile_helper() + from . import helpers + except ImportError: + raise ImportError( + "Could not compile megatron dataset C++ helper functions and therefore cannot import helpers python file." + ) + + sample_idx = helpers.build_sample_idx( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch, drop_last, add_extra_token + ) + # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, + # num_epochs, tokens_per_epoch, drop_last, add_extra_token) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + log_rank( + " > elasped time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time), + logger=logger, + level=logging.INFO, + rank=0, + ) + + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + log_rank( + " > elasped time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time), + logger=logger, + level=logging.INFO, + rank=0, + ) + # counts = torch.cuda.LongTensor([1]) + # dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=parallel_context.dp_pg) + # dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=parallel_context.pp_pg) + # assert counts[0].item() == ( + # dist.get_world_size() + # // dist.get_world_size(group=parallel_context.tp_pg) + # ) + + # Load mappings. + start_time = time.time() + log_rank(" > loading doc-idx mapping from {}".format(doc_idx_filename), logger=logger, level=logging.INFO, rank=0) + doc_idx: np.ndarray = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + log_rank( + " > loading sample-idx mapping from {}".format(sample_idx_filename), logger=logger, level=logging.INFO, rank=0 + ) + sample_idx: np.ndarray = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") + log_rank( + " > loading shuffle-idx mapping from {}".format(shuffle_idx_filename), + logger=logger, + level=logging.INFO, + rank=0, + ) + shuffle_idx: np.ndarray = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + log_rank( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time), + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank(" total number of samples: {}".format(sample_idx.shape[0]), logger=logger, level=logging.INFO, rank=0) + log_rank(" total number of epochs: {}".format(num_epochs), logger=logger, level=logging.INFO, rank=0) + + subset_log = SubsetSplitLog( + name=name, + data_prefix=data_prefix, + doc_idx_filename=doc_idx_filename, + sample_idx_filename=sample_idx_filename, + shuffle_idx_filename=shuffle_idx_filename, + tokens_per_epoch=tokens_per_epoch, + num_epochs=num_epochs, + num_samples=sample_idx.shape[0], + seq_length=seq_length, + ) + + return (doc_idx, sample_idx, shuffle_idx), subset_log + + +def _num_tokens(documents: np.ndarray, sizes: np.ndarray) -> int: + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch: int, seq_length: int, num_samples: int, add_extra_token: int = 1) -> int: + """Based on number of samples and sequence length, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - add_extra_token) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents: np.ndarray, num_epochs: int, np_rng: Any, separate_last_epoch: bool) -> np.ndarray: + """Build an array with length = number-of-epochs * number-of-dcuments. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) + + +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch, drop_last=True, add_extra_token=1): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + if not drop_last: + num_samples = -(-(num_epochs * tokens_per_epoch - add_extra_token) // seq_length) + else: + num_samples = (num_epochs * tokens_per_epoch - add_extra_token) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Beginning offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + add_extra_token + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += remaining_seq_length + doc_length - add_extra_token + remaining_seq_length = 0 + else: + # Otherwise, start from the beginning of the next document. + if doc_idx_index == (len(doc_idx) - 1): + assert ( + sample_index == num_samples + ), f"sample_index={sample_index} and num_samples={num_samples} should be the same" + doc_offset = sizes[doc_idx[doc_idx_index]] - add_extra_token + break + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(num_samples: int, total_size: int, np_rng: Any) -> np.ndarray: + """Build the range [0, size) and shuffle.""" + logger.info( + " > building shuffle index with split [0, {}) and [{}, {}) " + "...".format(num_samples, num_samples, total_size), + ) + + dtype_ = np.uint32 + if total_size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + + shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = np.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_last) + + return np.concatenate((shuffle_idx_first, shuffle_idx_last)) + + +# From https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py#L339 +def permute( + sample: np.ndarray, + np_rng: np.random.Generator, + fim_rate: float, + fim_spm_rate: float, + tokenizer: PreTrainedTokenizerBase, + truncate_or_pad: bool = True, + suffix_tok_id: int = None, + prefix_tok_id: int = None, + middle_tok_id: int = None, + pad_tok_id: int = None, + no_fim_prefix: str = None, +): + """ + Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it. + Maintain the same sample length (if transform creates a few extra tokens, drop them). + """ + if np_rng.binomial(1, fim_rate): # sample bernoulli dist + contents = tokenizer.decode(sample) + + # Do not apply FIM if the sample starts with no_fim_prefix + if no_fim_prefix is not None and contents.startswith(no_fim_prefix): + return sample + + try: + # A boundary can be =0 (prefix will be empty) + # a boundary can be =len(contents) (suffix will be empty) + # The two boundaries can be equal (middle will be empty) + boundaries = list(np_rng.randint(low=0, high=len(contents) + 1, size=2)) + boundaries.sort() + except ValueError as e: + print(len(contents), contents) + print(e) + raise e + + prefix = contents[: boundaries[0]] + middle = contents[boundaries[0] : boundaries[1]] + suffix = contents[boundaries[1] :] + + prefix = tokenizer.encode(prefix, return_tensors="np").squeeze(axis=0) + middle = tokenizer.encode(middle, return_tensors="np").squeeze(axis=0) + suffix = tokenizer.encode(suffix, return_tensors="np").squeeze(axis=0) + + # here we truncate each given segment to fit the same length as it was before + # A consequence is that we never reach the end of a file? + # we should rather truncate at the context-level + if truncate_or_pad: + # need to make same length as the input. Take the 3 sentinel tokens into account + new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3 + diff = new_length - sample.shape[0] + if diff > 0: # too long + if ( + suffix.shape[0] <= diff + ): # if there's no space to truncate the suffix: stop and report it. atm i should have stopped this from happening + return sample, np_rng + suffix = suffix[: suffix.shape[0] - diff] + elif diff < 0: # too short + suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)]) + + if np_rng.binomial(1, fim_spm_rate): + # SPM (variant 2 from FIM paper) + new_sample = np.concatenate([[prefix_tok_id, suffix_tok_id], suffix, [middle_tok_id], prefix, middle]) + else: + # PSM + new_sample = np.concatenate([[prefix_tok_id], prefix, [suffix_tok_id], suffix, [middle_tok_id], middle]) + + else: + # don't do FIM preproc + new_sample = sample + + return new_sample diff --git a/src/nanotron/data/nemo_dataset/blendable_dataset.py b/src/nanotron/data/nemo_dataset/blendable_dataset.py new file mode 100644 index 000000000..aa06cda9d --- /dev/null +++ b/src/nanotron/data/nemo_dataset/blendable_dataset.py @@ -0,0 +1,210 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blendable dataset.""" + +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, List + +import numpy as np +import torch + +from nanotron import logging +from nanotron.logging import log_rank +from nanotron.parallel import ParallelContext +from nanotron.utils import main_rank_first + +if TYPE_CHECKING: + from . import GPTDataset, SubsetSplitLog + +logger = logging.get_logger(__name__) + + +@dataclass +class BlendedSubsetSplitLog: + blended_total_num_samples: int + blended_per_subset_samples: List[int] + blended_subset: List["SubsetSplitLog"] + + +class BlendableDataset(torch.utils.data.Dataset): + def __init__( + self, datasets: List["GPTDataset"], weights: List[float], size: int, parallel_context: ParallelContext + ): + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = size + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indices. + start_time = time.time() + # from https://github.com/NVIDIA/Megatron-LM/commit/c6e65b2e96e8376ccc84225dd1a9b60dd242fc48 + assert num_datasets < 32767 + self.dataset_index = np.zeros(self.size, dtype=np.int16) + self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + self.dataset_num_samples = np.zeros(num_datasets, dtype=np.int64) + + with main_rank_first(parallel_context.world_pg): + try: + from . import helpers + except ImportError: + try: + from .dataset_utils import compile_helper + + compile_helper() + from . import helpers + except ImportError: + raise ImportError( + "Could not compile megatron dataset C++ helper functions and therefore cannot import helpers python file." + ) + + helpers.build_blending_indices( + self.dataset_index, + self.dataset_sample_index, + self.dataset_num_samples, + weights, + num_datasets, + self.size, + torch.distributed.get_rank() == 0, + ) + + log_rank( + "> elapsed time for building blendable dataset indices: " "{:.2f} (sec)".format(time.time() - start_time), + logger=logger, + level=logging.INFO, + rank=0, + ) + + self.subset_log = BlendedSubsetSplitLog( + blended_total_num_samples=self.size, + blended_per_subset_samples=self.dataset_num_samples.tolist(), + blended_subset=[d.subset_log for d in self.datasets], + ) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return self.datasets[dataset_idx][sample_idx] + + +class MemoryEfficientBlendableDataset(torch.utils.data.Dataset): + """ + A BlendableDataset implementation that uses less memory than the original implementation. + Indices are computed algorithmically instead of storing them in memory. + + To test call: MemoryEfficientBlendableDataset.test_index_blending() + """ + + def __init__(self, datasets, weights, size, weight_bins=100): + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + weight_bins = min(weight_bins, size) + + self.size = size + self.weight_bins = weight_bins + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + assert (weights > 0.0).all() + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + self.weights = weights / sum_weights + + # create ds index based on weights + ds_index = [] + ds_bias = [] + for i, w in enumerate(self.weights): + n = int(w * weight_bins) + ds_index.extend([i] * n) + ds_bias.extend(range(n)) + # make sure arrays have length of weight_bins + n = weight_bins - len(ds_index) + ds_index.extend([i] * n) + ds_bias.extend(range(ds_bias[-1], ds_bias[-1] + n)) + + self.ds_index = np.array(ds_index, dtype=np.uint32) + self.ds_index_size = np.array([(self.ds_index == i).sum() for i in range(num_datasets)], dtype=np.uint32) + assert ( + self.ds_index_size > 0 + ).all(), f"Some datasets have no samples in the blendable dataset, increase weight_bins or the offending weight. ds_index_size = {self.ds_index_size}" + self.ds_bias = np.array(ds_bias, dtype=np.uint32) + + self.ds_size = np.array([len(ds) for ds in datasets], dtype=np.uint32) + + def get_ds_sample_idx(self, idx): + """Returns ds index and sample index (within the ds) for the given index in the blendable dataset.""" + + bin = idx % self.weight_bins + ds_idx = self.ds_index[bin] + sample_idx = (self.ds_bias[bin] + (idx // self.weight_bins) * self.ds_index_size[ds_idx]) % self.ds_size[ + ds_idx + ] + + return ds_idx, sample_idx + + def __len__(self): + return self.size + + def __getitem__(self, idx): + ds_idx, sample_idx = self.get_ds_sample_idx(idx) + + return self.datasets[ds_idx][sample_idx] + + @classmethod + def test_index_blending(cls): + """Visualize indices of blended dataset""" + + import matplotlib.pyplot as plt + + plt.ion() + + class DS(torch.utils.data.Dataset): + def __init__(self, size, data): + self.size = size + self.data = data + + def __len__(self): + return self.size + + def __getitem__(self, idx): + return self.data[idx] + + for weight_bins in [10, 100]: + blend_ds = MemoryEfficientBlendableDataset( + [DS(10, "a"), DS(10, "b"), DS(10, "c")], [0.5, 0.3, 0.2], 50, weight_bins=weight_bins + ) + + ds_sample_idx_list = [blend_ds.get_ds_sample_idx(i) for i in range(50)] + ds_list = list(zip(*ds_sample_idx_list))[0] + sample_list = list(zip(*ds_sample_idx_list))[1] + + plt.figure() + plt.plot(ds_list, label="ds idx") + plt.plot(sample_list, label="sample") + plt.legend() + plt.grid() + plt.title(f"weight_bins={weight_bins}") diff --git a/src/nanotron/data/nemo_dataset/dataset_utils.py b/src/nanotron/data/nemo_dataset/dataset_utils.py new file mode 100644 index 000000000..f99e63d42 --- /dev/null +++ b/src/nanotron/data/nemo_dataset/dataset_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +import subprocess +from typing import List, Tuple, Union + +from nanotron import logging +from nanotron.logging import log_rank + +logger = logging.get_logger(__name__) + + +def compile_helper(): + """Compile helper function ar runtime. Make sure this + is invoked on a single process.""" + + path = os.path.abspath(os.path.dirname(__file__)) + ret = subprocess.run(["make", "-C", path]) + if ret.returncode != 0: + log_rank("Making C++ dataset helpers module failed, exiting.", logger=logger, level=logging.ERROR, rank=0) + import sys + + sys.exit(1) + + +def get_datasets_weights_and_num_samples( + data_prefix: List[str], num_samples: Union[int, List[int]] +) -> Tuple[List[str], List[float], List[int]]: + """Return tuple of: + - list of prefixes + - list of associated normalized weights + - list of associated number of samples from the total num_samples + """ + if len(data_prefix) % 2 != 0: + raise ValueError( + "The data prefix should be in the format of: weight-1, data-prefix-1, weight-2, data-prefix-2, .." + ) + num_datasets = len(data_prefix) // 2 + weights = [0] * num_datasets + prefixes = [0] * num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + if weight_sum <= 0.0: + raise ValueError("Total sum of the weights should be > 0") + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the bleding dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + # TODO: check data leakage between train/val/test? + datasets_train_valid_test_num_samples = [] + for weight in weights: + # Comes here when we have separate train,test and validation datasets. + if isinstance(num_samples, int): + datasets_train_valid_test_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) + else: + datasets_train_valid_test_num_samples.append([int(math.ceil(val * weight * 1.005)) for val in num_samples]) + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def get_train_valid_test_split_(splits_string: str, size: int) -> Tuple[int]: + """Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] + else: + splits = [float(splits_string)] + if len(splits) != 3: + raise ValueError(f"Invalid splits string: {splits_string}. Expected 3 comma separated values.") + while len(splits) < 3: + splits.append(0.0) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index diff --git a/src/nanotron/data/nemo_dataset/helpers.cpp b/src/nanotron/data/nemo_dataset/helpers.cpp new file mode 100644 index 000000000..dbbfdd9d8 --- /dev/null +++ b/src/nanotron/data/nemo_dataset/helpers.cpp @@ -0,0 +1,730 @@ +/* +Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + py::array_t& dataset_num_samples, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those weights.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto dataset_num_samples_ptr = dataset_num_samples.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + // int64_t current_samples[num_datasets]; + // for(int64_t i = 0; i < num_datasets; ++i) { + // current_samples[i] = 0; + // } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(dataset_num_samples_ptr[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(dataset_num_samples_ptr[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = dataset_num_samples_ptr[max_error_index]; + + // Update the total samples. + dataset_num_samples_ptr[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << size << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(dataset_num_samples_ptr[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch, + const bool drop_last = true, + const int add_extra_token = 1) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = 0; + if (drop_last == false) { + num_samples = ceil(float(num_epochs * tokens_per_epoch - add_extra_token) / seq_length); + } else { + num_samples = (num_epochs * tokens_per_epoch - add_extra_token) / seq_length; + } + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Beginning offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + add_extra_token; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - add_extra_token); + remaining_seq_length = 0; + } else { + // Otherwise, start from the beginning of the next document. + if (doc_idx_index == (doc_idx_.shape(0) - 1)) { + assert(sample_index == num_samples); + doc_offset = sizes[doc_idx[doc_idx_index]] - add_extra_token; + break; + } + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the beginning of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the beginning of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/src/nanotron/data/nemo_dataset/indexed_dataset.py b/src/nanotron/data/nemo_dataset/indexed_dataset.py new file mode 100644 index 000000000..e785b308e --- /dev/null +++ b/src/nanotron/data/nemo_dataset/indexed_dataset.py @@ -0,0 +1,344 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Most of the code here has been copied from: +# fairseq/fairseq/data/indexed_dataset.py + +# with some modifications: + +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +import os +import shutil +import struct +from functools import lru_cache +from itertools import accumulate +from typing import Tuple + +import numpy as np +import torch + +from nanotron import logging +from nanotron.logging import log_rank + +logger = logging.get_logger(__name__) +# logging.getLogger('botocore').setLevel(logging.WARNING) + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def make_builder(out_file, vocab_size=None, chunk_size=64, pad_id=0, retrieval_db=False, stride=64): + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) + + +def deallocate_indexed_dataset_memory(indexed_dataset): + """Deallocate memory of an IndexedDataset.""" + if isinstance(indexed_dataset, MMapIndexedDataset): + # for MMapIndexedDataset we cannot release any memory of sizes + indexed_dataset._index._doc_idx = None + else: + indexed_dataset.sizes = None + indexed_dataset.doc_idx = None + + +def make_indexed_dataset(path: str, skip_warmup: bool = False): + # now handle bin memap + if not MMapIndexedDataset.exists(path): + raise ValueError( + f"Dataset does not exist: {path}" + f"Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) + if MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + raise ValueError("MMapIndexedDataset doesn't exist") + + +dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float64, 7: np.double, 8: np.uint16} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path: str) -> str: + return prefix_path + ".idx" + + +def data_file_path(prefix_path: str) -> str: + return prefix_path + ".bin" + + +def _warmup_mmap_file(path: str): + with open(path, "rb") as stream: + while stream.read(100 * 1024 * 1024): + pass + + +class MMapIndexedDataset(torch.utils.data.Dataset): + class Index(object): + _HDR_MAGIC = b"MMIDIDX\x00\x00" + + @classmethod + def writer(cls, path, dtype): + class _Writer(object): + def __enter__(self): + self._file = open(path, "wb") + + self._file.write(cls._HDR_MAGIC) + self._file.write(struct.pack(" Tuple[int, int]: + return self._pointers[i], self._sizes[i] + + def __len__(self): + return self._len + + def __init__(self, path: str, skip_warmup: bool = False): + super().__init__() + + self._path = None + self._index = None + self._bin_buffer = None + + self._do_init(path, skip_warmup) + + def __getstate__(self): + return self._path + + # def __setstate__(self, state): + # self._do_init(state) + + def _do_init(self, path: str, skip_warmup: bool): + self._path = path + self._index = self.Index(index_file_path(self._path), skip_warmup) + + if not skip_warmup: + log_rank(" warming up data mmap file...", logger=logger, level=logging.INFO, rank=0) + _warmup_mmap_file(data_file_path(self._path)) + log_rank(" creating numpy buffer of mmap...", logger=logger, level=logging.INFO, rank=0) + self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C") + log_rank(" creating memory view of numpy buffer...", logger=logger, level=logging.INFO, rank=0) + self._bin_buffer = memoryview(self._bin_buffer_mmap) + + def __del__(self): + if self._bin_buffer_mmap is not None: + self._bin_buffer_mmap._mmap.close() + del self._bin_buffer_mmap + del self._index + + def __len__(self): + return len(self._index) + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + ptr, size = self._index[idx] + np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) + return np_array + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + ptr = self._index._pointers[start] + sizes = self._index._sizes[idx] + offsets = list(accumulate(sizes)) + total_size = sum(sizes) + np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr) + sents = np.split(np_array, offsets[:-1]) + return sents + + def get(self, idx, offset=0, length=None): + """Retrieves a single item from the dataset with the option to only + return a portion of the item. + + get(idx) is the same as [idx] but get() does not support slicing. + """ + ptr, size = self._index[idx] + if length is None: + length = size - offset + ptr += offset * np.dtype(self._index.dtype).itemsize + np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr) + return np_array + + @property + def sizes(self): + return self._index.sizes + + @property + def doc_idx(self): + return self._index.doc_idx + + def get_doc_idx(self): + return self._index._doc_idx + + def set_doc_idx(self, doc_idx_): + self._index._doc_idx = doc_idx_ + + @property + def supports_prefetch(self): + return False + + @staticmethod + def exists(path): + logger.debug( + f"Checking file path: {path}, index_file_path: {index_file_path(path)}, data_file_path: {data_file_path(path)}" + ) + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + + def deallocate_indexed_dataset_memory(self): + """Deallocate memory of an IndexedDataset.""" + self._index._doc_idx = None + + +class MMapIndexedDatasetBuilder(object): + def __init__(self, out_file, dtype=np.int64): + self._data_file = open(out_file, "wb") + self._dtype = dtype + self._sizes = [] + self._doc_idx = [0] + + def add_item(self, tensor): + np_array = np.array(tensor.numpy(), dtype=self._dtype) + self._data_file.write(np_array.tobytes(order="C")) + self._sizes.append(np_array.size) + + def end_document(self): + self._doc_idx.append(len(self._sizes)) + + def merge_file_(self, another_file): + # Concatenate index + index = MMapIndexedDataset.Index(index_file_path(another_file)) + assert index.dtype == self._dtype + + for size in index.sizes: + self._sizes.append(size) + + # Concatenate data + with open(data_file_path(another_file), "rb") as f: + shutil.copyfileobj(f, self._data_file) + + def finalize(self, index_file): + self._data_file.close() + + with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: + index.write(self._sizes, self._doc_idx) diff --git a/src/nanotron/data/s3_utils.py b/src/nanotron/data/s3_utils.py new file mode 100644 index 000000000..075154796 --- /dev/null +++ b/src/nanotron/data/s3_utils.py @@ -0,0 +1,84 @@ +import os +import re +from collections import deque +from re import Pattern +from typing import Union + +from datasets.download.streaming_download_manager import xPath + +try: + import boto3 + + BOTO3_AVAILABLE = True +except ImportError: + BOTO3_AVAILABLE = False + + +# mostly borrowed from datatrove: https://github.com/huggingface/datatrove/blob/main/src/datatrove/io/cloud/s3.py + + +def _get_s3_path_components(s3_path: Union[str, xPath]): + s3_path = str(s3_path) + bucket_name, _, prefix = s3_path[len("s3://") :].replace("//", "/").partition(os.sep) + return bucket_name, prefix + + +def _get_s3_object(s3_path: Union[str, xPath]): + s3_path = str(s3_path) + bucket_name, prefix = _get_s3_path_components(s3_path) + s3_resource = boto3.resource("s3") + s3_object = s3_resource.Object(bucket_name=bucket_name, key=prefix) + return s3_object + + +def _stream_file(file_path: Union[str, xPath], chunk_size, offset): + file_path = str(file_path) + if file_path.startswith("s3://"): + s3_object = _get_s3_object(file_path) + yield from s3_object.get(Range=f"bytes={chunk_size * offset}-")["Body"].iter_chunks(chunk_size) + else: + with open(file_path, "rb") as f: + f.seek(chunk_size * offset) + while True: + chunk = f.read(chunk_size) # Read a chunk of the specified size + if not chunk: + break # If the chunk is empty, end of file is reached + yield chunk + + +def _get_s3_file_list( + s3_path: Union[str, xPath], pattern: Union[str, Pattern] = None, recursive: bool = True, max_recursion: int = -1 +): + """Get list of relative paths to files in a cloud folder with a given (optional) pattern + + Args: + s3_path: path to the cloud folder (e.g. s3://bucket/prefix) + pattern: optional pattern to filter files (str or re.Pattern) + recursive: whether to recursively search for files (default: True) + max_recursion: how many levels to recursively search for files (-1 means no limit) + """ + s3_path = str(s3_path) + + if isinstance(pattern, str): + pattern = re.compile(pattern) + + s3_client = boto3.client("s3") + bucket, main_prefix = _get_s3_path_components(s3_path) + + paginator = s3_client.get_paginator("list_objects_v2") + objects = [] + prefixes = deque() + + prefixes.append((0, main_prefix)) + while prefixes: + level, prefix = prefixes.popleft() + for resp in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"): + if recursive and (max_recursion == -1 or level < max_recursion): + prefixes.extend([(level + 1, next_prefix["Prefix"]) for next_prefix in resp.get("CommonPrefixes", [])]) + filtered_objects = [x for x in resp.get("Contents", []) if x["Key"] != prefix] + if pattern is not None: + filtered_objects = [ + x for x in filtered_objects if pattern.fullmatch(os.path.relpath(x["Key"], main_prefix)) + ] + objects.extend([f"s3://{bucket}/{x['Key']}" for x in filtered_objects]) + return sorted(objects) diff --git a/src/nanotron/data/samplers.py b/src/nanotron/data/samplers.py index 5c4f2673e..da2257f96 100644 --- a/src/nanotron/data/samplers.py +++ b/src/nanotron/data/samplers.py @@ -1,3 +1,19 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass from typing import Optional, Union import datasets @@ -7,6 +23,11 @@ from torch.utils.data.distributed import DistributedSampler from transformers.trainer_pt_utils import DistributedSamplerWithLoop +from nanotron import logging +from nanotron.logging import log_rank + +logger = logging.get_logger(__name__) + class SkipBatchSampler(BatchSampler): """ @@ -109,3 +130,237 @@ def __getitem__(self, item) -> dict: def __len__(self) -> int: return self._length + + +@dataclass +class BaseMegatronSampler: + total_samples: int + consumed_samples: int + micro_batch_size: int + data_parallel_rank: int + data_parallel_size: int + global_batch_size: int + drop_last: bool = True + pad_samples_to_global_batch_size: Optional[bool] = False + + def __post_init__(self): + self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size + + # Sanity checks. + if self.total_samples <= 0: + raise RuntimeError("no sample to consume: {}".format(self.total_samples)) + if self.consumed_samples >= self.total_samples: + raise RuntimeError("no samples left to consume: {}, {}".format(self.consumed_samples, self.total_samples)) + if self.micro_batch_size <= 0: + raise RuntimeError(f"micro_batch_size size must be greater than 0, but {self.micro_batch_size}") + if self.data_parallel_size <= 0: + raise RuntimeError(f"data parallel size must be greater than 0, but {self.data_parallel_size}") + if self.data_parallel_rank >= self.data_parallel_size: + raise RuntimeError( + "data_parallel_rank should be smaller than data size, but {} >= {}".format( + self.data_parallel_rank, self.data_parallel_size + ) + ) + if self.global_batch_size % (self.micro_batch_size * self.data_parallel_size) != 0: + raise RuntimeError( + f"`global_batch_size` ({self.global_batch_size}) is not divisible by " + f"`micro_batch_size ({self.micro_batch_size}) x data_parallel_size " + f"({self.data_parallel_size})`" + ) + if self.pad_samples_to_global_batch_size and self.global_batch_size is None: + raise RuntimeError( + "`pad_samples_to_global_batch_size` can be `True` only when " + "`global_batch_size` is set to an integer value" + ) + log_rank( + f"Instantiating MegatronPretrainingSampler with total_samples: {self.total_samples} and consumed_samples: {self.consumed_samples}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + @abc.abstractmethod + def __iter__(self): + ... + + +@dataclass +class MegatronPretrainingSampler(BaseMegatronSampler): + def get_start_end_idx(self): + start_idx = self.data_parallel_rank * self.micro_batch_size + end_idx = start_idx + self.micro_batch_size + return start_idx, end_idx + + def __len__(self): + num_available_samples: int = self.total_samples - self.consumed_samples + if self.global_batch_size is not None: + if self.drop_last: + return num_available_samples // self.global_batch_size + else: + return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + else: + if self.drop_last: + return num_available_samples // self.micro_batch_times_data_parallel_size + else: + return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1 + + def __iter__(self): + batch = [] + batch_idx = 0 + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size: + start_idx, end_idx = self.get_start_end_idx() + log_rank( + f"DLrank {self.data_parallel_rank} batch {batch_idx} {batch[start_idx:end_idx]} self.consumed_samples {self.consumed_samples}", + logger=logger, + level=logging.DEBUG, + ) + yield batch[start_idx:end_idx] + batch = [] + batch_idx += 1 + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + if self.pad_samples_to_global_batch_size: + for i in range( + self.data_parallel_rank, self.global_batch_size, self.micro_batch_times_data_parallel_size + ): + indices = [batch[j] for j in range(i, max(len(batch), i + self.micro_batch_size))] + num_pad = self.micro_batch_size - len(indices) + indices = indices + [-1] * num_pad + yield indices + else: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + +@dataclass +class MegatronPretrainingRandomSampler(BaseMegatronSampler): + def __len__(self): + num_available_samples: int = self.total_samples - self.consumed_samples + if self.global_batch_size is not None: + if self.drop_last: + return num_available_samples // self.global_batch_size + else: + return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + else: + if self.drop_last: + return num_available_samples // self.micro_batch_times_data_parallel_size + else: + return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1 + + def __iter__(self): + self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + # data sharding and random sampling + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + g = torch.Generator() + g.manual_seed(self.epoch) + random_idx = torch.randperm(bucket_size, generator=g).tolist() + idx_range = [start_idx + x for x in random_idx[bucket_offset:]] + + batch = [] + # Last batch if not complete will be dropped. + for idx in idx_range: + batch.append(idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + yield batch + + +class MegatronPretrainingCyclicSampler(BaseMegatronSampler): + """Cyclic sampler + + This sampler is used for the cyclic pretraining. It will go through the dataset + once and then start over again without shuffling. + + For data parallelism, the dataset is sharded into `data_parallel_size` chunks at the full dataset level. + Each rank will then sample from its own shard starting from a different offset in the full dataset. + + Args: + total_samples (int): total number of samples in the dataset + consumed_samples (int): number of samples already consumed across all dataparallel ranks + micro_batch_size (int): number of samples in a micro batch + data_parallel_rank (int): rank of the data parallel group + data_parallel_size (int): size of the data parallel group + drop_last (bool): drop the last batch if it is not complete + global_batch_size (int): global batch size + pad_samples_to_global_batch_size (bool): pad the last batch to global batch size + """ + + def __init__( + self, + total_samples: int, + consumed_samples: int, + micro_batch_size: int, + data_parallel_rank: int, + data_parallel_size: int, + global_batch_size: int, + drop_last: bool = True, + pad_samples_to_global_batch_size: Optional[bool] = False, + ) -> None: + super().__init__( + total_samples=total_samples, + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=data_parallel_rank, + data_parallel_size=data_parallel_size, + drop_last=drop_last, + global_batch_size=global_batch_size, + pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, + ) + assert ( + pad_samples_to_global_batch_size is False + ), "`MegatronPretrainingCyclicSampler` does not support sample padding" + self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size + + def __len__(self): + num_available_samples: int = self.total_samples - self.consumed_samples + if self.global_batch_size is not None: + if self.drop_last: + return num_available_samples // self.global_batch_size + else: + return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size + else: + if self.drop_last: + return num_available_samples // self.micro_batch_times_data_parallel_size + else: + return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1 + + def __iter__(self): + active_total_samples = self.total_samples - self.last_batch_size + self.epoch = self.consumed_samples // active_total_samples + current_epoch_samples = self.consumed_samples % active_total_samples + assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 + + # data sharding and random sampling + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size + bucket_offset = current_epoch_samples // self.data_parallel_size + start_idx = self.data_parallel_rank * bucket_size + + batch = [] + # Last batch if not complete will be dropped. + for idx in range(bucket_offset, bucket_size): + batch.append(idx + start_idx) + if len(batch) == self.micro_batch_size: + self.consumed_samples += self.micro_batch_times_data_parallel_size + yield batch + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + yield batch diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py new file mode 100644 index 000000000..25836f1e9 --- /dev/null +++ b/src/nanotron/data/tokenized_bytes.py @@ -0,0 +1,447 @@ +import os +import re +import time +from bisect import bisect +from dataclasses import dataclass +from re import Pattern +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import NanosetDatasetsArgs +from nanotron.data import DataCollatorForCLM, EmptyInfiniteDataset +from nanotron.data.dataloader import get_dataloader_worker_init +from nanotron.data.nemo_dataset import BlendableDataset +from nanotron.data.nemo_dataset.dataset_utils import compile_helper +from nanotron.data.s3_utils import BOTO3_AVAILABLE, _get_s3_file_list, _get_s3_object, _stream_file +from nanotron.data.samplers import MegatronPretrainingSampler +from nanotron.logging import human_format, log_rank +from nanotron.parallel import ParallelContext + +try: + tb_logger_available = True +except ImportError: + tb_logger_available = False + +logger = logging.get_logger(__name__) + + +@dataclass +class TBFileDatasetLog: + dataset_type: str + file_path: str + seq_len: int + dtype: str + skip_in_stream: bool + num_samples: int + num_tokens: int + human_format_num_tokens: str + num_epochs: Optional[int] + + +@dataclass +class TBFolderDatasetLog: + dataset_type: str + folder_path: str + filename_pattern: str + recursive: bool + seq_len: int + dtype: str + skip_in_stream: bool + num_samples: int + num_tokens: int + human_format_num_tokens: str + shuffle: Optional[bool] + seed: Optional[int] + num_epochs: Optional[int] + files_order: Optional[list] + + +@dataclass +class TrainDataLog: + global_batch_size: int + sequence_length: int + total_training_tokens: int + human_total_train_tokens: str + train_num_samples: int + eval_num_samples: int + test_num_samples: int + train_subset: Union[TBFileDatasetLog, TBFolderDatasetLog] + eval_subset: Union[TBFileDatasetLog, TBFolderDatasetLog] + test_subset: Union[TBFileDatasetLog, TBFolderDatasetLog] + + +class TokenizedBytesFileDataset(Dataset): + def __init__( + self, + file_path: str, + seq_len: int, + dtype: np.dtype = np.uint16, + skip_in_stream: bool = True, + num_samples: Optional[int] = None, + max_tokens: Optional[int] = None, + skip_tokens: Optional[int] = None, + ): + """Streaming dataset for a single TokenizedByte file + We loop on the dataset if asking for an index larger than the dataset size + + Args: + file_path (str): path to file on s3 or locally + seq_len (int): sequence length + dtype (np.dtype, optional): numpy dtype. Defaults to np.uint16. + skip_in_stream (bool, optional): skip ahead in stream. Defaults to True. + num_samples (Optional[int], optional): number of samples. Defaults to None. Only indicative for the number of epoch + """ + self.file_path = file_path + self.seq_len = seq_len + self.dtype = dtype + self.dtype_size = np.dtype(dtype).itemsize + self.skip_in_stream = skip_in_stream + self.skip_tokens = skip_tokens or 0 + # total number of full contexts in this file + if file_path.startswith("s3://"): + if not BOTO3_AVAILABLE: + raise ImportError("boto3 is required: pip install boto3") + num_tokens = _get_s3_object(file_path).content_length // self.dtype_size - self.skip_tokens + else: + num_tokens = os.path.getsize(file_path) // self.dtype_size - self.skip_tokens + self._len = (min(max_tokens, num_tokens) if max_tokens else num_tokens) // (seq_len + 1) + self._stream = None + self._last_item_requested = None + + self.subset_log = TBFileDatasetLog( + dataset_type=self.__class__.__name__, + file_path=file_path, + seq_len=seq_len, + dtype=np.dtype(dtype).name, + skip_in_stream=skip_in_stream, + num_samples=self._len, + num_tokens=self._len * (seq_len + 1), + human_format_num_tokens=human_format(self._len * (seq_len + 1)), + num_epochs=num_samples // self._len if num_samples else 0, + ) + + def _get_new_stream(self, index): + """Get a new stream starting from index (in sequence length contexts) + + Note: we pick chunks of seq_len + 1 to account for the label/target of the last tokens + This means that we drop one token of training per sample. + """ + chunk_size = self.dtype_size * (self.seq_len + 1) + index += self.skip_tokens + for chunk in _stream_file(self.file_path, chunk_size, index): + assert len(chunk) == self.dtype_size * (self.seq_len + 1), ( + f"Expected {chunk_size} bytes from file but got " f"{len(chunk)}" + ) + # careful with type conversions here + yield torch.as_tensor(np.frombuffer(chunk, self.dtype).astype(np.int64), dtype=torch.int64) + + def __getitem__(self, item): + # We loop on the dataset if asking for an index larger than the dataset size + epoch_item = item % len(self) + # if item >= len(self): + # raise IndexError(f"Index {item} requested for file {self.file_path} but it only has size {len(self)}") + # skip ahead without creating a new stream + if self._stream and epoch_item > self._last_item_requested and self.skip_in_stream: + while self._last_item_requested < epoch_item - 1: + self._last_item_requested += 1 + self._get_next_from_stream() # consume stream + # new stream starting from "epoch_item" + elif not self._stream or epoch_item != self._last_item_requested + 1: + self._stream = self._get_new_stream(epoch_item) + + self._last_item_requested = epoch_item + + return {"input_ids": self._get_next_from_stream()} + + def _get_next_from_stream(self): + sleep_time = 0.01 + while True: + try: + return next(self._stream) + except Exception as e: + if sleep_time >= 2.0: + logger.error("Giving up on re-establishing stream.") + raise e + + time.sleep(sleep_time) + self._stream = self._get_new_stream(self._last_item_requested) + sleep_time *= 2 + + def __len__(self): + return self._len + + +class TokenizedBytesFolderDataset(Dataset): + def __init__( + self, + folder_path: str, + seq_len: int, + filename_pattern: Union[Pattern, str] = None, + recursive: bool = True, + dtype: np.dtype = np.uint16, + skip_in_stream: bool = True, + num_samples: Optional[int] = None, + max_tokens: Optional[int] = None, + skip_tokens: Optional[int] = None, + shuffle: bool = False, + seed: int = 42, + ): + """Dataset for a folder of TokenizedBytes files + We loop on the dataset if asking for an index larger than the dataset size + + Args: + folder_path (str): path to folder on S3 or locally + seq_len (int): sequence length + filename_pattern (Union[Pattern, str], optional): filename pattern. Defaults to None. + recursive (bool, optional): search recursively. Defaults to True. + dtype (np.dtype, optional): numpy dtype. Defaults to np.uint16. + skip_in_stream (bool, optional): skip ahead in stream. Defaults to True. + num_samples (Optional[int], optional): number of samples. Defaults to None. Only indicative for the number of epoch + shuffle (bool, optional): shuffle the files in the folder. Defaults to False. + seed (int, optional): seed for shuffling. Defaults to 42. + """ + self.folder_path = folder_path + if isinstance(filename_pattern, str): + filename_pattern = re.compile(filename_pattern) + self.filename_pattern = filename_pattern + if folder_path.startswith("s3://"): + matched_file_paths = _get_s3_file_list(folder_path, filename_pattern, recursive) + else: + matched_file_paths = [ + os.path.join(root, file) + for root, _, files in os.walk(folder_path) + for file in files + if filename_pattern.match(os.path.join(root, file)) + ] + if not matched_file_paths: + raise FileNotFoundError(f'No files matching "{filename_pattern}" found in {folder_path}') + + self.files = [] + remaining_tokens = max_tokens + remaining_skip_tokens = skip_tokens or 0 + for path in matched_file_paths: + file_data = TokenizedBytesFileDataset( + path, + seq_len, + dtype=dtype, + skip_in_stream=skip_in_stream, + max_tokens=remaining_tokens, + skip_tokens=remaining_skip_tokens, + ) + if remaining_skip_tokens: + remaining_skip_tokens -= len(file_data) * (seq_len + 1) + if remaining_skip_tokens <= 0: + remaining_skip_tokens = 0 + elif remaining_skip_tokens > 0: + continue # We skip this file entirely + self.files.append(file_data) + if remaining_tokens: + remaining_tokens -= len(file_data) * (seq_len + 1) + if remaining_tokens <= 0: + break + + log_rank(f"Found {len(self.files)} files.", logger=logger, level=logging.INFO, rank=0) + if shuffle: + log_rank("Shuffling...", logger=logger, level=logging.INFO, rank=0) + rand = np.random.default_rng(seed) + ordering = rand.permutation(range(len(self.files))) + self.files = [self.files[i] for i in ordering] + + self.lens = np.cumsum([0] + [len(f) for f in self.files]).tolist() + + self.current_file = 0 + + self.subset_log = TBFolderDatasetLog( + dataset_type=self.__class__.__name__, + folder_path=folder_path, + filename_pattern=str(filename_pattern), + recursive=recursive, + seq_len=seq_len, + dtype=np.dtype(dtype).name, + skip_in_stream=skip_in_stream, + num_samples=self.lens[-1] if self.lens else 0, + num_tokens=self.lens[-1] * (seq_len + 1), + human_format_num_tokens=human_format(self.lens[-1] * (seq_len + 1)), + shuffle=shuffle, + seed=seed, + num_epochs=num_samples // self.lens[-1] if num_samples and self.lens else 0, + files_order=[str(f.file_path) for f in self.files], + ) + + def __getitem__(self, item): + epoch_item = item % len(self) + # if item >= len(self): + # raise IndexError( + # f"Index {item} requested for dataset {self.folder_path} (pattern: {self.filename_pattern}) " + # f"but it only has size {len(self)}" + # ) + # check if we are in the same file as before + if not (self.lens[self.current_file] <= epoch_item < self.lens[self.current_file + 1]): + # figure out current file + self.current_file = bisect(self.lens, epoch_item) - 1 + # subtract file starting offset + return self.files[self.current_file][epoch_item - self.lens[self.current_file]] + + def __len__(self): + return self.lens[-1] if self.lens else 0 + + +def build_dataset( + dataset_folder: str, + seq_length: int, + skip_in_stream: bool = True, + num_samples: Optional[int] = None, + max_tokens: Optional[int] = None, + filename_pattern: Optional[str] = ".*\\.ds$", + skip_tokens: Optional[int] = None, + shuffle: Optional[bool] = False, + seed: Optional[int] = 6, +) -> "TokenizedBytesFolderDataset": + """Build one TokenizedBytes dataset from a file or a folder on S3 or locally + + Args: + dataset_args (Union[TokenizedBytesDatasetFileArgs, TokenizedBytesDatasetFolderArgs]): dataset config + seq_length ([type]): sequence length + skip_in_stream (bool, optional): skip ahead in stream. Defaults to True. + """ + return TokenizedBytesFolderDataset( + dataset_folder, + seq_length, + filename_pattern, + skip_in_stream=skip_in_stream, + max_tokens=max_tokens, + num_samples=num_samples, + skip_tokens=skip_tokens, # # Optional number of tokens to skip at the beginning (We'll only train on the rest) + shuffle=shuffle, + seed=seed, + ) + + +def get_tb_datasets( + config: NanosetDatasetsArgs, + sequence_length: int, + global_batch_size: int, + train_steps: int, + parallel_context: ParallelContext, + shuffle: bool = False, + seed: int = 6, +) -> Tuple[DataLoader, TrainDataLog]: + """Build TokenizedBytes datasets + + Args: + config (NanosetDatasetsArgs): dataset config + sequence_length (int): sequence length + global_batch_size (int): global batch size + train_steps (int): number of training steps + parallel_context (ParallelContext): distributed process groups + """ + log_rank("Building Streamable datasets.", logger=logger, level=logging.INFO, rank=0) + dataset_max_tokens = config.dataset_max_tokens + if dataset_max_tokens is None: + dataset_max_tokens = [None] * len(config.dataset_folder) + train_num_samples = train_steps * global_batch_size + + datasets = [ + build_dataset( + dataset_folder, + sequence_length, + config.skip_in_stream, + max_tokens=max_tokens, + num_samples=train_num_samples, + shuffle=shuffle, + seed=seed, + ) + for dataset_folder, max_tokens in zip(config.dataset_folder, dataset_max_tokens) + ] + + if len(datasets) == 1: + outputs_dataset = datasets[0] + else: + if dist.get_rank(parallel_context.world_pg) == 0: + try: + compile_helper() + except ImportError: + raise ImportError( + "Could not compile megatron dataset C++ helper functions and therefore cannot import helpers python file." + ) + dist.barrier(parallel_context.world_pg) + weights = config.dataset_weights + if not weights: + weights = [1] * len(datasets) + + outputs_dataset = BlendableDataset(datasets, weights, train_num_samples, parallel_context=parallel_context) + + log_rank("Streamable datasets ready.", logger=logger, level=logging.INFO, rank=0) + train_data_log = TrainDataLog( + train_num_samples=train_num_samples, + eval_num_samples=None, + test_num_samples=None, + global_batch_size=global_batch_size, + sequence_length=sequence_length, + total_training_tokens=train_num_samples * sequence_length, + human_total_train_tokens=human_format(train_num_samples * sequence_length), + train_subset=outputs_dataset.subset_log, + eval_subset=None, + test_subset=None, + ) + return outputs_dataset, train_data_log + + +def get_tb_dataloader( + dataset: Union[TokenizedBytesFolderDataset, TokenizedBytesFileDataset, BlendableDataset, Dataset], + sequence_length: int, + micro_batch_size: int, + global_batch_size: int, + cfg: NanosetDatasetsArgs, + num_workers: int, + consumed_samples: int, + num_samples: int, + parallel_context: ParallelContext, + input_pp_rank: int, + output_pp_rank: int, + dataloader_drop_last: bool = True, + dataloader_pin_memory: bool = True, +) -> DataLoader: + # Only some rank require to run the dataloader. + if dist.get_rank(parallel_context.pp_pg) not in [ + input_pp_rank, + output_pp_rank, + ]: + dataset = EmptyInfiniteDataset(length=len(dataset)) + + log_rank( + f"Building dataloader with consumed samples: {consumed_samples}", logger=logger, level=logging.INFO, rank=0 + ) + # Megatron sampler + batch_sampler = MegatronPretrainingSampler( + total_samples=num_samples, + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=dist.get_rank(parallel_context.dp_pg), + data_parallel_size=parallel_context.dp_pg.size(), + drop_last=dataloader_drop_last, + global_batch_size=global_batch_size, + pad_samples_to_global_batch_size=cfg.pad_samples_to_global_batch_size, + ) + + # We use the data collator to put the tensors on the right pipeline parallelism rank + data_collator = DataCollatorForCLM( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + + return DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=data_collator, + pin_memory=dataloader_pin_memory, + worker_init_fn=get_dataloader_worker_init(dp_rank=dist.get_rank(parallel_context.dp_pg)), + ) diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py index b4d4a5e14..eec241bb5 100644 --- a/src/nanotron/logging/timers.py +++ b/src/nanotron/logging/timers.py @@ -26,6 +26,7 @@ class TimerRecord: end_time: float = 0.0 running: bool = False call_count: int = 0 + cuda_sync: bool = False # Option to add CUDA synchronization for more accurate timings # For CPU timers we still track total_time _cpu_total_time: float = 0.0 @@ -51,6 +52,9 @@ def start(self) -> "TimerRecord": if self.timer_type == TimerType.CUDA: if torch.cuda.is_available(): + # Synchronize before starting timing if requested + if self.cuda_sync: + torch.cuda.synchronize() # Create a new start event - we'll create the end event when end() is called self._current_start_event = torch.cuda.Event(enable_timing=True) self._current_start_event.record() @@ -72,6 +76,9 @@ def end(self) -> None: if self.timer_type == TimerType.CUDA: if torch.cuda.is_available() and self._current_start_event is not None: + # Synchronize before ending timing if requested + if self.cuda_sync: + torch.cuda.synchronize() # Create and record an end event end_event = torch.cuda.Event(enable_timing=True) end_event.record() @@ -121,6 +128,8 @@ def elapsed(self) -> float: if self.timer_type == TimerType.CUDA: if torch.cuda.is_available() and self._current_start_event is not None: # Create a temporary end event to measure elapsed time so far + if self.cuda_sync: + torch.cuda.synchronize() tmp_end_event = torch.cuda.Event(enable_timing=True) tmp_end_event.record() tmp_end_event.synchronize() @@ -168,7 +177,9 @@ def __new__(cls): cls._instance._timers: Dict[str, TimerRecord] = {} return cls._instance - def __call__(self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU) -> TimerRecord: + def __call__( + self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU, cuda_sync: bool = True + ) -> TimerRecord: """Get or create a timer with the given name. Can be used as a decorator, context manager, or directly: @@ -180,6 +191,7 @@ def __call__(self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU) name: Name of the timer timer_type: Type of timer, either TimerType.CPU or TimerType.CUDA (or 'cpu'/'cuda' strings) + cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing """ if isinstance(timer_type, str): timer_type = TimerType(timer_type) @@ -188,10 +200,13 @@ def __call__(self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU) # Being used as a decorator with default settings func = name timer_name = func.__name__ - return self._create_timer_decorator(timer_name, TimerType.CPU)(func) + return self._create_timer_decorator(timer_name, TimerType.CPU, cuda_sync)(func) if name not in self._timers: - self._timers[name] = TimerRecord(name=name, timer_type=timer_type) + self._timers[name] = TimerRecord(name=name, timer_type=timer_type, cuda_sync=cuda_sync) + else: + # Update the cuda_sync option if the timer already exists + self._timers[name].cuda_sync = cuda_sync # Check if we're being called as a decorator if not callable(name): @@ -200,9 +215,9 @@ def __call__(self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU) return timer_record # If we get here, we're being called as @nanotron_timer("name", timer_type) - return self._create_timer_decorator(name, timer_type) + return self._create_timer_decorator(name, timer_type, cuda_sync) - def _create_timer_decorator(self, name, timer_type): + def _create_timer_decorator(self, name, timer_type, cuda_sync=False): """Create a decorator that times the execution of a function.""" def decorator(func): @@ -210,7 +225,7 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - with self(name, timer_type): + with self(name, timer_type, cuda_sync): return func(*args, **kwargs) return wrapper From faa155a2b6956169bfcdffcb3f8cb4b6e28433a0 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 7 Apr 2025 16:59:48 +0000 Subject: [PATCH 12/12] . --- run_train.py | 2 + src/nanotron/config/config.py | 7 +- src/nanotron/data/tokenized_bytes.py | 25 ++++-- src/nanotron/logging/base.py | 6 +- src/nanotron/logging/timers.py | 21 +++++ src/nanotron/trainer.py | 126 ++++++++++++++++++++++----- 6 files changed, 153 insertions(+), 34 deletions(-) diff --git a/run_train.py b/run_train.py index dbe5452c2..5870d608e 100644 --- a/run_train.py +++ b/run_train.py @@ -220,6 +220,8 @@ def get_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, dataloader_drop_last=True, + use_position_ids=isinstance(trainer.model_config, Qwen2Config), + use_doc_masking=False, ) log_rank( f"[TokenizedBytes] Time taken to create TokenizedBytes: {time.strftime('%M:%S', time.gmtime(time.time() - start_time))} (MM:SS)", diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 61e78a92f..e2eef39df 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -153,8 +153,7 @@ class NanosetDatasetsArgs: # Tokenized bytes dataset config skip_in_stream: Optional[bool] = True - pad_samples_to_global_batch_size: Optional[bool] = True - dataloader_type: Optional[str] = "single" # single or cyclic + pad_samples_to_global_batch_size: Optional[bool] = False dataset_max_tokens: Optional[List[int]] = None def __post_init__(self): @@ -539,6 +538,10 @@ def save_as_yaml(self, file_path: str): # Sanity test config can be reloaded _ = get_config_from_file(file_path, config_class=self.__class__) + def get_yaml(self): + config_dict = serialize(self) + return yaml.dump(config_dict) + @classmethod def load_from_yaml(cls, file_path: str): config_dict = yaml.load(open(file_path), Loader=SafeLoader) diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index 25836f1e9..b46804089 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -13,7 +13,7 @@ from nanotron import distributed as dist from nanotron import logging from nanotron.config import NanosetDatasetsArgs -from nanotron.data import DataCollatorForCLM, EmptyInfiniteDataset +from nanotron.data import DataCollatorForCLM, DataCollatorForCLMWithPositionIds, EmptyInfiniteDataset from nanotron.data.dataloader import get_dataloader_worker_init from nanotron.data.nemo_dataset import BlendableDataset from nanotron.data.nemo_dataset.dataset_utils import compile_helper @@ -406,6 +406,8 @@ def get_tb_dataloader( output_pp_rank: int, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, + use_position_ids: bool = False, + use_doc_masking: bool = False, ) -> DataLoader: # Only some rank require to run the dataloader. if dist.get_rank(parallel_context.pp_pg) not in [ @@ -430,12 +432,21 @@ def get_tb_dataloader( ) # We use the data collator to put the tensors on the right pipeline parallelism rank - data_collator = DataCollatorForCLM( - sequence_length=sequence_length, - input_pp_rank=input_pp_rank, - output_pp_rank=output_pp_rank, - parallel_context=parallel_context, - ) + if use_position_ids: + data_collator = DataCollatorForCLMWithPositionIds( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + use_doc_masking=use_doc_masking, + ) + else: + data_collator = DataCollatorForCLM( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) return DataLoader( dataset, diff --git a/src/nanotron/logging/base.py b/src/nanotron/logging/base.py index e84554ee1..a873e86a2 100644 --- a/src/nanotron/logging/base.py +++ b/src/nanotron/logging/base.py @@ -429,9 +429,9 @@ def log_libraries_versions(logger: logging.Logger): log_rank(f"datasets version: {datasets.__version__}", logger=logger, level=logging.INFO, rank=0) log_rank(f"flash-attn version: {flash_attn.__version__}", logger=logger, level=logging.INFO, rank=0) log_rank(f"numpy version: {numpy.__version__}", logger=logger, level=logging.INFO, rank=0) - log_rank( - f"\ntorch.utils.collect_env: {torch.utils.collect_env.main()}", logger=logger, level=logging.INFO, rank=0 - ) + # log_rank( + # f"\ntorch.utils.collect_env: {torch.utils.collect_env.main()}", logger=logger, level=logging.INFO, rank=0 + # ) _configure_library_root_logger() diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py index eec241bb5..1a921e467 100644 --- a/src/nanotron/logging/timers.py +++ b/src/nanotron/logging/timers.py @@ -1,3 +1,4 @@ +import os import time from dataclasses import dataclass, field from enum import Enum @@ -170,6 +171,7 @@ class Timers: """A collection of timers for tracking execution time in Nanotron.""" _instance = None + _enabled = os.environ.get("ENABLE_TIMERS", "1") == "1" # Add global enable/disable flag def __new__(cls): if cls._instance is None: @@ -177,6 +179,21 @@ def __new__(cls): cls._instance._timers: Dict[str, TimerRecord] = {} return cls._instance + @classmethod + def enable(cls) -> None: + """Enable all timing operations.""" + cls._enabled = True + + @classmethod + def disable(cls) -> None: + """Disable all timing operations.""" + cls._enabled = False + + @classmethod + def is_enabled(cls) -> bool: + """Check if timers are enabled.""" + return cls._enabled + def __call__( self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU, cuda_sync: bool = True ) -> TimerRecord: @@ -193,6 +210,10 @@ def __call__( (or 'cpu'/'cuda' strings) cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing """ + if not self._enabled: + # Return a dummy timer that does nothing when timing is disabled + return TimerRecord(name="dummy", timer_type=TimerType.CPU) + if isinstance(timer_type, str): timer_type = TimerType(timer_type) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 553b34f2f..53dc050de 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -3,6 +3,7 @@ import json import os import shutil +import tempfile import time from dataclasses import asdict from pathlib import Path @@ -20,6 +21,7 @@ cast, ) +import psutil import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader @@ -118,6 +120,14 @@ wandb = None +def get_size(bytes): + """Convert bytes to human readable format""" + for unit in ["", "K", "M", "G", "T", "P"]: + if bytes < 1024: + return f"{bytes:.2f}{unit}B" + bytes /= 1024 + + class DistributedTrainer: def __init__( self, @@ -358,27 +368,30 @@ def pre_training(self, *args, **kwargs): level=logging.INFO, rank=world_rank, ) - else: - if world_rank == self.logger_ranks[0]: - run_name = f"{current_time}_{self.config.general.run}" - wandb.init( - project=self.config.general.project, - name=run_name, - config={"nanotron_config": self.config.as_dict()}, - settings=wandb.Settings( - x_stats_sampling_interval=1.0, # TODO: put back to default 15.0 - x_stats_disk_paths=("/scratch", "/fsx/nouamane/"), - x_stats_open_metrics_endpoints={"dcgm": "http://localhost:9104/metrics"}, - x_stats_open_metrics_filters=["DCGM_FI_"], - x_file_stream_transmit_interval=1.0, - ), - ) - log_rank( - f"Initialized wandb run '{run_name}' for TP rank {tp_rank}", - logger=logger, - level=logging.INFO, - rank=world_rank, - ) + elif world_rank == self.logger_ranks[0]: + run_name = f"{current_time}_{self.config.general.run}" + wandb.init( + project=self.config.general.project, + name=run_name, + config={"nanotron_config": self.config.as_dict()}, + settings=wandb.Settings( + x_stats_sampling_interval=1.0, # TODO: put back to default 15.0 + x_stats_disk_paths=("/scratch", "/fsx/nouamane/"), + x_stats_open_metrics_endpoints={"dcgm": "http://localhost:9104/metrics"}, + x_stats_open_metrics_filters=["DCGM_FI_"], + x_file_stream_transmit_interval=1.0, + ), + ) + # save config file + temp_config_path = tempfile.mktemp(suffix=".yaml", prefix="config") + self.config.save_as_yaml(temp_config_path) + wandb.save(temp_config_path, base_path=os.path.dirname(temp_config_path), policy="now") + log_rank( + f"Initialized wandb run '{run_name}' for TP rank {tp_rank}", + logger=logger, + level=logging.INFO, + rank=world_rank, + ) def post_train_step(self): @@ -557,6 +570,9 @@ def training_step( if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger, msg="Before train_batch_iter") + # DEBUG: sleep for 29800ms + # time.sleep(29800 / 1000) + nanotron_timer("train_batch_iter").start() with torch.profiler.record_function("train_batch_iter"): outputs = self.pipeline_engine.train_batch_iter( @@ -578,7 +594,11 @@ def training_step( assert ( self.grad_accumulator.fp32_grads_allreduce_handle is not None ), "No fp32_grads_allreduce_handle maybe you're using only a single training process" - self.grad_accumulator.fp32_grads_allreduce_handle.wait() + if isinstance(self.grad_accumulator.fp32_grads_allreduce_handle, list): + for handle in self.grad_accumulator.fp32_grads_allreduce_handle: + handle.wait() + else: + self.grad_accumulator.fp32_grads_allreduce_handle.wait() nanotron_timer("sync_gradients").start() # Sync tied weights @@ -725,6 +745,67 @@ def train_step_logs( # LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), LogItem("eta", str(datetime.timedelta(seconds=eta_seconds))), ] + + def get_cpu_logitems(): + # Add CPU memory usage metrics + memory = psutil.virtual_memory() + cpu_memory_log_entries = [ + LogItem("cpu_memory/total", memory.total, "human_format"), + LogItem("cpu_memory/available_bytes", memory.available, "human_format"), + LogItem("cpu_memory/used_bytes", memory.used, "human_format"), + LogItem("cpu_memory/percent", memory.percent, "human_format"), + ] + + # Add swap memory usage metrics + swap = psutil.swap_memory() + swap_memory_log_entries = [ + LogItem("swap_memory/total", swap.total, "human_format"), + LogItem("swap_memory/free", swap.free, "human_format"), + LogItem("swap_memory/used", swap.used, "human_format"), + LogItem("swap_memory/percent", swap.percent, "human_format"), + ] + + # Add detailed process memory info for main process and workers + process = psutil.Process() + worker_processes = [] + # Get all child processes + try: + worker_processes = process.children(recursive=True) + except psutil.NoSuchProcess: + pass + + # Log main process memory + mem_info = process.memory_info() + process_memory_log_entries = [ + LogItem("process_memory/rss", mem_info.rss, "human_format"), + LogItem("process_memory/shared", mem_info.shared, "human_format"), + LogItem("process_memory/vms", mem_info.vms, "human_format"), + LogItem("process_memory/text", mem_info.text, "human_format"), + LogItem("process_memory/data", mem_info.data, "human_format"), + LogItem("process_memory/lib", mem_info.lib, "human_format"), + LogItem("process_memory/dirty", mem_info.dirty, "human_format"), + ] + + # Log worker process memory + for idx, worker in enumerate(worker_processes): + try: + worker_mem = worker.memory_info() + process_memory_log_entries.extend( + [ + LogItem(f"worker_{idx}_memory/rss", worker_mem.rss, "human_format"), + LogItem(f"worker_{idx}_memory/shared", worker_mem.shared, "human_format"), + LogItem(f"worker_{idx}_memory/vms", worker_mem.vms, "human_format"), + LogItem(f"worker_{idx}_memory/text", worker_mem.text, "human_format"), + LogItem(f"worker_{idx}_memory/data", worker_mem.data, "human_format"), + LogItem(f"worker_{idx}_memory/lib", worker_mem.lib, "human_format"), + LogItem(f"worker_{idx}_memory/dirty", worker_mem.dirty, "human_format"), + ] + ) + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + return cpu_memory_log_entries + swap_memory_log_entries + process_memory_log_entries + if z_loss_avg is not None: basic_log_entries.insert(6, LogItem("z_loss", z_loss_avg.item(), "human_format")) # , "1.6E"), @@ -738,6 +819,7 @@ def train_step_logs( assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" self.loggerwriter.add_scalars_from_list(basic_log_entries, self.iteration_step) + basic_log_entries.extend(get_cpu_logitems()) for timer_name, timer in nanotron_timer.items(): basic_log_entries.append(LogItem(f"timers/{timer_name}", timer.elapsed, ".2f"))