Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion photomap/backend/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
EmbeddingCacheMismatch,
ImageTextEncoder,
build_encoder,
capture_download_progress,
get_cached_encoder,
)
from .metadata_extraction import MetadataExtractor
Expand Down Expand Up @@ -605,6 +606,7 @@ def _process_images_batch(
progress_callback: Callable | None = None,
batch_size: int = 1,
num_workers: int = 1,
download_callback: Callable | None = None,
) -> IndexResult:
"""
Process a batch of images and return IndexResult.
Expand All @@ -619,13 +621,17 @@ def _process_images_batch(
parallel. Default 1 keeps the legacy serial path. >1 enables a
bounded producer/consumer pipeline so the GPU stays fed while images
decode concurrently.
download_callback: Optional callback(downloaded, total, desc) invoked with
byte progress while the encoder weights are downloaded on first use.
``None`` (the CLI/sync path) leaves tqdm's console output untouched.
"""
if batch_size < 1:
raise ValueError(f"batch_size must be >= 1 (got {batch_size})")
if num_workers < 1:
raise ValueError(f"num_workers must be >= 1 (got {num_workers})")

encoder = self._build_encoder()
with capture_download_progress(download_callback):
encoder = self._build_encoder()
embedding_dim = encoder.embedding_dim

embeddings: list[np.ndarray] = []
Expand Down Expand Up @@ -752,15 +758,23 @@ async def _process_images_batch_async(
def progress_cb(i: int, total: int, message: str) -> None:
if progress_tracker.is_cancel_requested(album_key):
raise IndexingCancelled("Indexing cancelled by user")
# The first encode marks the end of any model-download phase; flip
# the status back to INDEXING (a no-op when nothing was downloaded).
if i == 0:
progress_tracker.begin_indexing(album_key, total)
progress_tracker.update_progress(album_key, i, message)

def download_cb(downloaded: int, total: int | None, desc: str) -> None:
progress_tracker.report_download(album_key, downloaded, total)

async with _get_indexing_semaphore():
return await asyncio.to_thread(
self._process_images_batch,
image_paths,
progress_cb,
batch_size,
num_workers,
download_cb,
)

def _save_embeddings(self, index_result: IndexResult) -> None:
Expand Down
101 changes: 101 additions & 0 deletions photomap/backend/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@

from __future__ import annotations

import contextlib
import importlib
import logging
import math
import sys
import threading
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterator
from typing import ClassVar

import numpy as np
Expand Down Expand Up @@ -436,6 +439,104 @@ def _free_cuda(device: str) -> None:
torch.cuda.empty_cache()


# --- Model download progress capture --------------------------------------
# The first time an encoder is built on a fresh install, its weights (hundreds
# of MB) are fetched from the network, which can take minutes with no UI
# feedback. Each backend renders that fetch through a ``tqdm`` byte-bar, but via
# a *different* module-level ``tqdm`` reference:
#
# - ``clip.clip.tqdm`` -- openai-clip URL download
# - ``open_clip.pretrained.tqdm`` -- open-clip URL download (non-HF tags)
# - ``huggingface_hub.utils.tqdm.tqdm`` -- HF Hub fetches (open-clip HF + siglip)
#
# ``capture_download_progress`` temporarily swaps those references for a
# subclass that forwards byte progress to a caller-supplied callback, so the
# indexing UI can show a real download bar. The originals are restored on exit,
# and the swap only happens when a callback is provided — so the CLI/console
# path keeps its normal tqdm output.

# (module, attribute) pairs to patch. Resolved lazily inside the context manager
# so a backend that isn't installed is simply skipped.
_DOWNLOAD_TQDM_TARGETS: tuple[tuple[str, str], ...] = (
("clip.clip", "tqdm"),
("open_clip.pretrained", "tqdm"),
("huggingface_hub.utils.tqdm", "tqdm"),
)

# tqdm reports download bars in these units; anything else (e.g. a plain
# iteration counter) is ignored so we only ever surface real byte progress.
_BYTE_UNITS = frozenset({"B", "iB", "bytes"})


def _make_reporting_tqdm(base_cls: type, callback: Callable[[int, int | None, str], None]) -> type:
"""Build a ``tqdm`` subclass that forwards byte progress to ``callback``.

Subclassing the real ``tqdm`` keeps all of its behavior intact (including
console rendering); we only add a side-channel report on each update/close.
Reports are best-effort: any exception raised by ``callback`` is swallowed
so a UI hiccup can never interrupt or fail a model download.
"""

class _ReportingTqdm(base_cls): # type: ignore[valid-type, misc]
def update(self, n: float = 1): # noqa: ANN001 - mirror tqdm signature
ret = super().update(n)
self._photomap_report()
return ret

def close(self):
self._photomap_report()
return super().close()

def _photomap_report(self) -> None:
try:
if getattr(self, "unit", "") not in _BYTE_UNITS:
return
total = getattr(self, "total", None)
callback(
int(getattr(self, "n", 0) or 0),
int(total) if total else None,
str(getattr(self, "desc", "") or ""),
)
except Exception:
# Never let progress reporting break a download.
pass

return _ReportingTqdm


@contextlib.contextmanager
def capture_download_progress(
callback: Callable[[int, int | None, str], None] | None,
) -> Iterator[None]:
"""Route encoder-weight download progress to ``callback`` within the block.

``callback(downloaded, total, desc)`` is invoked with cumulative bytes for
the *currently active* download bar (``total`` is ``None`` when the server
omits ``Content-Length``). When ``callback`` is ``None`` this is a no-op, so
callers that don't want UI reporting (the CLI/sync path) get unchanged
behavior.
"""
if callback is None:
yield
return

patched: list[tuple[object, str, object]] = []
for module_name, attr in _DOWNLOAD_TQDM_TARGETS:
try:
module = importlib.import_module(module_name)
original = getattr(module, attr)
except (ImportError, AttributeError):
continue
patched.append((module, attr, original))
setattr(module, attr, _make_reporting_tqdm(original, callback))

try:
yield
finally:
for module, attr, original in patched:
setattr(module, attr, original)


def build_encoder(
spec: str | None = None,
*,
Expand Down
47 changes: 47 additions & 0 deletions photomap/backend/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class IndexStatus(Enum):
IDLE = "idle"
SCANNING = "scanning"
DOWNLOADING = "downloading"
INDEXING = "indexing"
UMAPPING = "mapping"
CURATING = "curating"
Expand Down Expand Up @@ -114,6 +115,51 @@ def update_progress(
):
progress.status = IndexStatus.COMPLETED

def report_download(
self,
album_key: str,
downloaded: int,
total: int | None,
message: str = "Downloading encoder model…",
) -> None:
"""Surface encoder-weight download progress as a DOWNLOADING phase.

``downloaded``/``total`` are *byte* counts (``total`` may be ``None`` when
the server omits ``Content-Length``); they reuse the
``images_processed``/``total_images`` fields so the existing percentage
and ETA machinery drives the UI bar without special-casing. On the first
transition into DOWNLOADING the start time is reset so the byte-rate ETA
reflects the download rather than the preceding scan. The cancel flag is
deliberately left untouched.
"""
with self._lock:
progress = self._progress.get(album_key)
if progress is None:
return
if progress.status != IndexStatus.DOWNLOADING:
progress.start_time = time.time()
progress.status = IndexStatus.DOWNLOADING
progress.images_processed = max(downloaded, 0)
progress.total_images = total if total and total > 0 else 0
progress.current_step = message

def begin_indexing(self, album_key: str, total_images: int) -> None:
"""Transition an album into the INDEXING phase.

Used to flip back from DOWNLOADING once the encoder is ready and image
encoding starts. Resets the processed count and start time (so the ETA
excludes any preceding download) but leaves the cancel flag intact.
"""
with self._lock:
progress = self._progress.get(album_key)
if progress is None:
return
progress.status = IndexStatus.INDEXING
progress.images_processed = 0
progress.total_images = total_images
progress.current_step = "Starting indexing"
progress.start_time = time.time()

def set_error(self, album_key: str, error_message: str):
"""Set error status for an album."""
with self._lock:
Expand Down Expand Up @@ -155,6 +201,7 @@ def is_running(self, album_key: str) -> bool:
progress = self._progress.get(album_key)
return progress is not None and progress.status in [
IndexStatus.SCANNING,
IndexStatus.DOWNLOADING,
IndexStatus.INDEXING,
IndexStatus.UMAPPING,
IndexStatus.CURATING,
Expand Down
3 changes: 3 additions & 0 deletions photomap/frontend/static/css/album-manager.css
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@
.index-status.scanning {
color: #636ceb;
}
.index-status.downloading {
color: #9c27b0 !important;
}
.index-status.indexing {
color: #ff9800 !important;
}
Expand Down
8 changes: 8 additions & 0 deletions photomap/frontend/static/javascript/album-manager.js
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export class AlbumManager {

static STATUS_CLASSES = {
SCANNING: "index-status scanning",
DOWNLOADING: "index-status downloading",
INDEXING: "index-status indexing",
UMAPPING: "index-status mapping",
COMPLETED: "index-status completed",
Expand Down Expand Up @@ -1292,6 +1293,13 @@ export class AlbumManager {
status.textContent = progress.current_step || "Scanning for images...";
status.style.color = "#ff9800"; // Orange for scanning
estimatedTime.textContent = "";
} else if (progress.status === "downloading") {
// First-use model download. The bar width/percent and ETA are already set
// from progress_percentage/estimated_time_remaining in updateProgress();
// here we just label the phase and leave the ETA line in place.
status.className = AlbumManager.STATUS_CLASSES.DOWNLOADING;
status.textContent = progress.current_step || "Downloading encoder model…";
status.style.color = "#9c27b0"; // Purple for downloading
} else if (progress.status === "mapping") {
status.className = AlbumManager.STATUS_CLASSES.UMAPPING;
status.textContent = progress.current_step || "Generating image map...";
Expand Down
96 changes: 96 additions & 0 deletions tests/backend/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
OpenClipEncoder,
SiglipEncoder,
build_encoder,
capture_download_progress,
clear_encoder_cache,
default_encoder_spec,
get_cached_encoder,
Expand Down Expand Up @@ -570,3 +571,98 @@ def fake_build(spec=None, *, cache_dir=None, device=None):
assert d is not a
assert call_count["n"] == 3
clear_encoder_cache()


# --- Model download progress capture --------------------------------------


def _hf_tqdm_module():
"""Return the huggingface_hub submodule whose ``tqdm`` we patch, or None."""
import importlib

try:
return importlib.import_module("huggingface_hub.utils.tqdm")
except ImportError: # pragma: no cover - hf_hub is a hard dep in practice
return None


def test_capture_download_progress_noop_when_callback_none():
module = _hf_tqdm_module()
if module is None:
pytest.skip("huggingface_hub not installed")
original = module.tqdm
with capture_download_progress(None):
# No callback -> tqdm references must be left exactly as-is so the
# console download bars behave normally on the CLI path.
assert module.tqdm is original
assert module.tqdm is original


def test_capture_download_progress_patches_and_restores():
module = _hf_tqdm_module()
if module is None:
pytest.skip("huggingface_hub not installed")
original = module.tqdm

def cb(downloaded, total, desc):
pass

with capture_download_progress(cb):
assert module.tqdm is not original
assert issubclass(module.tqdm, original)
# Originals restored on exit, even though the block did real work.
assert module.tqdm is original


def test_capture_forwards_byte_progress():
module = _hf_tqdm_module()
if module is None:
pytest.skip("huggingface_hub not installed")
import io

reports: list[tuple[int, int | None, str]] = []

with capture_download_progress(lambda d, t, desc: reports.append((d, t, desc))):
bar = module.tqdm(total=100, unit="B", desc="model.safetensors", file=io.StringIO())
bar.update(40)
bar.update(60)
bar.close()

assert reports, "expected byte progress to be reported"
# Cumulative byte counts and total flow straight through to the callback.
assert (40, 100, "model.safetensors") in reports
assert reports[-1] == (100, 100, "model.safetensors")


def test_capture_ignores_non_byte_bars():
module = _hf_tqdm_module()
if module is None:
pytest.skip("huggingface_hub not installed")
import io

reports: list[tuple[int, int | None, str]] = []

with capture_download_progress(lambda d, t, desc: reports.append((d, t, desc))):
# A plain iteration counter (default unit "it") is not a download and
# must not be surfaced as model-download progress.
bar = module.tqdm(total=100, unit="it", file=io.StringIO())
bar.update(10)
bar.close()

assert reports == []


def test_capture_swallows_callback_errors():
module = _hf_tqdm_module()
if module is None:
pytest.skip("huggingface_hub not installed")
import io

def boom(downloaded, total, desc):
raise RuntimeError("UI exploded")

# A misbehaving callback must never break the underlying download.
with capture_download_progress(boom):
bar = module.tqdm(total=100, unit="B", file=io.StringIO())
bar.update(50) # must not raise
bar.close()
Loading
Loading