diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index f25b6cb1..a4032095 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -560,13 +560,19 @@ def build_dataloaders( ) # convert to dataloader - cache_params_string = ( + dataset_cache_params_string = ( f"{args.train_data_path}-" f"{args.max_length}-" f"{args.chat_template}-" f"{args.target_model_path}" # Tokenizer may also different ) - cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + vocab_cache_params_string = ( + f"{dataset_cache_params_string}-" + f"{draft_model_config.draft_vocab_size}-" + f"{draft_model_config.vocab_size}" + ) + cache_key = hashlib.md5(dataset_cache_params_string.encode()).hexdigest() + vocab_cache_key = hashlib.md5(vocab_cache_params_string.encode()).hexdigest() train_dataset = Dataset.from_generator( generator=safe_conversations_generator, gen_kwargs={"file_path": args.train_data_path}, @@ -593,7 +599,7 @@ def build_dataloaders( target_vocab_size=draft_model_config.vocab_size, draft_vocab_size=draft_model_config.draft_vocab_size, cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), - cache_key=cache_key, + cache_key=vocab_cache_key, ) if not is_online: