diff --git a/examples/config_tiny_llama.py b/examples/config_tiny_llama.py index 479e1d471..f1f92ab20 100644 --- a/examples/config_tiny_llama.py +++ b/examples/config_tiny_llama.py @@ -18,9 +18,13 @@ RandomInit, TokenizerArgs, TokensArgs, + MetricsLoggingArgs, ) from nanotron.logging import human_format - +CHECKPOINT_ROOT_PATH = "./checkpoints" +tokenizer_name_or_path = "Qwen/Qwen3-0.6B" +tokenizer_vocab_size = 151643 +sequence_length = 2048 model_config = LlamaConfig( # Config for a tiny model model with 1.62M parameters bos_token_id=1, @@ -38,7 +42,7 @@ rope_scaling=None, tie_word_embeddings=True, use_cache=True, - vocab_size=256, + vocab_size=tokenizer_vocab_size, ) num_params = human_format( @@ -55,7 +59,7 @@ seed = 42 learning_rate = LRSchedulerArgs( - learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-7 ) optimizer = OptimizerArgs( @@ -73,28 +77,28 @@ ) parallelism = ParallelismArgs( - dp=2, - pp=2, - tp=2, + dp=1, + pp=1, + tp=1, pp_engine="1f1b", tp_mode="REDUCE_SCATTER", tp_linear_async_communication=True, ) -tokens = TokensArgs(sequence_length=256, train_steps=15, micro_batch_size=2, batch_accumulation_per_replica=1) +tokens = TokensArgs(sequence_length=sequence_length, train_steps=1000, micro_batch_size=2, batch_accumulation_per_replica=1) data_stages = [ - DatasetStageArgs( - name="Stable Training Stage", - start_training_step=1, - data=DataArgs( - dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"), - seed=seed, - ), - ), + # DatasetStageArgs( + # name="Stable Training Stage", + # start_training_step=1, + # data=DataArgs( + # dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"), + # seed=seed, + # ), + # ), DatasetStageArgs( name="Annealing Phase", - start_training_step=10, + start_training_step=1, data=DataArgs( dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"), seed=seed, @@ -102,17 +106,18 @@ ), ] -checkpoints_path = "./checkpoints" +checkpoints_path = CHECKPOINT_ROOT_PATH os.makedirs(checkpoints_path, exist_ok=True) config = Config( - general=GeneralArgs(project="debug", run="tiny_llama_%date_%jobid", seed=seed), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + general=GeneralArgs(project="tiny_llama_nanotron_test", run="tiny_llama_%date_%jobid", seed=seed), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=25), parallelism=parallelism, model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), - tokenizer=TokenizerArgs("robot-test/dummy-tokenizer-wordlevel"), + tokenizer=TokenizerArgs(tokenizer_name_or_path), optimizer=optimizer, logging=LoggingArgs(), + metrics_logging=MetricsLoggingArgs(1), tokens=tokens, data_stages=data_stages, profiler=None, diff --git a/examples/config_tiny_llama_resume.py b/examples/config_tiny_llama_resume.py new file mode 100644 index 000000000..6c02f05f2 --- /dev/null +++ b/examples/config_tiny_llama_resume.py @@ -0,0 +1,140 @@ +""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.""" +import os + +from nanotron.config import ( + AdamWOptimizerArgs, + CheckpointsArgs, + Config, + DataArgs, + DatasetStageArgs, + GeneralArgs, + LlamaConfig, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + PretrainDatasetsArgs, + RandomInit, + TokenizerArgs, + TokensArgs, + MetricsLoggingArgs, +) +from nanotron.logging import human_format +import json + + +CHECKPOINT_ROOT_PATH = "./checkpoints" +with open(f"{CHECKPOINT_ROOT_PATH}/latest.txt", 'r') as file: + content = file.read().strip() +number = 0 +# Try to convert to integer and store in 'number' if valid +try: + number = int(content) + print(f"Valid number found: {number}") +except ValueError: + print("The file does not contain a valid integer.") + raise ValueError("latest.txt does not contain a valid integer.") + + +CHECKPOINT_PATH = f"{CHECKPOINT_ROOT_PATH}/{number}" +tokenizer_name_or_path = "Qwen/Qwen3-0.6B" +sequence_length = 2048 +model_config_dict = json.load(open(f"{CHECKPOINT_PATH}/model_config.json")) +model_config = LlamaConfig( + **model_config_dict, +) + + +num_params = human_format( + model_config.vocab_size * model_config.hidden_size * 2 + + model_config.num_hidden_layers + * ( + 3 * model_config.hidden_size * model_config.intermediate_size + + 4 * model_config.hidden_size * model_config.hidden_size + ) +).replace(".", "p") + +print(f"Model has {num_params} parameters") + +seed = 42 + +learning_rate = LRSchedulerArgs( + learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-7 +) + +optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=True, + learning_rate_scheduler=learning_rate, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), +) + +parallelism = ParallelismArgs( + dp=1, + pp=1, + tp=1, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, +) + +tokens = TokensArgs(sequence_length=sequence_length, train_steps=2000, micro_batch_size=2, batch_accumulation_per_replica=1) + +data_stages = [ + # DatasetStageArgs( + # name="Stable Training Stage", + # start_training_step=1, + # data=DataArgs( + # dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"), + # seed=seed, + # ), + # ), + DatasetStageArgs( + name="Annealing Phase", + start_training_step=1, + data=DataArgs( + dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="stas/openwebtext-10k", text_column_name="text"), + seed=seed, + ), + ), +] + +checkpoints_path = CHECKPOINT_ROOT_PATH +os.makedirs(checkpoints_path, exist_ok=True) + +config = Config( + general=GeneralArgs(project="tiny_llama_nanotron_test", run="tiny_llama_%date_%jobid", seed=seed,ignore_sanity_checks=False), + checkpoints=CheckpointsArgs( + checkpoints_path=checkpoints_path, + checkpoint_interval=10, + resume_checkpoint_path=CHECKPOINT_PATH, + save_initial_state=True, + load_lr_scheduler=True, + load_optimizer=True, + ), + parallelism=parallelism, + model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), + tokenizer=TokenizerArgs(tokenizer_name_or_path), + optimizer=optimizer, + logging=LoggingArgs(), + metrics_logging=MetricsLoggingArgs(1), + tokens=tokens, + data_stages=data_stages, + profiler=None, +) + +if __name__ == "__main__": + dir = os.path.dirname(__file__) + + # Save config as YAML file + config.save_as_yaml(f"{dir}/config_tiny_llama_resume.yaml") + + # You can now train a model with this config using `/run_train.py` diff --git a/run_train.py b/run_train.py index 35bc62dab..e17b35f4f 100644 --- a/run_train.py +++ b/run_train.py @@ -167,7 +167,7 @@ def get_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, - consumed_train_samples_stage=consumed_train_samples_stage, + consumed_train_samples=consumed_train_samples_stage, dataloader_num_workers=data.num_loading_workers, seed_worker=data.seed, dataloader_drop_last=True, diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index db8206448..b63197b77 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -1092,7 +1092,7 @@ def init_model_randomly(self, config: Config): else: raise ValueError(f"Unknown init method {init_method}") - parametrizator = parametrizator_cls(config=config.model) + parametrizator = parametrizator_cls(config=config) log_rank( f"Parametrizing model parameters using {parametrizator.__class__.__name__}", diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 13ecf4975..2778972fe 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -572,14 +572,16 @@ def train( # Training Logs # Track consumed tokens for all dataset folders in current stage - if hasattr(self.current_base_dl, "dataset"): + if hasattr(self.current_base_dl, "dataset") and hasattr(self.current_base_dl.dataset, "get_consumption_stats"): consumption_stats = self.current_base_dl.dataset.get_consumption_stats() current_stage = self.metadata.data_stages[self.metadata.last_stage_idx] # Update consumed tokens for all folders in the consumption stats for folder_path, stats in consumption_stats.items(): current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] - + else: + self.metadata.current_stage.consumed_tokens_per_dataset_folder.setdefault("default", 0) + self.metadata.current_stage.consumed_tokens_per_dataset_folder["default"] += self.global_batch_size * self.sequence_length # Original consumption tracking self.metadata.consumed_train_samples += self.global_batch_size # TODO: Legacy: idc abt this self.metadata.consumed_tokens_total += self.global_batch_size * self.sequence_length @@ -883,7 +885,7 @@ def get_cpu_logitems(): assert self.current_base_dl is not None, "current_base_dl should be defined" # Log consumption statistics - if hasattr(self.current_base_dl, "dataset"): + if hasattr(self.current_base_dl, "dataset") and hasattr(self.current_base_dl.dataset, "get_consumption_stats"): for dataset_name, stats in self.current_base_dl.dataset.get_consumption_stats().items(): basic_log_entries.extend( [ diff --git a/test_flash_attn.py b/test_flash_attn.py new file mode 100644 index 000000000..73a74956b --- /dev/null +++ b/test_flash_attn.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +Simple test script for FlashAttention +""" + +import torch + +def test_flash_attention(): + """Test if FlashAttention is working correctly""" + + # Check if CUDA is available + if not torch.cuda.is_available(): + print("CUDA is not available. FlashAttention requires CUDA.") + return False + + print(f"CUDA available: {torch.cuda.is_available()}") + print(f"Number of GPUs: {torch.cuda.device_count()}") + print(f"Current device: {torch.cuda.current_device()}") + print(f"Device name: {torch.cuda.get_device_name()}") + + try: + # Try to import FlashAttention + from flash_attn import flash_attn_func + print("✓ FlashAttention imported successfully!") + except ImportError as e: + print(f"✗ Failed to import FlashAttention: {e}") + return False + + # Create dummy tensors + batch_size = 1 + seq_len = 128 + num_heads = 8 + head_dim = 64 + + device = torch.device('cuda') + dtype = torch.float16 + + # Create Q, K, V tensors + Q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + K = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + V = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + print(f"Input shapes: Q={Q.shape}, K={K.shape}, V={V.shape}") + + try: + # Run FlashAttention + output = flash_attn_func(Q, K, V, dropout_p=0.0, softmax_scale=None, causal=False) + print("✓ FlashAttention ran successfully!") + print(f"Output shape: {output.shape}") + + # Test with causal attention + output_causal = flash_attn_func(Q, K, V, dropout_p=0.0, softmax_scale=None, causal=True) + print("✓ Causal FlashAttention ran successfully!") + print(f"Causal output shape: {output_causal.shape}") + + return True + + except Exception as e: + print(f"✗ FlashAttention failed to run: {e}") + return False + +def test_memory_usage(): + """Test memory usage with different sequence lengths""" + + if not torch.cuda.is_available(): + print("CUDA not available, skipping memory test") + return + + try: + from flash_attn import flash_attn_func + except ImportError: + print("FlashAttention not available, skipping memory test") + return + + device = torch.device('cuda') + dtype = torch.float16 + + # Test different sequence lengths + seq_lengths = [512, 1024, 2048] + batch_size = 1 + num_heads = 8 + head_dim = 64 + + for seq_len in seq_lengths: + print(f"\nTesting sequence length: {seq_len}") + + # Clear cache before test + torch.cuda.empty_cache() + + # Get initial memory + initial_memory = torch.cuda.memory_allocated() / 1024**3 # GB + + try: + # Create tensors + Q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + K = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + V = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + # Run attention + output = flash_attn_func(Q, K, V, dropout_p=0.0, softmax_scale=None, causal=False) + + # Get peak memory + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # GB + + print(f" Initial memory: {initial_memory:.2f} GB") + print(f" Peak memory: {peak_memory:.2f} GB") + print(f" Memory used: {peak_memory - initial_memory:.2f} GB") + + except Exception as e: + print(f" Failed: {e}") + + # Clean up + del Q, K, V, output + torch.cuda.empty_cache() + +if __name__ == "__main__": + print("=== FlashAttention Test ===\n") + + # Basic functionality test + success = test_flash_attention() + + if success: + print("\n=== Memory Usage Test ===") + test_memory_usage() + + print("\n=== Test Complete ===") \ No newline at end of file