PriorTR's Visual Token Reduction (VTR) framework uses a strategy pattern: a token-reduction method is a small class that computes an importance score per visual token. The framework handles everything else — selecting Top-K, physically dropping tokens, and re-plumbing the attention mask, position IDs, KV cache, and RoPE. You do not touch the model forward, the cache, or the masking.
Paths below are for LLaVA (image/LLaVA/llava/vtr/); InternVL (internvl_vtr/), Qwen3-VL
(visual_token_pruning/), and Video-LLaVA (videollava/vtr/) mirror the same shape.
1. Create llava/vtr/strategy/mymethod.py — implement the one required method, compute_scores:
from .base import PruningStrategy
from .registry import register_strategy
@register_strategy("mymethod") # ← the name you pass as strategy=mymethod
class MyStrategy(PruningStrategy):
def compute_scores(self, attention, image_token_range, config, **ctx):
# attention: [batch, heads, seq, seq] at the prune layer
# image_token_range: (img_start, img_end)
# return: one score per image token, shape [num_img]
img_start, img_end = image_token_range
return attention[:, :, -1, img_start:img_end].mean(dim=1).squeeze(0)2. Register it — add one import to llava/vtr/strategy/__init__.py so the decorator runs:
from .mymethod import MyStrategy3. Run it — select your strategy by name:
python -m lmms_eval --model llava_vtr \
--model_args pretrained=liuhaotian/llava-v1.5-7b,strategy=mymethod \
--tasks mme --batch_size 1 --output_path ./results/mymethodThat's the whole loop. A FastV-style "rank by attention magnitude" method is essentially one line in
compute_scores.
PruningStrategy (strategy/base.py) already provides:
select_tokens(scores, num_tokens, config)— turns your scores into kept indices, honoringkeep_tokens→score_threshold→keep_ratio(in that priority) and preserving original order. The framework calls this for you; you never write Top-K logic._aggregate_attention(attention, image_token_range, config)— collapses the attention matrix to per-token scores usingquery_aggregation(last/question) andhead_aggregation(mean/max). Call it insidecompute_scoresif you just want pre-aggregated attention:scores = self._aggregate_attention(attention, image_token_range, config).
You can also subclass an existing strategy (e.g. class MyVariant(PriorTRStrategy)) and override a
piece of the scoring.
Add a field to VTRConfig (vtr/config.py):
@dataclass
class VTRConfig:
...
my_temperature: float = 1.0 # your new knobRead it in your strategy as config.my_temperature. It is automatically passable via
--model_args ...,my_temperature=0.7 (image models) or the matching --vtr_* flag (Video-LLaVA).
compute_scores(attention, image_token_range, config, **ctx) -> Tensor[num_img]
attention—[batch, heads, seq, seq]from the prune layer (the framework forces this layer to output attention via eager/sdpa;flash_attention_2cannot return weights, so usesdpa/eager).image_token_range—(img_start, img_end); tokens before/after the image are always kept.config— the activeVTRConfig.**ctx— extra per-call context (see below).- Return — one score per image token; higher = more important. The framework keeps the Top-K and physically prunes the rest, then continues from the next layer on the shortened sequence.
Pass a list: prune_layer=[3, 7, 16]. The framework prunes at each layer in turn, re-computing the
image-token range and re-plumbing the mask / position IDs / KV cache / RoPE between layers. Every
prune layer calls the same strategy (there is a single global self._vtr_strategy).
Not built in, but easy. Write a composite/router strategy that holds several sub-strategies and
dispatches by layer. The only missing piece is the current layer index, which compute_scores does not
receive today — route it through ctx:
# in prunable_llama.py, where attention is grabbed in the layer loop, add one line:
vtr_ctx["layer_idx"] = layer_idx # before _compute_and_apply_pruning(...)@register_strategy("router")
class RouterStrategy(PruningStrategy):
def __init__(self):
self.by_layer = {3: PriorTRStrategy(), 7: FastVStrategy()} # layer -> strategy
def compute_scores(self, attention, image_token_range, config, **ctx):
sub = self.by_layer[ctx["layer_idx"]]
return sub.compute_scores(attention, image_token_range, config, **ctx)The framework gives you two places to keep state across the prune layers of a single forward:
vtr_ctxis one shared dict threaded through every prune layer of a forward. Put a mutable container in it (e.g.vtr_ctx["cache"] = {}) and your strategy can write at layer 3 and read at layer 7.- The strategy instance (
self._vtr_strategy) persists for the model's lifetime, so it can stashself._cache = .... Reset it at the start of each forward, or state leaks across calls.
The one gap: compute_scores currently receives only the attention matrix, not the layer's
hidden states / features. Those are available at the call site
(_compute_and_apply_pruning(hidden_states=..., ...)) but not forwarded. If your method needs to cache
features (not just attention), add one line to route them in:
# in _compute_and_apply_pruning, before calling compute_scores:
vtr_ctx["hidden_states"] = hidden_statesthen read ctx["hidden_states"] in your strategy.
If your method needs a second forward, pre-LLM pruning, token merging, or multi-source signals, the
Qwen3-VL subproject (image/Qwen3-VL/visual_token_pruning/) is the richest template — it already
implements a two-forward prior (model/prior_utils.py), VisPruner-style pre-LLM pruning, token merging
(model/token_merge.py), and multi-source handling (model/deepstack_handler.py). Copy the closest
existing strategy there rather than extending LLaVA's minimal framework.
Each model ships its own VTR framework (llava/vtr, internvl_vtr, visual_token_pruning,
videollava/vtr) — independent copies, not a shared core. A strategy added to LLaVA only affects
LLaVA; to support several backbones, drop an equivalent strategy file into each (the interfaces are
nearly identical).