Skip to content

⚡ Bolt: Optimize _compute_repulsion_loss with torch.cdist#116

Open
hawkh wants to merge 1 commit into
mainfrom
bolt/optimize-repulsion-loss-5012791570065780251
Open

⚡ Bolt: Optimize _compute_repulsion_loss with torch.cdist#116
hawkh wants to merge 1 commit into
mainfrom
bolt/optimize-repulsion-loss-5012791570065780251

Conversation

@hawkh

@hawkh hawkh commented Jun 10, 2026

Copy link
Copy Markdown
Owner

💡 What: Replaced explicit unsqueeze dimension expansion in _compute_repulsion_loss with torch.cdist.
🎯 Why: The previous method manually computed pairwise distances class_embs.unsqueeze(0) - class_embs.unsqueeze(1) which allocated an intermediate tensor of size O(K^2 * D). This caused unnecessarily high memory usage and latency, presenting a critical performance bottleneck (PyTorch anti-pattern).
📊 Impact: Reduces pairwise distance memory allocation from O(K^2 * D) to O(K^2) while remaining mathematically equivalent. It improves both speed and memory efficiency during the GAN training phase.
🔬 Measurement: Run the test suite and monitor memory allocations and training speed during the GAN Phase 2 training.


PR created automatically by Jules for task 5012791570065780251 started by @hawkh

Summary by CodeRabbit

  • Refactor

    • Optimized repulsion loss computation in the image synthesis pipeline for improved memory efficiency and latency.
  • Documentation

    • Added optimization recommendations for pairwise distance computations in PyTorch.

Replaced memory-heavy `unsqueeze(0) - unsqueeze(1)` pairwise distance approach with `torch.cdist(class_embs, class_embs, p=2.0)` combined with `torch.triu_indices`. This optimization significantly reduces pairwise distance memory allocation from O(K^2 * D) to O(K^2) avoiding massive intermediate tensors in PyTorch. Also logged learning to Bolt journal.

Co-authored-by: hawkh <113750504+hawkh@users.noreply.github.com>
@google-labs-jules

Copy link
Copy Markdown

👋 Jules, reporting for duty! I'm here to lend a hand with this pull request.

When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down.

I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job!

For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with @jules. You can find this option in the Pull Request section of your global Jules UI settings. You can always switch back!

New to Jules? Learn more at jules.google/docs.


For security, I will only act on instructions from the user who triggered this task.

@qodo-code-review

qodo-code-review Bot commented Jun 10, 2026

Copy link
Copy Markdown

Code Review by Qodo

🐞 Bugs (0) 📘 Rule violations (0) 📎 Requirement gaps (0)

Grey Divider

Great, no issues found!

Qodo reviewed your code and found no material issues that require review

Grey Divider

ⓘ You are approaching your monthly quota for Qodo. Upgrade your plan

Qodo Logo

@coderabbitai

coderabbitai Bot commented Jun 10, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f337b5d1-e1e3-4760-adcf-bfb0fb2ba72f

📥 Commits

Reviewing files that changed from the base of the PR and between 398921c and e5bedc3.

📒 Files selected for processing (2)
  • .jules/bolt.md
  • smote_image_synthesis/pipeline.py

📝 Walkthrough

Walkthrough

This PR optimizes pairwise distance computation in the repulsion loss function by replacing manual tensor expansion with torch.cdist, and extracts unique pairs via torch.triu_indices instead of exhaustive allocation. The change improves memory usage and latency while maintaining identical return semantics. Documentation of the optimization technique is included.

Changes

Repulsion Loss Optimization

Layer / File(s) Summary
Implement torch.cdist optimization for pairwise distances
.jules/bolt.md, smote_image_synthesis/pipeline.py
Repulsion loss pairwise distance computation switches from manual tensor expansion to torch.cdist(x, x, p=2). Unique same-class non-self pair selection now uses torch.triu_indices. Pair counting is updated accordingly. Documentation records the optimization and its memory/latency benefits.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

🐰 A hop, skip, and cdist away—
No more tensor sprawl to dismay!
With triu_indices tight and true,
Pairwise repulsion's leaner crew.
Memory soars, latency glides through! ✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly identifies the main change: optimizing _compute_repulsion_loss using torch.cdist, which is the primary focus of both file modifications.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch bolt/optimize-repulsion-loss-5012791570065780251

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@qodo-code-review

Copy link
Copy Markdown

PR Summary by Qodo

Optimize repulsion loss pairwise distances with torch.cdist
✨ Enhancement 📝 Documentation 🕐 10-20 Minutes

Grey Divider

Walkthroughs

Description
• Replace O(K^2·D) pairwise diff tensor with torch.cdist for repulsion loss.
• Avoid boolean mask allocation by indexing upper-triangle pairs via triu_indices.
• Document the PyTorch performance anti-pattern and recommended replacement in Bolt journal.
Diagram
graph TD
  A["Phase 2 training"] --> B["_compute_repulsion_loss"] --> C["Class embeddings"] --> D["torch.cdist"] --> E["triu_indices"] --> F["Repulsion loss"]

  subgraph Legend
    direction LR
    _proc["Process"] ~~~ _op["Tensor op"]
  end
Loading
High-Level Assessment

The following are alternative approaches to this PR:

1. Use torch.pdist for upper-triangle distances
  • ➕ Returns condensed distance vector directly (avoids materializing the full [K, K] matrix).
  • ➕ Naturally matches the need for unique i<j pairs.
  • ➖ Less straightforward to map back to (i,j) pairs if later needed.
  • ➖ May have different kernel performance characteristics than cdist on some devices/shapes.
2. Compute squared distances via Gram matrix (x·xᵀ)
  • ➕ Can reuse optimized GEMM kernels; can compute squared distances without explicit broadcast.
  • ➖ More complex and error-prone (numerical stability, clamping negatives).
  • ➖ Still produces an O(K^2) matrix; limited benefit vs cdist here.

Recommendation: The PR’s approach (cdist + triu_indices) is a strong default: it removes the O(K^2·D) intermediate tensor and avoids a large boolean mask, while keeping the implementation simple and readable. Consider torch.pdist only if profiling shows the remaining [K, K] distance matrix is still a bottleneck for large K.

Grey Divider

File Changes

Enhancement (1)
pipeline.py Optimize repulsion loss distance computation using cdist + upper-triangle indices +6/-6

Optimize repulsion loss distance computation using cdist + upper-triangle indices

• Replaces broadcasted subtraction and norm with torch.cdist to avoid allocating a [K, K, D] intermediate. Switches upper-triangle selection from a boolean mask to torch.triu_indices and uses numel() for pair counting.

smote_image_synthesis/pipeline.py


Documentation (1)
bolt.md Add Bolt journal note on efficient pairwise distance computation +3/-0

Add Bolt journal note on efficient pairwise distance computation

• Introduces a short journal entry documenting why broadcasted pairwise diffs are memory-heavy and recommending torch.cdist plus triu_indices to avoid large intermediate allocations.

.jules/bolt.md


Grey Divider

ⓘ You are approaching your monthly quota for Qodo. Upgrade your plan

Qodo Logo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant