diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 654c7264..6ff5d8bc 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -481,9 +481,10 @@ def compute_embeddings( base_url=provider_options.get("base_url"), api_key=provider_options.get("api_key"), provider_options=provider_options, + is_build=is_build, ) elif mode == "mlx": - return compute_embeddings_mlx(texts, model_name) + return compute_embeddings_mlx(texts, model_name, is_build=is_build) elif mode == "ollama": return compute_embeddings_ollama( texts, @@ -848,8 +849,8 @@ def compute_embeddings_openai( base_url: Optional[str] = None, api_key: Optional[str] = None, provider_options: Optional[dict[str, Any]] = None, + is_build: bool = False, ) -> np.ndarray: - # TODO: @yichuan-w add progress bar only in build mode """Compute embeddings using OpenAI API""" try: import openai @@ -936,9 +937,12 @@ def compute_embeddings_openai( total_batches = (len(texts) + max_batch_size - 1) // max_batch_size batch_range = range(0, len(texts), max_batch_size) - batch_iterator = tqdm( - batch_range, desc="Computing embeddings", unit="batch", total=total_batches - ) + if is_build: + batch_iterator = tqdm( + batch_range, desc="Computing embeddings", unit="batch", total=total_batches + ) + else: + batch_iterator = batch_range except ImportError: # Fallback when tqdm is not available batch_iterator = range(0, len(texts), max_batch_size) @@ -967,8 +971,9 @@ def compute_embeddings_openai( return embeddings -def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray: - # TODO: @yichuan-w add progress bar only in build mode +def compute_embeddings_mlx( + chunks: list[str], model_name: str, batch_size: int = 16, is_build: bool = False +) -> np.ndarray: """Computes embeddings using an MLX model.""" try: import mlx.core as mx @@ -999,9 +1004,11 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = try: from tqdm import tqdm - batch_iterator = tqdm( - range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch" - ) + batch_range = range(0, len(chunks), batch_size) + if is_build: + batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch") + else: + batch_iterator = batch_range except ImportError: batch_iterator = range(0, len(chunks), batch_size)