-
Notifications
You must be signed in to change notification settings - Fork 33.5k
[docs] ALMModelTest #45900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+62
−4
Merged
[docs] ALMModelTest #45900
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,7 +36,21 @@ RUN_SLOW=1 pytest tests/models/mymodel/ -v | |
|
|
||
| The Hugging Face CI runs model tests without `@slow` on every pull request, and slow tests run on a nightly schedule (see [Pull request checks](./pr_checks) for what the CI validates). | ||
|
|
||
| ## Write tests for a causal language model | ||
| ## Pick a base test class | ||
|
|
||
| Three base classes cover the most common model families. Pick the one that matches your model's modality. | ||
|
|
||
| | Base class | Use for | Mixins | | ||
| |---|---|---| | ||
| | `CausalLMModelTest` | Causal language models | `ModelTesterMixin`, `GenerationTesterMixin`, `PipelineTesterMixin`, `TrainingTesterMixin`, `TensorParallelTesterMixin` | | ||
| | `VLMModelTest` | Vision-language models | `ModelTesterMixin`, `GenerationTesterMixin`, `PipelineTesterMixin` | | ||
| | `ALMModelTest` | Audio-language models | `ModelTesterMixin`, `GenerationTesterMixin`, `PipelineTesterMixin` | | ||
|
|
||
| `VLMModelTest` and `ALMModelTest` share a common `MultiModalModelTest` parent that nests sub-configs into a composite top-level config and places modality placeholder tokens in `input_ids` alongside the raw modality features (audio or vision). `CausalLMModelTest` doesn't use the multimodal parent. It builds on the three shared mixins and adds `TrainingTesterMixin` and `TensorParallelTesterMixin` for training and tensor-parallel coverage. | ||
|
|
||
| For architectures that don't fit any of the three (encoder-only, encoder-decoder, etc.), build the test infrastructure directly from the [two-class pattern](#modeltester-and-modeltest) and [test mixins](#test-mixins) described below. | ||
|
|
||
| ## CausalLMModelTest | ||
|
|
||
| `CausalLMModelTest` is the recommended base class for testing causal language models. It inherits from five [test mixins](#test-mixins) and auto-generates tests for save/load, generation, pipelines, training, and tensor parallelism. | ||
|
|
||
|
|
@@ -66,7 +80,7 @@ These two classes give full test coverage for `MyModel` and all its head classes | |
|
|
||
| `CausalLMModelTester` only requires `base_model_class`. The tester strips the `Model` suffix to get a base name (`LlamaModel` becomes `Llama`), then appends suffixes like `Config` or `ForCausalLM` to discover related classes. If a class doesn't exist in the module, the attribute stays `None` and the tester skips the corresponding tests. | ||
|
|
||
| ### Overriding defaults | ||
| ### Overriding defaults in the CausalLMTester | ||
|
|
||
| If your model doesn't follow standard naming, or you need to customize behavior, override attributes on the tester or test class. | ||
|
|
||
|
|
@@ -98,7 +112,7 @@ class YoutuModelTester(CausalLMModelTester): | |
| self.q_lora_rank = q_lora_rank | ||
| ``` | ||
|
|
||
| ## Write tests for a vision-language model | ||
| ## VLMModelTest | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question |
||
|
|
||
| `VLMModelTest` is the base class for vision-language models. It inherits from three mixins (`ModelTesterMixin`, `GenerationTesterMixin`, `PipelineTesterMixin`) and sets `_is_composite = True` to handle multiple sub-models. | ||
|
|
||
|
|
@@ -134,7 +148,7 @@ class MyVLMTest(VLMModelTest, unittest.TestCase): | |
| model_tester_class = MyVLMTester | ||
| ``` | ||
|
|
||
| ### Overriding defaults | ||
| ### Overriding defaults in the VLMModelTester | ||
|
|
||
| When the VLM needs custom vision parameters or non-default config values, override `__init__`. Set defaults with `setdefault` before calling `super().__init__(parent, **kwargs)`. The example below shows the first few defaults from [tests/models/qianfan_ocr/test_modeling_qianfan_ocr.py](https://github.com/huggingface/transformers/blob/main/tests/models/qianfan_ocr/test_modeling_qianfan_ocr.py). | ||
|
|
||
|
|
@@ -162,6 +176,50 @@ VLM tests differ from `CausalLMModelTest` in a few ways. | |
| - The tester's `__init__` accepts vision parameters (`image_size`, `patch_size`, `num_channels`, `num_image_tokens`) from `**kwargs` and `setdefault()`. | ||
| - `ConfigTester` uses `has_text_modality=False` because the top-level config is a composite config rather than a text model config. | ||
|
|
||
| ## ALMModelTest | ||
|
|
||
| `ALMModelTest` is the base class for audio-language models (ALMs) like Qwen2Audio, AudioFlamingo3, and GraniteSpeech. It mirrors the VLM pattern with the same `MultiModalModelTest` parent and auto-discovery of head classes. The vision-side machinery is swapped for audio features, an audio sub-config, and an audio-token placement strategy. | ||
|
stevhliu marked this conversation as resolved.
|
||
|
|
||
| ```py | ||
| class MyALMTester(ALMModelTester): | ||
| config_class = MyALMConfig | ||
| text_config_class = MyALMTextConfig | ||
| audio_config_class = MyALMAudioConfig | ||
| conditional_generation_class = MyALMForConditionalGeneration | ||
| audio_mask_key = "feature_attention_mask" | ||
|
|
||
|
|
||
| class MyALMTest(ALMModelTest, unittest.TestCase): | ||
| model_tester_class = MyALMTester | ||
| ``` | ||
|
|
||
| ### Overriding defaults in the ALMModelTester | ||
|
|
||
| The tester's `__init__` sets ALM-specific defaults (`feat_seq_length=128`, `num_mel_bins=80`, `audio_token_id=0`). Override them with `setdefault` before calling `super().__init__(parent, **kwargs)`. | ||
|
|
||
| Two class attributes tell the tester how your model names things. | ||
|
|
||
| - `audio_mask_key`: the kwarg name your model expects for the audio mask (`"feature_attention_mask"`, `"input_features_mask"`, etc.). Leave it `None` if your model doesn't consume a separate audio mask. | ||
| - `audio_config_key`: the attribute name your top-level config uses to nest the audio sub-config. Defaults to `"audio_config"` but models like GraniteSpeech use `"encoder_config"`. | ||
|
|
||
| ```py | ||
| class Qwen2AudioModelTester(ALMModelTester): | ||
| def __init__(self, parent, **kwargs): | ||
| kwargs.setdefault("feat_seq_length", 60) | ||
| kwargs.setdefault("max_source_positions", kwargs["feat_seq_length"] // 2) | ||
| super().__init__(parent, **kwargs) | ||
| ``` | ||
|
|
||
| `ALMModelTester` requires you to override one hook, `get_audio_embeds_mask(audio_mask)`, and exposes a few more optional ones for customization. | ||
|
|
||
| - `get_audio_embeds_mask(audio_mask)`: returns the per-batch mask of audio embedding positions after the encoder's downsampling. The tester uses its row-wise sum to decide how many `audio_token_id` placeholders to insert into `input_ids`, so the count must match what your encoder emits. | ||
| - `create_audio_features()`: returns the audio feature tensor. Default shape is `[batch_size, num_mel_bins, feat_seq_length]`. Override when your model, like GraniteSpeech, expects time-first features (`[batch_size, feat_seq_length, num_mel_bins]`). | ||
| - `create_audio_mask()`: returns the audio-level attention mask. The default builds a random contiguous valid region per row in the batch. Override with a deterministic full-length mask if your tests compare two `prepare_config_and_inputs_for_common()` invocations against each other, or if your audio encoder dispatches to a backend that rejects non-null masks. | ||
| - `place_audio_tokens(input_ids, config, num_audio_tokens)`: places audio placeholder tokens contiguously after `BOS`. Override only if your model needs a different layout. | ||
| - `get_audio_feature_key()`: returns the inputs-dict key for audio features (`"input_features"` by default). | ||
|
|
||
| In addition to the inherited multimodal tests, `ALMModelTest` adds `test_mismatching_num_audio_tokens`. The test asserts the model raises a clear `ValueError` when the number of audio features doesn't match the number of audio placeholder tokens in `input_ids`, and verifies that a prompt with multiple audio segments still forwards successfully. | ||
|
|
||
| ## Write tests for other architectures | ||
|
|
||
| For encoder-only, encoder-decoder, audio, or other non-standard architectures, build the test infrastructure directly from the two-class pattern and test mixins described below. | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not keeping the explicit title here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i feel like the explicit title is a bit wordy as its already pretty clear this doc is about writing tests, so it might be nicer to just shorten it to just the test class itself :)