Skip to content

Reuse the HuggingFace cache for model assets#46

Open
zboyles wants to merge 1 commit into
magenta:mainfrom
zboyles:respect-hf-cache
Open

Reuse the HuggingFace cache for model assets#46
zboyles wants to merge 1 commit into
magenta:mainfrom
zboyles:respect-hf-cache

Conversation

@zboyles

@zboyles zboyles commented Jun 9, 2026

Copy link
Copy Markdown

Reuse the HuggingFace cache for model assets

Model assets (musiccoca + spectrostream resources, exported .mlxfn models, raw checkpoints) were always downloaded into ~/Documents/Magenta/magenta-rt-v2/ via hf_hub_download(local_dir=...), which never consults the global HF cache. A repo already pulled with hf download google/magenta-realtime-2 was re-downloaded over the network, and assets were duplicated on disk.

Resolve assets MAGENTA_HOME-first, then the global HF cache:

  • paths.py: add resolve_asset() plus resolve_musiccoca_dir / resolve_spectrostream_dir / resolve_model_dir helpers. The local MAGENTA_HOME layout always wins (so local exports, GCS downloads, and user overrides take precedence); otherwise the asset is served from the global HF cache. Load paths reuse the cache only (local_files_only) — no network fetch at load time — so default behavior is unchanged except that an existing hf download is now picked up. A missing asset raises a clear FileNotFoundError pointing at both locations.

  • Route load sites through the resolver: musiccoca.py, mlx/system.py (mlxfn model dir + checkpoint), jax/system.py (checkpoint).

  • mlx/spectrostream/load_weights.py: when encoder.safetensors is not a sibling of the checkpoint (e.g. assets live in the HF cache, where the checkpoint and resources sit in different repo subdirs), fall back to the resolved spectrostream resource dir. Guarded lazy import keeps the module standalone.

  • resolve_checkpoint(): build the candidate path directly instead of via checkpoints_dir(), which mkdir's as a side effect. Resolution no longer creates directories, so a cache hit leaves MAGENTA_HOME untouched.

CLI (opt-in, default behavior unchanged):

  • Add --use-hf-cache (env: MAGENTA_RT_USE_HF_CACHE) to mrt models init, mrt models download, and mrt checkpoints download. When set, downloads populate/reuse the global HF cache (omit local_dir) and the interactive picker's checkmarks consult the cache. The flag is HuggingFace-only; the GCS source and --download-path are unaffected.

Docs: note cache reuse and the flag in README.md and docs/models.md.

Related Issues

Local Pytests

Note

This change reroutes where assets load from (MAGENTA_HOME → global HF cache); it does not alter model math or sampling. Validation below focuses on (a) cache reuse and (b) that both load paths still load correctly from the cache. The numeric-parity/bitlevel tests are gated on a generated reference in this environment and skip; they're included to show no regressions. Paths anonymized.

I ran

# 1. Warm the cache (first run downloads into ~/.cache/huggingface), then re-run = cache hit
mrt models init --use-hf-cache
Initializing model resources from HuggingFace → HuggingFace cache

📦 Downloading resources/musiccoca …
  Caching audio_preprocessor.tflite …
resources/musiccoca/audio_preprocessor.t(…): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8.73M/8.73M [00:00<00:00, 14.8MB/s]
  Caching mapper.tflite …
resources/musiccoca/mapper.tflite: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 86.2M/86.2M [00:01<00:00, 61.4MB/s]
  Caching music_encoder.tflite …
resources/musiccoca/music_encoder.tflite: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 371M/371M [00:05<00:00, 66.5MB/s]
  Caching pretrained_vector_quantizer.tflite …
resources/musiccoca/pretrained_vector_qu(…): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72.4M/72.4M [00:01<00:00, 51.5MB/s]
  Caching spm.model …
  Caching text_encoder.tflite …
resources/musiccoca/text_encoder.tflite: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 419M/419M [00:06<00:00, 64.2MB/s]

📦 Downloading resources/spectrostream …
  Caching decoder.safetensors …
resources/spectrostream/decoder.safetens(…): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 210M/210M [00:04<00:00, 52.0MB/s]
  Caching encoder.safetensors …
resources/spectrostream/encoder.safetens(…): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 37.0M/37.0M [00:01<00:00, 27.8MB/s]
  Caching quantizer.safetensors …
resources/spectrostream/quantizer.safete(…): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67.1M/67.1M [00:01<00:00, 37.4MB/s]
  Caching spectrostream_encoder.mlxfn …
