Skip to content

perf(jit): lean cache-key fast path for @flyc.jit launches#669

Open
fsx950223 wants to merge 3 commits into
mainfrom
worktree-jit-key-fastpath
Open

perf(jit): lean cache-key fast path for @flyc.jit launches#669
fsx950223 wants to merge 3 commits into
mainfrom
worktree-jit-key-fastpath

Conversation

@fsx950223

@fsx950223 fsx950223 commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Summary

Every @flyc.jit launch rebuilds its cache key to probe _call_state_cache (polymorphic dispatch — unavoidable). Profiling the softmax launcher showed ~50us of host overhead per cache-hit launch. This PR drives that down in three steps.

1. Lean cache-key fast path (34b5f48)

_resolve_and_make_cache_key was dominated by two avoidable costs:

  1. building a full TensorAdaptor per tensor (DLPack export) just to read a cache signature — ~6.6us/tensor;
  2. isinstance(arg, JitArgument) where JitArgument is a runtime_checkable Protocol → structural check ~2.2us/arg.

Added a lean fast probe that produces a byte-identical key without either cost:

  • TensorAdaptor.lean_cache_signature(t) — derives the signature directly from dtype/shape/strides (no DLPack); the full adaptor is built only on a cache miss when actually compiling.
  • JitFunction._fast_cache_key + _build_fast_key_plan — per-param plan computed once; lean tensor sig + a cheap hasattr("__cache_signature__") instead of the Protocol isinstance; does NOT mutate bound args.
  • __call__ probes with the lean key first; the cache-hit path passes raw args straight to CallState. The miss path is unchanged and its key is identical, so the cached CallState is found by the fast probe.

2. Boundary hardening (e511c9f)

Verified lean_cache_signature stays byte-identical to TensorAdaptor.__cache_signature__ across the layouts where the framework and DLPack stride views can disagree (DLPack coerces unit/zero-size strides) — the lean path follows the framework view via _pick_unit_stride_axis, same as the full path:

broadcast (stride 0), 3d broadcast, leading/mid/all size-1 dims, trailing size-1 (ambiguous unit axis), permuted, channels_last, strided rows, zero-size dims, fp8 transposed.

Plus test_lean_and_full_reject_no_unit_stride_consistently: a tensor with no stride-1 axis is rejected by both paths, so the fast probe never silently dispatches a tensor the full path would refuse. Also corrected the _fast_cache_key comment that overclaimed hasattr(...) ≡ the 3-method isinstance(arg, JitArgument).

3. Drift-check optimization (e33eca2)

After the key build dropped, the next lever was _check_globals_drift (~7.7us, 16% of the cache-hit __call__), which re-summarizes every captured global per launch. Most captured "globals" are modules / imported helpers / classes (softmax: 7 of 9 refs), summarized purely by id() — and the callable branch builds repr(val) every call.

