audio tester class#45391
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: audioflamingo3, granite_speech, qwen2_audio |
|
This comment contains models: ["models/audioflamingo3", "models/granite_speech", "models/qwen2_audio"] |
eustlb
left a comment
There was a problem hiding this comment.
This is cool!! 🔥
Models that should be covered by this PR:
- audioflamingo3
- glmasr
- granite_speech
- higgs_audio_v2
- kyutai_speech_to_text
- qwen2_audio
- vibevoice_asr
- voxtral
- voxtral_realtime
- musicflamingo
might:
- gemma3n
- gemma4
- qwen2_5_omni
- qwen3_omni_moe
| def get_num_audio_tokens(self, audio_features): | ||
| """Compute number of audio placeholder tokens from features. Override for different subsampling.""" | ||
| # Default: 2-stage pooling (common for Whisper-style encoders) | ||
| input_length = (audio_features.shape[-1] - 1) // 2 + 1 | ||
| return (input_length - 2) // 2 + 1 |
There was a problem hiding this comment.
we shouldn't put whisper defaults here but rather force sub classes to write this method
| input_ids = input_ids.clone() | ||
| input_ids[input_ids == self.audio_token_id] = self.pad_token_id | ||
| for i in range(input_ids.shape[0]): | ||
| n = num_audio_tokens[i].item() if isinstance(num_audio_tokens, torch.Tensor) else num_audio_tokens | ||
| if 1 + int(n) > self.seq_length: | ||
| raise ValueError( | ||
| f"Cannot place {int(n)} audio tokens after BOS in a sequence of length {self.seq_length}. " | ||
| "This likely indicates a mismatch between your feature extraction/configuration and your sequence length. " | ||
| "Please ensure `seq_length` is >= the number of audio embedding positions + 1." | ||
| ) | ||
| input_ids[i, 1 : 1 + int(n)] = self.audio_token_id | ||
| return input_ids |
There was a problem hiding this comment.
i like it, allows to test different numbers of multimodal data per sample !
| return {self.audio_config_key: self.get_audio_config()} | ||
|
|
||
| def _prepare_modality_inputs(self, input_ids, config): | ||
| # TODO: add a clear diagram that explains input prep ? |
|
run-slow: audioflamingo3, gemma3, glmasr, granite_speech, llava_next, musicflamingo, qwen2_5_omni, qwen2_audio, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, vibevoice_asr, voxtral, voxtral_realtime |
|
This comment contains models: ["models/audioflamingo3", "models/gemma3", "models/glmasr", "models/granite_speech", "models/llava_next", "models/musicflamingo", "models/qwen2_5_omni", "models/qwen2_audio", "models/qwen3_omni_moe", "models/qwen3_vl", "models/qwen3_vl_moe", "models/vibevoice_asr", "models/voxtral", "models/voxtral_realtime"] |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: audioflamingo3, gemma3, glmasr, granite_speech, granite_speech_plus, llava_next, musicflamingo, qwen2_5_omni, qwen2_audio, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, vibevoice_asr, voxtral, voxtral_realtime |
* audio tester * tweak check repo for audio tester * audio -> ALM * ALMTester: no audio/text defaults; better input prep * udpate test_sdpa_can_dispatch_composite_models to hanlde ALMs * propagate to other model classes * cleaner * updates * audio_mask_key + updates * typo * simplify granite speech * nits * some more cleaning * add test_mismatching_num_audio_tokens * add get_placeholder_mask * specific to musicflamingo * granite speech fix * let's factorise alm/vlm testers * make fix-repo * unskip test_sdpa_can_dispatch_on_flash on qwen2_audio * should not be skipped * make fix-repo * test_mismatching_num_audio_tokens should be skipped for voxtral_realtime * nit * _special_token_ids as property and skipped in prepare_config_and_inputs_for_common * MoE params in common class * add _TEXT_MODEL_TESTER_DEFAULTS to avoid divergence * nit * clearer inits * _prepare_modality_inputs return dict * format * split line for readability * ran python utils/check_modular_conversion.py --fix_and_overwrite * testing auto cancel * testing auto cancel - part 2 * remove comment * udpate granite speech plus tests * fix test --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
What does this PR do?
Similarly to the VLM tester, this patch introduces a audio tester class, used in
Adding a new audio-language model using this will require ~8-20 lines for the tester (vs ~100-160 before). The boilerplate (config introspection, input preparation, SDPA dispatch test, common skips) lives in one place.