Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions docs/source/en/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,21 @@ RUN_SLOW=1 pytest tests/models/mymodel/ -v

The Hugging Face CI runs model tests without `@slow` on every pull request, and slow tests run on a nightly schedule (see [Pull request checks](./pr_checks) for what the CI validates).

## Write tests for a causal language model
## Pick a base test class

Three base classes cover the most common model families. Pick the one that matches your model's modality.

| Base class | Use for | Mixins |
|---|---|---|
| `CausalLMModelTest` | Causal language models | `ModelTesterMixin`, `GenerationTesterMixin`, `PipelineTesterMixin`, `TrainingTesterMixin`, `TensorParallelTesterMixin` |
| `VLMModelTest` | Vision-language models | `ModelTesterMixin`, `GenerationTesterMixin`, `PipelineTesterMixin` |
| `ALMModelTest` | Audio-language models | `ModelTesterMixin`, `GenerationTesterMixin`, `PipelineTesterMixin` |

`VLMModelTest` and `ALMModelTest` share a common `MultiModalModelTest` parent that nests sub-configs into a composite top-level config and places modality placeholder tokens in `input_ids` alongside the raw modality features (audio or vision). `CausalLMModelTest` doesn't use the multimodal parent. It builds on the three shared mixins and adds `TrainingTesterMixin` and `TensorParallelTesterMixin` for training and tensor-parallel coverage.

For architectures that don't fit any of the three (encoder-only, encoder-decoder, etc.), build the test infrastructure directly from the [two-class pattern](#modeltester-and-modeltest) and [test mixins](#test-mixins) described below.

## CausalLMModelTest

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not keeping the explicit title here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel like the explicit title is a bit wordy as its already pretty clear this doc is about writing tests, so it might be nicer to just shorten it to just the test class itself :)


`CausalLMModelTest` is the recommended base class for testing causal language models. It inherits from five [test mixins](#test-mixins) and auto-generates tests for save/load, generation, pipelines, training, and tensor parallelism.

Expand Down Expand Up @@ -66,7 +80,7 @@ These two classes give full test coverage for `MyModel` and all its head classes

`CausalLMModelTester` only requires `base_model_class`. The tester strips the `Model` suffix to get a base name (`LlamaModel` becomes `Llama`), then appends suffixes like `Config` or `ForCausalLM` to discover related classes. If a class doesn't exist in the module, the attribute stays `None` and the tester skips the corresponding tests.

### Overriding defaults
### Overriding defaults in the CausalLMTester

If your model doesn't follow standard naming, or you need to customize behavior, override attributes on the tester or test class.

Expand Down Expand Up @@ -98,7 +112,7 @@ class YoutuModelTester(CausalLMModelTester):
self.q_lora_rank = q_lora_rank
```

## Write tests for a vision-language model
## VLMModelTest

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question


`VLMModelTest` is the base class for vision-language models. It inherits from three mixins (`ModelTesterMixin`, `GenerationTesterMixin`, `PipelineTesterMixin`) and sets `_is_composite = True` to handle multiple sub-models.

Expand Down Expand Up @@ -134,7 +148,7 @@ class MyVLMTest(VLMModelTest, unittest.TestCase):
model_tester_class = MyVLMTester
```

### Overriding defaults
### Overriding defaults in the VLMModelTester

When the VLM needs custom vision parameters or non-default config values, override `__init__`. Set defaults with `setdefault` before calling `super().__init__(parent, **kwargs)`. The example below shows the first few defaults from [tests/models/qianfan_ocr/test_modeling_qianfan_ocr.py](https://github.com/huggingface/transformers/blob/main/tests/models/qianfan_ocr/test_modeling_qianfan_ocr.py).

Expand Down Expand Up @@ -162,6 +176,50 @@ VLM tests differ from `CausalLMModelTest` in a few ways.
- The tester's `__init__` accepts vision parameters (`image_size`, `patch_size`, `num_channels`, `num_image_tokens`) from `**kwargs` and `setdefault()`.
- `ConfigTester` uses `has_text_modality=False` because the top-level config is a composite config rather than a text model config.

## ALMModelTest

`ALMModelTest` is the base class for audio-language models (ALMs) like Qwen2Audio, AudioFlamingo3, and GraniteSpeech. It mirrors the VLM pattern with the same `MultiModalModelTest` parent and auto-discovery of head classes. The vision-side machinery is swapped for audio features, an audio sub-config, and an audio-token placement strategy.
Comment thread
stevhliu marked this conversation as resolved.

```py
class MyALMTester(ALMModelTester):
config_class = MyALMConfig
text_config_class = MyALMTextConfig
audio_config_class = MyALMAudioConfig
conditional_generation_class = MyALMForConditionalGeneration
audio_mask_key = "feature_attention_mask"


class MyALMTest(ALMModelTest, unittest.TestCase):
model_tester_class = MyALMTester
```

### Overriding defaults in the ALMModelTester

The tester's `__init__` sets ALM-specific defaults (`feat_seq_length=128`, `num_mel_bins=80`, `audio_token_id=0`). Override them with `setdefault` before calling `super().__init__(parent, **kwargs)`.

Two class attributes tell the tester how your model names things.

- `audio_mask_key`: the kwarg name your model expects for the audio mask (`"feature_attention_mask"`, `"input_features_mask"`, etc.). Leave it `None` if your model doesn't consume a separate audio mask.
- `audio_config_key`: the attribute name your top-level config uses to nest the audio sub-config. Defaults to `"audio_config"` but models like GraniteSpeech use `"encoder_config"`.

```py
class Qwen2AudioModelTester(ALMModelTester):
def __init__(self, parent, **kwargs):
kwargs.setdefault("feat_seq_length", 60)
kwargs.setdefault("max_source_positions", kwargs["feat_seq_length"] // 2)
super().__init__(parent, **kwargs)
```

`ALMModelTester` requires you to override one hook, `get_audio_embeds_mask(audio_mask)`, and exposes a few more optional ones for customization.

- `get_audio_embeds_mask(audio_mask)`: returns the per-batch mask of audio embedding positions after the encoder's downsampling. The tester uses its row-wise sum to decide how many `audio_token_id` placeholders to insert into `input_ids`, so the count must match what your encoder emits.
- `create_audio_features()`: returns the audio feature tensor. Default shape is `[batch_size, num_mel_bins, feat_seq_length]`. Override when your model, like GraniteSpeech, expects time-first features (`[batch_size, feat_seq_length, num_mel_bins]`).
- `create_audio_mask()`: returns the audio-level attention mask. The default builds a random contiguous valid region per row in the batch. Override with a deterministic full-length mask if your tests compare two `prepare_config_and_inputs_for_common()` invocations against each other, or if your audio encoder dispatches to a backend that rejects non-null masks.
- `place_audio_tokens(input_ids, config, num_audio_tokens)`: places audio placeholder tokens contiguously after `BOS`. Override only if your model needs a different layout.
- `get_audio_feature_key()`: returns the inputs-dict key for audio features (`"input_features"` by default).

In addition to the inherited multimodal tests, `ALMModelTest` adds `test_mismatching_num_audio_tokens`. The test asserts the model raises a clear `ValueError` when the number of audio features doesn't match the number of audio placeholder tokens in `input_ids`, and verifies that a prompt with multiple audio segments still forwards successfully.

## Write tests for other architectures

For encoder-only, encoder-decoder, audio, or other non-standard architectures, build the test infrastructure directly from the two-class pattern and test mixins described below.
Expand Down
Loading