From d4e9daf7cc554082be16479f9f3d603de30ae5bc Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 14 Apr 2025 16:41:51 +0000 Subject: [PATCH 01/31] InitScalingMethod --- src/nanotron/config/models_config.py | 13 ++++++++++- src/nanotron/models/qwen.py | 2 +- src/nanotron/scaling/parametrization.py | 31 +++++++++++++++++++------ 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 410634b87..cb52e8d1b 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from enum import Enum from pathlib import Path from typing import Any, List, Optional, Union @@ -8,9 +9,17 @@ DEFAULT_ATTENTION_IMPLEMENTATION = "flash_attention_2" +class InitScalingMethod(Enum): + NONE = "none" # No scaling applied (factor = 1.0) + NUM_LAYERS = "num_layers" # Scale by sqrt(2 * total_layers) + LAYER_INDEX = "layer_index" # Scale by sqrt(2 * current_layer) + MODEL_SCALE = "model_scale" # Scale by hidden_dim/base_dim + + @dataclass class RandomInit: std: float + scaling_method: InitScalingMethod = InitScalingMethod.NUM_LAYERS @dataclass @@ -141,7 +150,9 @@ class Qwen2Config: sliding_window_size: Optional[int] = None z_loss_enabled: bool = False # Z-loss regularization https://www.jmlr.org/papers/volume24/22-1144/22-1144.pdf z_loss_coefficient: float = 0.0001 # Default from the paper (10^-4) - no_rope_layer: Optional[int] = None # Skip rope every no_rope_layer layers (see https://arxiv.org/abs/2501.18795 https://arxiv.org/abs/2305.19466 and Llama4) + no_rope_layer: Optional[ + int + ] = None # Skip rope every no_rope_layer layers (see https://arxiv.org/abs/2501.18795 https://arxiv.org/abs/2305.19466 and Llama4) _fused_rotary_emb: bool = True _fused_rms_norm: bool = True _use_qkv_packed: bool = True diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index 8115a9bb9..eee5cba38 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -896,7 +896,7 @@ def init_model_randomly(self, config: Config): else: raise ValueError(f"Unknown init method {init_method}") - parametrizator = parametrizator_cls(config=config.model) + parametrizator = parametrizator_cls(config=config) log_rank( f"Parametrizing model parameters using {parametrizator.__class__.__name__}", diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index 187e76e09..4cf15eed8 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -3,7 +3,8 @@ from enum import Enum, auto from typing import Dict -from nanotron.config import ModelArgs +from nanotron.config import Config, ModelArgs +from nanotron.config.models_config import InitScalingMethod from nanotron.nn.layer_norm import LlamaRMSNorm, TritonRMSNorm from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, @@ -31,7 +32,7 @@ def parametrize(self, param_name: str, module: nn.Module): class StandardParametrizator(Parametrizator): - def __init__(self, config: ModelArgs): + def __init__(self, config: Config): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { TensorParallelColumnLinear: self._parametrize_column_linear, @@ -41,8 +42,11 @@ def __init__(self, config: ModelArgs): TensorParallelEmbedding: self._parametrize_embedding, } - self.std = config.init_method.std - self.num_layers = config.model_config.num_hidden_layers + self.std = config.model.init_method.std + self.num_layers = config.model.model_config.num_hidden_layers + self.tp = config.parallelism.tp + self.scaling_method = config.model.init_method.scaling_method + self.hidden_size = config.model.model_config.hidden_size def _parametrize_column_linear(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] @@ -52,12 +56,26 @@ def _parametrize_column_linear(self, param_name: str, module: nn.Module): elif "bias" == param_name: module.bias.zero_() + def _compute_scaling_factor(self) -> float: + """Compute initialization scaling based on selected method""" + if self.scaling_method == InitScalingMethod.NONE: + return 1.0 + elif self.scaling_method == InitScalingMethod.NUM_LAYERS: + # Scale based on total network depth + return math.sqrt(2 * self.num_layers) + elif self.scaling_method == InitScalingMethod.LAYER_INDEX: + # Scale based on layer position + raise NotImplementedError("Layer position scaling not yet implemented") + else: + raise ValueError(f"Invalid scaling method: {self.scaling_method}") + def _parametrize_row_linear(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] if "weight" == param_name: - std = self.std / math.sqrt(2 * self.num_layers) - init.normal_(module.weight, mean=0.0, std=std) + scaling = self._compute_scaling_factor() + adjusted_std = self.std / scaling + init.normal_(module.weight, mean=0.0, std=adjusted_std) elif "bias" == param_name: module.bias.zero_() @@ -65,7 +83,6 @@ def _parametrize_layer_norm(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 module.weight.fill_(1) elif "bias" == param_name: module.bias.zero_() From 6e7f0fa0cf49748bedef1744e8498828e0fbdb62 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 14 Apr 2025 18:04:13 +0000 Subject: [PATCH 02/31] InitScalingMethod --- src/nanotron/config/config.py | 2 ++ src/nanotron/config/models_config.py | 9 +-------- src/nanotron/config/utils_config.py | 9 +++++++++ src/nanotron/scaling/parametrization.py | 2 ++ src/nanotron/trainer.py | 4 +++- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 4a8472097..02c1067bc 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -17,6 +17,7 @@ from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( + InitScalingMethod, RecomputeGranularity, cast_str_to_pipeline_engine, cast_str_to_torch_dtype, @@ -620,6 +621,7 @@ def get_config_from_dict( PipelineEngine: cast_str_to_pipeline_engine, TensorParallelLinearMode: lambda x: TensorParallelLinearMode[x.upper()], RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()], + InitScalingMethod: lambda x: InitScalingMethod[x.upper()], SamplerType: lambda x: SamplerType[x.upper()], }, # strict_unions_match=True, diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index cb52e8d1b..03b5ba71d 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -1,21 +1,14 @@ from dataclasses import dataclass, field -from enum import Enum from pathlib import Path from typing import Any, List, Optional, Union +from nanotron.config.utils_config import InitScalingMethod from nanotron.nn.attention import ALL_ATTENTION_FUNCTIONS, AttentionImplementation # The default attention implementation to use DEFAULT_ATTENTION_IMPLEMENTATION = "flash_attention_2" -class InitScalingMethod(Enum): - NONE = "none" # No scaling applied (factor = 1.0) - NUM_LAYERS = "num_layers" # Scale by sqrt(2 * total_layers) - LAYER_INDEX = "layer_index" # Scale by sqrt(2 * current_layer) - MODEL_SCALE = "model_scale" # Scale by hidden_dim/base_dim - - @dataclass class RandomInit: std: float diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index f4c071462..84e8079a4 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -18,6 +18,13 @@ class RecomputeGranularity(Enum): FULL = auto() +class InitScalingMethod(Enum): + NONE = auto() + NUM_LAYERS = auto() + LAYER_INDEX = auto() + MODEL_SCALE = auto() + + def serialize(data) -> dict: """Recursively serialize a nested dataclass to a dict - do some type conversions along the way""" if data is None: @@ -39,6 +46,8 @@ def serialize(data) -> dict: result[field.name] = value.name elif isinstance(value, RecomputeGranularity): result[field.name] = value.name + elif isinstance(value, InitScalingMethod): + result[field.name] = value.name elif isinstance(value, SamplerType): result[field.name] = value.name elif isinstance(value, torch.dtype): diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index 4cf15eed8..8f3062a93 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -52,6 +52,7 @@ def _parametrize_column_linear(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] if "weight" == param_name: + # TODO @nouamane: should we use trunc_normal_ init.normal_(module.weight, mean=0.0, std=self.std) elif "bias" == param_name: module.bias.zero_() @@ -75,6 +76,7 @@ def _parametrize_row_linear(self, param_name: str, module: nn.Module): if "weight" == param_name: scaling = self._compute_scaling_factor() adjusted_std = self.std / scaling + # TODO @nouamane: should we use trunc_normal_ init.normal_(module.weight, mean=0.0, std=adjusted_std) elif "bias" == param_name: module.bias.zero_() diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 5110d6eb2..330b82463 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -185,7 +185,9 @@ def __init__( ######################################## # Set random states - set_random_seed(self.config.general.seed) + # Set different random seed for each TP rank to ensure diversity (especially at weight init) + tp_rank = dist.get_rank(self.parallel_context.tp_pg) + set_random_seed(self.config.general.seed + tp_rank) # Init model and build on pp ranks self.random_states = init_random_states( From 24d07e5015ff3009a4323fd9177dff854df2c3f7 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Wed, 16 Apr 2025 17:47:49 +0000 Subject: [PATCH 03/31] eval --- src/nanotron/config/lighteval_config.py | 4 + src/nanotron/eval/README.md | 13 + src/nanotron/eval/__init__.py | 3 + src/nanotron/eval/evaluation_tasks.py | 368 ++++++++++++++++++++++++ src/nanotron/eval/one_job_runner.py | 330 +++++++++++++++++++++ 5 files changed, 718 insertions(+) create mode 100644 src/nanotron/eval/README.md create mode 100644 src/nanotron/eval/__init__.py create mode 100644 src/nanotron/eval/evaluation_tasks.py create mode 100644 src/nanotron/eval/one_job_runner.py diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index b5f12059a..9a656b2fd 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -91,3 +91,7 @@ class LightEvalConfig: tasks: Optional[LightEvalTasksArgs] = None logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None + + def __post_init__(self): + if self.parallelism is None: + self.parallelism = ParallelismArgs(dp=1, pp=1, tp=1, tp_linear_async_communication=True) diff --git a/src/nanotron/eval/README.md b/src/nanotron/eval/README.md new file mode 100644 index 000000000..05bfe1623 --- /dev/null +++ b/src/nanotron/eval/README.md @@ -0,0 +1,13 @@ +# Nanotron Evaluation + +This directory contains code for evaluating models trained with Nanotron. + +## Installation + +To use the evaluation functionality, you need to install the `lighteval` package: + +```bash +uv pip install lighteval[dev] +``` + +## Usage diff --git a/src/nanotron/eval/__init__.py b/src/nanotron/eval/__init__.py new file mode 100644 index 000000000..d7ea002c5 --- /dev/null +++ b/src/nanotron/eval/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa: F401 + +from .one_job_runner import LightEvalRunner diff --git a/src/nanotron/eval/evaluation_tasks.py b/src/nanotron/eval/evaluation_tasks.py new file mode 100644 index 000000000..2543df313 --- /dev/null +++ b/src/nanotron/eval/evaluation_tasks.py @@ -0,0 +1,368 @@ +from functools import partial + +from lighteval.metrics.dynamic_metrics import ( + loglikelihood_acc_metric, +) +from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm +from lighteval.tasks.default_prompts import LETTER_INDICES +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.multilingual.adapters import ( + winogrand_adapter, +) +from lighteval.tasks.multilingual.tasks import TASKS_TABLE as ML_TASKS_TABLE +from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation +from lighteval.tasks.templates.continuation import get_continuation_prompt_function +from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function +from lighteval.tasks.templates.multichoice import get_mcq_prompt_function +from lighteval.tasks.templates.utils.formulation import ( + CFFormulation, + HybridFormulation, + MCFFormulation, +) +from lighteval.utils.language import Language + +TASKS_TABLE = [] + +TASKS_TABLE.extend(ML_TASKS_TABLE) + +arc_tasks = [ + LightevalTaskConfig( + name=f"arc_{formulation.name.lower()}:{subset.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"]["text"], + "gold_idx": int(line["answerKey"]) - 1 + if line["answerKey"].isdigit() + else LETTER_INDICES.index(line["answerKey"]), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="allenai/ai2_arc", + hf_subset=f"ARC-{subset}", + hf_revision="210d026faf9955653af8916fad021475a3f00453", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="train", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for subset in ["Easy", "Challenge"] + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(arc_tasks) + +hellaswag_tasks = [ + LightevalTaskConfig( + name=f"hellaswag_{formulation.name.lower()}", + suite=["custom"], + prompt_function=get_hellaswag_prompt_function( + language=Language.ENGLISH, + adapter=lambda line: { + "activity_label": line["activity_label"], + "ctx_a": line["ctx_a"], + "ctx_b": line["ctx_b"], + "continuations": line["endings"], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + ), + hf_repo="Rowan/hellaswag", + hf_subset="default", + hf_revision="6002345709e0801764318f06bf06ce1e7d1a1fe3", + evaluation_splits=["validation"], + hf_avail_splits=["validation"], + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + trust_dataset=True, + ) + for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] +] + +TASKS_TABLE.extend(hellaswag_tasks) + +commonsense_qa_tasks = [ + LightevalTaskConfig( + name=f"commonsenseqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"]["text"], + "gold_idx": line["choices"]["label"].index(line["answerKey"].strip()), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="tau/commonsense_qa", + hf_subset="default", + hf_revision="94630fe30dad47192a8546eb75f094926d47e155", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(commonsense_qa_tasks) + +openbook_qa_tasks = [ + LightevalTaskConfig( + name=f"openbookqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question_stem"], + "choices": line["choices"]["text"], + "gold_idx": LETTER_INDICES.index(line["answerKey"]), + }, + formulation=formulation, + ), + suite=["custom"], + hf_repo="allenai/openbookqa", + hf_subset="main", + hf_revision="388097ea7776314e93a529163e0fea805b8a6454", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(openbook_qa_tasks) + +winogrande_tasks = [ + LightevalTaskConfig( + name=f"winogrande_{formulation.name.lower()}", + suite=("custom",), + prompt_function=get_continuation_prompt_function( + Language.ENGLISH, partial(winogrand_adapter, Language.ENGLISH), formulation=formulation + ), + hf_repo="allenai/winogrande", + hf_subset="winogrande_xl", + trust_dataset=True, + hf_revision="85ac5b5a3b7a930e22d590176e39460400d19e41", + metric=[ + loglikelihood_acc_metric(normalization=None), + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(winogrande_tasks) + +piqa_tasks = [ + LightevalTaskConfig( + name=f"piqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["goal"], + "choices": [line["sol1"], line["sol2"]], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + ), + suite=["custom"], + hf_repo="ybisk/piqa", + hf_revision="2e8ac2dffd59bac8c3c6714948f4c551a0848bb0", + hf_subset="plain_text", + trust_dataset=True, + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(piqa_tasks) + + +MMLU_SUBSETS = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_medicine", + "college_physics", + "computer_security", + "conceptual_physics", + "econometrics", + "electrical_engineering", + "elementary_mathematics", + "formal_logic", + "global_facts", + "high_school_biology", + "high_school_chemistry", + "high_school_computer_science", + "high_school_european_history", + "high_school_geography", + "high_school_government_and_politics", + "high_school_macroeconomics", + "high_school_mathematics", + "high_school_microeconomics", + "high_school_physics", + "high_school_psychology", + "high_school_statistics", + "high_school_us_history", + "high_school_world_history", + "human_aging", + "human_sexuality", + "international_law", + "jurisprudence", + "logical_fallacies", + "machine_learning", + "management", + "marketing", + "medical_genetics", + "miscellaneous", + "moral_disputes", + "moral_scenarios", + "nutrition", + "philosophy", + "prehistory", + "professional_accounting", + "professional_law", + "professional_medicine", + "professional_psychology", + "public_relations", + "security_studies", + "sociology", + "us_foreign_policy", + "virology", + "world_religions", +] + +mmlu_tasks = [ + LightevalTaskConfig( + name=f"mmlu_{formulation.name.lower()}:{subset}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"], + "gold_idx": int(line["answer"]), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="cais/mmlu", + hf_subset=subset, + hf_revision="c30699e8356da336a370243923dbaf21066bb9fe", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="dev", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for subset in MMLU_SUBSETS + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(mmlu_tasks) + +mmlu_pro_tasks = [ + LightevalTaskConfig( + name=f"mmlu_pro_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["options"], + "gold_idx": line["answer_index"], + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="TIGER-Lab/MMLU-Pro", + hf_subset="default", + hf_revision="3373e0b32277875b8db2aa555a333b78a08477ea", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="validation", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(mmlu_pro_tasks) + + +if __name__ == "__main__": + print(t.name for t in TASKS_TABLE) + print(len(TASKS_TABLE)) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py new file mode 100644 index 000000000..3574ea5d7 --- /dev/null +++ b/src/nanotron/eval/one_job_runner.py @@ -0,0 +1,330 @@ +""" Mostly complete a SLURM template with a link to a single checkpoint on s3 and launch it +""" +import datetime +import math +import os +import subprocess +from typing import List, Optional, Tuple + +from nanotron import logging +from nanotron.config import Config, LightEvalConfig +from nanotron.logging import log_rank +from nanotron.parallel import ParallelContext + +logger = logging.get_logger(__name__) + + +class LightEvalRunner: + def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = None): + self.config = config + assert config.lighteval is not None, "LightEval config is required" + self.lighteval_config = config.lighteval + self.parallel_context = parallel_context + + def eval_single_checkpoint_no_s3(self, checkpoint_path: str) -> Tuple[str, str]: + raise NotImplementedError("Not implemented") + + def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: + """Run light evaluation on uploaded files.""" + logger.warning(f"Lighteval Runner got {len(uploaded_files)} files. Checking configs.") + config_files = [ + f for f in uploaded_files if "config.py" in f["destination"] or "config.yaml" in f["destination"] + ] + # Sanity check on the config files len (we want only one) + if len(config_files) == 0: + log_rank( + "No config files founds in uploaded checkpoints. Not running evaluation.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + if len(config_files) > 1: + log_rank( + "Found multiple config files in uploaded checkpoints.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + checkpoint_path = config_files[0]["destination"].replace("config.yaml", "") + + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + + return slurm_job_id, slurm_log + + +def run_slurm_one_job( + config: Config, + lighteval_config: LightEvalConfig, + model_checkpoint_path: str, + current_step: int, +): + """Launch a single job on Slurm with the given mapping""" + # Default evaluation config + default_slurm_config = { + "gpus_per_node": 8, + "partition": "hopper-prod", + "hf_cache": "/fsx/nouamane/.cache/huggingface", + "cpus_per_task": 88, + "qos": "high", + "time": "24:00:00", + "reservation": "smollm", + } + + # Use lighteval config paths if available, otherwise use defaults + eval_launch_script_path = os.path.join( + lighteval_config.slurm_script_dir + if lighteval_config.slurm_script_dir + else "/fsx/nouamane/projects/nanotron/eval_results/launch-config", + str(current_step), + ) + eval_logs_path = os.path.join( + lighteval_config.checkpoints_path + if lighteval_config.checkpoints_path + else "/fsx/nouamane/projects/nanotron/eval_results/logs", + str(current_step), + ) + + # Create directories + os.makedirs(eval_launch_script_path, exist_ok=True) + os.makedirs(eval_logs_path, exist_ok=True) + + # Calculate the number of nodes based on parallelism config + total_gpus_needed = ( + lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp + ) + gpus_per_node = default_slurm_config["gpus_per_node"] + nodes = math.ceil(total_gpus_needed / gpus_per_node) + + # Get timestamp for log files + timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + run_name = f"eval_{current_step}".replace(" ", "_") + + # Create log directory with run name subdirectory + logs_path = os.path.join(eval_logs_path, run_name) + os.makedirs(logs_path, exist_ok=True) + + # Create the SLURM script content + slurm_script = f"""#!/bin/bash +#SBATCH --job-name={run_name} +#SBATCH --partition={default_slurm_config["partition"]} +#SBATCH --nodes={nodes} +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task={default_slurm_config["cpus_per_task"]} +#SBATCH --gpus={gpus_per_node} +#SBATCH --exclusive +#SBATCH --qos={default_slurm_config["qos"]} +#SBATCH --time={default_slurm_config["time"]} +#SBATCH --output={logs_path}/{timestamp}-%x-%j.out""" + + if default_slurm_config.get("reservation"): + slurm_script += f"\n#SBATCH --reservation={default_slurm_config['reservation']}" + + # Add the rest of the script content + local_path = os.path.join("/tmp", f"eval_{config.general.run}", str(current_step)) + + slurm_script += f""" + +set -x -e + +LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={local_path} + +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token setup +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +# Set environment variables +export CUDA_DEVICE_MAX_CONNECTIONS=1 +# export CUBLAS_WORKSPACE_CONFIG=":4096:8" + +# Set HuggingFace cache locations +export HUGGINGFACE_HUB_CACHE={default_slurm_config["hf_cache"]} +export HF_DATASETS_CACHE={default_slurm_config["hf_cache"]} +export HF_MODULES_CACHE={default_slurm_config["hf_cache"]} +export HF_HOME={default_slurm_config["hf_cache"]} + +echo "Running on $COUNT_NODE nodes: $HOSTNAMES" + +# Create checkpoint directory +mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER + +# Handle S3 paths +if [[ "{model_checkpoint_path}" == s3://* ]]; then + echo "Downloading checkpoint from S3: {model_checkpoint_path}" + s5cmd sync \ + --concurrency=50 \ + --size-only \ + --exclude "optimizer/*" \ + --exclude "random/*" \ + --exclude "lr_scheduler/*" \ + --part-size 100 \ + "{model_checkpoint_path}/*" "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/" +else + echo "Copying checkpoint files from {model_checkpoint_path} to $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" + rsync -av --progress --inplace --no-whole-file \ + --exclude 'optimizer/' \ + --exclude 'random/' \ + --exclude 'lr_scheduler/' \ + {model_checkpoint_path} $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ +fi + +echo "Contents of checkpoint directory:" +ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + +# Add random sleep to avoid hub request conflicts +# sleep $(( RANDOM % 300 )) + +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \\ + --nproc_per_node {gpus_per_node} \\ + --nnodes $COUNT_NODE \\ + --node_rank $SLURM_PROCID \\ + --master_addr $MASTER_ADDR \\ + --master_port $MASTER_PORT \\ + run_evals.py \\ + --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ + --lighteval-override smollm3_eval.yaml""" + + # if lighteval_config.batch_size: + # slurm_script += f" \\\n --batch-size {lighteval_config.batch_size}" + + # if lighteval_config.tasks: + # slurm_script += """ + # if [ -n "${TASKS}" ]; then + # CMD="$CMD --tasks ${TASKS}" + # fi + # if [ -n "${CUSTOM_TASKS}" ]; then + # CMD="$CMD --custom-tasks ${CUSTOM_TASKS}" + # fi + # if [ -n "${MAX_SAMPLES}" ]; then + # CMD="$CMD --max-samples ${MAX_SAMPLES}" + # fi""" + + slurm_script += """ + +echo "END TIME: $(date)" +""" + + # Write the script to file + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}.slurm") + os.makedirs(os.path.dirname(launch_script_path), exist_ok=True) + + with open(launch_script_path, "w") as f: + f.write(slurm_script) + + # Preserve important environment variables + env = { + "PATH": os.environ["PATH"], + "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), + "HOME": os.path.expanduser("~"), + } + + try: + # Use subprocess.run instead of check_output for better error handling + result = subprocess.run(["sbatch", launch_script_path], env=env, check=True, capture_output=True, text=True) + output = result.stdout + job_ids = output.split()[-1] + + output_log = os.path.join(logs_path, f"{timestamp}-{run_name}-{job_ids}.out") + + logger.warning( + f"""🚀 Slurm job launched successfully: + Job name: {run_name} + Job ID: {job_ids} + Launch script: {launch_script_path} + Log file: {output_log}""" + ) + except subprocess.CalledProcessError as e: + logger.error(f"Error while launching Slurm job: {e}") + logger.error(f"Command output: {e.output}") + logger.error(f"Command stderr: {e.stderr}") + job_ids = None + output_log = None + + return job_ids, output_log + + +if __name__ == "__main__": + + from nanotron.config.config import Config + + # Load existing config from checkpoint + # checkpoint_path = "/fsx/nouamane/projects/nanotron/checkpoints/smollm3-training-test-tps-48nn-seed-6-/10" + # config_path = os.path.join(checkpoint_path, "config.yaml") + checkpoint_path = "s3://smollm3/smollm3-3B-final/3B-final-GQA-noTP-2k-seq/20000/" + config_path = "/fsx/nouamane/projects/nanotron/checkpoints/smollm3-training-test-tps-48nn-seed-6-/10/config.yaml" + try: + # Load the existing config + print(f"\nLoading config from: {config_path}") + config = Config.load_from_yaml(config_path) + + # Print config details + print("\nConfig details:") + print(f"Project: {config.general.project}") + print(f"Run: {config.general.run}") + print(f"Step: {config.general.step}") + + if config.lighteval: + print("\nLightEval config:") + print( + f"Parallelism: dp={config.lighteval.parallelism.dp}, tp={config.lighteval.parallelism.tp}, pp={config.lighteval.parallelism.pp}" + ) + print(f"Batch size: {config.lighteval.batch_size}") + print(f"Slurm template: {config.lighteval.slurm_template}") + print(f"Checkpoints path: {config.lighteval.checkpoints_path}") + if config.lighteval.tasks: + print(f"Tasks: {config.lighteval.tasks.tasks}") + print(f"Custom tasks: {config.lighteval.tasks.custom_tasks}") + print(f"Max samples: {config.lighteval.tasks.max_samples}") + + # Create test files structure + test_files = [ + { + "destination": os.path.join(checkpoint_path, "config.yaml"), + "source": "existing_config", + } + ] + + if config.lighteval is None: + config.lighteval = LightEvalConfig() + + print("\nInitializing LightEvalRunner...") + runner = LightEvalRunner(config=config) + + print("\nTesting LightEvalRunner.eval_single_checkpoint()...") + job_id, log_path = runner.eval_single_checkpoint(test_files) + + except Exception as e: + print(f"\nError during test: {str(e)}") + import traceback + + traceback.print_exc() + + finally: + print("\nTest completed") From 438257abafd9c163827b74e92290a0a30ba07c52 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Wed, 16 Apr 2025 18:00:13 +0000 Subject: [PATCH 04/31] try adding lightevalrunner to trainer --- src/nanotron/trainer.py | 59 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 330b82463..7be3e361e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -39,6 +39,7 @@ ) from nanotron.constants import MODEL_CONFIG_FILE_NAME from nanotron.data.dataloader import sanity_check_dataloader +from nanotron.eval import LightEvalRunner from nanotron.helpers import ( _vocab_size_with_padding, compute_remain_train_steps_of_a_data_stage_from_ckp, @@ -122,7 +123,7 @@ def get_size(bytes): """Convert bytes to human readable format""" - for unit in ["", "K", "M", "G", "T", "P"]: + for unit in ["", "K", "M", "B", "T", "P"]: if bytes < 1024: return f"{bytes:.2f}{unit}B" bytes /= 1024 @@ -314,6 +315,21 @@ def post_init(self): else: self.s3_mover = None + # Initialize LightEval runner on rank 0 + if dist.get_rank(self.parallel_context.world_pg) == 0: + if self.config.lighteval is not None: + self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + if self.s3_mover is not None: + # If we have S3 upload enabled, use the eval_single_checkpoint as post-upload callback + self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint + else: + # If no S3 upload, use the no_s3 version directly after checkpoint save + self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 + else: + self.post_checkpoint_callback = None + else: + self.post_checkpoint_callback = None + def pre_training(self, *args, **kwargs): if not self.config.general.ignore_sanity_checks: log_rank( @@ -364,6 +380,14 @@ def pre_training(self, *args, **kwargs): name=run_name, config={"nanotron_config": self.config.as_dict()}, ) + # Define tokens metric as x-axis for all metrics + wandb.define_metric("Tokens") + wandb.define_metric("*", step_metric="Tokens") + + # Handle resuming from a previous run + initial_tokens = self.initial_iter_step * self.global_batch_size + # Log initial tokens to set the starting point + wandb.log({"Tokens": initial_tokens}) log_rank( f"Initialized wandb run '{run_name}' for TP rank {tp_rank}", logger=logger, @@ -525,7 +549,6 @@ def train( ], **kwargs, ) -> None: - self.pre_training(**kwargs) if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: self.save_checkpoint() @@ -545,6 +568,7 @@ def train( self.initial_iter_step = self.metadata.last_train_step + 1 self.last_iter_step = self.config.tokens.train_steps + self.pre_training(**kwargs) prof = get_profiler(config=self.config) # free memory @@ -925,6 +949,8 @@ def get_cpu_logitems(): **{log_item.tag: log_item.scalar_value for log_item in all_log_entries}, **tp_group_info, "iteration_step": self.iteration_step, + "Tokens": self.metadata.consumed_train_samples + * self.config.tokens.sequence_length, # TODO: this is not true if we change seqlen }, step=self.iteration_step, ) @@ -939,6 +965,8 @@ def get_cpu_logitems(): { **{log_item.tag: log_item.scalar_value for log_item in basic_log_entries}, "iteration_step": self.iteration_step, + "Tokens": self.metadata.consumed_train_samples + * self.config.tokens.sequence_length, # TODO: this is not true if we change seqlen }, step=self.iteration_step, ) @@ -1166,22 +1194,28 @@ def pre_save_checkpoint(self) -> Path: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs - self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval") + log_rank( + f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", + logger=logger, + level=logging.INFO, + rank=0, + ) def post_save_checkpoint(self): # Upload to S3 if self.s3_mover is not None: self.s3_mover.start_uploading() - # free memory TODO: do we need this? - # gc.collect() - # torch.cuda.empty_cache() + elif self.post_checkpoint_callback is not None: + # If we're not using S3, but we have a post-checkpoint callback for evals + checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" + self.post_checkpoint_callback(checkpoint_path) def save_checkpoint(self) -> Path: self.pre_save_checkpoint() checkpoints_path = self.config.checkpoints.checkpoints_path - checkpoint_path = checkpoints_path / f"{self.iteration_step}" + checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}" if self.config.checkpoints.checkpoints_path_is_shared_file_system: should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0 else: @@ -1227,6 +1261,17 @@ def save_checkpoint(self) -> Path: self.post_save_checkpoint() + # Handle post-checkpoint evaluation if configured + if self.post_checkpoint_callback is not None: + job_id, log_path = self.post_checkpoint_callback(str(checkpoint_path)) + if job_id is not None and log_path is not None: + log_rank( + f"launching eval job: job_id={job_id} log at {log_path} slurm_eval", + logger=logger, + level=logging.INFO, + rank=0, + ) + return checkpoint_path def _mark_tied_parameters( From 4f8a35032b80ae93c4d0ab21d0356825c2ab1914 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Wed, 16 Apr 2025 18:04:54 +0000 Subject: [PATCH 05/31] amend --- src/nanotron/trainer.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 7be3e361e..b930238c9 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -380,14 +380,6 @@ def pre_training(self, *args, **kwargs): name=run_name, config={"nanotron_config": self.config.as_dict()}, ) - # Define tokens metric as x-axis for all metrics - wandb.define_metric("Tokens") - wandb.define_metric("*", step_metric="Tokens") - - # Handle resuming from a previous run - initial_tokens = self.initial_iter_step * self.global_batch_size - # Log initial tokens to set the starting point - wandb.log({"Tokens": initial_tokens}) log_rank( f"Initialized wandb run '{run_name}' for TP rank {tp_rank}", logger=logger, @@ -433,6 +425,9 @@ def pre_training(self, *args, **kwargs): level=logging.INFO, rank=world_rank, ) + # Define tokens metric as x-axis for all metrics + wandb.define_metric("consumed_tokens") + wandb.define_metric("*", step_metric="consumed_tokens") def post_train_step(self): @@ -789,7 +784,8 @@ def train_step_logs( # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", - self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + self.metadata.consumed_train_samples + * self.config.tokens.sequence_length, # TODO: not true if we change seqlen "human_format", ), # , "12d"), LogItem("time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), @@ -948,9 +944,6 @@ def get_cpu_logitems(): { **{log_item.tag: log_item.scalar_value for log_item in all_log_entries}, **tp_group_info, - "iteration_step": self.iteration_step, - "Tokens": self.metadata.consumed_train_samples - * self.config.tokens.sequence_length, # TODO: this is not true if we change seqlen }, step=self.iteration_step, ) @@ -965,8 +958,6 @@ def get_cpu_logitems(): { **{log_item.tag: log_item.scalar_value for log_item in basic_log_entries}, "iteration_step": self.iteration_step, - "Tokens": self.metadata.consumed_train_samples - * self.config.tokens.sequence_length, # TODO: this is not true if we change seqlen }, step=self.iteration_step, ) From c9c479dea6ed763fc1948b108e17d2ff8a726a98 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Wed, 16 Apr 2025 18:07:31 +0000 Subject: [PATCH 06/31] amend --- src/nanotron/trainer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b930238c9..600b8b20c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -386,6 +386,9 @@ def pre_training(self, *args, **kwargs): level=logging.INFO, rank=world_rank, ) + # Define tokens metric as x-axis for all metrics + wandb.define_metric("consumed_tokens") + wandb.define_metric("*", step_metric="consumed_tokens") elif world_rank == self.logger_ranks[0]: run_name = f"{current_time}_{self.config.general.run}" x_stats_sampling_interval = os.environ.get("STATS_SAMPLING_INTERVAL_IN_SEC", None) @@ -425,9 +428,9 @@ def pre_training(self, *args, **kwargs): level=logging.INFO, rank=world_rank, ) - # Define tokens metric as x-axis for all metrics - wandb.define_metric("consumed_tokens") - wandb.define_metric("*", step_metric="consumed_tokens") + # Define tokens metric as x-axis for all metrics + wandb.define_metric("consumed_tokens") + wandb.define_metric("*", step_metric="consumed_tokens") def post_train_step(self): From 190a6b9e97f7bff9c0844aeb4ae857a29f34d30e Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 11:51:15 +0000 Subject: [PATCH 07/31] amend --- src/nanotron/trainer.py | 36 ++++++------------------------------ 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 600b8b20c..a6de78d39 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -322,13 +322,6 @@ def post_init(self): if self.s3_mover is not None: # If we have S3 upload enabled, use the eval_single_checkpoint as post-upload callback self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint - else: - # If no S3 upload, use the no_s3 version directly after checkpoint save - self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 - else: - self.post_checkpoint_callback = None - else: - self.post_checkpoint_callback = None def pre_training(self, *args, **kwargs): if not self.config.general.ignore_sanity_checks: @@ -386,9 +379,6 @@ def pre_training(self, *args, **kwargs): level=logging.INFO, rank=world_rank, ) - # Define tokens metric as x-axis for all metrics - wandb.define_metric("consumed_tokens") - wandb.define_metric("*", step_metric="consumed_tokens") elif world_rank == self.logger_ranks[0]: run_name = f"{current_time}_{self.config.general.run}" x_stats_sampling_interval = os.environ.get("STATS_SAMPLING_INTERVAL_IN_SEC", None) @@ -428,9 +418,6 @@ def pre_training(self, *args, **kwargs): level=logging.INFO, rank=world_rank, ) - # Define tokens metric as x-axis for all metrics - wandb.define_metric("consumed_tokens") - wandb.define_metric("*", step_metric="consumed_tokens") def post_train_step(self): @@ -547,7 +534,6 @@ def train( ], **kwargs, ) -> None: - if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: self.save_checkpoint() @@ -947,6 +933,7 @@ def get_cpu_logitems(): { **{log_item.tag: log_item.scalar_value for log_item in all_log_entries}, **tp_group_info, + "iteration_step": self.iteration_step, }, step=self.iteration_step, ) @@ -1191,7 +1178,7 @@ def pre_save_checkpoint(self) -> Path: log_rank( f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", logger=logger, - level=logging.INFO, + level=logging.WARNING, rank=0, ) @@ -1200,10 +1187,10 @@ def post_save_checkpoint(self): if self.s3_mover is not None: self.s3_mover.start_uploading() - elif self.post_checkpoint_callback is not None: - # If we're not using S3, but we have a post-checkpoint callback for evals - checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" - self.post_checkpoint_callback(checkpoint_path) + if dist.get_rank(self.parallel_context.world_pg) == 0: + if self.config.lighteval is not None and self.s3_mover is None: + checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" + self.lighteval_runner.eval_single_checkpoint(checkpoint_path) def save_checkpoint(self) -> Path: self.pre_save_checkpoint() @@ -1255,17 +1242,6 @@ def save_checkpoint(self) -> Path: self.post_save_checkpoint() - # Handle post-checkpoint evaluation if configured - if self.post_checkpoint_callback is not None: - job_id, log_path = self.post_checkpoint_callback(str(checkpoint_path)) - if job_id is not None and log_path is not None: - log_rank( - f"launching eval job: job_id={job_id} log at {log_path} slurm_eval", - logger=logger, - level=logging.INFO, - rank=0, - ) - return checkpoint_path def _mark_tied_parameters( From 004a89cae000589c8b7e2c9af60f2717af57ec0f Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 11:53:09 +0000 Subject: [PATCH 08/31] amend --- src/nanotron/trainer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index a6de78d39..459eeded0 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -571,12 +571,13 @@ def train( outputs, loss_avg, z_loss_avg = self.training_step(dataloader=self.current_dataloader) # Update consumption tracking for current batch - self.current_base_dl.dataset.update_consumption_metrics( - start_idx=(self.iteration_step - 1) - * self.global_batch_size, # assumes we start from iteration_step=1 - end_idx=self.iteration_step * self.global_batch_size, - sequence_length=self.sequence_length, - ) + if hasattr(self.current_base_dl, "dataset"): + self.current_base_dl.dataset.update_consumption_metrics( + start_idx=(self.iteration_step - 1) + * self.global_batch_size, # assumes we start from iteration_step=1 + end_idx=self.iteration_step * self.global_batch_size, + sequence_length=self.sequence_length, + ) # Training Logs # Track consumed tokens for all dataset folders in current stage From b4cbb55d9bc562da24add2db286c852e528a538b Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 11:54:06 +0000 Subject: [PATCH 09/31] amend --- src/nanotron/trainer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 459eeded0..65908df8c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -581,12 +581,13 @@ def train( # Training Logs # Track consumed tokens for all dataset folders in current stage - consumption_stats = self.current_base_dl.dataset.get_consumption_stats() - current_stage = self.metadata.data_stages[self.metadata.last_stage_idx] + if hasattr(self.current_base_dl, "dataset"): + consumption_stats = self.current_base_dl.dataset.get_consumption_stats() + current_stage = self.metadata.data_stages[self.metadata.last_stage_idx] - # Update consumed tokens for all folders in the consumption stats - for folder_path, stats in consumption_stats.items(): - current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] + # Update consumed tokens for all folders in the consumption stats + for folder_path, stats in consumption_stats.items(): + current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] # Original consumption tracking self.metadata.consumed_train_samples += self.global_batch_size From d39872b028f6402cf7d9af6b956730b3cf0a4df8 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 11:55:15 +0000 Subject: [PATCH 10/31] amend --- src/nanotron/trainer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 65908df8c..743627835 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -876,12 +876,13 @@ def get_cpu_logitems(): assert self.current_base_dl is not None, "current_base_dl should be defined" # Log consumption statistics - for dataset_name, stats in self.current_base_dl.dataset.get_consumption_stats().items(): - basic_log_entries.extend( - [ - LogItem(f"dataloader/consumed_tokens/{dataset_name}", stats["tokens"], "human_format"), - ] - ) + if hasattr(self.current_base_dl, "dataset"): + for dataset_name, stats in self.current_base_dl.dataset.get_consumption_stats().items(): + basic_log_entries.extend( + [ + LogItem(f"dataloader/consumed_tokens/{dataset_name}", stats["tokens"], "human_format"), + ] + ) # WandB logging - determine if this rank should log to wandb should_log_to_wandb = wandb is not None and ( From feb818ab9d3f38a849e0996d091fdc165af166b9 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 12:01:20 +0000 Subject: [PATCH 11/31] . --- src/nanotron/eval/one_job_runner.py | 3 --- src/nanotron/logging/base.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 3574ea5d7..d30c50a24 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -21,9 +21,6 @@ def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = self.lighteval_config = config.lighteval self.parallel_context = parallel_context - def eval_single_checkpoint_no_s3(self, checkpoint_path: str) -> Tuple[str, str]: - raise NotImplementedError("Not implemented") - def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: """Run light evaluation on uploaded files.""" logger.warning(f"Lighteval Runner got {len(uploaded_files)} files. Checking configs.") diff --git a/src/nanotron/logging/base.py b/src/nanotron/logging/base.py index e84554ee1..b14b94aab 100644 --- a/src/nanotron/logging/base.py +++ b/src/nanotron/logging/base.py @@ -265,7 +265,7 @@ def warn_once( def human_format(num: float, billions: bool = False, divide_by_1024: bool = False) -> str: if abs(num) < 1: return "{:.3g}".format(num) - SIZES = ["", "K", "M", "G", "T", "P", "E"] + SIZES = ["", "K", "M", "B", "T", "P", "E"] num = float("{:.3g}".format(num)) magnitude = 0 i = 0 From 025f31413c6ea6a0fb6a5e35063b780ee91aba79 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 12:29:18 +0000 Subject: [PATCH 12/31] amend --- src/nanotron/config/lighteval_config.py | 22 +++++++- src/nanotron/config/models_config.py | 8 +-- src/nanotron/eval/one_job_runner.py | 73 ++++++++----------------- 3 files changed, 47 insertions(+), 56 deletions(-) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 9a656b2fd..2811c83f7 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -73,6 +73,22 @@ def __post_init__(self): assert self.wandb_project != "", "Please specify a wandb_project" +@dataclass +class LightEvalSlurm: + """Arguments related to SLURM configuration for LightEval""" + + gpus_per_node: int = 8 + partition: str = "hopper-prod" + hf_cache: str = "~/.cache/huggingface" + cpus_per_task: int = 88 + qos: str = "high" + time: str = "24:00:00" + reservation: Optional[str] = "smollm" + + def __post_init__(self): + self.hf_cache = str(Path(self.hf_cache).expanduser()) + + @dataclass class LightEvalConfig: """Arguments related to running LightEval on checkpoints. @@ -81,9 +97,7 @@ class LightEvalConfig: the saved config when running LightEval after training. """ - slurm_template: Optional[str] = None slurm_script_dir: Optional[str] = None - checkpoints_path: Optional[str] = None parallelism: Optional[ParallelismArgs] = None batch_size: Optional[int] = None @@ -91,7 +105,11 @@ class LightEvalConfig: tasks: Optional[LightEvalTasksArgs] = None logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None + slurm: Optional[LightEvalSlurm] = None + eval_config_override: str = "smollm3_eval.yaml" # Previously hardcoded in run_slurm_one_job def __post_init__(self): if self.parallelism is None: self.parallelism = ParallelismArgs(dp=1, pp=1, tp=1, tp_linear_async_communication=True) + if self.slurm is None: + self.slurm = LightEvalSlurm() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 03b5ba71d..dd575e399 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -146,10 +146,10 @@ class Qwen2Config: no_rope_layer: Optional[ int ] = None # Skip rope every no_rope_layer layers (see https://arxiv.org/abs/2501.18795 https://arxiv.org/abs/2305.19466 and Llama4) - _fused_rotary_emb: bool = True - _fused_rms_norm: bool = True - _use_qkv_packed: bool = True - _use_doc_masking: bool = True + _fused_rotary_emb: bool = False + _fused_rms_norm: bool = False + _use_qkv_packed: bool = False + _use_doc_masking: bool = False # MoE configuration moe_config: Optional[MoEConfig] = None diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index d30c50a24..3327bded6 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -23,7 +23,6 @@ def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: """Run light evaluation on uploaded files.""" - logger.warning(f"Lighteval Runner got {len(uploaded_files)} files. Checking configs.") config_files = [ f for f in uploaded_files if "config.py" in f["destination"] or "config.yaml" in f["destination"] ] @@ -47,6 +46,9 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: ) return checkpoint_path = config_files[0]["destination"].replace("config.yaml", "") + logger.warning( + f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path." + ) slurm_job_id, slurm_log = run_slurm_one_job( config=self.config, @@ -65,28 +67,16 @@ def run_slurm_one_job( current_step: int, ): """Launch a single job on Slurm with the given mapping""" - # Default evaluation config - default_slurm_config = { - "gpus_per_node": 8, - "partition": "hopper-prod", - "hf_cache": "/fsx/nouamane/.cache/huggingface", - "cpus_per_task": 88, - "qos": "high", - "time": "24:00:00", - "reservation": "smollm", - } + # Use config values instead of hardcoded defaults + slurm_config = lighteval_config.slurm # Use lighteval config paths if available, otherwise use defaults eval_launch_script_path = os.path.join( - lighteval_config.slurm_script_dir - if lighteval_config.slurm_script_dir - else "/fsx/nouamane/projects/nanotron/eval_results/launch-config", + lighteval_config.slurm_script_dir if lighteval_config.slurm_script_dir else "eval_results/launch-config", str(current_step), ) eval_logs_path = os.path.join( - lighteval_config.checkpoints_path - if lighteval_config.checkpoints_path - else "/fsx/nouamane/projects/nanotron/eval_results/logs", + lighteval_config.checkpoints_path if lighteval_config.checkpoints_path else "eval_results/logs", str(current_step), ) @@ -98,8 +88,7 @@ def run_slurm_one_job( total_gpus_needed = ( lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp ) - gpus_per_node = default_slurm_config["gpus_per_node"] - nodes = math.ceil(total_gpus_needed / gpus_per_node) + nodes = math.ceil(total_gpus_needed / slurm_config.gpus_per_node) # Get timestamp for log files timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -112,18 +101,18 @@ def run_slurm_one_job( # Create the SLURM script content slurm_script = f"""#!/bin/bash #SBATCH --job-name={run_name} -#SBATCH --partition={default_slurm_config["partition"]} +#SBATCH --partition={slurm_config.partition} #SBATCH --nodes={nodes} #SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task={default_slurm_config["cpus_per_task"]} -#SBATCH --gpus={gpus_per_node} +#SBATCH --cpus-per-task={slurm_config.cpus_per_task} +#SBATCH --gpus={slurm_config.gpus_per_node} #SBATCH --exclusive -#SBATCH --qos={default_slurm_config["qos"]} -#SBATCH --time={default_slurm_config["time"]} +#SBATCH --qos={slurm_config.qos} +#SBATCH --time={slurm_config.time} #SBATCH --output={logs_path}/{timestamp}-%x-%j.out""" - if default_slurm_config.get("reservation"): - slurm_script += f"\n#SBATCH --reservation={default_slurm_config['reservation']}" + if slurm_config.reservation: + slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" # Add the rest of the script content local_path = os.path.join("/tmp", f"eval_{config.general.run}", str(current_step)) @@ -161,10 +150,10 @@ def run_slurm_one_job( # export CUBLAS_WORKSPACE_CONFIG=":4096:8" # Set HuggingFace cache locations -export HUGGINGFACE_HUB_CACHE={default_slurm_config["hf_cache"]} -export HF_DATASETS_CACHE={default_slurm_config["hf_cache"]} -export HF_MODULES_CACHE={default_slurm_config["hf_cache"]} -export HF_HOME={default_slurm_config["hf_cache"]} +export HUGGINGFACE_HUB_CACHE={slurm_config.hf_cache} +export HF_DATASETS_CACHE={slurm_config.hf_cache} +export HF_MODULES_CACHE={slurm_config.hf_cache} +export HF_HOME={slurm_config.hf_cache} echo "Running on $COUNT_NODE nodes: $HOSTNAMES" @@ -198,30 +187,14 @@ def run_slurm_one_job( # sleep $(( RANDOM % 300 )) CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \\ - --nproc_per_node {gpus_per_node} \\ + --nproc_per_node {slurm_config.gpus_per_node} \\ --nnodes $COUNT_NODE \\ --node_rank $SLURM_PROCID \\ --master_addr $MASTER_ADDR \\ --master_port $MASTER_PORT \\ run_evals.py \\ --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ - --lighteval-override smollm3_eval.yaml""" - - # if lighteval_config.batch_size: - # slurm_script += f" \\\n --batch-size {lighteval_config.batch_size}" - - # if lighteval_config.tasks: - # slurm_script += """ - # if [ -n "${TASKS}" ]; then - # CMD="$CMD --tasks ${TASKS}" - # fi - # if [ -n "${CUSTOM_TASKS}" ]; then - # CMD="$CMD --custom-tasks ${CUSTOM_TASKS}" - # fi - # if [ -n "${MAX_SAMPLES}" ]; then - # CMD="$CMD --max-samples ${MAX_SAMPLES}" - # fi""" - + --lighteval-override {lighteval_config.eval_config_override}""" slurm_script += """ echo "END TIME: $(date)" @@ -272,10 +245,10 @@ def run_slurm_one_job( from nanotron.config.config import Config # Load existing config from checkpoint - # checkpoint_path = "/fsx/nouamane/projects/nanotron/checkpoints/smollm3-training-test-tps-48nn-seed-6-/10" + # checkpoint_path = "checkpoints/smollm3-training-test-tps-48nn-seed-6-/10" # config_path = os.path.join(checkpoint_path, "config.yaml") checkpoint_path = "s3://smollm3/smollm3-3B-final/3B-final-GQA-noTP-2k-seq/20000/" - config_path = "/fsx/nouamane/projects/nanotron/checkpoints/smollm3-training-test-tps-48nn-seed-6-/10/config.yaml" + config_path = "checkpoints/smollm3-training-test-tps-48nn-seed-6-/10/config.yaml" try: # Load the existing config print(f"\nLoading config from: {config_path}") From abe75af3aa80cc814a1613723323bec1f2a8f7a7 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 16:50:17 +0000 Subject: [PATCH 13/31] amend --- src/nanotron/config/lighteval_config.py | 2 + src/nanotron/eval/one_job_runner.py | 78 ++++++++++++++++++++----- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 2811c83f7..98cf2d025 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -99,6 +99,7 @@ class LightEvalConfig: slurm_script_dir: Optional[str] = None checkpoints_path: Optional[str] = None + local_checkpoint_dir: str = "/scratch" # Base directory for temporary checkpoint storage, will store under {local_checkpoint_dir}/{run_name}/{step} parallelism: Optional[ParallelismArgs] = None batch_size: Optional[int] = None generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None @@ -113,3 +114,4 @@ def __post_init__(self): self.parallelism = ParallelismArgs(dp=1, pp=1, tp=1, tp_linear_async_communication=True) if self.slurm is None: self.slurm = LightEvalSlurm() + self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser()) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 3327bded6..d6d603e4a 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -6,8 +6,11 @@ import subprocess from typing import List, Optional, Tuple +from datasets.download.streaming_download_manager import xPath + from nanotron import logging from nanotron.config import Config, LightEvalConfig +from nanotron.data.s3_utils import _get_s3_path_components from nanotron.logging import log_rank from nanotron.parallel import ParallelContext @@ -60,6 +63,15 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: return slurm_job_id, slurm_log +def normalize_s3_path(path: str) -> str: + """Normalize S3 path using existing s3_utils""" + # Use existing utility to normalize path components + path = xPath(path) + bucket, prefix = _get_s3_path_components(path) + # Reconstruct normalized path + return f"s3://{bucket}/{prefix}".rstrip("/") + + def run_slurm_one_job( config: Config, lighteval_config: LightEvalConfig, @@ -67,6 +79,11 @@ def run_slurm_one_job( current_step: int, ): """Launch a single job on Slurm with the given mapping""" + # Normalize S3 path if needed + if model_checkpoint_path.startswith(("s3:/", "s3://")): + model_checkpoint_path = normalize_s3_path(model_checkpoint_path) + logger.info(f"Normalized S3 path: {model_checkpoint_path}") + # Use config values instead of hardcoded defaults slurm_config = lighteval_config.slurm @@ -98,6 +115,9 @@ def run_slurm_one_job( logs_path = os.path.join(eval_logs_path, run_name) os.makedirs(logs_path, exist_ok=True) + # Use configured local path instead of hardcoded /tmp + local_path = os.path.join(lighteval_config.local_checkpoint_dir, f"eval_{config.general.run}", str(current_step)) + # Create the SLURM script content slurm_script = f"""#!/bin/bash #SBATCH --job-name={run_name} @@ -115,8 +135,6 @@ def run_slurm_one_job( slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" # Add the rest of the script content - local_path = os.path.join("/tmp", f"eval_{config.general.run}", str(current_step)) - slurm_script += f""" set -x -e @@ -163,21 +181,55 @@ def run_slurm_one_job( # Handle S3 paths if [[ "{model_checkpoint_path}" == s3://* ]]; then echo "Downloading checkpoint from S3: {model_checkpoint_path}" - s5cmd sync \ - --concurrency=50 \ - --size-only \ - --exclude "optimizer/*" \ - --exclude "random/*" \ - --exclude "lr_scheduler/*" \ - --part-size 100 \ + + # First check if the S3 path exists + if ! s5cmd ls "{model_checkpoint_path}" &>/dev/null; then + echo "Error: S3 path {model_checkpoint_path} does not exist" + exit 1 + fi + + # Try sync command and check its exit status + s5cmd sync \\ + --concurrency=50 \\ + --size-only \\ + --exclude "optimizer/*" \\ + --exclude "random/*" \\ + --exclude "lr_scheduler/*" \\ + --part-size 100 \\ "{model_checkpoint_path}/*" "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/" + + if [ $? -ne 0 ]; then + echo "Error: Failed to sync files from S3" + exit 1 + fi + + # Verify that config.yaml was downloaded + if [ ! -f "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml" ]; then + echo "Error: config.yaml not found in downloaded checkpoint" + echo "Contents of $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER:" + ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + exit 1 + fi else echo "Copying checkpoint files from {model_checkpoint_path} to $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" - rsync -av --progress --inplace --no-whole-file \ - --exclude 'optimizer/' \ - --exclude 'random/' \ - --exclude 'lr_scheduler/' \ + rsync -av --progress --inplace --no-whole-file \\ + --exclude 'optimizer/' \\ + --exclude 'random/' \\ + --exclude 'lr_scheduler/' \\ {model_checkpoint_path} $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + + if [ $? -ne 0 ]; then + echo "Error: Failed to copy files using rsync" + exit 1 + fi + + # Verify that config.yaml was copied + if [ ! -f "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml" ]; then + echo "Error: config.yaml not found in copied checkpoint" + echo "Contents of $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER:" + ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + exit 1 + fi fi echo "Contents of checkpoint directory:" From bd50c667e98ff26803ede776883443f2269e9151 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 18:04:48 +0000 Subject: [PATCH 14/31] . --- src/nanotron/eval/one_job_runner.py | 47 +++++++++++++++-------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index d6d603e4a..316c6c953 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -41,7 +41,7 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: return if len(config_files) > 1: log_rank( - "Found multiple config files in uploaded checkpoints.", + f"Found multiple config files in uploaded checkpoints: {config_files}", logger=logger, level=logging.ERROR, group=self.parallel_context.dp_pg if self.parallel_context is not None else None, @@ -87,20 +87,6 @@ def run_slurm_one_job( # Use config values instead of hardcoded defaults slurm_config = lighteval_config.slurm - # Use lighteval config paths if available, otherwise use defaults - eval_launch_script_path = os.path.join( - lighteval_config.slurm_script_dir if lighteval_config.slurm_script_dir else "eval_results/launch-config", - str(current_step), - ) - eval_logs_path = os.path.join( - lighteval_config.checkpoints_path if lighteval_config.checkpoints_path else "eval_results/logs", - str(current_step), - ) - - # Create directories - os.makedirs(eval_launch_script_path, exist_ok=True) - os.makedirs(eval_logs_path, exist_ok=True) - # Calculate the number of nodes based on parallelism config total_gpus_needed = ( lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp @@ -109,18 +95,30 @@ def run_slurm_one_job( # Get timestamp for log files timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - run_name = f"eval_{current_step}".replace(" ", "_") + run_name = f"{timestamp}-eval_{config.general.run}".replace(" ", "_") + + # Use lighteval config paths if available, otherwise use defaults + eval_launch_script_path = ( + lighteval_config.slurm_script_dir if lighteval_config.slurm_script_dir else "eval_results/launch-config" + ) + eval_logs_path = lighteval_config.checkpoints_path if lighteval_config.checkpoints_path else "eval_results/logs" + eval_launch_script_path = os.path.join(eval_launch_script_path, run_name) + eval_logs_path = os.path.join(eval_logs_path, run_name) + + # Create directories + os.makedirs(eval_launch_script_path, exist_ok=True) + os.makedirs(eval_logs_path, exist_ok=True) # Create log directory with run name subdirectory logs_path = os.path.join(eval_logs_path, run_name) os.makedirs(logs_path, exist_ok=True) # Use configured local path instead of hardcoded /tmp - local_path = os.path.join(lighteval_config.local_checkpoint_dir, f"eval_{config.general.run}", str(current_step)) + local_path = os.path.join(lighteval_config.local_checkpoint_dir, run_name, str(current_step)) # Create the SLURM script content slurm_script = f"""#!/bin/bash -#SBATCH --job-name={run_name} +#SBATCH --job-name=eval_{current_step}_{run_name} #SBATCH --partition={slurm_config.partition} #SBATCH --nodes={nodes} #SBATCH --ntasks-per-node=1 @@ -134,10 +132,10 @@ def run_slurm_one_job( if slurm_config.reservation: slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" - # Add the rest of the script content + # Rest of the script content slurm_script += f""" -set -x -e +set -x LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={local_path} @@ -246,15 +244,20 @@ def run_slurm_one_job( --master_port $MASTER_PORT \\ run_evals.py \\ --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ - --lighteval-override {lighteval_config.eval_config_override}""" + --lighteval-override {lighteval_config.eval_config_override} + --cache-dir {slurm_config.hf_cache}""" slurm_script += """ +echo "Cleaning up downloaded checkpoints..." +rm -rf "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" +echo "Cleanup completed" + echo "END TIME: $(date)" """ # Write the script to file current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}.slurm") + launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}-.slurm") os.makedirs(os.path.dirname(launch_script_path), exist_ok=True) with open(launch_script_path, "w") as f: From 2227432c8038749979b0e857a30ac7519506ea1c Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Thu, 17 Apr 2025 20:20:38 +0200 Subject: [PATCH 15/31] qos to low --- src/nanotron/config/lighteval_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 98cf2d025..4ecd113bd 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -81,7 +81,7 @@ class LightEvalSlurm: partition: str = "hopper-prod" hf_cache: str = "~/.cache/huggingface" cpus_per_task: int = 88 - qos: str = "high" + qos: str = "low" time: str = "24:00:00" reservation: Optional[str] = "smollm" From b62cacd3681ea9438d67692e6dd0de039593f184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 17 Apr 2025 18:52:43 +0000 Subject: [PATCH 16/31] add nanotron_path --- src/nanotron/config/lighteval_config.py | 1 + src/nanotron/eval/one_job_runner.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 4ecd113bd..e35f29c61 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -107,6 +107,7 @@ class LightEvalConfig: logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None slurm: Optional[LightEvalSlurm] = None + nanotron_path: Optional[str] = "./" eval_config_override: str = "smollm3_eval.yaml" # Previously hardcoded in run_slurm_one_job def __post_init__(self): diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 316c6c953..b01c98ef5 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -115,7 +115,7 @@ def run_slurm_one_job( # Use configured local path instead of hardcoded /tmp local_path = os.path.join(lighteval_config.local_checkpoint_dir, run_name, str(current_step)) - + nanotron_path = lighteval_config.nanotron_path # Create the SLURM script content slurm_script = f"""#!/bin/bash #SBATCH --job-name=eval_{current_step}_{run_name} @@ -242,7 +242,7 @@ def run_slurm_one_job( --node_rank $SLURM_PROCID \\ --master_addr $MASTER_ADDR \\ --master_port $MASTER_PORT \\ - run_evals.py \\ + {nanotron_path}/run_evals.py \\ --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ --lighteval-override {lighteval_config.eval_config_override} --cache-dir {slurm_config.hf_cache}""" From 802fad67afd5122179646cbf30d2c0a5cccf14e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 17 Apr 2025 19:26:26 +0000 Subject: [PATCH 17/31] some fix: logs, and config --- src/nanotron/config/lighteval_config.py | 4 ++-- src/nanotron/eval/one_job_runner.py | 23 +++++++++-------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index e35f29c61..3e919976f 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -97,8 +97,8 @@ class LightEvalConfig: the saved config when running LightEval after training. """ - slurm_script_dir: Optional[str] = None - checkpoints_path: Optional[str] = None + slurm_script_dir: Optional[str] = "eval_results/launch-config" + logs_path: Optional[str] = "eval_results/logs" local_checkpoint_dir: str = "/scratch" # Base directory for temporary checkpoint storage, will store under {local_checkpoint_dir}/{run_name}/{step} parallelism: Optional[ParallelismArgs] = None batch_size: Optional[int] = None diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index b01c98ef5..12442fa98 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -95,24 +95,19 @@ def run_slurm_one_job( # Get timestamp for log files timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - run_name = f"{timestamp}-eval_{config.general.run}".replace(" ", "_") + general_run_name = config.general.run + run_name = f"{timestamp}-eval_{general_run_name}".replace(" ", "_") # Use lighteval config paths if available, otherwise use defaults - eval_launch_script_path = ( - lighteval_config.slurm_script_dir if lighteval_config.slurm_script_dir else "eval_results/launch-config" - ) - eval_logs_path = lighteval_config.checkpoints_path if lighteval_config.checkpoints_path else "eval_results/logs" - eval_launch_script_path = os.path.join(eval_launch_script_path, run_name) - eval_logs_path = os.path.join(eval_logs_path, run_name) + eval_launch_script_path = lighteval_config.slurm_script_dir + eval_logs_path = lighteval_config.logs_path + eval_launch_script_path = os.path.join(eval_launch_script_path, general_run_name, f"step-{current_step}") + eval_logs_path = os.path.join(eval_logs_path, general_run_name, f"step-{current_step}") # Create directories os.makedirs(eval_launch_script_path, exist_ok=True) os.makedirs(eval_logs_path, exist_ok=True) - # Create log directory with run name subdirectory - logs_path = os.path.join(eval_logs_path, run_name) - os.makedirs(logs_path, exist_ok=True) - # Use configured local path instead of hardcoded /tmp local_path = os.path.join(lighteval_config.local_checkpoint_dir, run_name, str(current_step)) nanotron_path = lighteval_config.nanotron_path @@ -127,7 +122,7 @@ def run_slurm_one_job( #SBATCH --exclusive #SBATCH --qos={slurm_config.qos} #SBATCH --time={slurm_config.time} -#SBATCH --output={logs_path}/{timestamp}-%x-%j.out""" +#SBATCH --output={eval_logs_path}/%j.out""" if slurm_config.reservation: slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" @@ -257,7 +252,7 @@ def run_slurm_one_job( # Write the script to file current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}-.slurm") + launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}.slurm") os.makedirs(os.path.dirname(launch_script_path), exist_ok=True) with open(launch_script_path, "w") as f: @@ -276,7 +271,7 @@ def run_slurm_one_job( output = result.stdout job_ids = output.split()[-1] - output_log = os.path.join(logs_path, f"{timestamp}-{run_name}-{job_ids}.out") + output_log = os.path.join(eval_logs_path, f"{timestamp}-{run_name}-{job_ids}.out") logger.warning( f"""🚀 Slurm job launched successfully: From 895354a8a330072b0886bd092be9c3534a37468e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 17 Apr 2025 19:47:00 +0000 Subject: [PATCH 18/31] cp instead of sync --- src/nanotron/eval/one_job_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 12442fa98..13f773d29 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -122,7 +122,7 @@ def run_slurm_one_job( #SBATCH --exclusive #SBATCH --qos={slurm_config.qos} #SBATCH --time={slurm_config.time} -#SBATCH --output={eval_logs_path}/%j.out""" +#SBATCH --output={eval_logs_path}/%j-{timestamp}.out""" if slurm_config.reservation: slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" @@ -182,9 +182,8 @@ def run_slurm_one_job( fi # Try sync command and check its exit status - s5cmd sync \\ + s5cmd cp \\ --concurrency=50 \\ - --size-only \\ --exclude "optimizer/*" \\ --exclude "random/*" \\ --exclude "lr_scheduler/*" \\ From 55a5d3e66fa64433cdd0c02c3f674da09c42c5d4 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 20:13:29 +0000 Subject: [PATCH 19/31] eval_interval --- src/nanotron/config/config.py | 7 +++++ src/nanotron/config/lighteval_config.py | 22 ++++++++++--- src/nanotron/trainer.py | 41 +++++++++++++++++++++++-- 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 02c1067bc..e58e0174c 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -461,6 +461,13 @@ def __post_init__(self): if self.s3_upload is not None: self.s3_upload.__post_init__() + if self.lighteval is not None: + if self.lighteval.eval_interval is None: + self.lighteval.eval_interval = self.checkpoints.checkpoint_interval + else: + assert ( + self.lighteval.eval_interval % self.checkpoints.checkpoint_interval == 0 + ), f"eval_interval={self.lighteval.eval_interval} must be a multiple of checkpoint_interval={self.checkpoints.checkpoint_interval}" # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 3e919976f..a571380aa 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -97,9 +97,11 @@ class LightEvalConfig: the saved config when running LightEval after training. """ - slurm_script_dir: Optional[str] = "eval_results/launch-config" - logs_path: Optional[str] = "eval_results/logs" - local_checkpoint_dir: str = "/scratch" # Base directory for temporary checkpoint storage, will store under {local_checkpoint_dir}/{run_name}/{step} + slurm_script_dir: Optional[Path] = Path("eval_results/launch-config") + logs_path: Optional[Path] = Path("eval_results/logs") + local_checkpoint_dir: Path = Path( + "/scratch" + ) # Base directory for temporary checkpoint storage, will store under {local_checkpoint_dir}/{run_name}/{step} parallelism: Optional[ParallelismArgs] = None batch_size: Optional[int] = None generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None @@ -107,8 +109,14 @@ class LightEvalConfig: logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None slurm: Optional[LightEvalSlurm] = None - nanotron_path: Optional[str] = "./" - eval_config_override: str = "smollm3_eval.yaml" # Previously hardcoded in run_slurm_one_job + nanotron_path: Optional[Path] = "./" + eval_config_override: Path = Path("smollm3_eval.yaml") # Previously hardcoded in run_slurm_one_job + eval_interval: Optional[ + int + ] = None # Must be multiple of checkpoint_interval. If None, eval will be done after each checkpoint upload to s3 + eval_interval_file: Optional[ + Path + ] = None # If specified, eval_interval will be read from this file upon the next evaluation. def __post_init__(self): if self.parallelism is None: @@ -116,3 +124,7 @@ def __post_init__(self): if self.slurm is None: self.slurm = LightEvalSlurm() self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser()) + if self.eval_interval_file is not None and Path(self.eval_interval_file).exists(): + logger.warning( + f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want." + ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 743627835..bc72a5db3 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -1174,6 +1174,39 @@ def setup_log_writers( return loggerwriter def pre_save_checkpoint(self) -> Path: + # Check if eval_interval should be updated from file + eval_interval_file = self.config.lighteval.eval_interval_file + if eval_interval_file is not None and Path(eval_interval_file).exists(): + try: + with open(eval_interval_file, "r") as f: + new_eval_interval = int(f.read().strip()) + + # Verify that the new interval is a multiple of checkpoint_interval + if new_eval_interval == self.config.lighteval.eval_interval: + pass + elif new_eval_interval % self.config.checkpoints.checkpoint_interval == 0: + log_rank( + f"Updating lighteval.eval_interval from {self.config.lighteval.eval_interval} to {new_eval_interval}", + logger=logger, + level=logging.INFO, + rank=0, + ) + self.config.lighteval.eval_interval = new_eval_interval + else: + log_rank( + f"New eval_interval={new_eval_interval} must be a multiple of checkpoint_interval={self.config.checkpoints.checkpoint_interval}. Keeping current value: {self.config.lighteval.eval_interval}", + logger=logger, + level=logging.WARNING, + rank=0, + ) + except (ValueError, IOError) as e: + log_rank( + f"Error reading eval_interval from file: {e}. Keeping current value: {self.config.lighteval.eval_interval}", + logger=logger, + level=logging.WARNING, + rank=0, + ) + if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: @@ -1192,8 +1225,12 @@ def post_save_checkpoint(self): if dist.get_rank(self.parallel_context.world_pg) == 0: if self.config.lighteval is not None and self.s3_mover is None: - checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" - self.lighteval_runner.eval_single_checkpoint(checkpoint_path) + if ( + self.config.lighteval.eval_interval is None + or self.iteration_step % self.config.lighteval.eval_interval == 0 + ): + checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" + self.lighteval_runner.eval_single_checkpoint(checkpoint_path) def save_checkpoint(self) -> Path: self.pre_save_checkpoint() From 298492e0e8785b44c1429fccdb9306be40d8acdc Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 20:18:15 +0000 Subject: [PATCH 20/31] serialize sanity checks --- src/nanotron/config/config.py | 5 +++-- src/nanotron/serialize/main.py | 2 +- src/nanotron/trainer.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index e58e0174c..c16f076c1 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -550,14 +550,15 @@ def global_batch_size(self): def global_batch_size_in_tokens(self): return self.global_batch_size * self.tokens.sequence_length - def save_as_yaml(self, file_path: str): + def save_as_yaml(self, file_path: str, sanity_checks: bool = True): config_dict = serialize(self) file_path = str(file_path) with open(file_path, "w") as f: yaml.dump(config_dict, f) # Sanity test config can be reloaded - _ = get_config_from_file(file_path, config_class=self.__class__) + if sanity_checks: + _ = get_config_from_file(file_path, config_class=self.__class__) def get_yaml(self): config_dict = serialize(self) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index b1445b481..2b5d45585 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -64,7 +64,7 @@ def save( try: if should_save_config: - config.save_as_yaml(root_folder / "config.yaml") + config.save_as_yaml(root_folder / "config.yaml", sanity_checks=sanity_checks) except Exception as e: # TODO @nouamane: catch full disk error log_rank( diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index bc72a5db3..00c26943b 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -1267,6 +1267,7 @@ def save_checkpoint(self) -> Path: root_folder=checkpoint_path, training_metadata=self.metadata, config=self.config, + sanity_checks=not self.config.general.ignore_sanity_checks, ) save_random_states( random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path From 4219ec8b42a6926b9e2f134af2ff6c6063cbee04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 17 Apr 2025 20:20:37 +0000 Subject: [PATCH 21/31] add output dir and s3_save path in the config --- src/nanotron/config/lighteval_config.py | 4 +++- src/nanotron/eval/one_job_runner.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 3e919976f..36f0ed551 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -107,8 +107,10 @@ class LightEvalConfig: logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None slurm: Optional[LightEvalSlurm] = None + s3_save_path: Optional[str] = None # should not be dependent of the run_name + output_dir: Optional[str] = None # we should sanity check that it's the same as the one in the eval_config_override nanotron_path: Optional[str] = "./" - eval_config_override: str = "smollm3_eval.yaml" # Previously hardcoded in run_slurm_one_job + eval_config_override: str = None def __post_init__(self): if self.parallelism is None: diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 13f773d29..2eb524e3c 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -240,8 +240,9 @@ def run_slurm_one_job( --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ --lighteval-override {lighteval_config.eval_config_override} --cache-dir {slurm_config.hf_cache}""" - slurm_script += """ + slurm_script += f""" +s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path} echo "Cleaning up downloaded checkpoints..." rm -rf "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" echo "Cleanup completed" From 016760e1541a9061d94ff426eab72a3612c69bc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 17 Apr 2025 20:23:42 +0000 Subject: [PATCH 22/31] fix s3 only if define --- src/nanotron/eval/one_job_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 2eb524e3c..4030b43ee 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -240,9 +240,11 @@ def run_slurm_one_job( --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ --lighteval-override {lighteval_config.eval_config_override} --cache-dir {slurm_config.hf_cache}""" - slurm_script += f""" - + if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None: + slurm_script += f""" s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path} +""" + slurm_script += """ echo "Cleaning up downloaded checkpoints..." rm -rf "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" echo "Cleanup completed" From 85138cabcefef40d13d62ad3ea8867098f797db1 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 17 Apr 2025 20:28:58 +0000 Subject: [PATCH 23/31] fixes --- src/nanotron/eval/one_job_runner.py | 8 ++++++++ src/nanotron/s3_checkpoints/s3_mover.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 13f773d29..51cf5351d 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -26,6 +26,14 @@ def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: """Run light evaluation on uploaded files.""" + if ( + self.config.lighteval.eval_interval is not None + and self.config.general.step % self.config.lighteval.eval_interval != 0 + ): + logger.debug( + f"Skipping evaluation at step {self.config.general.step} because eval_interval is {self.config.lighteval.eval_interval}" + ) + return config_files = [ f for f in uploaded_files if "config.py" in f["destination"] or "config.yaml" in f["destination"] ] diff --git a/src/nanotron/s3_checkpoints/s3_mover.py b/src/nanotron/s3_checkpoints/s3_mover.py index 483019d5a..32842a909 100644 --- a/src/nanotron/s3_checkpoints/s3_mover.py +++ b/src/nanotron/s3_checkpoints/s3_mover.py @@ -225,7 +225,7 @@ def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None): dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False) dist.barrier() all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list) - time.sleep(1) + time.sleep(1) # TODO @nouamane: make this configurable def is_previous_save_finished(self) -> bool: """Return True if a potential previous checkpoint has been fully uploaded to S3 From fefb560c4f7e87d3e1d6ff50b9243f8b49842dd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 17 Apr 2025 20:47:04 +0000 Subject: [PATCH 24/31] add requeue --- src/nanotron/eval/one_job_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 4030b43ee..884a44505 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -122,7 +122,8 @@ def run_slurm_one_job( #SBATCH --exclusive #SBATCH --qos={slurm_config.qos} #SBATCH --time={slurm_config.time} -#SBATCH --output={eval_logs_path}/%j-{timestamp}.out""" +#SBATCH --output={eval_logs_path}/%j-{timestamp}.out +#SBATCH --requeue""" if slurm_config.reservation: slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" From 45580361e6b6f44df210792dbe4024e810abf8d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Fri, 18 Apr 2025 23:16:02 +0000 Subject: [PATCH 25/31] add wandb with lighteval and fix eval interval --- src/nanotron/config/lighteval_config.py | 7 ++ src/nanotron/eval/one_job_runner.py | 30 ++++++--- src/nanotron/eval/upload_to_wandb.py | 87 +++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 src/nanotron/eval/upload_to_wandb.py diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 363ee9887..1f0ffc650 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -110,6 +110,9 @@ class LightEvalConfig: wandb: Optional[LightEvalWandbLoggerConfig] = None slurm: Optional[LightEvalSlurm] = None s3_save_path: Optional[str] = None # should not be dependent of the run_name + upload_to_wandb: Optional[bool] = False + wandb_project: Optional[str] = None + wandb_entity: Optional[str] = None output_dir: Optional[str] = None # we should sanity check that it's the same as the one in the eval_config_override nanotron_path: Optional[str] = "./" eval_config_override: str = None @@ -127,6 +130,10 @@ def __post_init__(self): if self.slurm is None: self.slurm = LightEvalSlurm() self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser()) + if self.upload_to_wandb: + assert self.s3_save_path is not None, " We should have a s3_save_path if we want to upload to wandb" # todo: add the option to read from local folder i guess + assert self.wandb_project is not None, "wandb_project must be specified if upload_to_wandb is True" + assert self.wandb_entity is not None, "wandb_entity must be specified if upload_to_wandb is True" if self.eval_interval_file is not None and Path(self.eval_interval_file).exists(): logger.warning( f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want." diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 884a44505..9f1f1f5a5 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -52,13 +52,16 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: logger.warning( f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path." ) - - slurm_job_id, slurm_log = run_slurm_one_job( - config=self.config, - lighteval_config=self.lighteval_config, - model_checkpoint_path=checkpoint_path, - current_step=self.config.general.step, - ) + if self.config.general.step % self.lighteval_config.eval_interval == 0: + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + else: + logger.warning(f"Skipping evaluation at step {self.config.general.step} because it's not a multiple of {self.lighteval_config.eval_interval}") + return None, None return slurm_job_id, slurm_log @@ -243,7 +246,18 @@ def run_slurm_one_job( --cache-dir {slurm_config.hf_cache}""" if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None: slurm_script += f""" -s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path} +s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}/ +""" + if lighteval_config.upload_to_wandb: + gbs_tok = config.parallelism.dp * config.tokens.micro_batch_size * config.tokens.sequence_length * config.tokens.batch_accumulation_per_replica + slurm_script += f""" +python {nanotron_path}/src/nanotron/eval/upload_to_wandb.py \\ + --wandb_project {lighteval_config.wandb_project} \\ + --wandb_entity {lighteval_config.wandb_entity} \\ + --model_name {general_run_name} \\ + --results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\ + --train_step {current_step} \\ + --consumed_tokens {current_step*gbs_tok} """ slurm_script += """ echo "Cleaning up downloaded checkpoints..." diff --git a/src/nanotron/eval/upload_to_wandb.py b/src/nanotron/eval/upload_to_wandb.py new file mode 100644 index 000000000..aa8c12d41 --- /dev/null +++ b/src/nanotron/eval/upload_to_wandb.py @@ -0,0 +1,87 @@ +import json +import s3fs +import wandb +import re +import argparse +from wandb.sdk.lib.runid import generate_id + + +def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens): + s3 = s3fs.S3FileSystem(anon=False) + all_metrics = { + # basic X axis replacements for all metrics + "consumed_tokens": consumed_tokens, + "train_step": train_step, + } + + for result_file in sorted(s3.ls(results_path)): + if not result_file.endswith(".json"): + continue + + with s3.open(result_file, "r") as f: + results = json.loads(f.read())["results"] + + for benchmark, metrics in results.items(): + if benchmark == "all": + continue + + # extract dataset and config name + match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark) + if match: + dataset, subtask = match.groups() + + for metric_name, metric_value in metrics.items(): + if "_stderr" in metric_name: + continue + # wandb-friendly metric name + wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}" + all_metrics[wandb_metric] = metric_value + + run_id = f"{model_name}-{generate_id()}" + + # try to find the run in wandb and resume it + api = wandb.Api() + runs = api.runs(f"{wandb_entity}/{wandb_project}") + for run in runs: + if run.name == model_name: + run_id = run.id + break + + wandb.init( + project=wandb_project, + entity=wandb_entity, + name=model_name, + id=run_id, + config={ + "model_name": model_name, + }, + resume="allow", + ) + + # log all metrics for this checkpoint + wandb.log(all_metrics) + + wandb.finish() + +if __name__ == "__main__": + # Setup argument parser + parser = argparse.ArgumentParser(description="Upload evaluation results to Weights & Biases.") + parser.add_argument("--wandb_project", type=str, required=True, help="WandB project name.") + parser.add_argument("--wandb_entity", type=str, required=True, help="WandB entity name.") + parser.add_argument("--model_name", type=str, required=True, help="Name of the model.") + parser.add_argument("--results_path", type=str, required=True, help="S3 path to the results directory.") + parser.add_argument("--train_step", type=int, required=True, help="Training step corresponding to the checkpoint.") + parser.add_argument("--consumed_tokens", type=int, required=True, help="Total consumed tokens up to this checkpoint.") + + # Parse arguments + args = parser.parse_args() + + # Call the main function with parsed arguments + push_to_wandb( + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + model_name=args.model_name, + results_path=args.results_path, + train_step=args.train_step, + consumed_tokens=args.consumed_tokens + ) From b5ea942207ee6b344a967c1706e5f97bac185323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sun, 20 Apr 2025 00:27:46 +0000 Subject: [PATCH 26/31] fix this little space :( --- src/nanotron/eval/one_job_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index b0c70ea82..8739a31a9 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -263,7 +263,7 @@ def run_slurm_one_job( --wandb_project {lighteval_config.wandb_project} \\ --wandb_entity {lighteval_config.wandb_entity} \\ --model_name {general_run_name} \\ - --results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\ + --results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\ --train_step {current_step} \\ --consumed_tokens {current_step*gbs_tok} """ From 561ca6b71deaa42f02f634f1ef5b1e37202886b1 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Wed, 23 Apr 2025 19:37:56 +0000 Subject: [PATCH 27/31] folder_path should always have s3 when using s3 (fix consumed tokens issue) --- src/nanotron/data/nemo_dataset/blendable_dataset.py | 3 +++ src/nanotron/data/tokenized_bytes.py | 7 ++++--- src/nanotron/trainer.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/nanotron/data/nemo_dataset/blendable_dataset.py b/src/nanotron/data/nemo_dataset/blendable_dataset.py index e4e999169..8e59f7e1c 100644 --- a/src/nanotron/data/nemo_dataset/blendable_dataset.py +++ b/src/nanotron/data/nemo_dataset/blendable_dataset.py @@ -181,6 +181,9 @@ def get_consumption_stats(self): """ stats = {} for dataset_idx, dataset in enumerate(self.datasets): + assert ( + "s3" in dataset.folder_path + ), "Only S3 paths are supported for consumption stats" # TODO: remove this stats[dataset.folder_path] = {"tokens": self.consumed_tokens[dataset_idx]} return stats diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index 4f9063eb6..880983bb6 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -369,12 +369,13 @@ def __init__( ) from datatrove.utils.dataset import url_to_fs - fs_folder, folder_path = url_to_fs(folder_path) + fs_folder, stripped_folder_path = url_to_fs(folder_path) matched_files = ( - fs_folder.find(folder_path, detail=False, maxdepth=1 if not recursive else None) + fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None) if not filename_pattern else fs_folder.glob( - os.path.join(folder_path, filename_pattern), maxdepth=1 if not recursive else None + os.path.join(stripped_folder_path, filename_pattern), + maxdepth=1 if not recursive else None, ) ) matched_files = sorted(matched_files) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 00c26943b..d86e397da 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -1175,7 +1175,7 @@ def setup_log_writers( def pre_save_checkpoint(self) -> Path: # Check if eval_interval should be updated from file - eval_interval_file = self.config.lighteval.eval_interval_file + eval_interval_file = self.config.lighteval.eval_interval_file if self.config.lighteval is not None else None if eval_interval_file is not None and Path(eval_interval_file).exists(): try: with open(eval_interval_file, "r") as f: From 7724cf1acf1a9c6ea4b170290a3cd584b4e1b02f Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 24 Apr 2025 12:14:02 +0000 Subject: [PATCH 28/31] config qwen --- examples/config_qwen.py | 16 ++++++++++------ examples/config_qwen.yaml | 24 ++++++++++++------------ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/examples/config_qwen.py b/examples/config_qwen.py index 8ca8487b3..1e8bc50db 100644 --- a/examples/config_qwen.py +++ b/examples/config_qwen.py @@ -30,7 +30,7 @@ "410m": (24, 1024, 16, 16, 4096), # ~410M params # Small to medium models "1b": (16, 2048, 16, 16, 5632), # ~1B params - "3b": (28, 2048, 16, 2, 11008), # ~3B params + "3b": (36, 2048, 16, 4, 11008), # ~3B params # Standard sizes "7b": (32, 4096, 32, 32, 11008), # ~7B params "13b": (40, 5120, 40, 40, 13824), # ~13B params @@ -47,7 +47,7 @@ def get_args(): parser.add_argument( "--model", choices=MODEL_SIZES.keys(), - default="custom", + default="3b", help="Model size to generate config for (e.g., 7b, 13b)", ) parser.add_argument( @@ -76,6 +76,10 @@ def get_args(): tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size") tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica") + # checkpoints + checkpoints_group = parser.add_argument_group("checkpoints") + checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval") + args = parser.parse_args() return args @@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config: is_qwen2_config=True, pad_token_id=None, _attn_implementation="flash_attention_2", - # sliding_window_size=20, + _use_doc_masking=True, ) @@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str: def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config: learning_rate = LRSchedulerArgs( - learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0 ) parallelism = ParallelismArgs( dp=args.dp, @@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config ) optimizer = OptimizerArgs( zero_stage=args.zero, - weight_decay=0.01, + weight_decay=0.1, clip_grad=1.0, accumulate_grad_in_fp32=True, learning_rate_scheduler=learning_rate, @@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config return Config( general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save), parallelism=parallelism, model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), # tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"), diff --git a/examples/config_qwen.yaml b/examples/config_qwen.yaml index a2ce9bd14..cf6f40fac 100644 --- a/examples/config_qwen.yaml +++ b/examples/config_qwen.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 10 + checkpoint_interval: 100000 checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false load_lr_scheduler: true @@ -30,9 +30,9 @@ data_stages: general: benchmark_csv_path: null consumed_train_samples: null - ignore_sanity_checks: false + ignore_sanity_checks: true project: debug - run: qwen_20250423_201000_16423158 + run: qwen_20250424_120835_16423158 seed: 42 step: null lighteval: null @@ -50,24 +50,24 @@ model: make_vocab_size_divisible_by: 1 model_config: _attn_implementation: flash_attention_2 - _fused_rms_norm: false - _fused_rotary_emb: false - _use_doc_masking: false - _use_qkv_packed: false + _fused_rms_norm: true + _fused_rotary_emb: true + _use_doc_masking: true + _use_qkv_packed: true attention_bias: false bos_token_id: 1 eos_token_id: 2 flex_attention_mask: null hidden_act: silu - hidden_size: 256 + hidden_size: 2048 initializer_range: 0.02 - intermediate_size: 768 + intermediate_size: 11008 is_qwen2_config: true max_position_embeddings: 4096 moe_config: null no_rope_layer: null - num_attention_heads: 4 - num_hidden_layers: 12 + num_attention_heads: 16 + num_hidden_layers: 36 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -108,7 +108,7 @@ parallelism: pp: 1 pp_engine: 1f1b recompute_layer: false - tp: 1 + tp: 2 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER tp_recompute_allgather: true From 46949b6c7f5e8f4ef99e166b9ebc29d5e56292d6 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Thu, 24 Apr 2025 12:14:11 +0000 Subject: [PATCH 29/31] . --- examples/config_qwen.py | 6 +++++- src/nanotron/data/nemo_dataset/blendable_dataset.py | 3 --- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/config_qwen.py b/examples/config_qwen.py index 1e8bc50db..a5d901b24 100644 --- a/examples/config_qwen.py +++ b/examples/config_qwen.py @@ -223,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config world_size = args.dp * args.tp * args.pp * args.cp if world_size <= 8: print( - f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" + f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" ) + print("You can also use environment variables for more debugging:") + print(" - ENABLE_TIMERS=1: Enable detailed timing information") + print(" - DEBUG_CPU=1: Log CPU and memory usage statistics") + print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection") else: print("Checkout slurm_launcher.py to launch a multi-node job") diff --git a/src/nanotron/data/nemo_dataset/blendable_dataset.py b/src/nanotron/data/nemo_dataset/blendable_dataset.py index 8e59f7e1c..e4e999169 100644 --- a/src/nanotron/data/nemo_dataset/blendable_dataset.py +++ b/src/nanotron/data/nemo_dataset/blendable_dataset.py @@ -181,9 +181,6 @@ def get_consumption_stats(self): """ stats = {} for dataset_idx, dataset in enumerate(self.datasets): - assert ( - "s3" in dataset.folder_path - ), "Only S3 paths are supported for consumption stats" # TODO: remove this stats[dataset.folder_path] = {"tokens": self.consumed_tokens[dataset_idx]} return stats From 10784044b492a0c345f552fd33492330d1aa58cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Thu, 8 May 2025 22:13:23 +0200 Subject: [PATCH 30/31] fix makefile, sync with datatrove, update lighteval config (#363) * fix makefile, sync with datatrove, update lighteval config * fix path --------- Co-authored-by: Hynek Kydlicek --- src/nanotron/config/config.py | 2 +- src/nanotron/config/lighteval_config.py | 7 +- src/nanotron/config/models_config.py | 2 +- src/nanotron/data/nanoset.py | 4 +- src/nanotron/data/nemo_dataset/Makefile | 3 +- src/nanotron/data/tokenized_bytes.py | 181 +++++++----------------- src/nanotron/eval/one_job_runner.py | 26 ++-- 7 files changed, 68 insertions(+), 157 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c16f076c1..a5e8e6eb5 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -632,7 +632,7 @@ def get_config_from_dict( InitScalingMethod: lambda x: InitScalingMethod[x.upper()], SamplerType: lambda x: SamplerType[x.upper()], }, - # strict_unions_match=True, + strict_unions_match=True, strict=True, ), ) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 0806acffe..72852831e 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -79,11 +79,11 @@ class LightEvalSlurm: gpus_per_node: int = 8 partition: str = "hopper-prod" - hf_cache: str = "~/.cache/huggingface" + hf_cache: Optional[str] = None cpus_per_task: int = 88 qos: str = "low" time: str = "24:00:00" - reservation: Optional[str] = "smollm" + reservation: Optional[str] = None def __post_init__(self): self.hf_cache = str(Path(self.hf_cache).expanduser()) @@ -117,8 +117,7 @@ class LightEvalConfig: str ] = None # we should sanity check that it's the same as the one in the eval_config_override nanotron_path: Optional[str] = "./" - eval_config_override: str = None - eval_config_override: Path = None # Previously hardcoded in run_slurm_one_job + lighteval_config_path: Path = None # Previously hardcoded in run_slurm_one_job eval_interval: Optional[ int ] = None # Must be multiple of checkpoint_interval. If None, eval will be done after each checkpoint upload to s3 diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 999d13379..8d0b6b005 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -279,4 +279,4 @@ def n_inner(self): return self.intermediate_size -NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Qwen2Config, Any] +NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Qwen2Config] diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index 026970ccb..057b2e129 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -65,14 +65,14 @@ def __init__( for dataset_folder in self.dataset_folders: self.datatrove_datasets.append( DatatroveFolderDataset( - folder_path=dataset_folder, + data_folder=dataset_folder, filename_pattern=os.path.join(dataset_folder, "*.ds"), seq_len=sequence_length, recursive=False, token_size=self.token_size, shuffle=True, return_positions=self.return_positions, # if set to True, the position ids are directly build datatrove - eos_token_id=self.eos_token_id, + positions_from_eos_token_id=self.eos_token_id, ) ) diff --git a/src/nanotron/data/nemo_dataset/Makefile b/src/nanotron/data/nemo_dataset/Makefile index 150939026..47805f6a9 100644 --- a/src/nanotron/data/nemo_dataset/Makefile +++ b/src/nanotron/data/nemo_dataset/Makefile @@ -15,7 +15,8 @@ CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color CPPFLAGS += $(shell python3 -m pybind11 --includes) LIBNAME = helpers -LIBEXT = $(shell python3-config --extension-suffix) +# Works with uv too (unlike python3-config --extension-suffix) +LIBEXT = $(shell python3 -c "import sysconfig; print(sysconfig.get_config_var('EXT_SUFFIX'))") default: $(LIBNAME)$(LIBEXT) diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index 880983bb6..f0c927b50 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -23,6 +23,8 @@ from nanotron.logging import human_format, log_rank from nanotron.parallel import ParallelContext from nanotron.utils import main_rank_first +from huggingface_hub import cached_assets_path + try: tb_logger_available = True @@ -302,15 +304,13 @@ def __init__( filename_pattern: str = None, recursive: bool = True, token_size: int = 2, - max_tokens: int | None = None, shuffle: bool = False, seed: int = 42, return_positions: bool = False, eos_token_id: int | None = None, skip_in_stream: bool = True, num_samples: Optional[int] = None, - folder_read_path: Optional[str] = None, - force_update_cache: bool = os.environ.get("FORCE_UPDATE_CACHE_S3", 0) == "1", + paths_file: str | None = None, ): log_rank("Using DatatroveFolderDataset", logger=logger, level=logging.INFO, rank=0) if return_positions and not eos_token_id: @@ -320,117 +320,17 @@ def __init__( level=logging.WARNING, rank=0, ) - - # Handle S3 paths specially - matched_files = None - file_sizes = None - if folder_path.startswith("s3://"): - cache_dir = os.path.expanduser("~/.cache/nanotron/s3_cache") - os.makedirs(cache_dir, exist_ok=True) - - # Create a unique cache key based on the folder path and pattern - import hashlib - - cache_key = hashlib.md5(f"{folder_path}:{filename_pattern}:{recursive}".encode()).hexdigest() - cache_file = os.path.join(cache_dir, f"{cache_key}.cache") - - with main_rank_first(): - if dist.get_rank() == 0: - # Check if we have a valid cache - if os.path.exists(cache_file) and not force_update_cache: - try: - import pickle - - with open(cache_file, "rb") as f: - cached_data = pickle.load(f) - matched_files = cached_data["matched_files"] - file_sizes = cached_data["file_sizes"] - log_rank( - "[TokenizedBytesFolderDataset] Using cached S3 file list", - logger=logger, - level=logging.INFO, - rank=0, - ) - except Exception as e: - log_rank( - f"[TokenizedBytesFolderDataset] Failed to load cache, fetching from S3: {e}", - logger=logger, - level=logging.WARNING, - rank=0, - ) - - # If no cache or cache invalid, fetch from S3 - if matched_files is None: - log_rank( - "[TokenizedBytesFolderDataset] Fetching file list from S3...", - logger=logger, - level=logging.INFO, - rank=0, - ) - from datatrove.utils.dataset import url_to_fs - - fs_folder, stripped_folder_path = url_to_fs(folder_path) - matched_files = ( - fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None) - if not filename_pattern - else fs_folder.glob( - os.path.join(stripped_folder_path, filename_pattern), - maxdepth=1 if not recursive else None, - ) - ) - matched_files = sorted(matched_files) - - # Get file sizes - file_sizes = {} - for path in matched_files: - file_path = fs_folder.unstrip_protocol(path) - fs, file_path = url_to_fs(file_path) - file_sizes[file_path] = fs.size(file_path) - - # Save to cache - try: - import pickle - - with open(cache_file, "wb") as f: - pickle.dump({"matched_files": matched_files, "file_sizes": file_sizes}, f) - log_rank( - "[TokenizedBytesFolderDataset] Saved S3 file list to cache", - logger=logger, - level=logging.INFO, - rank=0, - ) - except Exception as e: - log_rank( - f"[TokenizedBytesFolderDataset] Failed to save cache: {e}", - logger=logger, - level=logging.WARNING, - rank=0, - ) - if dist.get_rank() != 0: - try: - import pickle - - with open(cache_file, "rb") as f: - cached_data = pickle.load(f) - matched_files = cached_data["matched_files"] - file_sizes = cached_data["file_sizes"] - except Exception as e: - raise RuntimeError(f"Failed to read cache file on rank {dist.get_rank()}: {e}") - super().__init__( - folder_path=folder_path, + data_folder=folder_path, seq_len=seq_len, filename_pattern=filename_pattern, recursive=recursive, token_size=token_size, - max_tokens=max_tokens, shuffle=shuffle, seed=seed, return_positions=return_positions, - eos_token_id=eos_token_id, - read_path=folder_read_path, - matched_files=matched_files, - file_sizes=file_sizes, + positions_from_eos_token_id=eos_token_id, + paths_file=paths_file, ) self.subset_log = TBFolderDatasetLog( @@ -500,20 +400,42 @@ def build_dataset( dtype=np.uint16 if token_size == 2 else np.uint32, ) + paths_file = None + if folder_read_path: + paths_file = ( + cached_assets_path(library_name="nanotron", namespace="path_files") + / f"{dataset_folder.replace('/', '_')}_paths.json" + ).as_posix() + # This ensures that the paths file is created BASED on the datasef_folder + DatatroveFolderDataset( + data_folder=dataset_folder, + filename_pattern="*.ds", + seq_len=seq_length, + recursive=False, + token_size=token_size, + shuffle=False, + return_positions=return_positions, + positions_from_eos_token_id=eos_token_id, + seed=seed, + paths_file=paths_file, + ) + + # From now on, we use the folder_read_path to read the dataset + dataset_folder = folder_read_path + return TokenizedBytesFolderDataset( folder_path=dataset_folder, filename_pattern="*.ds", seq_len=seq_length, recursive=False, token_size=token_size, - max_tokens=max_tokens, shuffle=shuffle, return_positions=return_positions, # if set to True, the position ids are directly read from datatrove eos_token_id=eos_token_id, seed=seed, skip_in_stream=skip_in_stream, num_samples=num_samples, - folder_read_path=folder_read_path, + paths_file=paths_file, ) @@ -561,29 +483,26 @@ def get_tb_datasets( for i, (dataset_folder, max_tokens) in enumerate(zip(config.dataset_folder, dataset_max_tokens)) ] - if len(datasets) == 1 and False: - outputs_dataset = datasets[0] - else: - if dist.get_rank(parallel_context.world_pg) == 0: - try: - compile_helper() - except ImportError: - raise ImportError( - "Could not compile megatron dataset C++ helper functions and therefore cannot import helpers python file." - ) - dist.barrier(parallel_context.world_pg) - weights = config.dataset_weights - if not weights: - weights = [1] * len(datasets) - - outputs_dataset = BlendableDataset( - datasets, - weights, - train_num_samples, - parallel_context=parallel_context, - seed=seed, - consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, - ) + if dist.get_rank(parallel_context.world_pg) == 0: + try: + compile_helper() + except ImportError: + raise ImportError( + "Could not compile megatron dataset C++ helper functions and therefore cannot import helpers python file." + ) + dist.barrier(parallel_context.world_pg) + weights = config.dataset_weights + if not weights: + weights = [1] * len(datasets) + + outputs_dataset = BlendableDataset( + datasets, + weights, + train_num_samples, + parallel_context=parallel_context, + seed=seed, + consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, + ) log_rank("Streamable datasets ready.", logger=logger, level=logging.INFO, rank=0) train_data_log = TrainDataLog( diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 6567ec94a..96bbf90d0 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -160,25 +160,17 @@ def run_slurm_one_job( export MASTER_PORT=6000 export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` -# Hugging Face token setup -if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then - if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then - export HUGGING_FACE_HUB_TOKEN=$TOKEN - else - echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." - exit 1 - fi -fi - # Set environment variables export CUDA_DEVICE_MAX_CONNECTIONS=1 # export CUBLAS_WORKSPACE_CONFIG=":4096:8" # Set HuggingFace cache locations -export HUGGINGFACE_HUB_CACHE={slurm_config.hf_cache} -export HF_DATASETS_CACHE={slurm_config.hf_cache} -export HF_MODULES_CACHE={slurm_config.hf_cache} -export HF_HOME={slurm_config.hf_cache} +if [ -n "{slurm_config.hf_cache}" ] && [ "{slurm_config.hf_cache}" != "None" ]; then + export HUGGINGFACE_HUB_CACHE={slurm_config.hf_cache} + export HF_DATASETS_CACHE={slurm_config.hf_cache} + export HF_MODULES_CACHE={slurm_config.hf_cache} + export HF_HOME={slurm_config.hf_cache} +fi echo "Running on $COUNT_NODE nodes: $HOSTNAMES" @@ -250,10 +242,10 @@ def run_slurm_one_job( --node_rank $SLURM_PROCID \\ --master_addr $MASTER_ADDR \\ --master_port $MASTER_PORT \\ - {nanotron_path}/run_evals.py \\ + -m lighteval nanotron \\ --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ - --lighteval-override {lighteval_config.eval_config_override} - --cache-dir {slurm_config.hf_cache}""" + --lighteval-config-path {lighteval_config.lighteval_config_path} + """ if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None: slurm_script += f""" s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}/ From bb33a521407379633927598c292c9a246bf61720 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 23 Jun 2025 17:38:12 +0200 Subject: [PATCH 31/31] Nouamane/lighteval fix (#360) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Nouamane/lighteval (#356) * InitScalingMethod * InitScalingMethod * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes * add requeue * add wandb with lighteval and fix eval interval * fix this little space :( * folder_path should always have s3 when using s3 (fix consumed tokens issue) * config qwen * . --------- Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” * fix inference in case of varlen (input with paddings) * . * legacy * remove bos token * max-micro-batch * separate inference from training * use_decode_text * add no use cache case to decode_tokenized --------- Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” --- run_generate.py | 99 +++++--- src/nanotron/config/parallelism_config.py | 2 + src/nanotron/data/sft_processing.py | 4 +- src/nanotron/data/tokenized_bytes.py | 97 ++++++++ src/nanotron/generation/decode.py | 68 +++++- src/nanotron/models/qwen.py | 274 ++++++++++++++++++---- src/nanotron/nn/rotary.py | 29 +++ 7 files changed, 475 insertions(+), 98 deletions(-) diff --git a/run_generate.py b/run_generate.py index e21fe7e22..405748e3b 100644 --- a/run_generate.py +++ b/run_generate.py @@ -5,6 +5,14 @@ ``` export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations torchrun --nproc_per_node=1 run_generate.py --ckpt-path checkpoints/10 +torchrun --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 + +torchrun --rdzv_endpoint=127.0.0.1:12357 --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 --use-cache --max-micro-batch-size 2 +export CUDA_VISIBLE_DEVICES=2,3 +torchrun --rdzv_endpoint=127.0.0.1:12356 --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 --use-cache +export CUDA_VISIBLE_DEVICES=4,5 +torchrun --rdzv_endpoint=127.0.0.1:12355 --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 --max-micro-batch-size 2 --use-decode-tokenized +torchrun --rdzv_endpoint=127.0.0.1:12355 --nproc_per_node=2 run_generate.py --ckpt-path /scratch/1044000 --use-decode-tokenized ``` """ @@ -45,10 +53,7 @@ from nanotron.serialize import load_weights from nanotron.trainer import CONFIG_TO_MODEL_CLASS, mark_tied_parameters -try: - from transformers import AutoTokenizer -except ImportError: - AutoTokenizer = None +from transformers import AutoTokenizer # import lovely_tensors as lt @@ -65,6 +70,10 @@ def get_args(): parser.add_argument("--tp", type=int, default=0) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") parser.add_argument("--use-cache", action="store_true", help="Use KV cache to speed up generation") + parser.add_argument( + "--max-micro-batch-size", type=int, default=1, help="Maximum number of micro batches to generate" + ) + parser.add_argument("--use-decode-tokenized", action="store_true", help="Use decode_tokenized to generate text") return parser.parse_args() @@ -73,6 +82,17 @@ def main(): assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist" + dummy_inputs = [ + # "The future of AI is", + # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + "def fib(n)", + # 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.', + # "Advancements in technology will lead to", + # "Tomorrow's world is shaped by", + # "What is the meaning of the word chutzpah?\nThe word chutzpah means", + ] + + config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix()) model_config = config.model.model_config tokenizer_path = config.tokenizer.tokenizer_name_or_path @@ -154,36 +174,29 @@ def main(): load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path) model.eval() - if AutoTokenizer is not None: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - # tokenizer.pad_token_id = tokenizer.eos_token_id - if tokenizer.pad_token_id is None: - if tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - elif getattr(model.config, "pad_token_id", None) is not None: - tokenizer.pad_token_id = int(model.config.pad_token_id) - elif getattr(model.config, "eos_token_id", None) is not None: - tokenizer.pad_token_id = int(model.config.eos_token_id) - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - tokenizer.padding_side = "left" - tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? - dummy_inputs = [ - # "The future of AI is", - "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", - "def fib(n)", - # 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.', - # "Advancements in technology will lead to", - # "Tomorrow's world is shaped by", - ] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + # tokenizer.pad_token_id = tokenizer.eos_token_id + if tokenizer.pad_token_id is None: + if tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + elif getattr(model.config, "pad_token_id", None) is not None: + tokenizer.pad_token_id = int(model.config.pad_token_id) + elif getattr(model.config, "eos_token_id", None) is not None: + tokenizer.pad_token_id = int(model.config.eos_token_id) + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? + + if not args.use_decode_tokenized: outputs = decode_text( input_iter=(GenerationInput(text=text) for text in dummy_inputs), tokenizer=tokenizer, model=model.model, parallel_context=parallel_context, max_new_tokens=args.max_new_tokens, - max_micro_batch_size=2, + max_micro_batch_size=args.max_micro_batch_size, generation_config=GenerationArgs(sampler="greedy", use_cache=args.use_cache), tokenizer_config=TokenizerConfig(max_input_length=None), is_bench=os.environ.get("USE_BENCH", "0") == "1", @@ -217,15 +230,27 @@ def main(): rank=0, ) else: + # Tokenize dummy inputs + tokenized_inputs = tokenizer( + dummy_inputs, + padding=True, + truncation=True, + return_tensors="pt", + add_special_tokens=False, # TODO: this is important to avoid adding bos token to the input + ) + input_ids = tokenized_inputs["input_ids"].to(device="cuda") + attention_mask = tokenized_inputs["attention_mask"].to(device="cuda") + outputs = decode_tokenized( - input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), - input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"), + input_ids=input_ids, + input_mask=attention_mask, model=model.model, parallel_context=parallel_context, - generation_config=GenerationArgs(sampler="greedy", use_cache=True), - max_micro_batch_size=1, - max_new_tokens=12, + generation_config=GenerationArgs(sampler="greedy", use_cache=args.use_cache), + max_micro_batch_size=args.max_micro_batch_size, + max_new_tokens=args.max_new_tokens, returns_logits=False, + bos_token_id=tokenizer.bos_token_id, ) for output in outputs: input_ids = output.input_ids @@ -234,8 +259,16 @@ def main(): assert isinstance(generated_ids, TensorPointer) continue assert isinstance(generated_ids, torch.Tensor) + + log_rank( + f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( - f"generation: {generated_ids[len(input_ids) :]}", + f"generation: {tokenizer.decode(generated_ids[len(input_ids):], clean_up_tokenization_spaces=False)}", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 40d95119a..46e500fba 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -35,6 +35,8 @@ class ParallelismArgs: recompute_layer: bool = False tp_recompute_allgather: bool = True + moe_layer_recompute: bool = False # TODO: legacy config for smollm + expert_parallel_size: int = 1 context_parallel_size: int = 1 diff --git a/src/nanotron/data/sft_processing.py b/src/nanotron/data/sft_processing.py index 3133a749d..40ccfbd33 100644 --- a/src/nanotron/data/sft_processing.py +++ b/src/nanotron/data/sft_processing.py @@ -60,7 +60,9 @@ def process_sft(examples, tokenizer, trainer_sequence_length): # Set position ids for all tokens (prompt and completion) to sequential values # But only where attention_mask is True (non-padding tokens) valid_length = attention_mask[i].sum().item() - position_ids[i, :valid_length] = torch.arange(valid_length) + position_ids[i, :valid_length] = torch.arange( + valid_length + ) # TODO: better to pad left, although in modeling we remove pads so it doesnt matter if left or right # Set label_mask to True only for completion tokens # If prompt consumes the entire sequence, no tokens are used for loss diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index f0c927b50..a99981c9d 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -320,6 +320,103 @@ def __init__( level=logging.WARNING, rank=0, ) + + # Handle S3 paths specially + matched_files = None + file_sizes = None + if folder_path.startswith("s3://"): + cache_dir = os.path.expanduser("~/.cache/nanotron/s3_cache") + os.makedirs(cache_dir, exist_ok=True) + + # Create a unique cache key based on the folder path and pattern + import hashlib + + cache_key = hashlib.md5(f"{folder_path}:{filename_pattern}:{recursive}".encode()).hexdigest() + cache_file = os.path.join(cache_dir, f"{cache_key}.cache") + + with main_rank_first(): + if dist.get_rank() == 0: + # Check if we have a valid cache + if os.path.exists(cache_file) and not force_update_cache: + try: + import pickle + + with open(cache_file, "rb") as f: + cached_data = pickle.load(f) + matched_files = cached_data["matched_files"] + file_sizes = cached_data["file_sizes"] + log_rank( + "[TokenizedBytesFolderDataset] Using cached S3 file list", + logger=logger, + level=logging.INFO, + rank=0, + ) + except Exception as e: + log_rank( + f"[TokenizedBytesFolderDataset] Failed to load cache, fetching from S3: {e}", + logger=logger, + level=logging.WARNING, + rank=0, + ) + + # If no cache or cache invalid, fetch from S3 + if matched_files is None: + log_rank( + "[TokenizedBytesFolderDataset] Fetching file list from S3...", + logger=logger, + level=logging.INFO, + rank=0, + ) + from datatrove.utils.dataset import url_to_fs + + fs_folder, stripped_folder_path = url_to_fs(folder_path) + matched_files = ( + fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None) + if not filename_pattern + else fs_folder.glob( + os.path.join(stripped_folder_path, filename_pattern), + maxdepth=1 if not recursive else None, + ) + ) + matched_files = sorted(matched_files) + + # Get file sizes + file_sizes = {} + for path in matched_files: + file_path = fs_folder.unstrip_protocol(path) + fs, file_path = url_to_fs(file_path) + file_sizes[file_path] = fs.size(file_path) + + # Save to cache + try: + import pickle + + with open(cache_file, "wb") as f: + pickle.dump({"matched_files": matched_files, "file_sizes": file_sizes}, f) + log_rank( + "[TokenizedBytesFolderDataset] Saved S3 file list to cache", + logger=logger, + level=logging.INFO, + rank=0, + ) + except Exception as e: + log_rank( + f"[TokenizedBytesFolderDataset] Failed to save cache: {e}", + logger=logger, + level=logging.WARNING, + rank=0, + ) + if dist.get_rank() != 0: + try: + import pickle + + with open(cache_file, "rb") as f: + cached_data = pickle.load(f) + matched_files = cached_data["matched_files"] + file_sizes = cached_data["file_sizes"] + except Exception as e: + raise RuntimeError(f"Failed to read cache file on rank {dist.get_rank()}: {e}") + super().__init__( data_folder=folder_path, seq_len=seq_len, diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 338801100..6d87161ff 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -103,6 +103,17 @@ def micro_batcher( # Each dp is responsible for its own micro batches continue + if tokenizer.bos_token_id is not None: + # TODO: do we want this always? + add_special_tokens = False + logging.warn_once( + f"Tokenizer {tokenizer.name_or_path} has a bos_token_id={tokenizer.bos_token_id}, we set add_special_tokens to False to avoid adding bos token to input ids", + main_rank_only=True, + logger=logger, + ) + else: + add_special_tokens = True + if dist.get_rank(parallel_context.pp_pg) == input_rank: encodings = tokenizer( [elt.text for elt in micro_batch], @@ -111,6 +122,7 @@ def micro_batcher( padding=tokenizer_config.padding, max_length=tokenizer_config.max_input_length, truncation=tokenizer_config.truncation, + add_special_tokens=add_special_tokens, # pad_to_multiple_of=8 ) @@ -157,13 +169,17 @@ def micro_splitter( @torch.inference_mode() -def get_position_ids(input_ids, tokenizer): +def get_position_ids(input_ids, padding_token_id=None, input_mask=None): + assert padding_token_id is not None or input_mask is not None, "Either padding_token_id or input_mask must be provided" + # Find where padding ends for each sequence batch_size, seq_length = input_ids.shape - padding_token_id = tokenizer.eos_token_id # Create a mask of padding tokens - padding_mask = input_ids == padding_token_id + if padding_token_id is not None: + padding_mask = input_ids == padding_token_id + else: + padding_mask = input_mask # Find indices where non-padding tokens start # For sequences with no padding, this will be 0 @@ -305,7 +321,12 @@ def decode_text( else: batch_generated_ids = state.new_input_ids batch_generated_mask = state.new_input_mask - position_ids = get_position_ids(batch_generated_ids, tokenizer) + # assert first token isn't bos + assert ( + batch_generated_ids[0, 0] != tokenizer.bos_token_id + ), "First token is bos. Make sure you're using add_special_tokens=True when initializing the tokenizer." + padding_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id + position_ids = get_position_ids(batch_generated_ids, padding_token_id=padding_token_id) sharded_logits = model( input_ids=batch_generated_ids, position_ids=position_ids, # [batch_size, seq_len] @@ -541,8 +562,9 @@ def decode_tokenized( generation_config: GenerationArgs, max_micro_batch_size: int, max_new_tokens: int, - tokenizer: "PreTrainedTokenizer", returns_logits: Optional[bool] = False, + logits_are_batch_first: bool = True, + bos_token_id: Optional[int] = None, ) -> Generator[GenerationOutput, None, None]: """We assume the following: - Everyone receives ALL the input text. # TODO @thomasw21: technically only specific ranks need to receive input. @@ -606,7 +628,7 @@ def decode_tokenized( for batch in batches ) - for generation_iter in range(max_new_tokens): + for generation_iter in tqdm(range(max_new_tokens), desc="Generating"): all_new_decoder_input_ids_and_mask_same_rank: List[ Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]] ] = [] @@ -614,17 +636,37 @@ def decode_tokenized( for state_id, state in enumerate(decoder_states): new_decoder_states.append(state) # Get the new logits - with attach_store(model=model, store=state.store): - position_ids = get_position_ids(state.new_input_ids, tokenizer) + if generation_config.use_cache: + raise NotImplementedError("Use-cache is not supported for now") + with attach_store(model=model, store=state.store): + position_ids = get_position_ids(state.new_input_ids, tokenizer) + sharded_logits = model( + input_ids=state.new_input_ids, + position_ids=position_ids, # [batch_size, seq_len] + ) + else: + if isinstance(state.new_input_ids, torch.Tensor): + batch_generated_ids = torch.cat(state.generation_ids, dim=-1) + batch_generated_mask = torch.cat(state.generation_mask, dim=-1) + else: + batch_generated_ids = state.new_input_ids + batch_generated_mask = state.new_input_mask + # assert first token isn't bos + if bos_token_id is not None: + assert ( + batch_generated_ids[0, 0] != bos_token_id + ), "First token is bos. Make sure you're using add_special_tokens=True when initializing the tokenizer." + position_ids = get_position_ids(batch_generated_ids, input_mask=batch_generated_mask) sharded_logits = model( - input_ids=state.new_input_ids, + input_ids=batch_generated_ids, position_ids=position_ids, # [batch_size, seq_len] - ) - if isinstance(sharded_logits, torch.Tensor): - sharded_logits = sharded_logits.transpose(0, 1) + ) # [batch_size*seq_len, vocab_size] + + sharded_logits = sharded_logits.view(*position_ids.shape, -1) # [batch_size, seq_len, vocab_size] + if isinstance(sharded_logits, torch.Tensor) and not logits_are_batch_first: + sharded_logits = sharded_logits.transpose(0, 1) # Communicate - # TODO @thomasw21: Make a diagram to show how this works nb_send: int = 0 if is_decoder_input_rank: if is_max_nb_microbatches: diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index 8aa6eb46a..3f35d25e3 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -192,7 +192,7 @@ def __init__( async_communication=tp_linear_async_communication, ) if config._use_qkv_packed: - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + from nanotron.nn.rotary import FlashRotaryEmbedding self.rotary_emb = FlashRotaryEmbedding( dim=self.head_dim, @@ -214,66 +214,146 @@ def __init__( # TODO: support doc masking / SWA / SFT / inference - def forward( + def _determine_seq_length_and_flat_ids(self, position_ids: torch.Tensor, inference_max_seqlen: Optional[int]): + if position_ids.ndim == 2: + seq_length = position_ids.shape[1] + flat_position_ids = position_ids.view(-1) # [batch_size*seq_length] + else: + assert ( + inference_max_seqlen is not None + ), "inference_max_seqlen must be provided if position_ids is a 1D tensor" + seq_length = inference_max_seqlen + flat_position_ids = position_ids + return seq_length, flat_position_ids + + def _forward_train_attn( self, - hidden_states: torch.Tensor, # [batch_size*seq_length, hidden_size] - position_ids: torch.Tensor, # [batch_size, seq_length] where -1 is padding - cu_seqlens: Optional[torch.Tensor] = None, # Added cu_seqlens argument + qkv: torch.Tensor, + position_ids: torch.Tensor, # Original position_ids, likely [batch_size, seq_length] + cu_seqlens: Optional[torch.Tensor], ): - # [0, 1, 2, 3, 4, 0, 1, 2, -1, -1, -1] # 2 documents with 5 and 3 tokens then padding - # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 1 document with 11 tokens - # [0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1] # 1 document with 10 tokens then padding - # Replace -1 with 0 in position_ids to mark every padding token as a separate sequence. Ideally we want to get rid of padding tokens from qkv - # position_ids = position_ids.masked_fill(position_ids == -1, 0) - seq_length = position_ids.shape[1] - # Keep original position_ids shape for return, flatten for internal use - position_ids = position_ids.view(-1) # [batch_size*seq_length] + seq_length, flat_position_ids = self._determine_seq_length_and_flat_ids(position_ids, None) # inference_max_seqlen is None for training - qkv = self.qkv_proj(hidden_states) + if self._use_qkv_packed: + # _forward_packed uses self.training internally + attn_output = self._forward_packed(qkv, seq_length, flat_position_ids, cu_seqlens) + else: + q, k, v = qkv.split( + [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 + ) + q = q.view(-1, self.local_num_heads, self.head_dim) + k = k.view(-1, self.local_num_kv_heads, self.head_dim) + v = v.view(-1, self.local_num_kv_heads, self.head_dim) + + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + rotary_pos_emb = self.rotary_emb( + position_ids=flat_position_ids if not self.simple_causal_mask else None, seq_length=seq_length + ) + q = self.rotary_emb.apply_rotary_pos_emb(q, rotary_pos_emb, seq_length=seq_length) + k = self.rotary_emb.apply_rotary_pos_emb(k, rotary_pos_emb, seq_length=seq_length) + else: + log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + + attn_output = self.attention( + q, k, v, position_ids=flat_position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens + ) + + output = self.o_proj(attn_output) + return {"hidden_states": output, "position_ids": position_ids} # Return original position_ids + + def _forward_inference_attn( + self, + qkv: torch.Tensor, + position_ids: torch.Tensor, # Unpadded position_ids, likely [total_tokens] + cu_seqlens: Optional[torch.Tensor], + inference_max_seqlen: int, + ): + seq_length, flat_position_ids = self._determine_seq_length_and_flat_ids(position_ids, inference_max_seqlen) if self._use_qkv_packed: - attn_output = self._forward_packed(qkv, seq_length, position_ids, cu_seqlens) + # _forward_packed uses self.training internally + attn_output = self._forward_packed(qkv, seq_length, flat_position_ids, cu_seqlens) else: q, k, v = qkv.split( [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 - ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size] - q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim] - k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] - v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] + ) + q = q.view(-1, self.local_num_heads, self.head_dim) + k = k.view(-1, self.local_num_kv_heads, self.head_dim) + v = v.view(-1, self.local_num_kv_heads, self.head_dim) + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + # For inference, rotary_emb might be configured differently (e.g., for KV caching) + # Currently, using flat_position_ids which are unpadded and continuous for each token. rotary_pos_emb = self.rotary_emb( - position_ids=position_ids if not self.simple_causal_mask else None, seq_length=seq_length - ) # [b*s, dim] or [seq_length, dim] - q = self.rotary_emb.apply_rotary_pos_emb( - q, rotary_pos_emb, seq_length=seq_length - ) # [b*s, num_heads, head_dim] - k = self.rotary_emb.apply_rotary_pos_emb( - k, rotary_pos_emb, seq_length=seq_length - ) # [b*s, num_kv_heads, head_dim] + position_ids=flat_position_ids, seq_length=seq_length + ) + q = self.rotary_emb.apply_rotary_pos_emb(q, rotary_pos_emb, seq_length=seq_length) + k = self.rotary_emb.apply_rotary_pos_emb(k, rotary_pos_emb, seq_length=seq_length) else: log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + + # cu_seqlens for inference are based on unpadded data attn_output = self.attention( - q, k, v, position_ids=position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens + q, k, v, position_ids=flat_position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens ) + # TODO: KV Caching: return k, v here + output = self.o_proj(attn_output) - # Return original position_ids shape - return {"hidden_states": output, "position_ids": position_ids.view(-1, seq_length)} + return {"hidden_states": output, "position_ids": position_ids} # Return unpadded position_ids for inference consistency + + def forward( + self, + hidden_states: torch.Tensor, # [batch_size*seq_length, hidden_size] or [total_tokens, hidden_size] + position_ids: torch.Tensor, # [batch_size, seq_length] or [total_tokens] + cu_seqlens: Optional[torch.Tensor] = None, + inference_max_seqlen: Optional[int] = None, + ): + qkv = self.qkv_proj(hidden_states) + + if self.training: + return self._forward_train_attn( + qkv, + position_ids, # Original position_ids + cu_seqlens, + ) + else: + assert inference_max_seqlen is not None, "inference_max_seqlen must be provided for inference" + return self._forward_inference_attn( + qkv, + position_ids, # Unpadded position_ids + cu_seqlens, + inference_max_seqlen, + ) def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens): assert cu_seqlens is not None, "cu_seqlens must be provided for packed attention" q = qkv[..., : self.local_num_heads * self.head_dim] # Not contiguous, similar to flash_attn kv = qkv[..., self.local_num_heads * self.head_dim :] # Not contiguous, similar to flash_attn - q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) - kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: - q, kv = self.rotary_emb( - q, kv, seqlen_offset=0, max_seqlen=None - ) # TODO: should we use position_ids here? flash_attn doesn't + if self.training: + q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) + kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) + q, kv = self.rotary_emb( + q, kv, seqlen_offset=0, max_seqlen=None + ) # TODO: should we use position_ids here? flash_attn doesn't + else: + # TODO: support seqlen_offsets in case of use_cache + # qkv = qkv.view(-1, self.local_num_heads + 2 * self.local_num_kv_heads, self.head_dim) + # self.rotary_emb.varlen_forward(qkv, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) + # qkv = qkv.view(-1, (self.local_num_heads + 2 * self.local_num_kv_heads) * self.head_dim) + # q = qkv[..., : self.local_num_heads * self.head_dim] + # kv = qkv[..., self.local_num_heads * self.head_dim :] + q = q.view(-1, self.local_num_heads, self.head_dim) + kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) + k = kv[:, 0] + self.rotary_emb.varlen_forward(q, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) + self.rotary_emb.varlen_forward(k, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) else: log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) q = q.view(-1, self.local_num_heads, self.head_dim) kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) - max_seqlen = seq_length # TODO: should this be max position_ids? + max_seqlen = seq_length # TODO: should this be max position_ids? As long as it doesn't change often it and not too big should be fine assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None @@ -405,11 +485,17 @@ def _core_forward( hidden_states: Union[torch.Tensor, TensorPointer], # [batch_size*seq_length, hidden_size] position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding cu_seqlens: Union[torch.Tensor, TensorPointer], + inference_max_seqlen: Optional[int] = None, ) -> List[Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - output = self.attn(hidden_states=hidden_states, position_ids=position_ids, cu_seqlens=cu_seqlens) + output = self.attn( + hidden_states=hidden_states, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + inference_max_seqlen=inference_max_seqlen, + ) hidden_states = output["hidden_states"] hidden_states = hidden_states + residual @@ -433,18 +519,22 @@ def forward( hidden_states: Union[torch.Tensor, TensorPointer], position_ids: Union[torch.Tensor, TensorPointer], cu_seqlens: Union[torch.Tensor, TensorPointer], + inference_max_seqlen: Optional[int] = None, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: if self.recompute_layer and not isinstance(hidden_states, TensorPointer): hidden_states, position_ids, cu_seqlens = self._checkpointed_forward( - hidden_states, position_ids, cu_seqlens + hidden_states, position_ids, cu_seqlens, inference_max_seqlen ) else: - hidden_states, position_ids, cu_seqlens = self._core_forward(hidden_states, position_ids, cu_seqlens) + hidden_states, position_ids, cu_seqlens = self._core_forward( + hidden_states, position_ids, cu_seqlens, inference_max_seqlen + ) return { "hidden_states": hidden_states, "position_ids": position_ids, "cu_seqlens": cu_seqlens, + "inference_max_seqlen": inference_max_seqlen, } @@ -460,9 +550,8 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: Qwen2Config, parallel_confi ) self.pg = tp_pg - def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [batch_size, seq_length] - input_ids = input_ids.view(-1) # [batch_size*seq_length] - input_embeds = self.token_embedding(input_ids) # [batch_size*seq_length, hidden_size] + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [...] + input_embeds = self.token_embedding(input_ids) # [..., hidden_size] return {"input_embeds": input_embeds, "position_ids": position_ids} @@ -512,8 +601,8 @@ def __init__( "cp_pg": parallel_context.cp_pg, "layer_idx": layer_idx, }, - module_input_keys={"hidden_states", "position_ids", "cu_seqlens"}, - module_output_keys={"hidden_states", "position_ids", "cu_seqlens"}, + module_input_keys={"hidden_states", "position_ids", "cu_seqlens", "inference_max_seqlen"}, + module_output_keys={"hidden_states", "position_ids", "cu_seqlens", "inference_max_seqlen"}, ) for layer_idx in range(config.num_hidden_layers) ] @@ -544,36 +633,119 @@ def __init__( module_output_keys={"logits"}, ) - def forward( + def _forward_train( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding ): + # Training case (handles potential packing) + # Get embeddings for the original (potentially padded/packed) sequence output = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) - # Compute cu_seqlens + # output["position_ids"] is the original position_ids [batch_size, seq_length] + + # Compute cu_seqlens based on document starts (position_id == 0) if data is packed if position_ids.numel() > 0: start_indices = torch.where(position_ids.view(-1) == 0)[0] cu_seqlens = torch.cat( - [start_indices, torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device)] + [ + start_indices, + torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device), + ] ).to(torch.int32) else: - cu_seqlens = None + cu_seqlens = None # Or handle empty tensor case appropriately + # Prepare state for decoder layers using original/padded/packed data decoder_states = { - "hidden_states": output["input_embeds"], - "position_ids": output["position_ids"], - "cu_seqlens": cu_seqlens, + "hidden_states": output["input_embeds"], # Padded embeds [batch*seq_len, hidden_size] + "position_ids": output["position_ids"], # Original pos_ids [batch_size, seq_len] + "cu_seqlens": cu_seqlens, # Based on packing, might be None + # "inference_max_seqlen" is not needed for training } + # Pass the prepared decoder_states dictionary to the decoder layers for decoder_layer in self.decoder: + # Decoder layers need to handle both inference (unpadded) and training (padded/packed) states decoder_states = decoder_layer(**decoder_states) + # Final layer norm and LM head operate on the output hidden_states from the last decoder layer hidden_states = self.final_layer_norm(input=decoder_states["hidden_states"])["hidden_states"] + sharded_logits = self.lm_head(x=hidden_states)["logits"] + return sharded_logits + + def _forward_inference( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + ): + assert ( + position_ids.ndim == 2 + ), "position_ids must be 2D for inference, otherwise how do we know when to separate samples?" + inference_max_seqlen = position_ids.shape[1] + inference_batch_size = position_ids.shape[0] + # This gives the number of non-padding tokens per sequence in the batch + seqlens_in_batch = (position_ids != -1).sum(dim=-1, dtype=torch.int32) + input_ids = input_ids.view(-1) + position_ids = position_ids.view(-1) + # Find indices of non-padding tokens using the flattened position_ids + unpad_indices = torch.nonzero(position_ids != -1, as_tuple=False).flatten() + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=seqlens_in_batch.device), + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), + ] + ) + unpadded_input_ids = input_ids[unpad_indices] # (total_tokens) = (total_unpadded_tokens) + unpadded_position_ids = position_ids[unpad_indices] # (total_tokens) + + # TODO: compute this in dataloader to avoid cpu-gpu sync + cu_seqlens = cu_seqlens.to(unpadded_input_ids.device if isinstance(unpadded_input_ids, torch.Tensor) else input_ids.device) + + output = self.token_position_embeddings( + input_ids=unpadded_input_ids, position_ids=unpadded_position_ids + ) # (total_tokens, hidden_size) + decoder_states = { + "hidden_states": output["input_embeds"], # Unpadded embeds [total_tokens, hidden_size] + "position_ids": output["position_ids"], # Unpadded pos_ids [total_tokens] + "cu_seqlens": cu_seqlens, # cu_seqlens for unpadded sequence [batch_size + 1] + "inference_max_seqlen": inference_max_seqlen, # original seq_length using for inference + } + + # Pass the prepared decoder_states dictionary to the decoder layers + for decoder_layer in self.decoder: + # Decoder layers need to handle both inference (unpadded) and training (padded/packed) states + decoder_states = decoder_layer(**decoder_states) + + # Final layer norm and LM head operate on the output hidden_states from the last decoder layer + hidden_states = self.final_layer_norm(input=decoder_states["hidden_states"])["hidden_states"] sharded_logits = self.lm_head(x=hidden_states)["logits"] + # Pad logits back to original shape + assert inference_batch_size is not None and inference_max_seqlen is not None and unpad_indices is not None + # Create zero tensor with the full padded shape (flattened batch/seq) + padded_sharded_logits = torch.zeros( + inference_batch_size * inference_max_seqlen, + sharded_logits.shape[-1], # vocab_shard_size + dtype=sharded_logits.dtype, + device=sharded_logits.device, + ) + # Scatter the unpadded logits back into the zero tensor + padded_sharded_logits[unpad_indices] = sharded_logits + # Reshape to (batch_size, sequence_length, vocab_shard_size) + sharded_logits = padded_sharded_logits.view(inference_batch_size, inference_max_seqlen, -1) return sharded_logits + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + ): + if self.training: + return self._forward_train(input_ids=input_ids, position_ids=position_ids) + else: + return self._forward_inference(input_ids=input_ids, position_ids=position_ids) + def get_block_compute_costs(self): """Computes the compute cost of each block in the model for load balancing.""" model_config = self.config diff --git a/src/nanotron/nn/rotary.py b/src/nanotron/nn/rotary.py index 4e78849f9..6b9ae18de 100644 --- a/src/nanotron/nn/rotary.py +++ b/src/nanotron/nn/rotary.py @@ -1,8 +1,37 @@ import torch +from flash_attn.layers.rotary import RotaryEmbedding as OriginalFlashRotaryEmbedding from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb from torch import nn +class FlashRotaryEmbedding(OriginalFlashRotaryEmbedding): + """ + This is a modified version of the FlashRotaryEmbedding class that supports variable length sequences in case of inference. + """ + + def varlen_forward(self, x, seqlen_offsets, cu_seqlens, max_seqlen): + """ + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + seqlen_offsets: (batch_size) + cu_seqlens: (batch_size + 1) + max_seqlen: int + """ + assert not self.training, "This is supposed to be used only in inference" + assert max_seqlen is not None, "max_seqlen must be provided" + self._update_cos_sin_cache(max_seqlen, device=x.device, dtype=x.dtype) + flash_apply_rotary_emb( + x, + cos=self._cos_cached, + sin=self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + class RotaryEmbedding(nn.Module): def __init__( self,