resources/spectrostream/spectrostream_en(…): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104M/104M [00:01<00:00, 57.1MB/s]

✓ Init complete.


time mrt models init --use-hf-cache

Initializing model resources from HuggingFace → HuggingFace cache

📦 Downloading resources/musiccoca …
  Caching audio_preprocessor.tflite …
  Caching mapper.tflite …
  Caching music_encoder.tflite …
  Caching pretrained_vector_quantizer.tflite …
  Caching spm.model …
  Caching text_encoder.tflite …

📦 Downloading resources/spectrostream …
  Caching decoder.safetensors …
  Caching encoder.safetensors …
  Caching quantizer.safetensors …
  Caching spectrostream_encoder.mlxfn …

✓ Init complete.


# 2. Confirm the custom folder stays clean (nothing written to ~/Documents)
ls -R ~/Documents/Magenta/magenta-rt-v2/ 2>/dev/null || echo "magenta-rt-v2 has no downloaded assets (as intended)" 
magenta-rt-v2 has no downloaded assets (as intended)

# 3. Pull a model into the cache and generate (default .mlxfn path)

mrt models download mrt2_small --use-hf-cache
📦 Downloading model mrt2_small from HuggingFace → HuggingFace cache …
  Caching mrt2_small.mlxfn …
  Caching mrt2_small_state.safetensors …

✓ Model 'mrt2_small' downloaded.


mrt mlx generate --prompt "disco funk" --duration 4.0 --model=mrt2_small
INFO:magenta_rt.mlx.system:Loading mlxfn: ~/.cache/huggingface/hub/models--google--magenta-realtime-2/snapshots/010aa0dcb0dfd27b24f0ad07b4dad63e8f9521cc/models/mrt2_small/mrt2_small.mlxfn
INFO:magenta_rt.mlx.system:Loaded 165 state arrays
INFO:magenta_rt.mlx.system:Warming up (5 steps)...
INFO:magenta_rt.mlx.system:Warm-up done (1.0s).
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Generated 100 frames in 1.3s (78.4 steps/s, 12.8 ms/step)
Target: 25 steps/s, 40 ms/step for real-time
Saved to ~/Documents/Magenta/magenta-rt-v2/outputs/output_audio_mlx_mrt2_small.wav (4.0s of audio)


# 4. Exercise the Python (--no-mlxfn) path: checkpoint + SpectroStream encoder from cache

mrt checkpoints download mrt2_small.safetensors --use-hf-cache 
📦 Downloading checkpoint mrt2_small.safetensors from HuggingFace → HuggingFace cache …
checkpoints/mrt2_small.safetensors: 100%|██████████████████████████████████████████████████████████████████████████| 1.13G/1.13G [00:14<00:00, 77.7MB/s]

✓ Checkpoint 'mrt2_small.safetensors' downloaded.


mrt mlx generate --prompt "disco funk" --duration 4.0 --model=mrt2_small --no-mlxfn --bits=8

INFO:magenta_rt.mlx.system:Loading checkpoint: ~/.cache/huggingface/hub/models--google--magenta-realtime-2/snapshots/010aa0dcb0dfd27b24f0ad07b4dad63e8f9521cc/checkpoints/mrt2_small.safetensors
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: '/opt/homebrew/opt/ffmpeg/lib/libtpu.so' (no such file), 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
WARNING:root:Classifier free guidance scale(s) specified, but negative(s) not specified; using zeros as negatives.
INFO:magenta_rt.mlx.system:Quantizing to 8-bit (group_size=64).
INFO:magenta_rt.mlx.system:Warming up...
INFO:magenta_rt.mlx.system:Warm-up done (0.3s).
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Materializing deferred layers...
Materialized ✓
Loading checkpoint: ~/.cache/huggingface/hub/models--google--magenta-realtime-2/snapshots/010aa0dcb0dfd27b24f0ad07b4dad63e8f9521cc/checkpoints/mrt2_small.safetensors
Loading encoder...
  Detected branched encoder embedding.
  Branched encoder embeddings loaded.
Loading decoder...
Loading temporal body (transformer)...
Loading depth body...
  Loaded depth_input_adapter weights.
Warning: 2 Depthformer parameters were NOT updated during loading.
  NOT UPDATED: decoder.embedder.layers.1._scale
  NOT UPDATED: sampler.decoder.embedder.layers.1._scale
