Skip to content

Commit ff01ee0

Browse files
committed
update train config
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent c06cef5 commit ff01ee0

3 files changed

Lines changed: 22 additions & 9 deletions

File tree

training/tensor_parallel/hf_integration/configs/ds_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
}
2222
},
2323
"zero_optimization": {
24-
"stage": 1,
24+
"stage": 2,
2525
"gather_16bit_weights_on_model_save": true
2626
},
2727
"tensor_parallel":{

training/tensor_parallel/hf_integration/run.sh

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
weight_path=/host/ssd/hf_models/llama2-7b-hf
2-
# weight_path=/host/ssd/hf_models/Meta-Llama-3.1-8B
1+
# Default to a public HF model for out-of-the-box runs.
2+
weight_path=facebook/opt-125m
33
export WANDB_MODE=disabled
4-
num_gpus=8
4+
num_gpus=${NUM_GPUS:-8}
55
epoch=3
66
mbs=2
7-
MODE=${1:-zero1tp}
7+
MODE=${1:-zero2tp}
88
if [ "$MODE" == "zero1tp" ]; then
99
ZERO_STAGE=1
1010
AUTOTP_SIZE=4
@@ -33,6 +33,13 @@ else
3333
echo "error '$MODE',please use 'zero' or 'tp'。"
3434
exit 1
3535
fi
36+
37+
# HF Trainer + Accelerate currently builds a 1D device mesh of size AUTOTP_SIZE.
38+
# If num_gpus > AUTOTP_SIZE, ranks outside the mesh fail during init_device_mesh.
39+
if [ "$AUTOTP_SIZE" -gt 1 ] && [ "$num_gpus" -ne "$AUTOTP_SIZE" ]; then
40+
echo "Adjusting num_gpus to AUTOTP_SIZE=$AUTOTP_SIZE to avoid device_mesh init failure."
41+
num_gpus=$AUTOTP_SIZE
42+
fi
3643
TEMPLATE_FILE="configs/ds_config_temp.json"
3744
OUTPUT_FILE="configs/ds_config.json"
3845
sed -e "s/\${zero_stage}/${ZERO_STAGE}/g" \
@@ -50,15 +57,15 @@ deepspeed --num_gpus $num_gpus \
5057
--gradient_checkpointing false \
5158
--per_device_train_batch_size $per_device_train_batch_size \
5259
--per_device_eval_batch_size 1 \
53-
--evaluation_strategy no \
60+
--eval_strategy no \
5461
--save_strategy steps \
5562
--save_steps 10000 \
5663
--gradient_accumulation_steps 4 \
5764
--learning_rate 0 \
5865
--learning_rate 2e-5 \
5966
--weight_decay 0. \
60-
--warmup_ratio 0.03 \
67+
--warmup_steps 0 \
6168
--lr_scheduler_type cosine \
6269
--logging_steps 1 \
6370
--tf32 True \
64-
--deepspeed "./configs/ds_config.json"
71+
--deepspeed "./configs/ds_config.json"

training/tensor_parallel/hf_integration/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,13 @@ class MemoryCallback(TrainerCallback):
252252
def on_step_end(self, args, state, control, **kwargs):
253253
see_memory_usage("After step end", force=True)
254254
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
255-
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args,callbacks=[MemoryCallback], **data_module)
255+
trainer = Trainer(
256+
model=model,
257+
processing_class=tokenizer,
258+
args=training_args,
259+
callbacks=[MemoryCallback],
260+
**data_module,
261+
)
256262

257263
trainer.train()
258264
# load&save distributed checkpoint

0 commit comments

Comments
 (0)