Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions packages/leann-core/src/leann/embedding_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,10 @@ def compute_embeddings(
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
provider_options=provider_options,
is_build=is_build,
)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
return compute_embeddings_mlx(texts, model_name, is_build=is_build)
elif mode == "ollama":
return compute_embeddings_ollama(
texts,
Expand Down Expand Up @@ -848,8 +849,8 @@ def compute_embeddings_openai(
base_url: Optional[str] = None,
api_key: Optional[str] = None,
provider_options: Optional[dict[str, Any]] = None,
is_build: bool = False,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import openai
Expand Down Expand Up @@ -936,9 +937,12 @@ def compute_embeddings_openai(

total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
if is_build:
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
else:
batch_iterator = batch_range
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
Expand Down Expand Up @@ -967,8 +971,9 @@ def compute_embeddings_openai(
return embeddings


def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
def compute_embeddings_mlx(
chunks: list[str], model_name: str, batch_size: int = 16, is_build: bool = False
) -> np.ndarray:
"""Computes embeddings using an MLX model."""
try:
import mlx.core as mx
Expand Down Expand Up @@ -999,9 +1004,11 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
try:
from tqdm import tqdm

batch_iterator = tqdm(
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
)
batch_range = range(0, len(chunks), batch_size)
if is_build:
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch")
else:
batch_iterator = batch_range
except ImportError:
batch_iterator = range(0, len(chunks), batch_size)

Expand Down
Loading