Loading spectrostream...
  Loading quantizer embeddings...
  Loading SpectroStream decoder...
    SpectroStream decoder weights loaded.
  Loading SpectroStream encoder...
    SpectroStream encoder weights loaded.
Converting depthformer params to bfloat16...
Weight loading complete (11 groups loaded)
  Depthformer params:      229,732,614
  SpectroStream params:       61,713,666
  Total loaded params:     291,446,280
Generated 100 frames in 2.0s (50.9 steps/s, 19.6 ms/step)
Target: 25 steps/s, 40 ms/step for real-time
Saved to ~/Documents/Magenta/magenta-rt-v2/outputs/output_audio_mlx_mrt2_small.wav (4.0s of audio)

# 5. Parity suite
pytest -s tests/test_musiccoca.py tests/test_prefill_correctness.py tests/test_bitlevel_parity.py

and observed the following output:

# (2) custom folder untouched on cache hit:
magenta-rt-v2 has no downloaded assets (as intended)

# (3) default .mlxfn generate — model resolved from the HF cache:
INFO:magenta_rt.mlx.system:Loading mlxfn: ~/.cache/huggingface/hub/models--google--magenta-realtime-2/snapshots/<rev>/models/mrt2_small/mrt2_small.mlxfn
Generated 100 frames in 1.3s (78.4 steps/s, 12.8 ms/step)
Saved to ~/Documents/Magenta/magenta-rt-v2/outputs/output_audio_mlx_mrt2_small.wav (4.0s of audio)

# (4) --no-mlxfn generate — checkpoint AND SpectroStream encoder resolved from the HF cache:
INFO:magenta_rt.mlx.system:Loading checkpoint: ~/.cache/huggingface/hub/models--google--magenta-realtime-2/snapshots/<rev>/checkpoints/mrt2_small.safetensors
Loading spectrostream...
  Loading SpectroStream encoder...
    SpectroStream encoder weights loaded.
Weight loading complete (11 groups loaded)
Saved to ~/Documents/Magenta/magenta-rt-v2/outputs/output_audio_mlx_mrt2_small.wav (4.0s of audio)

# (5) parity suite — clean, no failures (parity/bitlevel tests skip without a generated reference):
======================================================= 3 passed, 21 skipped, 4 warnings in 2.59s =======================================================

Benchmark Regression Test

N/A — this change only affects asset resolution (which directory weights are loaded from). It does not touch sampling, the model graph, or any performance-sensitive code path, so there is no benchmark surface to regress.

Model assets (musiccoca + spectrostream resources, exported .mlxfn models,
raw checkpoints) were always downloaded into ~/Documents/Magenta/magenta-rt-v2/
via hf_hub_download(local_dir=...), which never consults the global HF cache.
A repo already pulled with `hf download google/magenta-realtime-2` was
re-downloaded over the network, and assets were duplicated on disk.

Resolve assets MAGENTA_HOME-first, then the global HF cache:

- paths.py: add resolve_asset() plus resolve_musiccoca_dir/spectrostream_dir/
  model_dir helpers. The local MAGENTA_HOME layout always wins (so local
  exports, GCS downloads, and user overrides take precedence); otherwise the
  asset is served from the global HF cache. Load paths reuse the cache only
  (local_files_only) — no network fetch at load time — so default behavior is
  unchanged except that an existing `hf download` is now picked up. A missing
  asset raises a clear FileNotFoundError pointing at both locations.

- Route load sites through the resolver: musiccoca.py, mlx/system.py (mlxfn
  model dir + checkpoint), jax/system.py (checkpoint).

- mlx/spectrostream/load_weights.py: when encoder.safetensors is not a sibling
  of the checkpoint (e.g. assets live in the HF cache, where the checkpoint and
  resources sit in different repo subdirs), fall back to the resolved
  spectrostream resource dir. Guarded lazy import keeps the module standalone.

- resolve_checkpoint(): build the candidate path directly instead of via
  checkpoints_dir(), which mkdir's as a side effect. Resolution no longer
  creates directories, so a cache hit leaves MAGENTA_HOME untouched.

CLI (opt-in, default behavior unchanged):

- Add --use-hf-cache (env: MAGENTA_RT_USE_HF_CACHE) to `mrt models init`,
  `mrt models download`, and `mrt checkpoints download`. When set, downloads
  populate/reuse the global HF cache (omit local_dir) and the interactive
  picker's checkmarks consult the cache. The flag is HuggingFace-only; the GCS
  source and --download-path are unaffected.

Docs: note cache reuse and the flag in README.md and docs/models.md.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant