Skip to content

[MM][Perf][CG] Support dual-path ViT full CUDA graph for DeepSeek-OCR#43586

Open
shen-shanshan wants to merge 15 commits into
vllm-project:mainfrom
shen-shanshan:vit-cg
Open

[MM][Perf][CG] Support dual-path ViT full CUDA graph for DeepSeek-OCR#43586
shen-shanshan wants to merge 15 commits into
vllm-project:mainfrom
shen-shanshan:vit-cg

Conversation

@shen-shanshan

@shen-shanshan shen-shanshan commented May 25, 2026

Copy link
Copy Markdown
Contributor

Purpose

This PR implements the SupportsEncoderCudaGraph protocol for DeepseekOCRForCausalLM, enabling full CUDA graph capture of the vision encoder with a dual-path graph architecture. DeepSeek-OCR uses a two-tower ViT (SAM + CLIP) with a dynamic tiling mechanism — global images at 1024×1024 and optional local patches at 640×640.

Rather than capturing a single monolithic graph for both paths, this PR introduces a dual-path design (enable_dual_path_graph=True): two independent graph sets are captured — one for global images (constant 272 tokens each) and one for local patches (100 tokens each). During inference, the manager independently selects the smallest fitting budget per path, enabling partial graph fallback (one path hits while the other falls back to eager), and skipping local-path graphs entirely when no patches are present. This avoids wasted compute on zero-padded patch buffers for untiled images and avoids graphs that would otherwise be invalidated by variable crop_shape per image.

Note

Find more background details at DeepSeek-OCR technical report.


TODO:

Test Plan

E2E functional test:

python examples/generate/multimodal/vision_language_offline.py -m deepseek_ocr --modality "image" --enable-vit-cuda-graph

Benchmark:

vllm bench mm-processor \
--model /shared/models/modelscope/models/deepseek-ai/DeepSeek-OCR \
--max-model-len 8192 \
--dataset-name random-mm \
--random-mm-base-items-per-request 1 \
--random-mm-num-mm-items-range-ratio 0.0 \
--random-mm-bucket-config '{(896, 896, 1): 1.0}' \
--random-mm-limit-mm-per-prompt '{"image": 1, "video": 0}' \
--num-prompts 100 \
--compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_max_vision_items_per_batch": 8}'

Test Result

E2E functional test results:

--------------------------------------------------
The image captures the majestic Tokyo Skytree, the tallest tower in Japan, standing tall against the backdrop of a clear blue sky. The tower, painted in a pristine white, is adorned with a distinctive blue and white striped pattern on its upper section, adding a touch of color to the otherwise monochrome structure. The perspective of the photo is from a low angle, giving the viewer a sense of the tower's impressive height. In the foreground, cherry blossom trees in full bloom add a splash of pink to the scene, their delicate petals contrasting beautifully with the stark white of the tower. The image beautifully encapsulates the blend of urban architecture and natural beauty that characterizes Tokyo.
--------------------------------------------------
The image captures the majestic Tokyo Skytree, the tallest tower in Japan, standing tall against the backdrop of a clear blue sky. The tower, painted in a pristine white, is adorned with a distinctive blue and white striped pattern on its upper section, adding a touch of color to the otherwise monochrome structure. The perspective of the photo is from a low angle, giving the viewer a sense of the tower's impressive height. In the foreground, cherry blossom trees in full bloom add a splash of pink to the scene, their branches reaching out towards the tower. The image beautifully juxtaposes the modernity of the Tokyo Skytree with the natural beauty of the cherry blossoms.
--------------------------------------------------
The image captures the majestic Tokyo Skytree, the tallest tower in Japan, standing tall against the backdrop of a clear blue sky. The tower, painted in a pristine white, is adorned with a distinctive blue and white striped pattern on its upper section, adding a touch of color to the otherwise monochrome structure. The perspective of the photo is from a low angle, giving the viewer a sense of the tower's impressive height. In the foreground, cherry blossom trees in full bloom add a splash of pink to the scene, their branches reaching out towards the tower. The image beautifully juxtaposes the modernity of the Tokyo Skytree with the natural beauty of the cherry blossoms.
--------------------------------------------------
The image captures the majestic Tokyo Skytree, the tallest tower in Japan, standing tall against the backdrop of a clear blue sky. The tower, painted in a pristine white, is adorned with a distinctive blue and white striped pattern on its upper section, adding a touch of color to the otherwise monochrome structure. The perspective of the photo is from a low angle, giving the viewer a sense of the tower's impressive height. In the foreground, cherry blossom trees in full bloom add a splash of pink to the scene, their branches reaching out towards the tower. The image beautifully juxtaposes the modernity of the Tokyo Skytree with the natural beauty of the cherry blossoms.
--------------------------------------------------

Benchmark results (old version: only contains global images into CUDA graph):

Input Size Tiled Mean Latency P99 Latency
(224, 224) -12.76% (20.22ms -> 17.64ms) ↓ -15.49% (22.59ms -> 19.09ms) ↓
(448, 448) -17.10% (21.34ms -> 17.69ms) ↓ -29.15% (28.47ms -> 20.17ms) ↓
(896, 896) -2.95% (42.08ms -> 40.84ms) ↓ -6.60% (44.07ms -> 41.16ms) ↓
(1024, 1024) -4.75% (42.96ms -> 40.92ms) ↓ -6.96% (45.00ms -> 41.87ms) ↓

