diff --git a/examples/config_qwen.py b/examples/config_qwen.py index 8ca8487b3..a5d901b24 100644 --- a/examples/config_qwen.py +++ b/examples/config_qwen.py @@ -30,7 +30,7 @@ "410m": (24, 1024, 16, 16, 4096), # ~410M params # Small to medium models "1b": (16, 2048, 16, 16, 5632), # ~1B params - "3b": (28, 2048, 16, 2, 11008), # ~3B params + "3b": (36, 2048, 16, 4, 11008), # ~3B params # Standard sizes "7b": (32, 4096, 32, 32, 11008), # ~7B params "13b": (40, 5120, 40, 40, 13824), # ~13B params @@ -47,7 +47,7 @@ def get_args(): parser.add_argument( "--model", choices=MODEL_SIZES.keys(), - default="custom", + default="3b", help="Model size to generate config for (e.g., 7b, 13b)", ) parser.add_argument( @@ -76,6 +76,10 @@ def get_args(): tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size") tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica") + # checkpoints + checkpoints_group = parser.add_argument_group("checkpoints") + checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval") + args = parser.parse_args() return args @@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config: is_qwen2_config=True, pad_token_id=None, _attn_implementation="flash_attention_2", - # sliding_window_size=20, + _use_doc_masking=True, ) @@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str: def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config: learning_rate = LRSchedulerArgs( - learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0 ) parallelism = ParallelismArgs( dp=args.dp, @@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config ) optimizer = OptimizerArgs( zero_stage=args.zero, - weight_decay=0.01, + weight_decay=0.1, clip_grad=1.0, accumulate_grad_in_fp32=True, learning_rate_scheduler=learning_rate, @@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config return Config( general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save), parallelism=parallelism, model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), # tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"), @@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config world_size = args.dp * args.tp * args.pp * args.cp if world_size <= 8: print( - f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" + f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" ) + print("You can also use environment variables for more debugging:") + print(" - ENABLE_TIMERS=1: Enable detailed timing information") + print(" - DEBUG_CPU=1: Log CPU and memory usage statistics") + print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection") else: print("Checkout slurm_launcher.py to launch a multi-node job") diff --git a/examples/config_qwen.yaml b/examples/config_qwen.yaml index a2ce9bd14..cf6f40fac 100644 --- a/examples/config_qwen.yaml +++ b/examples/config_qwen.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 10 + checkpoint_interval: 100000 checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false load_lr_scheduler: true @@ -30,9 +30,9 @@ data_stages: general: benchmark_csv_path: null consumed_train_samples: null - ignore_sanity_checks: false + ignore_sanity_checks: true project: debug - run: qwen_20250423_201000_16423158 + run: qwen_20250424_120835_16423158 seed: 42 step: null lighteval: null @@ -50,24 +50,24 @@ model: make_vocab_size_divisible_by: 1 model_config: _attn_implementation: flash_attention_2 - _fused_rms_norm: false - _fused_rotary_emb: false - _use_doc_masking: false - _use_qkv_packed: false + _fused_rms_norm: true + _fused_rotary_emb: true + _use_doc_masking: true + _use_qkv_packed: true attention_bias: false bos_token_id: 1 eos_token_id: 2 flex_attention_mask: null hidden_act: silu - hidden_size: 256 + hidden_size: 2048 initializer_range: 0.02 - intermediate_size: 768 + intermediate_size: 11008 is_qwen2_config: true max_position_embeddings: 4096 moe_config: null no_rope_layer: null - num_attention_heads: 4 - num_hidden_layers: 12 + num_attention_heads: 16 + num_hidden_layers: 36 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -108,7 +108,7 @@ parallelism: pp: 1 pp_engine: 1f1b recompute_layer: false - tp: 1 + tp: 2 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER tp_recompute_allgather: true diff --git a/run_generate.py b/run_generate.py index e21fe7e22..405748e3b 100644 --- a/run_generate.py +++ b/run_generate.py @@ -5,6 +5,14 @@ ``` export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations torchrun --nproc_per_node=1 run_generate.py --ckpt-path checkpoints/10 +torchrun --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 + +torchrun --rdzv_endpoint=127.0.0.1:12357 --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 --use-cache --max-micro-batch-size 2 +export CUDA_VISIBLE_DEVICES=2,3 +torchrun --rdzv_endpoint=127.0.0.1:12356 --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 --use-cache +export CUDA_VISIBLE_DEVICES=4,5 +torchrun --rdzv_endpoint=127.0.0.1:12355 --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 --max-micro-batch-size 2 --use-decode-tokenized +torchrun --rdzv_endpoint=127.0.0.1:12355 --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 --use-decode-tokenized ``` """ @@ -45,10 +53,7 @@ from nanotron.serialize import load_weights from nanotron.trainer import CONFIG_TO_MODEL_CLASS, mark_tied_parameters -try: - from transformers import AutoTokenizer -except ImportError: - AutoTokenizer = None +from transformers import AutoTokenizer # import lovely_tensors as lt @@ -65,6 +70,10 @@ def get_args(): parser.add_argument("--tp", type=int, default=0) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") parser.add_argument("--use-cache", action="store_true", help="Use KV cache to speed up generation") + parser.add_argument( + "--max-micro-batch-size", type=int, default=1, help="Maximum number of micro batches to generate" + ) + parser.add_argument("--use-decode-tokenized", action="store_true", help="Use decode_tokenized to generate text") return parser.parse_args() @@ -73,6 +82,17 @@ def main(): assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist" + dummy_inputs = [ + # "The future of AI is", + # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + "def fib(n)", + # 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.', + # "Advancements in technology will lead to", + # "Tomorrow's world is shaped by", + # "What is the meaning of the word chutzpah?\nThe word chutzpah means", + ] + + config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix()) model_config = config.model.model_config tokenizer_path = config.tokenizer.tokenizer_name_or_path @@ -154,36 +174,29 @@ def main(): load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path) model.eval() - if AutoTokenizer is not None: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - # tokenizer.pad_token_id = tokenizer.eos_token_id - if tokenizer.pad_token_id is None: - if tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - elif getattr(model.config, "pad_token_id", None) is not None: - tokenizer.pad_token_id = int(model.config.pad_token_id) - elif getattr(model.config, "eos_token_id", None) is not None: - tokenizer.pad_token_id = int(model.config.eos_token_id) - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - tokenizer.padding_side = "left" - tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? - dummy_inputs = [ - # "The future of AI is", - "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", - "def fib(n)", - # 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.', - # "Advancements in technology will lead to", - # "Tomorrow's world is shaped by", - ] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + # tokenizer.pad_token_id = tokenizer.eos_token_id + if tokenizer.pad_token_id is None: + if tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + elif getattr(model.config, "pad_token_id", None) is not None: + tokenizer.pad_token_id = int(model.config.pad_token_id) + elif getattr(model.config, "eos_token_id", None) is not None: + tokenizer.pad_token_id = int(model.config.eos_token_id) + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? + + if not args.use_decode_tokenized: outputs = decode_text( input_iter=(GenerationInput(text=text) for text in dummy_inputs), tokenizer=tokenizer, model=model.model, parallel_context=parallel_context, max_new_tokens=args.max_new_tokens, - max_micro_batch_size=2, + max_micro_batch_size=args.max_micro_batch_size, generation_config=GenerationArgs(sampler="greedy", use_cache=args.use_cache), tokenizer_config=TokenizerConfig(max_input_length=None), is_bench=os.environ.get("USE_BENCH", "0") == "1", @@ -217,15 +230,27 @@ def main(): rank=0, ) else: + # Tokenize dummy inputs + tokenized_inputs = tokenizer( + dummy_inputs, + padding=True, + truncation=True, + return_tensors="pt", + add_special_tokens=False, # TODO: this is important to avoid adding bos token to the input + ) + input_ids = tokenized_inputs["input_ids"].to(device="cuda") + attention_mask = tokenized_inputs["attention_mask"].to(device="cuda") + outputs = decode_tokenized( - input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), - input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"), + input_ids=input_ids, + input_mask=attention_mask, model=model.model, parallel_context=parallel_context, - generation_config=GenerationArgs(sampler="greedy", use_cache=True), - max_micro_batch_size=1, - max_new_tokens=12, + generation_config=GenerationArgs(sampler="greedy", use_cache=args.use_cache), + max_micro_batch_size=args.max_micro_batch_size, + max_new_tokens=args.max_new_tokens, returns_logits=False, + bos_token_id=tokenizer.bos_token_id, ) for output in outputs: input_ids = output.input_ids @@ -234,8 +259,16 @@ def main(): assert isinstance(generated_ids, TensorPointer) continue assert isinstance(generated_ids, torch.Tensor) + + log_rank( + f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( - f"generation: {generated_ids[len(input_ids) :]}", + f"generation: {tokenizer.decode(generated_ids[len(input_ids):], clean_up_tokenization_spaces=False)}", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c16f076c1..a5e8e6eb5 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -632,7 +632,7 @@ def get_config_from_dict( InitScalingMethod: lambda x: InitScalingMethod[x.upper()], SamplerType: lambda x: SamplerType[x.upper()], }, - # strict_unions_match=True, + strict_unions_match=True, strict=True, ), ) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 363ee9887..72852831e 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -79,11 +79,11 @@ class LightEvalSlurm: gpus_per_node: int = 8 partition: str = "hopper-prod" - hf_cache: str = "~/.cache/huggingface" + hf_cache: Optional[str] = None cpus_per_task: int = 88 qos: str = "low" time: str = "24:00:00" - reservation: Optional[str] = "smollm" + reservation: Optional[str] = None def __post_init__(self): self.hf_cache = str(Path(self.hf_cache).expanduser()) @@ -109,11 +109,15 @@ class LightEvalConfig: logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None slurm: Optional[LightEvalSlurm] = None - s3_save_path: Optional[str] = None # should not be dependent of the run_name - output_dir: Optional[str] = None # we should sanity check that it's the same as the one in the eval_config_override + s3_save_path: Optional[str] = None # should not be dependent of the run_name + upload_to_wandb: Optional[bool] = False + wandb_project: Optional[str] = None + wandb_entity: Optional[str] = None + output_dir: Optional[ + str + ] = None # we should sanity check that it's the same as the one in the eval_config_override nanotron_path: Optional[str] = "./" - eval_config_override: str = None - eval_config_override: Path = None # Previously hardcoded in run_slurm_one_job + lighteval_config_path: Path = None # Previously hardcoded in run_slurm_one_job eval_interval: Optional[ int ] = None # Must be multiple of checkpoint_interval. If None, eval will be done after each checkpoint upload to s3 @@ -127,6 +131,12 @@ def __post_init__(self): if self.slurm is None: self.slurm = LightEvalSlurm() self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser()) + if self.upload_to_wandb: + assert ( + self.s3_save_path is not None + ), " We should have a s3_save_path if we want to upload to wandb" # todo: add the option to read from local folder i guess + assert self.wandb_project is not None, "wandb_project must be specified if upload_to_wandb is True" + assert self.wandb_entity is not None, "wandb_entity must be specified if upload_to_wandb is True" if self.eval_interval_file is not None and Path(self.eval_interval_file).exists(): logger.warning( f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want." diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 999d13379..8d0b6b005 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -279,4 +279,4 @@ def n_inner(self): return self.intermediate_size -NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Qwen2Config, Any] +NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Qwen2Config] diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 40d95119a..46e500fba 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -35,6 +35,8 @@ class ParallelismArgs: recompute_layer: bool = False tp_recompute_allgather: bool = True + moe_layer_recompute: bool = False # TODO: legacy config for smollm + expert_parallel_size: int = 1 context_parallel_size: int = 1 diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index 026970ccb..057b2e129 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -65,14 +65,14 @@ def __init__( for dataset_folder in self.dataset_folders: self.datatrove_datasets.append( DatatroveFolderDataset( - folder_path=dataset_folder, + data_folder=dataset_folder, filename_pattern=os.path.join(dataset_folder, "*.ds"), seq_len=sequence_length, recursive=False, token_size=self.token_size, shuffle=True, return_positions=self.return_positions, # if set to True, the position ids are directly build datatrove - eos_token_id=self.eos_token_id, + positions_from_eos_token_id=self.eos_token_id, ) ) diff --git a/src/nanotron/data/nemo_dataset/Makefile b/src/nanotron/data/nemo_dataset/Makefile index 150939026..47805f6a9 100644 --- a/src/nanotron/data/nemo_dataset/Makefile +++ b/src/nanotron/data/nemo_dataset/Makefile @@ -15,7 +15,8 @@ CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color CPPFLAGS += $(shell python3 -m pybind11 --includes) LIBNAME = helpers -LIBEXT = $(shell python3-config --extension-suffix) +# Works with uv too (unlike python3-config --extension-suffix) +LIBEXT = $(shell python3 -c "import sysconfig; print(sysconfig.get_config_var('EXT_SUFFIX'))") default: $(LIBNAME)$(LIBEXT) diff --git a/src/nanotron/data/sft_processing.py b/src/nanotron/data/sft_processing.py index 3133a749d..40ccfbd33 100644 --- a/src/nanotron/data/sft_processing.py +++ b/src/nanotron/data/sft_processing.py @@ -60,7 +60,9 @@ def process_sft(examples, tokenizer, trainer_sequence_length): # Set position ids for all tokens (prompt and completion) to sequential values # But only where attention_mask is True (non-padding tokens) valid_length = attention_mask[i].sum().item() - position_ids[i, :valid_length] = torch.arange(valid_length) + position_ids[i, :valid_length] = torch.arange( + valid_length + ) # TODO: better to pad left, although in modeling we remove pads so it doesnt matter if left or right # Set label_mask to True only for completion tokens # If prompt consumes the entire sequence, no tokens are used for loss diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index 4f9063eb6..a99981c9d 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -23,6 +23,8 @@ from nanotron.logging import human_format, log_rank from nanotron.parallel import ParallelContext from nanotron.utils import main_rank_first +from huggingface_hub import cached_assets_path + try: tb_logger_available = True @@ -302,15 +304,13 @@ def __init__( filename_pattern: str = None, recursive: bool = True, token_size: int = 2, - max_tokens: int | None = None, shuffle: bool = False, seed: int = 42, return_positions: bool = False, eos_token_id: int | None = None, skip_in_stream: bool = True, num_samples: Optional[int] = None, - folder_read_path: Optional[str] = None, - force_update_cache: bool = os.environ.get("FORCE_UPDATE_CACHE_S3", 0) == "1", + paths_file: str | None = None, ): log_rank("Using DatatroveFolderDataset", logger=logger, level=logging.INFO, rank=0) if return_positions and not eos_token_id: @@ -369,12 +369,13 @@ def __init__( ) from datatrove.utils.dataset import url_to_fs - fs_folder, folder_path = url_to_fs(folder_path) + fs_folder, stripped_folder_path = url_to_fs(folder_path) matched_files = ( - fs_folder.find(folder_path, detail=False, maxdepth=1 if not recursive else None) + fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None) if not filename_pattern else fs_folder.glob( - os.path.join(folder_path, filename_pattern), maxdepth=1 if not recursive else None + os.path.join(stripped_folder_path, filename_pattern), + maxdepth=1 if not recursive else None, ) ) matched_files = sorted(matched_files) @@ -417,19 +418,16 @@ def __init__( raise RuntimeError(f"Failed to read cache file on rank {dist.get_rank()}: {e}") super().__init__( - folder_path=folder_path, + data_folder=folder_path, seq_len=seq_len, filename_pattern=filename_pattern, recursive=recursive, token_size=token_size, - max_tokens=max_tokens, shuffle=shuffle, seed=seed, return_positions=return_positions, - eos_token_id=eos_token_id, - read_path=folder_read_path, - matched_files=matched_files, - file_sizes=file_sizes, + positions_from_eos_token_id=eos_token_id, + paths_file=paths_file, ) self.subset_log = TBFolderDatasetLog( @@ -499,20 +497,42 @@ def build_dataset( dtype=np.uint16 if token_size == 2 else np.uint32, ) + paths_file = None + if folder_read_path: + paths_file = ( + cached_assets_path(library_name="nanotron", namespace="path_files") + / f"{dataset_folder.replace('/', '_')}_paths.json" + ).as_posix() + # This ensures that the paths file is created BASED on the datasef_folder + DatatroveFolderDataset( + data_folder=dataset_folder, + filename_pattern="*.ds", + seq_len=seq_length, + recursive=False, + token_size=token_size, + shuffle=False, + return_positions=return_positions, + positions_from_eos_token_id=eos_token_id, + seed=seed, + paths_file=paths_file, + ) + + # From now on, we use the folder_read_path to read the dataset + dataset_folder = folder_read_path + return TokenizedBytesFolderDataset( folder_path=dataset_folder, filename_pattern="*.ds", seq_len=seq_length, recursive=False, token_size=token_size, - max_tokens=max_tokens, shuffle=shuffle, return_positions=return_positions, # if set to True, the position ids are directly read from datatrove eos_token_id=eos_token_id, seed=seed, skip_in_stream=skip_in_stream, num_samples=num_samples, - folder_read_path=folder_read_path, + paths_file=paths_file, ) @@ -560,29 +580,26 @@ def get_tb_datasets( for i, (dataset_folder, max_tokens) in enumerate(zip(config.dataset_folder, dataset_max_tokens)) ] - if len(datasets) == 1 and False: - outputs_dataset = datasets[0] - else: - if dist.get_rank(parallel_context.world_pg) == 0: - try: - compile_helper() - except ImportError: - raise ImportError( - "Could not compile megatron dataset C++ helper functions and therefore cannot import helpers python file." - ) - dist.barrier(parallel_context.world_pg) - weights = config.dataset_weights - if not weights: - weights = [1] * len(datasets) - - outputs_dataset = BlendableDataset( - datasets, - weights, - train_num_samples, - parallel_context=parallel_context, - seed=seed, - consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, - ) + if dist.get_rank(parallel_context.world_pg) == 0: + try: + compile_helper() + except ImportError: + raise ImportError( + "Could not compile megatron dataset C++ helper functions and therefore cannot import helpers python file." + ) + dist.barrier(parallel_context.world_pg) + weights = config.dataset_weights + if not weights: + weights = [1] * len(datasets) + + outputs_dataset = BlendableDataset( + datasets, + weights, + train_num_samples, + parallel_context=parallel_context, + seed=seed, + consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, + ) log_rank("Streamable datasets ready.", logger=logger, level=logging.INFO, rank=0) train_data_log = TrainDataLog( diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 43d1a7653..96bbf90d0 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -60,13 +60,18 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: logger.warning( f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path." ) - - slurm_job_id, slurm_log = run_slurm_one_job( - config=self.config, - lighteval_config=self.lighteval_config, - model_checkpoint_path=checkpoint_path, - current_step=self.config.general.step, - ) + if self.config.general.step % self.lighteval_config.eval_interval == 0: + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + else: + logger.warning( + f"Skipping evaluation at step {self.config.general.step} because it's not a multiple of {self.lighteval_config.eval_interval}" + ) + return None, None return slurm_job_id, slurm_log @@ -130,7 +135,8 @@ def run_slurm_one_job( #SBATCH --exclusive #SBATCH --qos={slurm_config.qos} #SBATCH --time={slurm_config.time} -#SBATCH --output={eval_logs_path}/%j-{timestamp}.out""" +#SBATCH --output={eval_logs_path}/%j-{timestamp}.out +#SBATCH --requeue""" if slurm_config.reservation: slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" @@ -154,25 +160,17 @@ def run_slurm_one_job( export MASTER_PORT=6000 export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` -# Hugging Face token setup -if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then - if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then - export HUGGING_FACE_HUB_TOKEN=$TOKEN - else - echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." - exit 1 - fi -fi - # Set environment variables export CUDA_DEVICE_MAX_CONNECTIONS=1 # export CUBLAS_WORKSPACE_CONFIG=":4096:8" # Set HuggingFace cache locations -export HUGGINGFACE_HUB_CACHE={slurm_config.hf_cache} -export HF_DATASETS_CACHE={slurm_config.hf_cache} -export HF_MODULES_CACHE={slurm_config.hf_cache} -export HF_HOME={slurm_config.hf_cache} +if [ -n "{slurm_config.hf_cache}" ] && [ "{slurm_config.hf_cache}" != "None" ]; then + export HUGGINGFACE_HUB_CACHE={slurm_config.hf_cache} + export HF_DATASETS_CACHE={slurm_config.hf_cache} + export HF_MODULES_CACHE={slurm_config.hf_cache} + export HF_HOME={slurm_config.hf_cache} +fi echo "Running on $COUNT_NODE nodes: $HOSTNAMES" @@ -244,13 +242,29 @@ def run_slurm_one_job( --node_rank $SLURM_PROCID \\ --master_addr $MASTER_ADDR \\ --master_port $MASTER_PORT \\ - {nanotron_path}/run_evals.py \\ + -m lighteval nanotron \\ --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ - --lighteval-override {lighteval_config.eval_config_override} - --cache-dir {slurm_config.hf_cache}""" + --lighteval-config-path {lighteval_config.lighteval_config_path} + """ if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None: slurm_script += f""" -s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path} +s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}/ +""" + if lighteval_config.upload_to_wandb: + gbs_tok = ( + config.parallelism.dp + * config.tokens.micro_batch_size + * config.tokens.sequence_length + * config.tokens.batch_accumulation_per_replica + ) + slurm_script += f""" +python {nanotron_path}/src/nanotron/eval/upload_to_wandb.py \\ + --wandb_project {lighteval_config.wandb_project} \\ + --wandb_entity {lighteval_config.wandb_entity} \\ + --model_name {general_run_name} \\ + --results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\ + --train_step {current_step} \\ + --consumed_tokens {current_step*gbs_tok} """ slurm_script += """ echo "Cleaning up downloaded checkpoints..." diff --git a/src/nanotron/eval/upload_to_wandb.py b/src/nanotron/eval/upload_to_wandb.py new file mode 100644 index 000000000..aa8c12d41 --- /dev/null +++ b/src/nanotron/eval/upload_to_wandb.py @@ -0,0 +1,87 @@ +import json +import s3fs +import wandb +import re +import argparse +from wandb.sdk.lib.runid import generate_id + + +def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens): + s3 = s3fs.S3FileSystem(anon=False) + all_metrics = { + # basic X axis replacements for all metrics + "consumed_tokens": consumed_tokens, + "train_step": train_step, + } + + for result_file in sorted(s3.ls(results_path)): + if not result_file.endswith(".json"): + continue + + with s3.open(result_file, "r") as f: + results = json.loads(f.read())["results"] + + for benchmark, metrics in results.items(): + if benchmark == "all": + continue + + # extract dataset and config name + match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark) + if match: + dataset, subtask = match.groups() + + for metric_name, metric_value in metrics.items(): + if "_stderr" in metric_name: + continue + # wandb-friendly metric name + wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}" + all_metrics[wandb_metric] = metric_value + + run_id = f"{model_name}-{generate_id()}" + + # try to find the run in wandb and resume it + api = wandb.Api() + runs = api.runs(f"{wandb_entity}/{wandb_project}") + for run in runs: + if run.name == model_name: + run_id = run.id + break + + wandb.init( + project=wandb_project, + entity=wandb_entity, + name=model_name, + id=run_id, + config={ + "model_name": model_name, + }, + resume="allow", + ) + + # log all metrics for this checkpoint + wandb.log(all_metrics) + + wandb.finish() + +if __name__ == "__main__": + # Setup argument parser + parser = argparse.ArgumentParser(description="Upload evaluation results to Weights & Biases.") + parser.add_argument("--wandb_project", type=str, required=True, help="WandB project name.") + parser.add_argument("--wandb_entity", type=str, required=True, help="WandB entity name.") + parser.add_argument("--model_name", type=str, required=True, help="Name of the model.") + parser.add_argument("--results_path", type=str, required=True, help="S3 path to the results directory.") + parser.add_argument("--train_step", type=int, required=True, help="Training step corresponding to the checkpoint.") + parser.add_argument("--consumed_tokens", type=int, required=True, help="Total consumed tokens up to this checkpoint.") + + # Parse arguments + args = parser.parse_args() + + # Call the main function with parsed arguments + push_to_wandb( + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + model_name=args.model_name, + results_path=args.results_path, + train_step=args.train_step, + consumed_tokens=args.consumed_tokens + ) diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 338801100..6d87161ff 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -103,6 +103,17 @@ def micro_batcher( # Each dp is responsible for its own micro batches continue + if tokenizer.bos_token_id is not None: + # TODO: do we want this always? + add_special_tokens = False + logging.warn_once( + f"Tokenizer {tokenizer.name_or_path} has a bos_token_id={tokenizer.bos_token_id}, we set add_special_tokens to False to avoid adding bos token to input ids", + main_rank_only=True, + logger=logger, + ) + else: + add_special_tokens = True + if dist.get_rank(parallel_context.pp_pg) == input_rank: encodings = tokenizer( [elt.text for elt in micro_batch], @@ -111,6 +122,7 @@ def micro_batcher( padding=tokenizer_config.padding, max_length=tokenizer_config.max_input_length, truncation=tokenizer_config.truncation, + add_special_tokens=add_special_tokens, # pad_to_multiple_of=8 ) @@ -157,13 +169,17 @@ def micro_splitter( @torch.inference_mode() -def get_position_ids(input_ids, tokenizer): +def get_position_ids(input_ids, padding_token_id=None, input_mask=None): + assert padding_token_id is not None or input_mask is not None, "Either padding_token_id or input_mask must be provided" + # Find where padding ends for each sequence batch_size, seq_length = input_ids.shape - padding_token_id = tokenizer.eos_token_id # Create a mask of padding tokens - padding_mask = input_ids == padding_token_id + if padding_token_id is not None: + padding_mask = input_ids == padding_token_id + else: + padding_mask = input_mask # Find indices where non-padding tokens start # For sequences with no padding, this will be 0 @@ -305,7 +321,12 @@ def decode_text( else: batch_generated_ids = state.new_input_ids batch_generated_mask = state.new_input_mask - position_ids = get_position_ids(batch_generated_ids, tokenizer) + # assert first token isn't bos + assert ( + batch_generated_ids[0, 0] != tokenizer.bos_token_id + ), "First token is bos. Make sure you're using add_special_tokens=True when initializing the tokenizer." + padding_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id + position_ids = get_position_ids(batch_generated_ids, padding_token_id=padding_token_id) sharded_logits = model( input_ids=batch_generated_ids, position_ids=position_ids, # [batch_size, seq_len] @@ -541,8 +562,9 @@ def decode_tokenized( generation_config: GenerationArgs, max_micro_batch_size: int, max_new_tokens: int, - tokenizer: "PreTrainedTokenizer", returns_logits: Optional[bool] = False, + logits_are_batch_first: bool = True, + bos_token_id: Optional[int] = None, ) -> Generator[GenerationOutput, None, None]: """We assume the following: - Everyone receives ALL the input text. # TODO @thomasw21: technically only specific ranks need to receive input. @@ -606,7 +628,7 @@ def decode_tokenized( for batch in batches ) - for generation_iter in range(max_new_tokens): + for generation_iter in tqdm(range(max_new_tokens), desc="Generating"): all_new_decoder_input_ids_and_mask_same_rank: List[ Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]] ] = [] @@ -614,17 +636,37 @@ def decode_tokenized( for state_id, state in enumerate(decoder_states): new_decoder_states.append(state) # Get the new logits - with attach_store(model=model, store=state.store): - position_ids = get_position_ids(state.new_input_ids, tokenizer) + if generation_config.use_cache: + raise NotImplementedError("Use-cache is not supported for now") + with attach_store(model=model, store=state.store): + position_ids = get_position_ids(state.new_input_ids, tokenizer) + sharded_logits = model( + input_ids=state.new_input_ids, + position_ids=position_ids, # [batch_size, seq_len] + ) + else: + if isinstance(state.new_input_ids, torch.Tensor): + batch_generated_ids = torch.cat(state.generation_ids, dim=-1) + batch_generated_mask = torch.cat(state.generation_mask, dim=-1) + else: + batch_generated_ids = state.new_input_ids + batch_generated_mask = state.new_input_mask + # assert first token isn't bos + if bos_token_id is not None: + assert ( + batch_generated_ids[0, 0] != bos_token_id + ), "First token is bos. Make sure you're using add_special_tokens=True when initializing the tokenizer." + position_ids = get_position_ids(batch_generated_ids, input_mask=batch_generated_mask) sharded_logits = model( - input_ids=state.new_input_ids, + input_ids=batch_generated_ids, position_ids=position_ids, # [batch_size, seq_len] - ) - if isinstance(sharded_logits, torch.Tensor): - sharded_logits = sharded_logits.transpose(0, 1) + ) # [batch_size*seq_len, vocab_size] + + sharded_logits = sharded_logits.view(*position_ids.shape, -1) # [batch_size, seq_len, vocab_size] + if isinstance(sharded_logits, torch.Tensor) and not logits_are_batch_first: + sharded_logits = sharded_logits.transpose(0, 1) # Communicate - # TODO @thomasw21: Make a diagram to show how this works nb_send: int = 0 if is_decoder_input_rank: if is_max_nb_microbatches: diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index 8aa6eb46a..3f35d25e3 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -192,7 +192,7 @@ def __init__( async_communication=tp_linear_async_communication, ) if config._use_qkv_packed: - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + from nanotron.nn.rotary import FlashRotaryEmbedding self.rotary_emb = FlashRotaryEmbedding( dim=self.head_dim, @@ -214,66 +214,146 @@ def __init__( # TODO: support doc masking / SWA / SFT / inference - def forward( + def _determine_seq_length_and_flat_ids(self, position_ids: torch.Tensor, inference_max_seqlen: Optional[int]): + if position_ids.ndim == 2: + seq_length = position_ids.shape[1] + flat_position_ids = position_ids.view(-1) # [batch_size*seq_length] + else: + assert ( + inference_max_seqlen is not None + ), "inference_max_seqlen must be provided if position_ids is a 1D tensor" + seq_length = inference_max_seqlen + flat_position_ids = position_ids + return seq_length, flat_position_ids + + def _forward_train_attn( self, - hidden_states: torch.Tensor, # [batch_size*seq_length, hidden_size] - position_ids: torch.Tensor, # [batch_size, seq_length] where -1 is padding - cu_seqlens: Optional[torch.Tensor] = None, # Added cu_seqlens argument + qkv: torch.Tensor, + position_ids: torch.Tensor, # Original position_ids, likely [batch_size, seq_length] + cu_seqlens: Optional[torch.Tensor], ): - # [0, 1, 2, 3, 4, 0, 1, 2, -1, -1, -1] # 2 documents with 5 and 3 tokens then padding - # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 1 document with 11 tokens - # [0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1] # 1 document with 10 tokens then padding - # Replace -1 with 0 in position_ids to mark every padding token as a separate sequence. Ideally we want to get rid of padding tokens from qkv - # position_ids = position_ids.masked_fill(position_ids == -1, 0) - seq_length = position_ids.shape[1] - # Keep original position_ids shape for return, flatten for internal use - position_ids = position_ids.view(-1) # [batch_size*seq_length] + seq_length, flat_position_ids = self._determine_seq_length_and_flat_ids(position_ids, None) # inference_max_seqlen is None for training - qkv = self.qkv_proj(hidden_states) + if self._use_qkv_packed: + # _forward_packed uses self.training internally + attn_output = self._forward_packed(qkv, seq_length, flat_position_ids, cu_seqlens) + else: + q, k, v = qkv.split( + [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 + ) + q = q.view(-1, self.local_num_heads, self.head_dim) + k = k.view(-1, self.local_num_kv_heads, self.head_dim) + v = v.view(-1, self.local_num_kv_heads, self.head_dim) + + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + rotary_pos_emb = self.rotary_emb( + position_ids=flat_position_ids if not self.simple_causal_mask else None, seq_length=seq_length + ) + q = self.rotary_emb.apply_rotary_pos_emb(q, rotary_pos_emb, seq_length=seq_length) + k = self.rotary_emb.apply_rotary_pos_emb(k, rotary_pos_emb, seq_length=seq_length) + else: + log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + + attn_output = self.attention( + q, k, v, position_ids=flat_position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens + ) + + output = self.o_proj(attn_output) + return {"hidden_states": output, "position_ids": position_ids} # Return original position_ids + + def _forward_inference_attn( + self, + qkv: torch.Tensor, + position_ids: torch.Tensor, # Unpadded position_ids, likely [total_tokens] + cu_seqlens: Optional[torch.Tensor], + inference_max_seqlen: int, + ): + seq_length, flat_position_ids = self._determine_seq_length_and_flat_ids(position_ids, inference_max_seqlen) if self._use_qkv_packed: - attn_output = self._forward_packed(qkv, seq_length, position_ids, cu_seqlens) + # _forward_packed uses self.training internally + attn_output = self._forward_packed(qkv, seq_length, flat_position_ids, cu_seqlens) else: q, k, v = qkv.split( [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 - ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size] - q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim] - k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] - v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] + ) + q = q.view(-1, self.local_num_heads, self.head_dim) + k = k.view(-1, self.local_num_kv_heads, self.head_dim) + v = v.view(-1, self.local_num_kv_heads, self.head_dim) + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + # For inference, rotary_emb might be configured differently (e.g., for KV caching) + # Currently, using flat_position_ids which are unpadded and continuous for each token. rotary_pos_emb = self.rotary_emb( - position_ids=position_ids if not self.simple_causal_mask else None, seq_length=seq_length - ) # [b*s, dim] or [seq_length, dim] - q = self.rotary_emb.apply_rotary_pos_emb( - q, rotary_pos_emb, seq_length=seq_length - ) # [b*s, num_heads, head_dim] - k = self.rotary_emb.apply_rotary_pos_emb( - k, rotary_pos_emb, seq_length=seq_length - ) # [b*s, num_kv_heads, head_dim] + position_ids=flat_position_ids, seq_length=seq_length + ) + q = self.rotary_emb.apply_rotary_pos_emb(q, rotary_pos_emb, seq_length=seq_length) + k = self.rotary_emb.apply_rotary_pos_emb(k, rotary_pos_emb, seq_length=seq_length) else: log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + + # cu_seqlens for inference are based on unpadded data attn_output = self.attention( - q, k, v, position_ids=position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens + q, k, v, position_ids=flat_position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens ) + # TODO: KV Caching: return k, v here + output = self.o_proj(attn_output) - # Return original position_ids shape - return {"hidden_states": output, "position_ids": position_ids.view(-1, seq_length)} + return {"hidden_states": output, "position_ids": position_ids} # Return unpadded position_ids for inference consistency + + def forward( + self, + hidden_states: torch.Tensor, # [batch_size*seq_length, hidden_size] or [total_tokens, hidden_size] + position_ids: torch.Tensor, # [batch_size, seq_length] or [total_tokens] + cu_seqlens: Optional[torch.Tensor] = None, + inference_max_seqlen: Optional[int] = None, + ): + qkv = self.qkv_proj(hidden_states) + + if self.training: + return self._forward_train_attn( + qkv, + position_ids, # Original position_ids + cu_seqlens, + ) + else: + assert inference_max_seqlen is not None, "inference_max_seqlen must be provided for inference" + return self._forward_inference_attn( + qkv, + position_ids, # Unpadded position_ids + cu_seqlens, + inference_max_seqlen, + ) def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens): assert cu_seqlens is not None, "cu_seqlens must be provided for packed attention" q = qkv[..., : self.local_num_heads * self.head_dim] # Not contiguous, similar to flash_attn kv = qkv[..., self.local_num_heads * self.head_dim :] # Not contiguous, similar to flash_attn - q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) - kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: - q, kv = self.rotary_emb( - q, kv, seqlen_offset=0, max_seqlen=None - ) # TODO: should we use position_ids here? flash_attn doesn't + if self.training: + q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) + kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) + q, kv = self.rotary_emb( + q, kv, seqlen_offset=0, max_seqlen=None + ) # TODO: should we use position_ids here? flash_attn doesn't + else: + # TODO: support seqlen_offsets in case of use_cache + # qkv = qkv.view(-1, self.local_num_heads + 2 * self.local_num_kv_heads, self.head_dim) + # self.rotary_emb.varlen_forward(qkv, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) + # qkv = qkv.view(-1, (self.local_num_heads + 2 * self.local_num_kv_heads) * self.head_dim) + # q = qkv[..., : self.local_num_heads * self.head_dim] + # kv = qkv[..., self.local_num_heads * self.head_dim :] + q = q.view(-1, self.local_num_heads, self.head_dim) + kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) + k = kv[:, 0] + self.rotary_emb.varlen_forward(q, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) + self.rotary_emb.varlen_forward(k, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) else: log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) q = q.view(-1, self.local_num_heads, self.head_dim) kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) - max_seqlen = seq_length # TODO: should this be max position_ids? + max_seqlen = seq_length # TODO: should this be max position_ids? As long as it doesn't change often it and not too big should be fine assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None @@ -405,11 +485,17 @@ def _core_forward( hidden_states: Union[torch.Tensor, TensorPointer], # [batch_size*seq_length, hidden_size] position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding cu_seqlens: Union[torch.Tensor, TensorPointer], + inference_max_seqlen: Optional[int] = None, ) -> List[Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - output = self.attn(hidden_states=hidden_states, position_ids=position_ids, cu_seqlens=cu_seqlens) + output = self.attn( + hidden_states=hidden_states, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + inference_max_seqlen=inference_max_seqlen, + ) hidden_states = output["hidden_states"] hidden_states = hidden_states + residual @@ -433,18 +519,22 @@ def forward( hidden_states: Union[torch.Tensor, TensorPointer], position_ids: Union[torch.Tensor, TensorPointer], cu_seqlens: Union[torch.Tensor, TensorPointer], + inference_max_seqlen: Optional[int] = None, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: if self.recompute_layer and not isinstance(hidden_states, TensorPointer): hidden_states, position_ids, cu_seqlens = self._checkpointed_forward( - hidden_states, position_ids, cu_seqlens + hidden_states, position_ids, cu_seqlens, inference_max_seqlen ) else: - hidden_states, position_ids, cu_seqlens = self._core_forward(hidden_states, position_ids, cu_seqlens) + hidden_states, position_ids, cu_seqlens = self._core_forward( + hidden_states, position_ids, cu_seqlens, inference_max_seqlen + ) return { "hidden_states": hidden_states, "position_ids": position_ids, "cu_seqlens": cu_seqlens, + "inference_max_seqlen": inference_max_seqlen, } @@ -460,9 +550,8 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: Qwen2Config, parallel_confi ) self.pg = tp_pg - def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [batch_size, seq_length] - input_ids = input_ids.view(-1) # [batch_size*seq_length] - input_embeds = self.token_embedding(input_ids) # [batch_size*seq_length, hidden_size] + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [...] + input_embeds = self.token_embedding(input_ids) # [..., hidden_size] return {"input_embeds": input_embeds, "position_ids": position_ids} @@ -512,8 +601,8 @@ def __init__( "cp_pg": parallel_context.cp_pg, "layer_idx": layer_idx, }, - module_input_keys={"hidden_states", "position_ids", "cu_seqlens"}, - module_output_keys={"hidden_states", "position_ids", "cu_seqlens"}, + module_input_keys={"hidden_states", "position_ids", "cu_seqlens", "inference_max_seqlen"}, + module_output_keys={"hidden_states", "position_ids", "cu_seqlens", "inference_max_seqlen"}, ) for layer_idx in range(config.num_hidden_layers) ] @@ -544,36 +633,119 @@ def __init__( module_output_keys={"logits"}, ) - def forward( + def _forward_train( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding ): + # Training case (handles potential packing) + # Get embeddings for the original (potentially padded/packed) sequence output = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) - # Compute cu_seqlens + # output["position_ids"] is the original position_ids [batch_size, seq_length] + + # Compute cu_seqlens based on document starts (position_id == 0) if data is packed if position_ids.numel() > 0: start_indices = torch.where(position_ids.view(-1) == 0)[0] cu_seqlens = torch.cat( - [start_indices, torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device)] + [ + start_indices, + torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device), + ] ).to(torch.int32) else: - cu_seqlens = None + cu_seqlens = None # Or handle empty tensor case appropriately + # Prepare state for decoder layers using original/padded/packed data decoder_states = { - "hidden_states": output["input_embeds"], - "position_ids": output["position_ids"], - "cu_seqlens": cu_seqlens, + "hidden_states": output["input_embeds"], # Padded embeds [batch*seq_len, hidden_size] + "position_ids": output["position_ids"], # Original pos_ids [batch_size, seq_len] + "cu_seqlens": cu_seqlens, # Based on packing, might be None + # "inference_max_seqlen" is not needed for training } + # Pass the prepared decoder_states dictionary to the decoder layers for decoder_layer in self.decoder: + # Decoder layers need to handle both inference (unpadded) and training (padded/packed) states decoder_states = decoder_layer(**decoder_states) + # Final layer norm and LM head operate on the output hidden_states from the last decoder layer hidden_states = self.final_layer_norm(input=decoder_states["hidden_states"])["hidden_states"] + sharded_logits = self.lm_head(x=hidden_states)["logits"] + return sharded_logits + + def _forward_inference( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + ): + assert ( + position_ids.ndim == 2 + ), "position_ids must be 2D for inference, otherwise how do we know when to separate samples?" + inference_max_seqlen = position_ids.shape[1] + inference_batch_size = position_ids.shape[0] + # This gives the number of non-padding tokens per sequence in the batch + seqlens_in_batch = (position_ids != -1).sum(dim=-1, dtype=torch.int32) + input_ids = input_ids.view(-1) + position_ids = position_ids.view(-1) + # Find indices of non-padding tokens using the flattened position_ids + unpad_indices = torch.nonzero(position_ids != -1, as_tuple=False).flatten() + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=seqlens_in_batch.device), + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), + ] + ) + unpadded_input_ids = input_ids[unpad_indices] # (total_tokens) = (total_unpadded_tokens) + unpadded_position_ids = position_ids[unpad_indices] # (total_tokens) + + # TODO: compute this in dataloader to avoid cpu-gpu sync + cu_seqlens = cu_seqlens.to(unpadded_input_ids.device if isinstance(unpadded_input_ids, torch.Tensor) else input_ids.device) + + output = self.token_position_embeddings( + input_ids=unpadded_input_ids, position_ids=unpadded_position_ids + ) # (total_tokens, hidden_size) + decoder_states = { + "hidden_states": output["input_embeds"], # Unpadded embeds [total_tokens, hidden_size] + "position_ids": output["position_ids"], # Unpadded pos_ids [total_tokens] + "cu_seqlens": cu_seqlens, # cu_seqlens for unpadded sequence [batch_size + 1] + "inference_max_seqlen": inference_max_seqlen, # original seq_length using for inference + } + + # Pass the prepared decoder_states dictionary to the decoder layers + for decoder_layer in self.decoder: + # Decoder layers need to handle both inference (unpadded) and training (padded/packed) states + decoder_states = decoder_layer(**decoder_states) + + # Final layer norm and LM head operate on the output hidden_states from the last decoder layer + hidden_states = self.final_layer_norm(input=decoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] + # Pad logits back to original shape + assert inference_batch_size is not None and inference_max_seqlen is not None and unpad_indices is not None + # Create zero tensor with the full padded shape (flattened batch/seq) + padded_sharded_logits = torch.zeros( + inference_batch_size * inference_max_seqlen, + sharded_logits.shape[-1], # vocab_shard_size + dtype=sharded_logits.dtype, + device=sharded_logits.device, + ) + # Scatter the unpadded logits back into the zero tensor + padded_sharded_logits[unpad_indices] = sharded_logits + # Reshape to (batch_size, sequence_length, vocab_shard_size) + sharded_logits = padded_sharded_logits.view(inference_batch_size, inference_max_seqlen, -1) return sharded_logits + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + ): + if self.training: + return self._forward_train(input_ids=input_ids, position_ids=position_ids) + else: + return self._forward_inference(input_ids=input_ids, position_ids=position_ids) + def get_block_compute_costs(self): """Computes the compute cost of each block in the model for load balancing.""" model_config = self.config diff --git a/src/nanotron/nn/rotary.py b/src/nanotron/nn/rotary.py index 4e78849f9..6b9ae18de 100644 --- a/src/nanotron/nn/rotary.py +++ b/src/nanotron/nn/rotary.py @@ -1,8 +1,37 @@ import torch +from flash_attn.layers.rotary import RotaryEmbedding as OriginalFlashRotaryEmbedding from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb from torch import nn +class FlashRotaryEmbedding(OriginalFlashRotaryEmbedding): + """ + This is a modified version of the FlashRotaryEmbedding class that supports variable length sequences in case of inference. + """ + + def varlen_forward(self, x, seqlen_offsets, cu_seqlens, max_seqlen): + """ + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + seqlen_offsets: (batch_size) + cu_seqlens: (batch_size + 1) + max_seqlen: int + """ + assert not self.training, "This is supposed to be used only in inference" + assert max_seqlen is not None, "max_seqlen must be provided" + self._update_cos_sin_cache(max_seqlen, device=x.device, dtype=x.dtype) + flash_apply_rotary_emb( + x, + cos=self._cos_cached, + sin=self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + class RotaryEmbedding(nn.Module): def __init__( self,