Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
RandomInit,
TokenizerArgs,
TokensArgs,
MetricsLoggingArgs,
)
from nanotron.logging import human_format

CHECKPOINT_ROOT_PATH = "./checkpoints"
tokenizer_name_or_path = "Qwen/Qwen3-0.6B"
tokenizer_vocab_size = 151643
sequence_length = 2048
model_config = LlamaConfig(
# Config for a tiny model model with 1.62M parameters
bos_token_id=1,
Expand All @@ -38,7 +42,7 @@
rope_scaling=None,
tie_word_embeddings=True,
use_cache=True,
vocab_size=256,
vocab_size=tokenizer_vocab_size,
)

num_params = human_format(
Expand All @@ -55,7 +59,7 @@
seed = 42

learning_rate = LRSchedulerArgs(
learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5
learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-7
)

optimizer = OptimizerArgs(
Expand All @@ -73,46 +77,47 @@
)

parallelism = ParallelismArgs(
dp=2,
pp=2,
tp=2,
dp=1,
pp=1,
tp=1,
pp_engine="1f1b",
tp_mode="REDUCE_SCATTER",
tp_linear_async_communication=True,
)

tokens = TokensArgs(sequence_length=256, train_steps=15, micro_batch_size=2, batch_accumulation_per_replica=1)
tokens = TokensArgs(sequence_length=sequence_length, train_steps=1000, micro_batch_size=2, batch_accumulation_per_replica=1)

data_stages = [
DatasetStageArgs(
name="Stable Training Stage",
start_training_step=1,
data=DataArgs(
dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"),
seed=seed,
),
),
# DatasetStageArgs(
# name="Stable Training Stage",
# start_training_step=1,
# data=DataArgs(
# dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"),
# seed=seed,
# ),
# ),
DatasetStageArgs(
name="Annealing Phase",
start_training_step=10,
start_training_step=1,
data=DataArgs(
dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"),
seed=seed,
),
),
]

checkpoints_path = "./checkpoints"
checkpoints_path = CHECKPOINT_ROOT_PATH
os.makedirs(checkpoints_path, exist_ok=True)

config = Config(
general=GeneralArgs(project="debug", run="tiny_llama_%date_%jobid", seed=seed),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
general=GeneralArgs(project="tiny_llama_nanotron_test", run="tiny_llama_%date_%jobid", seed=seed),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=25),
parallelism=parallelism,
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
tokenizer=TokenizerArgs("robot-test/dummy-tokenizer-wordlevel"),
tokenizer=TokenizerArgs(tokenizer_name_or_path),
optimizer=optimizer,
logging=LoggingArgs(),
metrics_logging=MetricsLoggingArgs(1),
tokens=tokens,
data_stages=data_stages,
profiler=None,
Expand Down
140 changes: 140 additions & 0 deletions examples/config_tiny_llama_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information."""
import os

from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
DatasetStageArgs,
GeneralArgs,
LlamaConfig,
LoggingArgs,
LRSchedulerArgs,
ModelArgs,
OptimizerArgs,
ParallelismArgs,
PretrainDatasetsArgs,
RandomInit,
TokenizerArgs,
TokensArgs,
MetricsLoggingArgs,
)
from nanotron.logging import human_format
import json


CHECKPOINT_ROOT_PATH = "./checkpoints"
with open(f"{CHECKPOINT_ROOT_PATH}/latest.txt", 'r') as file:
content = file.read().strip()
number = 0
# Try to convert to integer and store in 'number' if valid
try:
number = int(content)
print(f"Valid number found: {number}")
except ValueError:
print("The file does not contain a valid integer.")
raise ValueError("latest.txt does not contain a valid integer.")


CHECKPOINT_PATH = f"{CHECKPOINT_ROOT_PATH}/{number}"
tokenizer_name_or_path = "Qwen/Qwen3-0.6B"
sequence_length = 2048
model_config_dict = json.load(open(f"{CHECKPOINT_PATH}/model_config.json"))
model_config = LlamaConfig(
**model_config_dict,
)


num_params = human_format(
model_config.vocab_size * model_config.hidden_size * 2
+ model_config.num_hidden_layers
* (
3 * model_config.hidden_size * model_config.intermediate_size
+ 4 * model_config.hidden_size * model_config.hidden_size
)
).replace(".", "p")

print(f"Model has {num_params} parameters")

seed = 42

learning_rate = LRSchedulerArgs(
learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-7
)

optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

parallelism = ParallelismArgs(
dp=1,
pp=1,
tp=1,
pp_engine="1f1b",
tp_mode="REDUCE_SCATTER",
tp_linear_async_communication=True,
)

tokens = TokensArgs(sequence_length=sequence_length, train_steps=2000, micro_batch_size=2, batch_accumulation_per_replica=1)

data_stages = [
# DatasetStageArgs(
# name="Stable Training Stage",
# start_training_step=1,
# data=DataArgs(
# dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"),
# seed=seed,
# ),
# ),
DatasetStageArgs(
name="Annealing Phase",
start_training_step=1,
data=DataArgs(
dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"),
seed=seed,
),
),
]

checkpoints_path = CHECKPOINT_ROOT_PATH
os.makedirs(checkpoints_path, exist_ok=True)

config = Config(
general=GeneralArgs(project="tiny_llama_nanotron_test", run="tiny_llama_%date_%jobid", seed=seed,ignore_sanity_checks=False),
checkpoints=CheckpointsArgs(
checkpoints_path=checkpoints_path,
checkpoint_interval=10,
resume_checkpoint_path=CHECKPOINT_PATH,
save_initial_state=True,
load_lr_scheduler=True,
load_optimizer=True,
),
parallelism=parallelism,
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
tokenizer=TokenizerArgs(tokenizer_name_or_path),
optimizer=optimizer,
logging=LoggingArgs(),
metrics_logging=MetricsLoggingArgs(1),
tokens=tokens,
data_stages=data_stages,
profiler=None,
)

if __name__ == "__main__":
dir = os.path.dirname(__file__)

# Save config as YAML file
config.save_as_yaml(f"{dir}/config_tiny_llama_resume.yaml")

# You can now train a model with this config using `/run_train.py`
2 changes: 1 addition & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_dataloader_from_data_stage(
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples_stage=consumed_train_samples_stage,
consumed_train_samples=consumed_train_samples_stage,
dataloader_num_workers=data.num_loading_workers,
seed_worker=data.seed,
dataloader_drop_last=True,
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ def init_model_randomly(self, config: Config):
else:
raise ValueError(f"Unknown init method {init_method}")

parametrizator = parametrizator_cls(config=config.model)
parametrizator = parametrizator_cls(config=config)

log_rank(
f"Parametrizing model parameters using {parametrizator.__class__.__name__}",
Expand Down
8 changes: 5 additions & 3 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,14 +572,16 @@ def train(

# Training Logs
# Track consumed tokens for all dataset folders in current stage
if hasattr(self.current_base_dl, "dataset"):
if hasattr(self.current_base_dl, "dataset") and hasattr(self.current_base_dl.dataset, "get_consumption_stats"):
consumption_stats = self.current_base_dl.dataset.get_consumption_stats()
current_stage = self.metadata.data_stages[self.metadata.last_stage_idx]

# Update consumed tokens for all folders in the consumption stats
for folder_path, stats in consumption_stats.items():
current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"]

else:
self.metadata.current_stage.consumed_tokens_per_dataset_folder.setdefault("default", 0)
self.metadata.current_stage.consumed_tokens_per_dataset_folder["default"] += self.global_batch_size * self.sequence_length
# Original consumption tracking
self.metadata.consumed_train_samples += self.global_batch_size # TODO: Legacy: idc abt this
self.metadata.consumed_tokens_total += self.global_batch_size * self.sequence_length
Expand Down Expand Up @@ -883,7 +885,7 @@ def get_cpu_logitems():
assert self.current_base_dl is not None, "current_base_dl should be defined"

# Log consumption statistics
if hasattr(self.current_base_dl, "dataset"):
if hasattr(self.current_base_dl, "dataset") and hasattr(self.current_base_dl.dataset, "get_consumption_stats"):
for dataset_name, stats in self.current_base_dl.dataset.get_consumption_stats().items():
basic_log_entries.extend(
[
Expand Down
Loading