Benchmark results (new version: dual-path graph select for global images and local patches respectively):

Input Size Tiled Mean Latency P99 Latency
(224, 224) -13.93% (20.74ms -> 17.85ms) ↓ -15.49% (22.59ms -> 19.09ms) ↓
(448, 448) -20.96% (23.24ms -> 18.37ms) ↓ -27.10% (29.11ms -> 21.22ms) ↓
(896, 896) -4.59% (42.08ms -> 40.15ms) ↓ -8.12% (44.07ms -> 40.49ms) ↓
(1024, 1024) -6.56% (42.96ms -> 40.14ms) ↓ -10.31% (45.00ms -> 40.36ms) ↓

Note

When image_width <= 640 and image_height <= 640, the mm inputs will only contain global image, without generating local patches.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@shen-shanshan shen-shanshan marked this pull request as draft May 25, 2026 09:24
@mergify

mergify Bot commented May 25, 2026

Copy link
Copy Markdown
Contributor

Documentation preview: https://vllm--43586.org.readthedocs.build/en/43586/

@mergify mergify Bot added documentation Improvements or additions to documentation deepseek Related to DeepSeek models nvidia labels May 25, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the SupportsEncoderCudaGraph protocol for the DeepseekOCRForCausalLM model, enabling CUDA graph support for its vision encoder. The implementation includes methods for token calculation, input preparation, and post-processing. Critical feedback was provided regarding a potential TypeError in _get_num_input_output_tokens due to a missing null check, and incorrect configuration in EncoderCudaGraphConfig where input_key_by_modality should be used instead of input_keys. Furthermore, the images_crop tensor must be explicitly registered in buffer_keys and included in both capture and replay buffers to ensure correct graph execution. Finally, the get_max_frames_per_video method should be added to fully comply with the protocol.

Comment thread vllm/model_executor/models/deepseek_ocr.py Outdated
Comment thread vllm/model_executor/models/deepseek_ocr.py
Comment thread vllm/model_executor/models/deepseek_ocr.py Outdated
Comment thread vllm/model_executor/models/deepseek_ocr.py Outdated
Comment thread vllm/model_executor/models/deepseek_ocr.py Outdated
@mergify

mergify Bot commented May 29, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @shen-shanshan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 29, 2026
@mergify mergify Bot removed the needs-rebase label May 29, 2026
@shen-shanshan shen-shanshan marked this pull request as ready for review June 3, 2026 04:01
@mergify mergify Bot added the multi-modality Related to multi-modality (#4194) label Jun 3, 2026
@mergify

mergify Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @shen-shanshan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@shen-shanshan

Copy link
Copy Markdown
Contributor Author

CC @Isotr0py

At first, I have tried to contain the local_patch encoding into the ViT cuda graph, but I find this solution has some drawbacks:

  1. The grid and number of local_patch is dynamic according to the input images, and it also requires to add newline tokens to the end of each raw in the grid, which makes the tensor shape dynamic and unpredictable. Thus, we have to only put the encoding of the raw local_patch into the cuda graph, then add newline tokens in the postprocess_encoder_output().
  2. To make sure images with max number of local_patch could be correctly repalyed, we have to capture max_crops buffers for each input, even for images with size < (640, 640), which will not be tiled and will not have local patches. In this case, the performance is worse than the eager execution, since it can lead to additional and redundant dummy local_patch replay.

Thus, I decide to only contain global_image encoding into the ViT cuda graph, with eager executing local_patch encoding, then assemble them in postprocess_encoder_output(). This is a tradeoff of performance for both images < (640, 640) and larger images.

Future Plan (in following PRs): I want to explore dual-path ViT cuda graph budget selecting mechanism for DeepSeek-OCR and Step3-VL to decouple the global_image replay and local_patch replay.

Comment thread docs/design/cuda_graphs_multimodal.md Outdated
Comment thread vllm/model_executor/models/deepseek_ocr.py Outdated
Comment thread vllm/model_executor/models/deepseek_ocr.py
@mergify

mergify Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @shen-shanshan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@Isotr0py Isotr0py self-assigned this Jun 12, 2026

@Isotr0py Isotr0py left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, leave a nit.

Comment thread vllm/v1/worker/encoder_cudagraph.py Outdated
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Jun 12, 2026
@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 12, 2026
@mergify

mergify Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Hi @shen-shanshan, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

@mergify

mergify Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @shen-shanshan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 12, 2026
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
@Isotr0py

Copy link
Copy Markdown
Member

Seems multimodal CI failures are related: https://buildkite.com/vllm/ci/builds/71879#019ebeee-1634-4a5e-b0a4-3280eadb8c8b

Signed-off-by: shen-shanshan <467638484@qq.com>
@mergify mergify Bot added llama Related to Llama models qwen Related to Qwen models labels Jun 13, 2026
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
@shen-shanshan

Copy link
Copy Markdown
Contributor Author

Seems multimodal CI failures are related: https://buildkite.com/vllm/ci/builds/71879#019ebeee-1634-4a5e-b0a4-3280eadb8c8b

Sorry, there are still some issues, please don't merge. I will fix it recently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models documentation Improvements or additions to documentation llama Related to Llama models multi-modality Related to multi-modality (#4194) nvidia qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Ready

Development

Successfully merging this pull request may close these issues.

2 participants