For these identity-stable values an is compare against the baseline object is exactly equivalent to re-summarizing (the snapshot's only discriminant is identity), and holding the object alive also removes id-reuse aliasing. Scalars and containers stay on the by-value path so in-place mutation is still caught.

Measurements (softmax 8192² / 256², MI308X)

before after
cache-key build 52us 8.6us (6x)
_check_globals_drift 7.7us 1.57us (5x)
full __call__ (cache hit) ~108us → ~50us ~40.5us (min-of-8)

Helps the polymorphic launch path (varying shapes, where flyc.compile can't be used); GPU-bound kernels were already hiding this behind execution.

Tests

  • tests/unit/test_lean_cache_signature.py — lean sig ≡ __cache_signature__ and _fast_cache_key_build_full_cache_key across dtypes/ranks/strides incl. all the boundary layouts above.
  • tests/unit/test_jit_cache_key_completeness.pytest_drift_identity_stable_fastpath_and_container_mutation covers the is-hit, is-miss-rebind, and in-place-container-mutation paths; existing drift tests updated to the new baseline shape.
  • Verified test_softmax.py passes. (Note: test_tensor_cache_signature.py failures in my local run are a pre-existing stale-_mlir-binary mismatch, unrelated to these Python-only changes.)

🤖 Generated with Claude Code

Every @flyc.jit launch rebuilds its cache key to probe _call_state_cache
(polymorphic dispatch — unavoidable). Profiling the softmax launcher showed
~50us of that ~108us host overhead was _resolve_and_make_cache_key, dominated
by two avoidable costs:

  1. building a full TensorAdaptor per tensor (DLPack export) just to read a
     cache signature — ~6.6us/tensor;
  2. isinstance(arg, JitArgument) where JitArgument is a runtime_checkable
     Protocol → structural check ~2.2us/arg.

Add a lean fast probe that produces a BYTE-IDENTICAL key without either cost:

- TensorAdaptor.lean_cache_signature(t): derives the signature directly from
  dtype/shape/strides (no DLPack); the full adaptor is built only on a cache
  miss when actually compiling.
- JitFunction._fast_cache_key + _build_fast_key_plan: per-param plan computed
  once; uses lean tensor sig and a cheap hasattr("__cache_signature__") instead
  of the Protocol isinstance; does NOT mutate bound args.
- __call__ probes with the lean key first; the cache-hit path passes raw args
  straight to CallState (whose slot extractors already accept raw tensors). The
  miss path is unchanged (full resolve + compile), and its key is identical to
  the lean key so the CallState it caches is found by the fast probe.

Measured (softmax 8192^2, MI308X): key build 52us -> 8.6us (6x); full __call__
on a launch-bound shape 108us -> 50us. Helps the polymorphic path (varying
shapes, where flyc.compile can't be used); GPU-bound kernels were already
hiding this behind execution.

tests/unit/test_lean_cache_signature.py asserts lean_cache_signature ==
TensorAdaptor.__cache_signature__ and _fast_cache_key == _build_full_cache_key
across dtypes/ranks/strides (incl. fp8, transposed, unit-size, non-pow2).
Verified test_softmax.py and test_pa.py pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 9, 2026 08:56

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the @flyc.jit launch hot path by adding a lean cache-key probe that avoids per-launch TensorAdaptor/DLPack construction and avoids runtime_checkable Protocol structural isinstance checks, while ensuring the fast key is byte-identical to the full cache key.

Changes:

  • Add TensorAdaptor.lean_cache_signature(t) to derive tensor cache signatures from dtype/shape/strides without DLPack.
  • Add JitFunction._fast_cache_key + a lazily-built per-signature plan, and probe _call_state_cache with the lean key before doing full resolution/compilation.
  • Add GPU unit tests validating lean_cache_signature and _fast_cache_key match the full-path signatures/keys.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
tests/unit/test_lean_cache_signature.py Adds regression tests ensuring lean tensor signatures and fast cache keys match the full path.
python/flydsl/compiler/jit_function.py Introduces fast cache-key planning/building and uses it for the cache-hit probe in __call__.
python/flydsl/compiler/jit_argument.py Adds TensorAdaptor.lean_cache_signature to avoid DLPack export on cache-hit probing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread python/flydsl/compiler/jit_function.py Outdated
Comment on lines +1487 to +1490
# Arg already a JitArgument (cheap duck-check == the Protocol isinstance).
if hasattr(arg, "__cache_signature__"):
parts.append((name, arg.__cache_signature__()))
continue
Comment on lines +328 to +332
strides = tuple(int(s) for s in t.stride())
rank = len(strides)
unit_axis = next((i for i, s in enumerate(strides) if s == 1), None)
if unit_axis is None:
raise RuntimeError("tensor has no axis with stride == 1; layout-dynamic memref requires one")
fsx950223 and others added 2 commits June 9, 2026 09:08
…k comment

Boundary review of the lean cache-key fast path. Verified
lean_cache_signature stays byte-identical to TensorAdaptor.__cache_signature__
across the layouts where the framework and DLPack stride views can disagree
(DLPack coerces unit/zero-size strides) — the lean path follows the framework
view via _pick_unit_stride_axis, matching the full path:

  broadcast (stride 0), 3d broadcast, leading/mid/all size-1 dims, trailing
  size-1 (ambiguous unit axis), permuted, channels_last, strided rows,
  zero-size dims, fp8 transposed.

Also add test_lean_and_full_reject_no_unit_stride_consistently: a tensor with
no stride-1 axis cannot be a layout-dynamic memref; both the full path
(TensorAdaptor.__init__) and the lean path must reject it, so the fast probe
never silently dispatches a tensor the full path would refuse.

Fix the _fast_cache_key comment that claimed hasattr("__cache_signature__") is
equivalent to the runtime_checkable isinstance(arg, JitArgument): the Protocol
also probes __get_ir_types__/__get_c_pointers__. They select the same branch
only because every __cache_signature__ implementer in the codebase is a
complete JitArgument — state that instead of claiming exact equivalence.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Follow-up to the lean cache-key fast path. After the key build dropped to
~8.6us, profiling the softmax launcher (256x512, MI308X) showed the next
host-overhead lever was _check_globals_drift at ~7.7us — 16% of the ~50us
cache-hit __call__ — run on every launch.

The cost is re-summarizing every captured global via _snapshot_global_value
each call. For a typical kernel most captured "globals" are modules, imported
helper functions, and enum/classes (softmax captures 9 refs: 7 are
module/function/type, 1 EnumType, only BLOCK_THREADS is a scalar). These are
summarized purely by id() — and the callable branch additionally builds
repr(val) every call.

For such identity-stable values an `is` comparison against the baseline object
is exactly equivalent to (and far cheaper than) re-summarizing the snapshot:
the snapshot's only discriminant is the identity. Holding the baseline object
also keeps it alive, eliminating the id-reuse aliasing that comparing stored
id() snapshots is theoretically prone to.

- _is_identity_stable(val): mirrors the id()-based ("callable"/"obj") branches
  of _snapshot_global_value(stable=False); False for scalars and builtin
  containers (those stay on the by-value path so in-place mutation is caught).
- _snapshot_refs_for_drift(refs): drift baseline now maps (name, mod) ->
  (snapshot, fastref), fastref being the live object for identity-stable values
  else None.
- _check_globals_drift: short-circuits when fastref is not None and the global
  is still bound to it; otherwise falls through to the unchanged full compare.

Measured: _check_globals_drift 7.7us -> 1.57us (5x); full __call__ ~49.6us ->
~40.5us (min-of-8). Drift semantics unchanged: rebinding a scalar OR an
identity-stable function/module still raises, in-place container mutation still
raises, no false positives after restore.

test_drift_identity_stable_fastpath_and_container_mutation covers the is-hit,
is-miss-rebind, and in-place-container-mutation paths; the per-owner-cls drift
test is updated to the new baseline shape via _snapshot_refs_for_drift.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@sjfeng1999

Copy link
Copy Markdown
Collaborator

I’m not fully convinced this is the right layer to optimize. Passing raw torch.Tensor to @flyc.jit is primarily the convenience path; performance-sensitive code can construct the TensorAdaptor explicitly with flyc.from_dlpack(...).mark_layout_dynamic(...) to avoid the cost of constructing dlpack, and can use flyc.compile() to avoid per-call cache-key construction entirely.

@fsx950223

Copy link
Copy Markdown
Contributor Author

I’m not fully convinced this is the right layer to optimize. Passing raw torch.Tensor to @flyc.jit is primarily the convenience path; performance-sensitive code can construct the TensorAdaptor explicitly with flyc.from_dlpack(...).mark_layout_dynamic(...) to avoid the cost of constructing dlpack, and can use flyc.compile() to avoid per-call cache-key construction entirely.

What's the purpose of using dlpack here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants