Skip to content

Fix HF model dtype fallback in forward_pass_logit_checker.py#4310

Draft
subawocit wants to merge 1 commit into
mainfrom
fix-logit-checker-dtype
Draft

Fix HF model dtype fallback in forward_pass_logit_checker.py#4310
subawocit wants to merge 1 commit into
mainfrom
fix-logit-checker-dtype

Conversation

@subawocit

@subawocit subawocit commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator

Description

In forward_pass_logit_checker.py around line 431, we map the MaxText config dtype to its PyTorch equivalent:

dtype_mapping = { "bfloat16": torch.bfloat16, "float32": torch.float32, ... }
torch_dtype = dtype_mapping.get(config.dtype.name, torch.bfloat16)

Accessing config.dtype.name returns the uppercase name (e.g., "FLOAT32"). Since the keys in dtype_mapping are lowercase ("float32"), the .get() lookup failed, and the Hugging Face reference model was always loaded in bfloat16 precision, regardless of what MaxText's precision.

This mismatch in precision leads to artificially high KL divergence and numerical differences.

BUGS: b/529889649

Fix

Convert config.dtype.name to lowercase for dtype_mapping.

Tests

Click to expand
set -ex

run_id='2026-06-30-22-08-44'

USE_MULTIMODAL=false
SCAN_LAYERS=false
scan_status="unscanned"

export HF_TOKEN=<your-token>

MODEL_NAME='gemma3-4b'
HF_GOLDEN_MODEL='google/gemma-3-4b-it'
export MODEL_BUCKET=${MODEL_BUCKET:-gs://yuchenhou-maxtext-logs/checkpoints}
BASE_OUTPUT_DIRECTORY=${MODEL_BUCKET}/${MODEL_NAME}/${run_id}
export CKPT_PATH=${MODEL_BUCKET}/${MODEL_NAME}/${run_id}/unscanned/${run_id}/0/items
export UNSCANNED_CKPT_PATH=${CKPT_PATH}
export LOCAL_PATH=/tmp/hf/${MODEL_NAME}/${run_id}

python3 -m maxtext.checkpoint_conversion.to_huggingface \
    "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml  model_name=${MODEL_NAME} tokenizer_type="huggingface" load_parameters_path=${CKPT_PATH} base_output_directory=${LOCAL_PATH} use_multimodal=${USE_MULTIMODAL} scan_layers=$SCAN_LAYERS dtype=float32 hf_access_token=${HF_TOKEN} override_model_config=true

python3 -m tests.utils.forward_pass_logit_checker \
    "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path=${HF_GOLDEN_MODEL} load_parameters_path=${UNSCANNED_CKPT_PATH} model_name=${MODEL_NAME} use_multimodal=${USE_MULTIMODAL} scan_layers=false --hf_model_path=${LOCAL_PATH} --max_kl_div=0.03 --run_hf_model=true dtype=float32 matmul_precision=highest hardware=cpu skip_jax_distributed_system=True override_model_config=true

Test Results

KL divergence of hugging face and maxtext model logits (lower is better)

Prompt Before Fix After Fix
"I love to" Avg: 8.70e-04 (Max: 2.62e-03) Avg: 2.63e-05 (Max: 3.32e-05)
"Today is a" Avg: 2.16e-03 (Max: 5.49e-03) Avg: 2.72e-05 (Max: 4.03e-05)
"What is the" Avg: 2.20e-03 (Max: 4.88e-03) Avg: 3.09e-05 (Max: 4.89e-05)

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 30, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant