From 16387bd5932cf2ae1f02781238ab8d3853957d07 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 09:56:17 -0700 Subject: [PATCH 1/9] update project tooling for uv/ruff/pytest --- .gitignore | 13 +- .pre-commit-config.yaml | 32 +++++ MANIFEST.in | 12 +- Makefile | 54 ++++++++ pyproject.toml | 109 +++++++++++++++-- setup.py | 16 +-- {wordllama => src/wordllama}/RESULTS.md | 1 - {wordllama => src/wordllama}/__init__.py | 6 +- .../wordllama}/adapters/__init__.py | 13 +- .../wordllama}/adapters/avg_pool.py | 8 +- .../wordllama}/adapters/binarizer.py | 7 +- {wordllama => src/wordllama}/adapters/mlp.py | 2 +- .../wordllama}/adapters/projector.py | 3 +- .../wordllama}/adapters/weighted_mlp.py | 14 +-- .../wordllama}/adapters/weighted_projector.py | 14 +-- .../wordllama}/algorithms/__init__.py | 8 +- .../algorithms/deduplicate_helpers.pyx | 0 .../algorithms/find_local_minima.pyx | 8 +- .../wordllama}/algorithms/kmeans.pyx | 8 +- .../algorithms/semantic_splitter.py | 35 +++--- .../wordllama}/algorithms/splitter.pyx | 4 +- .../algorithms/vector_similarity.pxd | 0 .../algorithms/vector_similarity.pyx | 2 +- .../wordllama}/config/__init__.py | 13 +- {wordllama => src/wordllama}/config/models.py | 20 ++- .../config/train/command_rplus.toml | 1 - .../wordllama}/config/train/dbrx.toml | 1 - .../config/train/deberta_v3_large.toml | 3 +- .../wordllama}/config/train/deepseekv2.toml | 1 - .../wordllama}/config/train/gemma2_27B.toml | 1 - .../wordllama}/config/train/l2_supercat.toml | 1 - .../wordllama}/config/train/l2p3.toml | 3 +- .../wordllama}/config/train/l2p3_lg.toml | 3 +- .../wordllama}/config/train/l3_supercat.toml | 3 +- .../wordllama}/config/train/llama2_70B.toml | 3 +- .../wordllama}/config/train/llama3_70B.toml | 3 +- .../wordllama}/config/train/llama3_8B.toml | 1 - .../wordllama}/config/train/llamaguard.toml | 3 +- .../wordllama}/config/train/miqu.toml | 3 +- .../wordllama}/config/train/mixtral.toml | 1 - .../config/train/mixtral_8x22B.toml | 1 - .../wordllama}/config/train/openelm_3B.toml | 3 +- .../wordllama}/config/train/phi3_medium.toml | 3 +- .../wordllama}/config/train/qwen2_72B.toml | 3 +- .../wordllama}/config/train/yi_1v5_34B.toml | 1 - .../wordllama}/embedding/__init__.py | 0 .../embedding/word_llama_embedding.py | 27 ++-- .../wordllama}/extract/__init__.py | 2 +- .../wordllama}/extract/extract_hf.py | 2 +- .../wordllama}/extract/extract_llama_70B.py | 1 + .../wordllama}/extract/extract_safetensors.py | 0 {wordllama => src/wordllama}/inference.py | 67 ++++------ .../wordllama}/mode_decorators.py | 0 .../wordllama}/tokenizers/__init__.py | 5 +- .../l2_supercat_tokenizer_config.json | 2 +- .../wordllama}/trainers/__init__.py | 0 .../wordllama}/trainers/reduce_dimension.py | 5 +- .../weights/l2_supercat_256.safetensors | Bin {wordllama => src/wordllama}/wordllama.py | 30 ++--- tests/test_functional.py | 11 +- tests/test_inference.py | 81 ++++++------ tests/test_kmeans.py | 37 +++--- tests/test_minima_functions.py | 19 +-- tests/test_semantic_splitter.py | 34 +++--- tests/test_splitting_functions.py | 43 +++---- tests/test_vector_similarity.py | 13 +- tests/test_wordllama.py | 115 ++++++++---------- 67 files changed, 506 insertions(+), 432 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 Makefile rename {wordllama => src/wordllama}/RESULTS.md (99%) rename {wordllama => src/wordllama}/__init__.py (95%) rename {wordllama => src/wordllama}/adapters/__init__.py (63%) rename {wordllama => src/wordllama}/adapters/avg_pool.py (92%) rename {wordllama => src/wordllama}/adapters/binarizer.py (95%) rename {wordllama => src/wordllama}/adapters/mlp.py (100%) rename {wordllama => src/wordllama}/adapters/projector.py (99%) rename {wordllama => src/wordllama}/adapters/weighted_mlp.py (91%) rename {wordllama => src/wordllama}/adapters/weighted_projector.py (87%) rename {wordllama => src/wordllama}/algorithms/__init__.py (59%) rename {wordllama => src/wordllama}/algorithms/deduplicate_helpers.pyx (100%) rename {wordllama => src/wordllama}/algorithms/find_local_minima.pyx (99%) rename {wordllama => src/wordllama}/algorithms/kmeans.pyx (99%) rename {wordllama => src/wordllama}/algorithms/semantic_splitter.py (86%) rename {wordllama => src/wordllama}/algorithms/splitter.pyx (99%) rename {wordllama => src/wordllama}/algorithms/vector_similarity.pxd (100%) rename {wordllama => src/wordllama}/algorithms/vector_similarity.pyx (99%) rename {wordllama => src/wordllama}/config/__init__.py (89%) rename {wordllama => src/wordllama}/config/models.py (90%) rename {wordllama => src/wordllama}/config/train/command_rplus.toml (99%) rename {wordllama => src/wordllama}/config/train/dbrx.toml (99%) rename {wordllama => src/wordllama}/config/train/deberta_v3_large.toml (95%) rename {wordllama => src/wordllama}/config/train/deepseekv2.toml (99%) rename {wordllama => src/wordllama}/config/train/gemma2_27B.toml (99%) rename {wordllama => src/wordllama}/config/train/l2_supercat.toml (99%) rename {wordllama => src/wordllama}/config/train/l2p3.toml (96%) rename {wordllama => src/wordllama}/config/train/l2p3_lg.toml (96%) rename {wordllama => src/wordllama}/config/train/l3_supercat.toml (95%) rename {wordllama => src/wordllama}/config/train/llama2_70B.toml (96%) rename {wordllama => src/wordllama}/config/train/llama3_70B.toml (94%) rename {wordllama => src/wordllama}/config/train/llama3_8B.toml (99%) rename {wordllama => src/wordllama}/config/train/llamaguard.toml (96%) rename {wordllama => src/wordllama}/config/train/miqu.toml (96%) rename {wordllama => src/wordllama}/config/train/mixtral.toml (99%) rename {wordllama => src/wordllama}/config/train/mixtral_8x22B.toml (99%) rename {wordllama => src/wordllama}/config/train/openelm_3B.toml (96%) rename {wordllama => src/wordllama}/config/train/phi3_medium.toml (95%) rename {wordllama => src/wordllama}/config/train/qwen2_72B.toml (95%) rename {wordllama => src/wordllama}/config/train/yi_1v5_34B.toml (99%) rename {wordllama => src/wordllama}/embedding/__init__.py (100%) rename {wordllama => src/wordllama}/embedding/word_llama_embedding.py (87%) rename {wordllama => src/wordllama}/extract/__init__.py (100%) rename {wordllama => src/wordllama}/extract/extract_hf.py (93%) rename {wordllama => src/wordllama}/extract/extract_llama_70B.py (99%) rename {wordllama => src/wordllama}/extract/extract_safetensors.py (100%) rename {wordllama => src/wordllama}/inference.py (90%) rename {wordllama => src/wordllama}/mode_decorators.py (100%) rename {wordllama => src/wordllama}/tokenizers/__init__.py (85%) rename {wordllama => src/wordllama}/tokenizers/l2_supercat_tokenizer_config.json (99%) rename {wordllama => src/wordllama}/trainers/__init__.py (100%) rename {wordllama => src/wordllama}/trainers/reduce_dimension.py (95%) rename {wordllama => src/wordllama}/weights/l2_supercat_256.safetensors (100%) rename {wordllama => src/wordllama}/wordllama.py (94%) diff --git a/.gitignore b/.gitignore index e26add1..9e0b9c3 100644 --- a/.gitignore +++ b/.gitignore @@ -162,10 +162,15 @@ cython_debug/ #.idea/ # Ignore generated Cython files -wordllama/algorithms/*.c -wordllama/algorithms/*.cpp -wordllama/algorithms/*.html +src/wordllama/algorithms/*.c +src/wordllama/algorithms/*.cpp +src/wordllama/algorithms/*.html # Ignore the generated version file -wordllama/_version.py +src/wordllama/_version.py +# uv +uv.lock + +# ruff +.ruff_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..69d1c51 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,32 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + exclude: '\.pyx$|\.pxd$' + - id: check-yaml + - id: check-toml + - id: check-json + - id: check-added-large-files + args: ['--maxkb=1024'] + exclude: 'weights/.*\.safetensors$' + - id: check-merge-conflict + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.4 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: local + hooks: + - id: pytest-fast + name: pytest-fast + entry: uv run pytest + language: system + pass_filenames: false + args: [-m, "not slow", --tb=short, -q] + types: [python] + stages: [pre-commit] diff --git a/MANIFEST.in b/MANIFEST.in index 612420c..f799f3a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,8 @@ include LICENSE include README.md -recursive-include wordllama *.py *.toml *.json -include wordllama/weights/*.safetensors -include wordllama/algorithms/*.pyx -include wordllama/algorithms/*.pxd -include wordllama/algorithms/*.so -include wordllama/algorithms/*.pyd +recursive-include src/wordllama *.py *.toml *.json +include src/wordllama/weights/*.safetensors +include src/wordllama/algorithms/*.pyx +include src/wordllama/algorithms/*.pxd +include src/wordllama/algorithms/*.so +include src/wordllama/algorithms/*.pyd diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..74e56a1 --- /dev/null +++ b/Makefile @@ -0,0 +1,54 @@ +.PHONY: help install install-dev build clean test test-cov lint format pre-commit-install pre-commit-run all + +help: + @echo "WordLlama Development Makefile" + @echo "" + @echo "Available targets:" + @echo " install - Install package and dependencies" + @echo " install-dev - Install package with dev dependencies" + @echo " build - Build Cython extensions" + @echo " clean - Clean build artifacts" + @echo " test - Run tests" + @echo " test-cov - Run tests with coverage" + @echo " lint - Run ruff linter" + @echo " format - Format code with ruff" + @echo " pre-commit-install - Install pre-commit hooks" + @echo " pre-commit-run - Run pre-commit on all files" + @echo " all - Clean, build, lint, format, and test" + +install: + uv sync + +install-dev: + uv sync --all-extras + +build: + uv run python setup.py build_ext --inplace + +clean: + rm -rf build/ dist/ *.egg-info + rm -rf src/wordllama/**/*.so src/wordllama/**/*.c src/wordllama/**/*.cpp + rm -rf .pytest_cache .ruff_cache htmlcov .coverage + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + +test: + uv run pytest + +test-cov: + uv run pytest --cov=wordllama --cov-report=html --cov-report=term-missing + +lint: + uv run ruff check src/ tests/ + +format: + uv run ruff format src/ tests/ + uv run ruff check --fix src/ tests/ + +pre-commit-install: + uv run pre-commit install + +pre-commit-run: + uv run pre-commit run --all-files + +all: clean build lint format test + @echo "✓ All tasks completed successfully" diff --git a/pyproject.toml b/pyproject.toml index 0a7ca4e..d4bf64e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "WordLlama NLP Utility" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [{ name = "Lee Miller", email = "dleemiller@gmail.com" }] dependencies = [ "numpy>=2", @@ -20,6 +20,13 @@ dependencies = [ ] [project.optional-dependencies] +dev = [ + "pytest>=7.4", + "pytest-cov>=4.1", + "pytest-xdist>=3.3", + "ruff>=0.1.0", + "pre-commit>=3.5", +] train = [ "accelerate", "torch>=2", @@ -37,7 +44,7 @@ Repository = "https://github.com/dleemiller/WordLlama" include-package-data = true # Ensures non-code files are included in the package [tool.setuptools.packages.find] -where = ["."] +where = ["src"] include = ["wordllama*"] [tool.setuptools.package-data] @@ -57,7 +64,7 @@ wordllama = [ classifiers = { file = "classifiers.txt" } [tool.setuptools_scm] -write_to = "wordllama/_version.py" +write_to = "src/wordllama/_version.py" version_scheme = "post-release" local_scheme = "no-local-version" @@ -74,7 +81,7 @@ print(similarity_score);" [tool.cibuildwheel.macos] # before-all = """ # brew install openblas libomp llvm -# +# # # Create a NumPy site.cfg so that it detects your Homebrew-installed OpenBLAS # cat > ~/.numpy-site.cfg <= -t) & (x < -interval) tmp[mask2] = (1 / o) * torch.pow(-x[mask2], (1 - o) / o) tmp[(x <= interval) & (x >= 0)] = approximate_function(interval, o) / interval - tmp[(x <= 0) & (x >= -interval)] = ( - -approximate_function(-interval, o) / interval - ) + tmp[(x <= 0) & (x >= -interval)] = -approximate_function(-interval, o) / interval # calculate the final gradient grad_x = tmp * grad_output.clone() diff --git a/wordllama/adapters/mlp.py b/src/wordllama/adapters/mlp.py similarity index 100% rename from wordllama/adapters/mlp.py rename to src/wordllama/adapters/mlp.py index bb9794e..94497e2 100644 --- a/wordllama/adapters/mlp.py +++ b/src/wordllama/adapters/mlp.py @@ -1,7 +1,7 @@ import os -from torch import nn import safetensors.torch as st +from torch import nn class MLP(nn.Module): diff --git a/wordllama/adapters/projector.py b/src/wordllama/adapters/projector.py similarity index 99% rename from wordllama/adapters/projector.py rename to src/wordllama/adapters/projector.py index 502e701..821e550 100644 --- a/wordllama/adapters/projector.py +++ b/src/wordllama/adapters/projector.py @@ -1,6 +1,7 @@ import os -from torch import nn + import safetensors.torch as st +from torch import nn class Projector(nn.Module): diff --git a/wordllama/adapters/weighted_mlp.py b/src/wordllama/adapters/weighted_mlp.py similarity index 91% rename from wordllama/adapters/weighted_mlp.py rename to src/wordllama/adapters/weighted_mlp.py index 5fcbee3..4c3e621 100644 --- a/wordllama/adapters/weighted_mlp.py +++ b/src/wordllama/adapters/weighted_mlp.py @@ -1,10 +1,10 @@ import os -from nltk.corpus import stopwords +import safetensors.torch as st import torch -from torch import nn import torch.nn.functional as F -import safetensors.torch as st +from nltk.corpus import stopwords +from torch import nn class WeightedMLP(nn.Module): @@ -22,9 +22,7 @@ def __init__(self, in_dim, out_dim, hidden_dim, tokenizer, dropout=0.1): n_vocab = len(tokenizer.vocab) self.weights = nn.Parameter(torch.ones(n_vocab)) stopword_list = set(stopwords.words("english")) - stopword_ids = [ - tokenizer.vocab[word] for word in stopword_list if word in tokenizer.vocab - ] + stopword_ids = [tokenizer.vocab[word] for word in stopword_list if word in tokenizer.vocab] with torch.no_grad(): for stopword_id in stopword_ids: self.weights[stopword_id] = 0.1 @@ -34,9 +32,7 @@ def forward(self, tensors) -> dict: weights = F.gelu( self.weights[token_ids] ) # use gelu to limit negative contribution of weights - weighted_embeddings = self.mlp(tensors["token_embeddings"]) * weights.unsqueeze( - -1 - ) + weighted_embeddings = self.mlp(tensors["token_embeddings"]) * weights.unsqueeze(-1) tensors.update({"x": weighted_embeddings}) return tensors diff --git a/wordllama/adapters/weighted_projector.py b/src/wordllama/adapters/weighted_projector.py similarity index 87% rename from wordllama/adapters/weighted_projector.py rename to src/wordllama/adapters/weighted_projector.py index 879e59e..812700e 100644 --- a/wordllama/adapters/weighted_projector.py +++ b/src/wordllama/adapters/weighted_projector.py @@ -1,10 +1,10 @@ import os -from nltk.corpus import stopwords +import safetensors.torch as st import torch -from torch import nn import torch.nn.functional as F -import safetensors.torch as st +from nltk.corpus import stopwords +from torch import nn class WeightedProjector(nn.Module): @@ -17,9 +17,7 @@ def __init__(self, in_dim, out_dim, tokenizer, n_vocab, key="token_embeddings"): # Initialize stopword weights to a lower value self.weights = nn.Parameter(torch.ones(n_vocab)) stopword_list = set(filter(lambda x: len(x) > 1, stopwords.words("english"))) - stopword_ids = [ - tokenizer.vocab[word] for word in stopword_list if word in tokenizer.vocab - ] + stopword_ids = [tokenizer.vocab[word] for word in stopword_list if word in tokenizer.vocab] print(f"Num stopword ids: {len(stopword_ids)}") with torch.no_grad(): for stopword_id in stopword_ids: @@ -30,9 +28,7 @@ def forward(self, tensors) -> dict: weights = F.gelu( self.weights[token_ids] ) # use gelu to limit negative contribution of weights - weighted_embeddings = self.proj( - tensors["token_embeddings"] * weights.unsqueeze(-1) - ) + weighted_embeddings = self.proj(tensors["token_embeddings"] * weights.unsqueeze(-1)) tensors.update({"x": weighted_embeddings}) return tensors diff --git a/wordllama/algorithms/__init__.py b/src/wordllama/algorithms/__init__.py similarity index 59% rename from wordllama/algorithms/__init__.py rename to src/wordllama/algorithms/__init__.py index dd772eb..7090bce 100644 --- a/wordllama/algorithms/__init__.py +++ b/src/wordllama/algorithms/__init__.py @@ -1,7 +1,7 @@ -from .kmeans import kmeans_clustering -from .vector_similarity import vector_similarity, binarize_and_packbits from .deduplicate_helpers import deduplicate_embeddings -from .splitter import split_sentences, constrained_batches, constrained_coalesce +from .kmeans import kmeans_clustering +from .splitter import constrained_batches, constrained_coalesce, split_sentences +from .vector_similarity import binarize_and_packbits, vector_similarity __all__ = [ "kmeans_clustering", @@ -10,5 +10,5 @@ "deduplicate_embeddings", "split_sentences", "constrained_batches", - "constrained_coalesce" + "constrained_coalesce", ] diff --git a/wordllama/algorithms/deduplicate_helpers.pyx b/src/wordllama/algorithms/deduplicate_helpers.pyx similarity index 100% rename from wordllama/algorithms/deduplicate_helpers.pyx rename to src/wordllama/algorithms/deduplicate_helpers.pyx diff --git a/wordllama/algorithms/find_local_minima.pyx b/src/wordllama/algorithms/find_local_minima.pyx similarity index 99% rename from wordllama/algorithms/find_local_minima.pyx rename to src/wordllama/algorithms/find_local_minima.pyx index 113b639..676c7ac 100644 --- a/wordllama/algorithms/find_local_minima.pyx +++ b/src/wordllama/algorithms/find_local_minima.pyx @@ -87,7 +87,7 @@ cpdef tuple find_local_minima(np.ndarray[DTYPE_t, ndim=1] y, int window_size=11, # Ensure that y is a float32 array y_float = np.asarray(y, dtype=np.float32) - + return _find_local_minima_impl(y_float, window_size, poly_order) cdef tuple _find_local_minima_impl(DTYPE_t[:] y, int window_size, int poly_order): @@ -110,7 +110,7 @@ cdef tuple _find_local_minima_impl(DTYPE_t[:] y, int window_size, int poly_order # Precompute Savitzky-Golay coefficients cdef np.ndarray[DTYPE_t, ndim=2] coeffs = compute_savitzky_golay_coeffs(window_size, poly_order) - + # Apply the filter for the first and second derivatives cdef np.ndarray[DTYPE_t, ndim=1] dy = apply_savitzky_golay_filter(coeffs, y, deriv=1) cdef np.ndarray[DTYPE_t, ndim=1] ddy = apply_savitzky_golay_filter(coeffs, y, deriv=2) @@ -154,11 +154,11 @@ cpdef np.ndarray[DTYPE_t, ndim=1] windowed_cross_similarity(np.ndarray[DTYPE_t, """ cdef int n = embeddings.shape[0] - + # Ensure the window size is odd and >= 3 if window_size < 3 or window_size % 2 == 0: raise ValueError("Window size must be odd and >= 3") - + cdef int half_window = window_size // 2 # Pre-allocate the output array for storing windowed averages diff --git a/wordllama/algorithms/kmeans.pyx b/src/wordllama/algorithms/kmeans.pyx similarity index 99% rename from wordllama/algorithms/kmeans.pyx rename to src/wordllama/algorithms/kmeans.pyx index 3a48a2e..9c99630 100644 --- a/wordllama/algorithms/kmeans.pyx +++ b/src/wordllama/algorithms/kmeans.pyx @@ -83,7 +83,7 @@ cdef kmeans_plusplus_initialization(float[:, :] embeddings, int k, object random cumulative_probabilities = np.cumsum(probabilities) r = random_state.rand() index = np.searchsorted(cumulative_probabilities, r) - + for j in range(n_features): centroids_view[i, j] = embeddings[index, j] @@ -115,10 +115,10 @@ cdef _kmeans_single(np.ndarray[FLOAT_t, ndim=2] X, int k, int min_iterations, in prev_inertia = inertia inertia = np.sum(np.min(distances, axis=1) ** 2) - + if iteration >= min_iterations - 1 and abs(prev_inertia - inertia) < tolerance: break - + centers = new_centers end = clock() @@ -149,7 +149,7 @@ def kmeans_clustering(np.ndarray[FLOAT_t, ndim=2] X, int k, int n_init=10, int m for i in range(n_init): labels, centers, inertia, n_iterations, time_taken = _kmeans_single(X, k, min_iterations, max_iterations, tolerance, random_state) logger.info(f"Initialization {i + 1}/{n_init}: Inertia = {inertia:.2f}, Iterations = {n_iterations}, Time = {time_taken:.2f} seconds") - + if inertia < best_inertia: best_labels = labels best_centers = centers diff --git a/wordllama/algorithms/semantic_splitter.py b/src/wordllama/algorithms/semantic_splitter.py similarity index 86% rename from wordllama/algorithms/semantic_splitter.py rename to src/wordllama/algorithms/semantic_splitter.py index 3490416..69ba2a3 100644 --- a/wordllama/algorithms/semantic_splitter.py +++ b/src/wordllama/algorithms/semantic_splitter.py @@ -1,12 +1,13 @@ -import numpy as np -from typing import List, Tuple, Union from itertools import chain +from typing import Union + +import numpy as np + from .find_local_minima import find_local_minima, windowed_cross_similarity from .splitter import ( - constrained_batches, constrained_coalesce, - split_sentences, reverse_merge, + split_sentences, ) @@ -19,7 +20,7 @@ class SemanticSplitter: """ @staticmethod - def flatten(nested_list: List[List[any]]) -> List[any]: + def flatten(nested_list: list[list[any]]) -> list[any]: """ Flatten a list of lists into a single list. @@ -37,7 +38,7 @@ def constrained_split( target_size: int, separator: str = " ", min_size: int = 24, - ) -> List[str]: + ) -> list[str]: """ Split text into chunks of approximately target_size. @@ -64,7 +65,7 @@ def split( target_size: int, cleanup_size: int = 24, intermediate_size: int = 96, - ) -> List[str]: + ) -> list[str]: """ Split the input text into chunks based on semantic coherence. @@ -77,9 +78,7 @@ def split( Returns: List[str]: List of text chunks. """ - assert ( - target_size > intermediate_size - ), "Target size must be larger than intermediate size." + assert target_size > intermediate_size, "Target size must be larger than intermediate size." assert ( intermediate_size > cleanup_size ), "Intermediate size must be larger than cleanup size." @@ -89,9 +88,7 @@ def split( lines = reverse_merge(lines, n=cleanup_size, separator="\n") chunks = [ - cls.constrained_split( - line, target_size, min_size=cleanup_size, separator=" " - ) + cls.constrained_split(line, target_size, min_size=cleanup_size, separator=" ") if len(line) > target_size else [line] for line in lines @@ -103,7 +100,7 @@ def split( @classmethod def reconstruct( cls, - lines: List[str], + lines: list[str], norm_embed: np.ndarray, target_size: int, window_size: int, @@ -111,7 +108,7 @@ def reconstruct( savgol_window: int, max_score_pct: float = 0.4, return_minima: bool = False, - ) -> Union[List[str], Tuple[np.ndarray, np.ndarray, np.ndarray]]: + ) -> Union[list[str], tuple[np.ndarray, np.ndarray, np.ndarray]]: """ Reconstruct text chunks based on semantic similarity. @@ -133,14 +130,10 @@ def reconstruct( Raises: AssertionError: If the number of texts doesn't equal the number of embeddings. """ - assert ( - len(lines) == norm_embed.shape[0] - ), "Number of texts must equal number of embeddings" + assert len(lines) == norm_embed.shape[0], "Number of texts must equal number of embeddings" sim_avg = windowed_cross_similarity(norm_embed, window_size) - roots, y = find_local_minima( - sim_avg, poly_order=poly_order, window_size=savgol_window - ) + roots, y = find_local_minima(sim_avg, poly_order=poly_order, window_size=savgol_window) if return_minima: return roots, y, sim_avg diff --git a/wordllama/algorithms/splitter.pyx b/src/wordllama/algorithms/splitter.pyx similarity index 99% rename from wordllama/algorithms/splitter.pyx rename to src/wordllama/algorithms/splitter.pyx index 7d33109..5316da4 100644 --- a/wordllama/algorithms/splitter.pyx +++ b/src/wordllama/algorithms/splitter.pyx @@ -69,7 +69,7 @@ def split_sentences(str text not None, set punct_chars=None): cdef cset[Py_UCS4] punct_chars_c if punct_chars is None: - punct_chars = {'.', '!', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹', '।', '॥', '၊', '။', '።', '፧', '፨', + punct_chars = {'.', '!', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹', '।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄', '᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿', '‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶', '꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒', @@ -131,7 +131,7 @@ cdef tuple _combine_pass(list iterable, Py_ssize_t max_size, str separator): """ A single pass to combine successive pairs of items if their combined size doesn't exceed max_size. The separator used for joining pairs can be configured. - + Returns: tuple: A new list of coalesced items and the number of changes made. """ diff --git a/wordllama/algorithms/vector_similarity.pxd b/src/wordllama/algorithms/vector_similarity.pxd similarity index 100% rename from wordllama/algorithms/vector_similarity.pxd rename to src/wordllama/algorithms/vector_similarity.pxd diff --git a/wordllama/algorithms/vector_similarity.pyx b/src/wordllama/algorithms/vector_similarity.pyx similarity index 99% rename from wordllama/algorithms/vector_similarity.pyx rename to src/wordllama/algorithms/vector_similarity.pyx index 232dfea..1572135 100644 --- a/wordllama/algorithms/vector_similarity.pyx +++ b/src/wordllama/algorithms/vector_similarity.pyx @@ -95,7 +95,7 @@ cpdef object vector_similarity( max_distance = a_binary.shape[1] * 64 if max_distance == 0: raise ValueError("Binary vectors must have at least one bit") - + # convert to similarity similarity = 1.0 - 2.0 * (dist / max_distance).astype(np.float32) diff --git a/wordllama/config/__init__.py b/src/wordllama/config/__init__.py similarity index 89% rename from wordllama/config/__init__.py rename to src/wordllama/config/__init__.py index a2200ec..cae909b 100644 --- a/wordllama/config/__init__.py +++ b/src/wordllama/config/__init__.py @@ -1,9 +1,10 @@ -import toml from pathlib import Path +from typing import Optional + +import toml from pydantic import BaseModel -from typing import List, Dict, Optional -from .models import ModelURI, WordLlamaModels, Model2VecModels +from .models import Model2VecModels, ModelURI, WordLlamaModels class TokenizerInferenceConfig(BaseModel): @@ -37,7 +38,7 @@ class TrainingConfig(BaseModel): class MatryoshkaConfig(BaseModel): - dims: List[int] + dims: list[int] class WordLlamaModel(BaseModel): @@ -57,7 +58,7 @@ class WordLlamaConfig(BaseModel): class Config: - _configurations: Dict[str, WordLlamaConfig] = {} + _configurations: dict[str, WordLlamaConfig] = {} @classmethod def setup(cls): @@ -67,7 +68,7 @@ def setup(cls): setattr(cls, config_name, config) # Set as class attributes for easy access @staticmethod - def load_configurations() -> Dict[str, WordLlamaConfig]: + def load_configurations() -> dict[str, WordLlamaConfig]: """Load configurations from TOML files within the same directory as this script.""" config_dir = Path(__file__).parent / "train" configs = {} diff --git a/wordllama/config/models.py b/src/wordllama/config/models.py similarity index 90% rename from wordllama/config/models.py rename to src/wordllama/config/models.py index d08d066..1d9551e 100644 --- a/wordllama/config/models.py +++ b/src/wordllama/config/models.py @@ -1,12 +1,12 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Optional @dataclass class ModelURI: repo_id: str - available_dims: List[int] - binary_dims: List[int] + available_dims: list[int] + binary_dims: list[int] tokenizer_config: Optional[str] remote_filename: Optional[str] = None remote_tokenizer_filename: Optional[str] = None @@ -15,18 +15,15 @@ class ModelURI: class WordLlamaModels: - @classmethod - def list_configs(cls) -> List[str]: + def list_configs(cls) -> list[str]: """ Return a list of configuration names defined as `ModelURI` instances in the class. Returns: List[str]: A list of configuration attribute names. """ - return [ - name for name, value in cls.__dict__.items() if isinstance(value, ModelURI) - ] + return [name for name, value in cls.__dict__.items() if isinstance(value, ModelURI)] l2_supercat = ModelURI( repo_id="dleemiller/word-llama-l2-supercat", @@ -46,18 +43,15 @@ def list_configs(cls) -> List[str]: class Model2VecModels: - @classmethod - def list_configs(cls) -> List[str]: + def list_configs(cls) -> list[str]: """ Return a list of configuration names defined as `ModelURI` instances in the class. Returns: List[str]: A list of configuration attribute names. """ - return [ - name for name, value in cls.__dict__.items() if isinstance(value, ModelURI) - ] + return [name for name, value in cls.__dict__.items() if isinstance(value, ModelURI)] potion_base_8m = ModelURI( repo_id="minishlab/potion-base-8M", diff --git a/wordllama/config/train/command_rplus.toml b/src/wordllama/config/train/command_rplus.toml similarity index 99% rename from wordllama/config/train/command_rplus.toml rename to src/wordllama/config/train/command_rplus.toml index 9f21018..42399cf 100644 --- a/wordllama/config/train/command_rplus.toml +++ b/src/wordllama/config/train/command_rplus.toml @@ -28,4 +28,3 @@ binarizer_ste = "tanh" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/dbrx.toml b/src/wordllama/config/train/dbrx.toml similarity index 99% rename from wordllama/config/train/dbrx.toml rename to src/wordllama/config/train/dbrx.toml index f5d9e16..322322c 100644 --- a/wordllama/config/train/dbrx.toml +++ b/src/wordllama/config/train/dbrx.toml @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/deberta_v3_large.toml b/src/wordllama/config/train/deberta_v3_large.toml similarity index 95% rename from wordllama/config/train/deberta_v3_large.toml rename to src/wordllama/config/train/deberta_v3_large.toml index bbd9e1d..970d946 100644 --- a/wordllama/config/train/deberta_v3_large.toml +++ b/src/wordllama/config/train/deberta_v3_large.toml @@ -3,7 +3,7 @@ dim = 1024 n_vocab = 128100 hf_model_id = "microsoft/deberta-v3-large" is_encoder = true -pad_token = "<|end_of_text|>" +pad_token = "<|end_of_text|>" [tokenizer] return_tensors = "pt" @@ -29,4 +29,3 @@ binarizer_ste = "tanh" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/deepseekv2.toml b/src/wordllama/config/train/deepseekv2.toml similarity index 99% rename from wordllama/config/train/deepseekv2.toml rename to src/wordllama/config/train/deepseekv2.toml index 07116b9..b3036b1 100644 --- a/wordllama/config/train/deepseekv2.toml +++ b/src/wordllama/config/train/deepseekv2.toml @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/gemma2_27B.toml b/src/wordllama/config/train/gemma2_27B.toml similarity index 99% rename from wordllama/config/train/gemma2_27B.toml rename to src/wordllama/config/train/gemma2_27B.toml index 69aa380..66dbcb9 100644 --- a/wordllama/config/train/gemma2_27B.toml +++ b/src/wordllama/config/train/gemma2_27B.toml @@ -28,4 +28,3 @@ binarizer_ste = "tanh" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/l2_supercat.toml b/src/wordllama/config/train/l2_supercat.toml similarity index 99% rename from wordllama/config/train/l2_supercat.toml rename to src/wordllama/config/train/l2_supercat.toml index 0ac4f88..bd2b14d 100644 --- a/wordllama/config/train/l2_supercat.toml +++ b/src/wordllama/config/train/l2_supercat.toml @@ -32,4 +32,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/l2p3.toml b/src/wordllama/config/train/l2p3.toml similarity index 96% rename from wordllama/config/train/l2p3.toml rename to src/wordllama/config/train/l2p3.toml index 28a8eab..dd1adea 100644 --- a/wordllama/config/train/l2p3.toml +++ b/src/wordllama/config/train/l2p3.toml @@ -2,7 +2,7 @@ dim = 13312 n_vocab = 32000 hf_model_id = "meta-llama/Llama-2-70b-hf" -pad_token = "" +pad_token = "" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/l2p3_lg.toml b/src/wordllama/config/train/l2p3_lg.toml similarity index 96% rename from wordllama/config/train/l2p3_lg.toml rename to src/wordllama/config/train/l2p3_lg.toml index f6c2f32..a814346 100644 --- a/wordllama/config/train/l2p3_lg.toml +++ b/src/wordllama/config/train/l2p3_lg.toml @@ -2,7 +2,7 @@ dim = 17408 n_vocab = 32000 hf_model_id = "meta-llama/Llama-2-70b-hf" -pad_token = "" +pad_token = "" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/l3_supercat.toml b/src/wordllama/config/train/l3_supercat.toml similarity index 95% rename from wordllama/config/train/l3_supercat.toml rename to src/wordllama/config/train/l3_supercat.toml index 0fe7c82..8f2df63 100644 --- a/wordllama/config/train/l3_supercat.toml +++ b/src/wordllama/config/train/l3_supercat.toml @@ -2,7 +2,7 @@ dim = 28672 n_vocab = 128256 hf_model_id = "meta-llama/Meta-Llama-3.1-405B" -pad_token = "<|end_of_text|>" +pad_token = "<|end_of_text|>" [tokenizer] return_tensors = "pt" @@ -32,4 +32,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/llama2_70B.toml b/src/wordllama/config/train/llama2_70B.toml similarity index 96% rename from wordllama/config/train/llama2_70B.toml rename to src/wordllama/config/train/llama2_70B.toml index 0f9fc46..af8e9f6 100644 --- a/wordllama/config/train/llama2_70B.toml +++ b/src/wordllama/config/train/llama2_70B.toml @@ -2,7 +2,7 @@ dim = 8192 n_vocab = 32000 hf_model_id = "meta-llama/Llama-2-70b-hf" -pad_token = "" +pad_token = "" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/llama3_70B.toml b/src/wordllama/config/train/llama3_70B.toml similarity index 94% rename from wordllama/config/train/llama3_70B.toml rename to src/wordllama/config/train/llama3_70B.toml index 0b31c7b..fe67067 100644 --- a/wordllama/config/train/llama3_70B.toml +++ b/src/wordllama/config/train/llama3_70B.toml @@ -2,7 +2,7 @@ dim = 8192 n_vocab = 128256 hf_model_id = "meta-llama/Meta-Llama-3-70B" -pad_token = "<|end_of_text|>" +pad_token = "<|end_of_text|>" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/llama3_8B.toml b/src/wordllama/config/train/llama3_8B.toml similarity index 99% rename from wordllama/config/train/llama3_8B.toml rename to src/wordllama/config/train/llama3_8B.toml index 0ce7baa..88f42f0 100644 --- a/wordllama/config/train/llama3_8B.toml +++ b/src/wordllama/config/train/llama3_8B.toml @@ -28,4 +28,3 @@ binarizer_ste = "tanh" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/llamaguard.toml b/src/wordllama/config/train/llamaguard.toml similarity index 96% rename from wordllama/config/train/llamaguard.toml rename to src/wordllama/config/train/llamaguard.toml index 04dabba..dcd89a8 100644 --- a/wordllama/config/train/llamaguard.toml +++ b/src/wordllama/config/train/llamaguard.toml @@ -2,7 +2,7 @@ dim = 4096 n_vocab = 32000 hf_model_id = "meta-llama/LlamaGuard-7b" -pad_token = "" +pad_token = "" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/miqu.toml b/src/wordllama/config/train/miqu.toml similarity index 96% rename from wordllama/config/train/miqu.toml rename to src/wordllama/config/train/miqu.toml index 3858e06..46ae966 100644 --- a/wordllama/config/train/miqu.toml +++ b/src/wordllama/config/train/miqu.toml @@ -2,7 +2,7 @@ dim = 8192 n_vocab = 32000 hf_model_id = "152334H/miqu-1-70b-sf" -pad_token = "" +pad_token = "" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/mixtral.toml b/src/wordllama/config/train/mixtral.toml similarity index 99% rename from wordllama/config/train/mixtral.toml rename to src/wordllama/config/train/mixtral.toml index 52c1084..1591485 100644 --- a/wordllama/config/train/mixtral.toml +++ b/src/wordllama/config/train/mixtral.toml @@ -28,4 +28,3 @@ binarizer_ste = "tanh" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/mixtral_8x22B.toml b/src/wordllama/config/train/mixtral_8x22B.toml similarity index 99% rename from wordllama/config/train/mixtral_8x22B.toml rename to src/wordllama/config/train/mixtral_8x22B.toml index f55f278..61368e0 100644 --- a/wordllama/config/train/mixtral_8x22B.toml +++ b/src/wordllama/config/train/mixtral_8x22B.toml @@ -28,4 +28,3 @@ binarizer_ste = "tanh" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/openelm_3B.toml b/src/wordllama/config/train/openelm_3B.toml similarity index 96% rename from wordllama/config/train/openelm_3B.toml rename to src/wordllama/config/train/openelm_3B.toml index eef1f75..632972f 100644 --- a/wordllama/config/train/openelm_3B.toml +++ b/src/wordllama/config/train/openelm_3B.toml @@ -2,7 +2,7 @@ dim = 3072 n_vocab = 32000 hf_model_id = "apple/OpenELM-3B" -pad_token = "" +pad_token = "" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/phi3_medium.toml b/src/wordllama/config/train/phi3_medium.toml similarity index 95% rename from wordllama/config/train/phi3_medium.toml rename to src/wordllama/config/train/phi3_medium.toml index d98b3cd..81987af 100644 --- a/wordllama/config/train/phi3_medium.toml +++ b/src/wordllama/config/train/phi3_medium.toml @@ -2,7 +2,7 @@ dim = 5120 n_vocab = 32064 hf_model_id = "microsoft/Phi-3-medium-4k-instruct" -pad_token = "<|endoftext|>" +pad_token = "<|endoftext|>" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/qwen2_72B.toml b/src/wordllama/config/train/qwen2_72B.toml similarity index 95% rename from wordllama/config/train/qwen2_72B.toml rename to src/wordllama/config/train/qwen2_72B.toml index f61f246..db7904d 100644 --- a/wordllama/config/train/qwen2_72B.toml +++ b/src/wordllama/config/train/qwen2_72B.toml @@ -2,7 +2,7 @@ dim = 8192 n_vocab = 152064 hf_model_id = "Qwen/Qwen2-72B" -pad_token = "<|endoftext|>" +pad_token = "<|endoftext|>" [tokenizer] return_tensors = "pt" @@ -28,4 +28,3 @@ binarizer_ste = "ste" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/config/train/yi_1v5_34B.toml b/src/wordllama/config/train/yi_1v5_34B.toml similarity index 99% rename from wordllama/config/train/yi_1v5_34B.toml rename to src/wordllama/config/train/yi_1v5_34B.toml index cfe18c7..d4ee9fe 100644 --- a/wordllama/config/train/yi_1v5_34B.toml +++ b/src/wordllama/config/train/yi_1v5_34B.toml @@ -28,4 +28,3 @@ binarizer_ste = "tanh" [matryoshka] dims = [1024, 512, 256, 128, 64] - diff --git a/wordllama/embedding/__init__.py b/src/wordllama/embedding/__init__.py similarity index 100% rename from wordllama/embedding/__init__.py rename to src/wordllama/embedding/__init__.py diff --git a/wordllama/embedding/word_llama_embedding.py b/src/wordllama/embedding/word_llama_embedding.py similarity index 87% rename from wordllama/embedding/word_llama_embedding.py rename to src/wordllama/embedding/word_llama_embedding.py index 9873bb4..781358a 100644 --- a/wordllama/embedding/word_llama_embedding.py +++ b/src/wordllama/embedding/word_llama_embedding.py @@ -1,13 +1,13 @@ +import warnings +from typing import Union + import numpy as np +import safetensors.torch import torch from torch import nn -import safetensors.torch - from transformers import AutoTokenizer -from typing import Union, List, Dict from ..adapters import AvgPool -import warnings class WordLlamaEmbedding(nn.Module): @@ -24,9 +24,7 @@ def __init__(self, config, tokenizer_kwargs=None, dims=None): super().__init__() self.config = config model = config.model - self.embedding = nn.Embedding( - model.n_vocab, model.dim if dims is None else dims - ) + self.embedding = nn.Embedding(model.n_vocab, model.dim if dims is None else dims) if tokenizer_kwargs: self.tokenizer_kwargs = tokenizer_kwargs @@ -44,7 +42,7 @@ def build(cls, filepath, config, dims=None): safetensors.torch.load_model(word_llama, filepath) return word_llama - def forward(self, tensors: Dict[str, torch.Tensor]): + def forward(self, tensors: dict[str, torch.Tensor]): return { "token_ids": tensors["input_ids"], "token_embeddings": self.embedding(tensors["input_ids"]), @@ -61,7 +59,7 @@ def tokenize(self, *args, **kwargs): @torch.inference_mode() def embed( self, - texts: Union[str, List[str]], + texts: Union[str, list[str]], norm: bool = False, binarize: bool = False, pack: bool = True, @@ -82,14 +80,14 @@ def embed( # Clamp out-of-bounds input_ids if tensors["input_ids"].max() >= self.embedding.num_embeddings: - warnings.warn("Some input_ids are out of bounds. Clamping to valid range.") - tensors["input_ids"] = tensors["input_ids"].clamp( - 0, self.embedding.num_embeddings - 1 + warnings.warn( + "Some input_ids are out of bounds. Clamping to valid range.", stacklevel=2 ) + tensors["input_ids"] = tensors["input_ids"].clamp(0, self.embedding.num_embeddings - 1) # Check for NaNs in input_ids and replace with 0 if torch.isnan(tensors["input_ids"]).any(): - warnings.warn("NaN values found in input_ids. Replacing NaNs with 0.") + warnings.warn("NaN values found in input_ids. Replacing NaNs with 0.", stacklevel=2) tensors["input_ids"][torch.isnan(tensors["input_ids"])] = 0 # Ensure at least one non-zero value in the attention mask @@ -97,7 +95,8 @@ def embed( attention_mask = tensors["attention_mask"] if (attention_mask.sum(dim=1) == 0).any(): warnings.warn( - "Some attention masks are all zeros. Setting the first token to 1 for these cases." + "Some attention masks are all zeros. Setting the first token to 1 for these cases.", + stacklevel=2, ) attention_mask[attention_mask.sum(dim=1) == 0, 0] = 1 diff --git a/wordllama/extract/__init__.py b/src/wordllama/extract/__init__.py similarity index 100% rename from wordllama/extract/__init__.py rename to src/wordllama/extract/__init__.py index e3f80eb..00db1b7 100644 --- a/wordllama/extract/__init__.py +++ b/src/wordllama/extract/__init__.py @@ -1,4 +1,4 @@ -from .extract_llama_70B import extract_llama_70B from .extract_hf import extract_from_hf +from .extract_llama_70B import extract_llama_70B __all__ = ["extract_from_hf", "extract_llama_70B"] diff --git a/wordllama/extract/extract_hf.py b/src/wordllama/extract/extract_hf.py similarity index 93% rename from wordllama/extract/extract_hf.py rename to src/wordllama/extract/extract_hf.py index d882263..2848157 100644 --- a/wordllama/extract/extract_hf.py +++ b/src/wordllama/extract/extract_hf.py @@ -1,5 +1,5 @@ import safetensors.torch -from transformers import AutoModelForCausalLM, AutoModel +from transformers import AutoModel, AutoModelForCausalLM from ..config import WordLlamaConfig from ..embedding.word_llama_embedding import WordLlamaEmbedding diff --git a/wordllama/extract/extract_llama_70B.py b/src/wordllama/extract/extract_llama_70B.py similarity index 99% rename from wordllama/extract/extract_llama_70B.py rename to src/wordllama/extract/extract_llama_70B.py index 76f4597..329d38c 100644 --- a/wordllama/extract/extract_llama_70B.py +++ b/src/wordllama/extract/extract_llama_70B.py @@ -1,4 +1,5 @@ import os + import safetensors.torch from ..config import Config diff --git a/wordllama/extract/extract_safetensors.py b/src/wordllama/extract/extract_safetensors.py similarity index 100% rename from wordllama/extract/extract_safetensors.py rename to src/wordllama/extract/extract_safetensors.py diff --git a/wordllama/inference.py b/src/wordllama/inference.py similarity index 90% rename from wordllama/inference.py rename to src/wordllama/inference.py index f5159ac..47264f4 100644 --- a/wordllama/inference.py +++ b/src/wordllama/inference.py @@ -1,16 +1,16 @@ +import logging +from typing import Callable, Optional, Union + import numpy as np from tokenizers import Tokenizer -from typing import Callable, List, Optional, Tuple, Union -import logging from .algorithms import ( - kmeans_clustering, - vector_similarity, binarize_and_packbits, deduplicate_embeddings, + kmeans_clustering, + vector_similarity, ) from .algorithms.semantic_splitter import SemanticSplitter -from .config import WordLlamaConfig from .mode_decorators import dense_only # Set up logging @@ -41,7 +41,7 @@ def __init__( self.tokenizer.enable_padding() self.tokenizer.no_truncation() - def tokenize(self, texts: Union[str, List[str]]) -> List: + def tokenize(self, texts: Union[str, list[str]]) -> list: """Tokenize input texts using the configured tokenizer. Args: @@ -55,18 +55,16 @@ def tokenize(self, texts: Union[str, List[str]]) -> List: else: assert isinstance(texts, list), "Input texts must be str or List[str]" - return self.tokenizer.encode_batch( - texts, is_pretokenized=False, add_special_tokens=False - ) + return self.tokenizer.encode_batch(texts, is_pretokenized=False, add_special_tokens=False) def embed( self, - texts: Union[str, List[str]], + texts: Union[str, list[str]], norm: bool = False, return_np: bool = True, pool_embeddings: bool = True, batch_size: int = 64, - ) -> Union[np.ndarray, List]: + ) -> Union[np.ndarray, list]: """Generate embeddings for input texts with optional normalization and binarization. Args: @@ -118,9 +116,7 @@ def embed( # Normalize embeddings after pooling if norm: - batch_embeddings /= np.linalg.norm( - batch_embeddings, axis=1, keepdims=True - ) + batch_embeddings /= np.linalg.norm(batch_embeddings, axis=1, keepdims=True) # Binarize embeddings if self.binary: @@ -197,15 +193,13 @@ def key(self, query: str, norm: bool = True) -> Callable[[str], float]: def similarity_key(candidate: str) -> float: candidate_embedding = self.embed(candidate, norm=norm) - return self.vector_similarity( - query_embedding[0], candidate_embedding[0] - ).item() + return self.vector_similarity(query_embedding[0], candidate_embedding[0]).item() return similarity_key def rank( - self, query: str, docs: List[str], sort: bool = True, batch_size: int = 64 - ) -> List[Tuple[str, float]]: + self, query: str, docs: list[str], sort: bool = True, batch_size: int = 64 + ) -> list[tuple[str, float]]: """Rank documents based on their similarity to a query. Result may be sorted by similarity score in descending order, or not (see `sort` parameter) @@ -220,9 +214,8 @@ def rank( List[Tuple[str, float]]: A list of tuples `(doc, score)`. """ assert isinstance(query, str), "Query must be a string" - assert ( - isinstance(docs, list) and len(docs) > 1 - ), "Docs must be a list of 2 more more strings." + assert isinstance(docs, list), "Docs must be a list" + assert len(docs) > 1, "Docs must contain 2 or more strings" query_embedding = self.embed(query) doc_embeddings = self.embed(docs, batch_size=batch_size) scores = self.vector_similarity(query_embedding[0], doc_embeddings) @@ -235,11 +228,11 @@ def rank( def deduplicate( self, - docs: List[str], + docs: list[str], threshold: float = 0.9, return_indices: bool = False, batch_size: Optional[int] = None, - ) -> List[Union[str, int]]: + ) -> list[Union[str, int]]: """Deduplicate documents based on a similarity threshold. Args: @@ -255,20 +248,16 @@ def deduplicate( if batch_size is None: batch_size = 500 if self.binary else 5000 - duplicate_indices = deduplicate_embeddings( - doc_embeddings, threshold, batch_size - ) + duplicate_indices = deduplicate_embeddings(doc_embeddings, threshold, batch_size) if return_indices: # turn set of numpy int into sorted list of python int - duplicate_indices = list(map(lambda x: x.item(), duplicate_indices)) + duplicate_indices = [x.item() for x in duplicate_indices] return sorted(duplicate_indices) - unique_docs = [ - doc for idx, doc in enumerate(docs) if idx not in duplicate_indices - ] + unique_docs = [doc for idx, doc in enumerate(docs) if idx not in duplicate_indices] return unique_docs - def topk(self, query: str, candidates: List[str], k: int = 3) -> List[str]: + def topk(self, query: str, candidates: list[str], k: int = 3) -> list[str]: """Retrieve the top-k most similar documents to a query. Args: @@ -285,9 +274,7 @@ def topk(self, query: str, candidates: List[str], k: int = 3) -> List[str]: ranked_docs = self.rank(query, candidates) return [doc for doc, score in ranked_docs[:k]] - def filter( - self, query: str, candidates: List[str], threshold: float = 0.3 - ) -> List[str]: + def filter(self, query: str, candidates: list[str], threshold: float = 0.3) -> list[str]: """Filter documents to include only those similar to the query above a threshold. Args: @@ -305,23 +292,21 @@ def filter( ).squeeze() filtered_docs = [ - doc - for doc, score in zip(candidates, similarity_scores) - if score > threshold + doc for doc, score in zip(candidates, similarity_scores) if score > threshold ] return filtered_docs @dense_only def cluster( self, - docs: List[str], + docs: list[str], k: int, max_iterations: int = 100, tolerance: float = 1e-4, n_init: int = 10, min_iterations: int = 5, random_state=None, - ) -> Tuple[List[int], float]: + ) -> tuple[list[int], float]: """Cluster documents into `k` clusters using KMeans clustering. Args: @@ -363,7 +348,7 @@ def split( cleanup_size: int = 24, intermediate_size: int = 96, return_minima: bool = False, - ) -> List[str]: + ) -> list[str]: """Split text into semantically coherent chunks. Args: diff --git a/wordllama/mode_decorators.py b/src/wordllama/mode_decorators.py similarity index 100% rename from wordllama/mode_decorators.py rename to src/wordllama/mode_decorators.py diff --git a/wordllama/tokenizers/__init__.py b/src/wordllama/tokenizers/__init__.py similarity index 85% rename from wordllama/tokenizers/__init__.py rename to src/wordllama/tokenizers/__init__.py index 31140ae..3efc0dc 100644 --- a/wordllama/tokenizers/__init__.py +++ b/src/wordllama/tokenizers/__init__.py @@ -1,4 +1,5 @@ from pathlib import Path + from tokenizers import Tokenizer @@ -17,9 +18,7 @@ def tokenizer_from_file(file_name: str) -> Tokenizer: tokenizer_path = current_dir / file_name # Assert that the tokenizer configuration file exists - assert ( - tokenizer_path.exists() - ), f"Tokenizer configuration file {tokenizer_path} does not exist." + assert tokenizer_path.exists(), f"Tokenizer configuration file {tokenizer_path} does not exist." # Load the tokenizer from the specified file tokenizer = Tokenizer.from_file(str(tokenizer_path)) diff --git a/wordllama/tokenizers/l2_supercat_tokenizer_config.json b/src/wordllama/tokenizers/l2_supercat_tokenizer_config.json similarity index 99% rename from wordllama/tokenizers/l2_supercat_tokenizer_config.json rename to src/wordllama/tokenizers/l2_supercat_tokenizer_config.json index 599eb71..16ca4e4 100644 --- a/wordllama/tokenizers/l2_supercat_tokenizer_config.json +++ b/src/wordllama/tokenizers/l2_supercat_tokenizer_config.json @@ -93389,4 +93389,4 @@ "▁ ▁▁▁▁▁▁▁▁▁▁▁▁▁▁" ] } -} \ No newline at end of file +} diff --git a/wordllama/trainers/__init__.py b/src/wordllama/trainers/__init__.py similarity index 100% rename from wordllama/trainers/__init__.py rename to src/wordllama/trainers/__init__.py diff --git a/wordllama/trainers/reduce_dimension.py b/src/wordllama/trainers/reduce_dimension.py similarity index 95% rename from wordllama/trainers/reduce_dimension.py rename to src/wordllama/trainers/reduce_dimension.py index 130d8d2..6decaa7 100644 --- a/wordllama/trainers/reduce_dimension.py +++ b/src/wordllama/trainers/reduce_dimension.py @@ -11,7 +11,6 @@ SimilarityFunction, ) - logger = logging.getLogger(__name__) @@ -58,9 +57,7 @@ def setup_evaluator(self) -> SequentialEvaluator: ) for dim in self.config.matryoshka_dims ] - return SequentialEvaluator( - evaluators, main_score_function=lambda scores: scores[0] - ) + return SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[0]) def initialize_trainer(self) -> SentenceTransformerTrainer: return SentenceTransformerTrainer( diff --git a/wordllama/weights/l2_supercat_256.safetensors b/src/wordllama/weights/l2_supercat_256.safetensors similarity index 100% rename from wordllama/weights/l2_supercat_256.safetensors rename to src/wordllama/weights/l2_supercat_256.safetensors diff --git a/wordllama/wordllama.py b/src/wordllama/wordllama.py similarity index 94% rename from wordllama/wordllama.py rename to src/wordllama/wordllama.py index 296d503..81062c2 100644 --- a/wordllama/wordllama.py +++ b/src/wordllama/wordllama.py @@ -1,15 +1,14 @@ import logging import warnings from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Optional, Union import requests from safetensors import safe_open from tokenizers import Tokenizer +from .config import Model2VecModels, ModelURI, WordLlamaModels from .inference import WordLlamaInference -from .config import ModelURI, WordLlamaModels, Model2VecModels - logger = logging.getLogger(__name__) @@ -26,7 +25,7 @@ class WordLlama: DEFAULT_CACHE_DIR = Path.home() / ".cache" / "wordllama" @classmethod - def list_configs(cls) -> Dict[str, List[str]]: + def list_configs(cls) -> dict[str, list[str]]: """ List the available configurations. @@ -166,9 +165,7 @@ def resolve_file( response = requests.get(url, stream=True) response.raise_for_status() except requests.RequestException as e: - logger.error( - f"Failed to download {file_type} file '{filename}' from '{url}': {e}" - ) + logger.error(f"Failed to download {file_type} file '{filename}' from '{url}': {e}") raise FileNotFoundError( f"Failed to download {file_type} file '{filename}' from '{url}'." ) from e @@ -267,18 +264,12 @@ def load( # Validate dimensions if dim not in model_uri.available_dims: - raise ValueError( - f"Model dimension must be one of {model_uri.available_dims}" - ) + raise ValueError(f"Model dimension must be one of {model_uri.available_dims}") if trunc_dim is not None: if trunc_dim > dim: - raise ValueError( - f"Cannot truncate to a higher dimension ({trunc_dim} > {dim})" - ) + raise ValueError(f"Cannot truncate to a higher dimension ({trunc_dim} > {dim})") if trunc_dim not in model_uri.available_dims: - raise ValueError( - f"Truncated dimension must be one of {model_uri.available_dims}" - ) + raise ValueError(f"Truncated dimension must be one of {model_uri.available_dims}") return cls._load( config_name=config_name, @@ -384,14 +375,13 @@ def load_tokenizer( if use_local_if_exists: # Check in the default path if tokenizer_file_path.exists(): - logger.debug( - f"Loading tokenizer from local config_name: {tokenizer_file_path}" - ) + logger.debug(f"Loading tokenizer from local config_name: {tokenizer_file_path}") return Tokenizer.from_file(str(tokenizer_file_path)) else: warnings.warn( f"Local tokenizer config not found at {tokenizer_file_path}. " - "Falling back to Hugging Face." + "Falling back to Hugging Face.", + stacklevel=2, ) if hf_model_id: diff --git a/tests/test_functional.py b/tests/test_functional.py index 6876a46..bc29003 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,4 +1,5 @@ import unittest + from wordllama import WordLlama @@ -24,8 +25,8 @@ def test_function_sorted(self): sim_key = wl.key(query) sorted_candidates = sorted(candidates, key=sim_key, reverse=True) - self.assertIsInstance(sorted_candidates, list) - self.assertEqual(len(sorted_candidates), len(candidates)) + assert isinstance(sorted_candidates, list) + assert len(sorted_candidates) == len(candidates) def test_function_max(self): wl = WordLlama.load() @@ -35,7 +36,5 @@ def test_function_max(self): sim_key = wl.key(query) best_candidate = max(candidates, key=sim_key) - self.assertIn(best_candidate, candidates) - self.assertEqual( - best_candidate, max(candidates, key=lambda x: wl.similarity(query, x)) - ) + assert best_candidate in candidates + assert best_candidate == max(candidates, key=lambda x: wl.similarity(query, x)) diff --git a/tests/test_inference.py b/tests/test_inference.py index 7101421..97c4e55 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,15 +1,10 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch + import numpy as np +import pytest + from wordllama.inference import WordLlamaInference -from wordllama.config import ( - WordLlamaConfig, - WordLlamaModel, - TokenizerConfig, - TrainingConfig, - MatryoshkaConfig, - TokenizerInferenceConfig, -) np.random.seed(42) @@ -43,9 +38,9 @@ def mock_encode_batch(texts, *args, **kwargs): def test_deduplicate_cosine(self, mock_embed): docs = ["doc1", "doc1_dup", "a second document that is different", "doc1_dup2"] deduplicated_docs = self.model.deduplicate(docs, threshold=0.9) - self.assertEqual(len(deduplicated_docs), 2) - self.assertIn("doc1", deduplicated_docs) - self.assertIn("a second document that is different", deduplicated_docs) + assert len(deduplicated_docs) == 2 + assert "doc1" in deduplicated_docs + assert "a second document that is different" in deduplicated_docs @patch.object( WordLlamaInference, @@ -62,10 +57,10 @@ def test_deduplicate_cosine(self, mock_embed): def test_deduplicate_no_duplicates(self, mock_embed): docs = ["doc1", "doc2", "doc3"] deduplicated_docs = self.model.deduplicate(docs, threshold=0.9) - self.assertEqual(len(deduplicated_docs), 3) - self.assertIn("doc1", deduplicated_docs) - self.assertIn("doc2", deduplicated_docs) - self.assertIn("doc3", deduplicated_docs) + assert len(deduplicated_docs) == 3 + assert "doc1" in deduplicated_docs + assert "doc2" in deduplicated_docs + assert "doc3" in deduplicated_docs @patch.object( WordLlamaInference, @@ -75,8 +70,8 @@ def test_deduplicate_no_duplicates(self, mock_embed): def test_deduplicate_all_duplicates(self, mock_embed): docs = ["doc1", "doc1_dup", "doc1_dup2"] deduplicated_docs = self.model.deduplicate(docs, threshold=0.9) - self.assertEqual(len(deduplicated_docs), 1) - self.assertIn("doc1", deduplicated_docs) + assert len(deduplicated_docs) == 1 + assert "doc1" in deduplicated_docs @patch.object( WordLlamaInference, @@ -85,32 +80,30 @@ def test_deduplicate_all_duplicates(self, mock_embed): ) def test_deduplicate_return_indices(self, mock_embed): docs = ["doc1", "doc1_dup", "doc1_dup2"] - duplicated_idx = self.model.deduplicate( - docs, return_indices=True, threshold=0.9 - ) - self.assertEqual(len(duplicated_idx), 2) - self.assertIn(1, duplicated_idx) - self.assertIn(2, duplicated_idx) + duplicated_idx = self.model.deduplicate(docs, return_indices=True, threshold=0.9) + assert len(duplicated_idx) == 2 + assert 1 in duplicated_idx + assert 2 in duplicated_idx def test_tokenize(self): tokens = self.model.tokenize("test string") self.mock_tokenizer.encode_batch.assert_called_with( ["test string"], is_pretokenized=False, add_special_tokens=False ) - self.assertEqual(len(tokens), 1) + assert len(tokens) == 1 def test_embed(self): embeddings = self.model.embed("test string", return_np=True) - self.assertEqual(embeddings.shape, (1, 64)) + assert embeddings.shape == (1, 64) def test_cluster_fails_binary(self): self.model.binary = True - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.model.cluster(["a", "b", "c"]) def test_split_fails_binary(self): self.model.binary = True - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.model.split("a" * 1000) def test_similarity_cosine(self): @@ -119,7 +112,7 @@ def mock_encode_batch(texts, *args, **kwargs): self.mock_tokenizer.encode_batch.side_effect = mock_encode_batch sim_score = self.model.similarity("test string 1", "test string 2") - self.assertTrue(isinstance(sim_score, float)) + assert isinstance(sim_score, float) def test_similarity_hamming(self): def mock_encode_batch(texts, *args, **kwargs): @@ -129,7 +122,7 @@ def mock_encode_batch(texts, *args, **kwargs): self.model.binary = True sim_score = self.model.similarity("test string 1", "test string 2") - self.assertTrue(isinstance(sim_score, float)) + assert isinstance(sim_score, float) def test_rank_cosine(self): def mock_encode_batch(texts, *args, **kwargs): @@ -145,7 +138,7 @@ def mock_embed(texts, *args, **kwargs): if isinstance(texts, str): texts = [texts] embeddings = [] - for i, text in enumerate(texts): + for i, _text in enumerate(texts): embedding = np.zeros(64, dtype=np.float32) embedding[1 if len(texts) == 1 else i] = 1 embeddings.append(embedding) @@ -155,15 +148,15 @@ def mock_embed(texts, *args, **kwargs): docs = ["doc1", "doc2", "doc3"] ranked_docs = self.model.rank("test query", docs) - self.assertEqual(len(ranked_docs), len(docs)) - self.assertTrue(all(isinstance(score, float) for doc, score in ranked_docs)) - self.assertEqual(ranked_docs[0], ("doc2", 1.0)) + assert len(ranked_docs) == len(docs) + assert all(isinstance(score, float) for doc, score in ranked_docs) + assert ranked_docs[0] == ("doc2", 1.0) # test turning off sorting unsorted_docs = self.model.rank("test query", docs, sort=False) - self.assertEqual(len(unsorted_docs), len(docs)) - self.assertTrue(all(isinstance(score, float) for doc, score in unsorted_docs)) - self.assertEqual(unsorted_docs[1], ("doc2", 1.0)) + assert len(unsorted_docs) == len(docs) + assert all(isinstance(score, float) for doc, score in unsorted_docs) + assert unsorted_docs[1] == ("doc2", 1.0) def test_rank_hamming(self): def mock_encode_batch(texts, *args, **kwargs): @@ -178,8 +171,8 @@ def mock_encode_batch(texts, *args, **kwargs): ranked_docs = self.model.rank("test query", docs) mock_hamming.assert_called_once() # check setting binary calls hamming - self.assertEqual(len(ranked_docs), len(docs)) - self.assertTrue(all(isinstance(score, float) for doc, score in ranked_docs)) + assert len(ranked_docs) == len(docs) + assert all(isinstance(score, float) for doc, score in ranked_docs) def test_instantiate_with_truncation(self): truncated_embedding = np.random.rand(128256, 32) @@ -187,22 +180,22 @@ def test_instantiate_with_truncation(self): embedding=truncated_embedding, tokenizer=self.mock_tokenizer, ) - self.assertEqual(truncated_model.embedding.shape[1], 32) + assert truncated_model.embedding.shape[1] == 32 def test_error_on_wrong_embedding_type(self): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.model.embed(np.array([1, 2])) def test_binarization_and_packing(self): self.model.binary = True binary_output = self.model.embed("test string") - self.assertIsInstance(binary_output, np.ndarray) - self.assertEqual(binary_output.dtype, np.uint64) + assert isinstance(binary_output, np.ndarray) + assert binary_output.dtype == np.uint64 def test_normalization_effect(self): normalized_output = self.model.embed("test string", norm=True) norm = np.linalg.norm(normalized_output) - self.assertAlmostEqual(norm, 1, places=5) + assert norm == pytest.approx(1, abs=1e-5) if __name__ == "__main__": diff --git a/tests/test_kmeans.py b/tests/test_kmeans.py index 9a82111..59b2c33 100644 --- a/tests/test_kmeans.py +++ b/tests/test_kmeans.py @@ -1,4 +1,5 @@ import unittest + import numpy as np from wordllama.algorithms.kmeans import ( @@ -39,53 +40,43 @@ def setUp(self): def test_kmeans_clustering_convergence(self): k = 2 - labels, inertia = kmeans_clustering( - self.embeddings, k, random_state=self.random_state - ) + labels, inertia = kmeans_clustering(self.embeddings, k, random_state=self.random_state) - self.assertEqual(len(labels), self.embeddings.shape[0]) - self.assertGreater(inertia, 0) + assert len(labels) == self.embeddings.shape[0] + assert inertia > 0 def test_kmeans_clustering_labels(self): k = 2 - labels, _ = kmeans_clustering( - self.embeddings, k, random_state=self.random_state - ) + labels, _ = kmeans_clustering(self.embeddings, k, random_state=self.random_state) # Check that labels are within the valid range for label in labels: - self.assertIn(label, range(k)) + assert label in range(k) def test_kmeans_clustering_different_k(self): k = 3 - labels, _ = kmeans_clustering( - self.embeddings, k, random_state=self.random_state - ) + labels, _ = kmeans_clustering(self.embeddings, k, random_state=self.random_state) - self.assertEqual(len(labels), self.embeddings.shape[0]) + assert len(labels) == self.embeddings.shape[0] # Check that labels are within the valid range for label in labels: - self.assertIn(label, range(k)) + assert label in range(k) def test_kmeans_clustering_random_state(self): k = 2 labels1, losses1 = kmeans_clustering(self.embeddings, k, random_state=42) labels2, losses2 = kmeans_clustering(self.embeddings, k, random_state=42) - self.assertEqual(labels1, labels2) - self.assertEqual(losses1, losses2) + assert labels1 == labels2 + assert losses1 == losses2 def test_kmeans_clustering_different_initializations(self): k = 2 - labels1, inertia1 = kmeans_clustering( - self.embeddings, k, random_state=42, n_init=1 - ) - labels2, inertia2 = kmeans_clustering( - self.embeddings, k, random_state=42, n_init=10 - ) + labels1, inertia1 = kmeans_clustering(self.embeddings, k, random_state=42, n_init=1) + labels2, inertia2 = kmeans_clustering(self.embeddings, k, random_state=42, n_init=10) - self.assertGreater(inertia1, inertia2) + assert inertia1 > inertia2 if __name__ == "__main__": diff --git a/tests/test_minima_functions.py b/tests/test_minima_functions.py index b74cb4d..de355f1 100644 --- a/tests/test_minima_functions.py +++ b/tests/test_minima_functions.py @@ -1,5 +1,8 @@ import unittest + import numpy as np +import pytest + from wordllama.algorithms.find_local_minima import ( find_local_minima, windowed_cross_similarity, @@ -17,23 +20,21 @@ def test_find_local_minima(self): x_minima, y_minima = find_local_minima(self.y, window_size=3, poly_order=2) # Known minima for sin(x) in the given range [0, 2*pi] - expected_x_minima = np.array([3 * np.pi / 2], dtype=np.float32) + np.array([3 * np.pi / 2], dtype=np.float32) expected_y_minima = np.array([-1.0], dtype=np.float32) # Check if the found minima are correct (allow small numerical tolerance) - np.testing.assert_array_almost_equal( - self.y[x_minima], expected_y_minima, decimal=2 - ) + np.testing.assert_array_almost_equal(self.y[x_minima], expected_y_minima, decimal=2) np.testing.assert_array_almost_equal(y_minima, expected_y_minima, decimal=2) def test_find_local_minima_invalid_window_size(self): # Test that the function raises a ValueError for an invalid window size - with self.assertRaises(ValueError): + with pytest.raises(ValueError): find_local_minima(self.y, window_size=2, poly_order=2) def test_find_local_minima_invalid_polynomial_order(self): # Test that the function raises a ValueError for an invalid polynomial order - with self.assertRaises(ValueError): + with pytest.raises(ValueError): find_local_minima(self.y, window_size=11, poly_order=11) @@ -63,17 +64,17 @@ def test_windowed_cross_similarity(self): def test_windowed_cross_similarity_invalid_window(self): # Test invalid window size (even window size should raise ValueError) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): windowed_cross_similarity(self.embeddings, window_size=4) # Test invalid window size (window size < 3) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): windowed_cross_similarity(self.embeddings, window_size=2) def test_windowed_cross_similarity_small_window(self): # Test windowed cross similarity with a small window (size 3) result = windowed_cross_similarity(self.embeddings, window_size=3) - self.assertEqual(result.shape[0], self.embeddings.shape[0]) + assert result.shape[0] == self.embeddings.shape[0] if __name__ == "__main__": diff --git a/tests/test_semantic_splitter.py b/tests/test_semantic_splitter.py index 2cf989b..5826029 100644 --- a/tests/test_semantic_splitter.py +++ b/tests/test_semantic_splitter.py @@ -1,5 +1,7 @@ import unittest + import numpy as np + from wordllama.algorithms.semantic_splitter import SemanticSplitter @@ -10,21 +12,19 @@ def setUp(self): def test_flatten(self): nested_list = [[1, 2], [3, 4], [5, 6]] flattened = self.splitter.flatten(nested_list) - self.assertEqual(flattened, [1, 2, 3, 4, 5, 6]) + assert flattened == [1, 2, 3, 4, 5, 6] def test_constrained_split(self): text = "This is a test sentence. Another sentence here. And one more. " * 10 chunks = self.splitter.constrained_split(text, target_size=50) - self.assertTrue(all(len(chunk) <= 50 for chunk in chunks)) - self.assertEqual(" ".join(chunks), text.strip()) + assert all(len(chunk) <= 50 for chunk in chunks) + assert " ".join(chunks) == text.strip() def test_split(self): text = "Short sentence.\n\nTwo sentences. Without a line break.\n\nAnother short one." - chunks = self.splitter.split( - text, target_size=30, cleanup_size=10, intermediate_size=20 - ) - self.assertTrue(all(len(chunk) <= 30 for chunk in chunks)) - self.assertTrue(all(len(chunk) >= 10 for chunk in chunks)) + chunks = self.splitter.split(text, target_size=30, cleanup_size=10, intermediate_size=20) + assert all(len(chunk) <= 30 for chunk in chunks) + assert all(len(chunk) >= 10 for chunk in chunks) def test_reconstruct(self): lines = ["Short text.", "Another short text.", "A bit longer text here."] @@ -39,14 +39,12 @@ def test_reconstruct(self): savgol_window=3, ) - self.assertIsInstance(reconstructed, list) - self.assertTrue(all(isinstance(chunk, str) for chunk in reconstructed)) + assert isinstance(reconstructed, list) + assert all(isinstance(chunk, str) for chunk in reconstructed) def test_reconstruct_return_minima(self): lines = ["Short text.", "Another short text.", "A bit longer text here."] - embeddings = np.random.rand(3, 16).astype( - np.float32 - ) # 3 texts, 16-dimensional embeddings + embeddings = np.random.rand(3, 16).astype(np.float32) # 3 texts, 16-dimensional embeddings result = self.splitter.reconstruct( lines, @@ -58,12 +56,12 @@ def test_reconstruct_return_minima(self): return_minima=True, ) - self.assertIsInstance(result, tuple) - self.assertEqual(len(result), 3) + assert isinstance(result, tuple) + assert len(result) == 3 roots, y, sim_avg = result - self.assertIsInstance(roots, np.ndarray) - self.assertIsInstance(y, np.ndarray) - self.assertIsInstance(sim_avg, np.ndarray) + assert isinstance(roots, np.ndarray) + assert isinstance(y, np.ndarray) + assert isinstance(sim_avg, np.ndarray) if __name__ == "__main__": diff --git a/tests/test_splitting_functions.py b/tests/test_splitting_functions.py index 471a691..8014793 100644 --- a/tests/test_splitting_functions.py +++ b/tests/test_splitting_functions.py @@ -1,11 +1,14 @@ +import string import unittest + +import pytest + from wordllama.algorithms.splitter import ( constrained_batches, - split_sentences, constrained_coalesce, reverse_merge, + split_sentences, ) -import string class TestSplitter(unittest.TestCase): @@ -14,26 +17,26 @@ def test_constrained_batches(self): data = ["a", "bb", "ccc", "dddd", "eeeee"] batches = list(constrained_batches(data, max_size=5)) expected = [("a", "bb"), ("ccc",), ("dddd",), ("eeeee",)] - self.assertEqual(batches, expected) + assert batches == expected # Batching with max_count data = ["a", "bb", "ccc", "dddd", "eeeee"] batches = list(constrained_batches(data, max_size=10, max_count=2)) - self.assertEqual(batches, [("a", "bb"), ("ccc", "dddd"), ("eeeee",)]) + assert batches == [("a", "bb"), ("ccc", "dddd"), ("eeeee",)] # Batching with get_len data = ["a", "bb", "ccc", "dddd", "eeeee"] batches = list(constrained_batches(data, max_size=5, get_len=lambda x: 1)) - self.assertEqual(batches, [("a", "bb", "ccc", "dddd", "eeeee")]) + assert batches == [("a", "bb", "ccc", "dddd", "eeeee")] # Non-strict mode data = ["aaaaaa", "b", "c"] batches = list(constrained_batches(data, max_size=5, strict=False)) - self.assertEqual(batches, [("aaaaaa",), ("b", "c")]) + assert batches == [("aaaaaa",), ("b", "c")] # Strict mode with item exceeding max_size data = ["aaaaaa", "b", "c"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): batches = list(constrained_batches(data, max_size=5)) def test_split_sentences(self): @@ -45,37 +48,35 @@ def test_split_sentences(self): "This is another sentence!", "And another one?", ] - self.assertEqual(sentences, expected) + assert sentences == expected # Test with no punctuation text = "This is a text without punctuation" sentences = split_sentences(text) expected = ["This is a text without punctuation"] - self.assertEqual(sentences, expected) + assert sentences == expected # Test with custom punctuation text = "Sentence one# Sentence two# Sentence three" sentences = split_sentences(text, punct_chars={"#"}) expected = ["Sentence one#", "Sentence two#", "Sentence three"] - self.assertEqual(sentences, expected) + assert sentences == expected # Test with text ending without punctuation text = "This is a sentence. This is another sentence" sentences = split_sentences(text) expected = ["This is a sentence.", "This is another sentence"] - self.assertEqual(sentences, expected) + assert sentences == expected def test_constrained_coalesce(self): letters = list(string.ascii_lowercase) # Using the example from the documentation result = constrained_coalesce(letters, max_size=5, separator="") expected = ["abcd", "efgh", "ijkl", "mnop", "qrst", "uvwx", "yz"] - self.assertEqual(result, expected) + assert result == expected # Test with max_iterations=1 - result = constrained_coalesce( - letters, max_size=5, separator="", max_iterations=1 - ) + result = constrained_coalesce(letters, max_size=5, separator="", max_iterations=1) expected_one_pass = [ "ab", "cd", @@ -91,37 +92,37 @@ def test_constrained_coalesce(self): "wx", "yz", ] - self.assertEqual(result, expected_one_pass) + assert result == expected_one_pass # Test with data that cannot be combined data = ["a"] * 100 result = constrained_coalesce(data, max_size=1, max_iterations=5) - self.assertEqual(result, ["a"] * 100) + assert result == ["a"] * 100 def test_reverse_merge(self): # Basic merging test data = ["long enough", "short", "tiny", "adequate", "s"] result = reverse_merge(data, n=6, separator=" ") expected = ["long enough short tiny", "adequate s"] - self.assertEqual(result, expected) + assert result == expected # Test with empty list data = [] result = reverse_merge(data, n=5) expected = [] - self.assertEqual(result, expected) + assert result == expected # All strings longer than n data = ["string1", "string2", "string3"] result = reverse_merge(data, n=5) expected = ["string1", "string2", "string3"] - self.assertEqual(result, expected) + assert result == expected # All strings shorter than n data = ["a", "bb", "ccc"] result = reverse_merge(data, n=5, separator=" ") expected = ["a bb ccc"] - self.assertEqual(result, expected) + assert result == expected if __name__ == "__main__": diff --git a/tests/test_vector_similarity.py b/tests/test_vector_similarity.py index e066d96..017baaa 100644 --- a/tests/test_vector_similarity.py +++ b/tests/test_vector_similarity.py @@ -1,7 +1,8 @@ import unittest + import numpy as np -from wordllama.algorithms import vector_similarity, binarize_and_packbits +from wordllama.algorithms import binarize_and_packbits, vector_similarity class TestVectorSimilarity(unittest.TestCase): @@ -9,21 +10,21 @@ def test_binarization_and_packing(self): vec = np.zeros((1, 64)) vec[0][7] = 1 binary_output = binarize_and_packbits(vec) - self.assertIsInstance(binary_output, np.ndarray) - self.assertEqual(binary_output.dtype, np.uint64) - self.assertEqual(binary_output, 1) + assert isinstance(binary_output, np.ndarray) + assert binary_output.dtype == np.uint64 + assert binary_output == 1 def test_cosine_similarity_direct(self): vec1 = np.random.rand(1, 64) vec2 = np.random.rand(1, 64) result = vector_similarity(vec1, vec2, binary=False) - self.assertIsInstance(result.item(), float) + assert isinstance(result.item(), float) def test_hamming_similarity_direct(self): vec1 = np.expand_dims(np.random.randint(2, size=64, dtype=np.uint64), axis=0) vec2 = np.expand_dims(np.random.randint(2, size=64, dtype=np.uint64), axis=0) result = vector_similarity(vec1, vec2, binary=True) - self.assertIsInstance(result.item(), float) + assert isinstance(result.item(), float) if __name__ == "__main__": diff --git a/tests/test_wordllama.py b/tests/test_wordllama.py index b5a8c42..323f4f1 100644 --- a/tests/test_wordllama.py +++ b/tests/test_wordllama.py @@ -1,18 +1,13 @@ import unittest -from unittest.mock import patch, MagicMock, mock_open, call from pathlib import Path +from unittest.mock import MagicMock, call, mock_open, patch + import numpy as np +import pytest from tokenizers import Tokenizer + +from wordllama.config.models import WordLlamaModels from wordllama.wordllama import WordLlama, WordLlamaInference -from wordllama.config import ( - WordLlamaConfig, - TokenizerConfig, - MatryoshkaConfig, - WordLlamaModel, - TrainingConfig, - TokenizerInferenceConfig, -) -from wordllama.config.models import ModelURI, WordLlamaModels class TestWordLlama(unittest.TestCase): @@ -24,7 +19,7 @@ def setUp(self): # Assemble WordLlamaConfig self.config = "l2_supercat" - self.model_uri = getattr(WordLlamaModels, "l2_supercat") + self.model_uri = WordLlamaModels.l2_supercat def test_list_configs(self): output = WordLlama.list_configs() @@ -82,9 +77,7 @@ def test_resolve_file_downloads_if_not_found( handle.write.assert_has_calls([call(b"chunk1"), call(b"chunk2")]) # Assert the returned path is correct - self.assertEqual( - weights_path, Path("/dummy/cache/weights/l2_supercat_256.safetensors") - ) + assert weights_path == Path("/dummy/cache/weights/l2_supercat_256.safetensors") @patch.object(WordLlama, "resolve_file", autospec=True) def test_load_with_default_cache_dir(self, mock_resolve_file): @@ -94,24 +87,21 @@ def test_load_with_default_cache_dir(self, mock_resolve_file): # Setup mock for resolve_file default_cache_dir = WordLlama.DEFAULT_CACHE_DIR weights_path = default_cache_dir / "weights" / "l2_supercat_256.safetensors" - tokenizer_path = ( - default_cache_dir / "tokenizers" / "l2_supercat_tokenizer_config.json" - ) + tokenizer_path = default_cache_dir / "tokenizers" / "l2_supercat_tokenizer_config.json" mock_resolve_file.side_effect = [weights_path, tokenizer_path] # Mock tokenizer and weights loading - with patch( - "wordllama.wordllama.WordLlama.load_tokenizer", - return_value=MagicMock(spec=Tokenizer), - ) as mock_load_tokenizer, patch( - "wordllama.wordllama.safe_open", autospec=True - ) as mock_safe_open: + with ( + patch( + "wordllama.wordllama.WordLlama.load_tokenizer", + return_value=MagicMock(spec=Tokenizer), + ) as mock_load_tokenizer, + patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open, + ): # Mock the tensor returned by safe_open mock_tensor = MagicMock() mock_tensor.__getitem__.return_value = np.random.rand(256, 4096) - mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = ( - mock_tensor - ) + mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = mock_tensor # Call load without specifying cache_dir model = WordLlama.load( @@ -146,7 +136,7 @@ def test_load_with_default_cache_dir(self, mock_resolve_file): ), ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) - self.assertEqual(mock_resolve_file.call_count, 2) + assert mock_resolve_file.call_count == 2 # Assert load_tokenizer was called with correct path mock_load_tokenizer.assert_called_once_with( @@ -162,7 +152,7 @@ def test_load_with_default_cache_dir(self, mock_resolve_file): ) # Assert the returned model is an instance of WordLlamaInference - self.assertIsInstance(model, WordLlamaInference) + assert isinstance(model, WordLlamaInference) @patch.object(WordLlama, "resolve_file", autospec=True) def test_load_with_custom_cache_dir(self, mock_resolve_file): @@ -192,24 +182,21 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file): # Setup mock for resolve_file weights_path = ( - expected_resolved_dirs[key] - / "weights" - / "l2_supercat_256.safetensors" + expected_resolved_dirs[key] / "weights" / "l2_supercat_256.safetensors" ) tokenizer_path = ( - expected_resolved_dirs[key] - / "tokenizers" - / "l2_supercat_tokenizer_config.json" + expected_resolved_dirs[key] / "tokenizers" / "l2_supercat_tokenizer_config.json" ) mock_resolve_file.side_effect = [weights_path, tokenizer_path] # Mock tokenizer and weights loading - with patch( - "wordllama.wordllama.WordLlama.load_tokenizer", - return_value=MagicMock(spec=Tokenizer), - ) as mock_load_tokenizer, patch( - "wordllama.wordllama.safe_open", autospec=True - ) as mock_safe_open: + with ( + patch( + "wordllama.wordllama.WordLlama.load_tokenizer", + return_value=MagicMock(spec=Tokenizer), + ) as mock_load_tokenizer, + patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open, + ): # Mock the tensor returned by safe_open mock_tensor = MagicMock() mock_tensor.__getitem__.return_value = np.random.rand(256, 4096) @@ -252,7 +239,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file): ), ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) - self.assertEqual(mock_resolve_file.call_count, 2) + assert mock_resolve_file.call_count == 2 # Assert load_tokenizer was called with correct path mock_load_tokenizer.assert_called_once_with( @@ -268,7 +255,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file): ) # Assert the returned model is an instance of WordLlamaInference - self.assertIsInstance(model, WordLlamaInference) + assert isinstance(model, WordLlamaInference) @patch.object(WordLlama, "resolve_file", autospec=True) def test_load_with_disable_download(self, mock_resolve_file): @@ -279,7 +266,7 @@ def test_load_with_disable_download(self, mock_resolve_file): mock_resolve_file.side_effect = FileNotFoundError("File not found") # Call load with disable_download=True and expect FileNotFoundError - with self.assertRaises(FileNotFoundError): + with pytest.raises(FileNotFoundError): WordLlama.load( config=self.config, cache_dir=Path("/dummy/cache"), @@ -303,7 +290,7 @@ def test_load_with_disable_download(self, mock_resolve_file): ) ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) - self.assertEqual(mock_resolve_file.call_count, 1) + assert mock_resolve_file.call_count == 1 @patch.object(WordLlama, "resolve_file", autospec=True) def test_load_with_truncated_dimension(self, mock_resolve_file): @@ -312,24 +299,21 @@ def test_load_with_truncated_dimension(self, mock_resolve_file): """ # Setup mock for resolve_file weights_path = Path("/dummy/cache/weights/l2_supercat_256.safetensors") - tokenizer_path = Path( - "/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json" - ) + tokenizer_path = Path("/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json") mock_resolve_file.side_effect = [weights_path, tokenizer_path] # Mock tokenizer and weights loading - with patch( - "wordllama.wordllama.WordLlama.load_tokenizer", - return_value=MagicMock(spec=Tokenizer), - ) as mock_load_tokenizer, patch( - "wordllama.wordllama.safe_open", autospec=True - ) as mock_safe_open: + with ( + patch( + "wordllama.wordllama.WordLlama.load_tokenizer", + return_value=MagicMock(spec=Tokenizer), + ) as mock_load_tokenizer, + patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open, + ): # Mock the tensor returned by safe_open mock_tensor = MagicMock() mock_tensor.__getitem__.return_value = np.random.rand(256, 4096) - mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = ( - mock_tensor - ) + mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = mock_tensor # Call load with trunc_dim model = WordLlama.load( @@ -364,7 +348,7 @@ def test_load_with_truncated_dimension(self, mock_resolve_file): ), ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) - self.assertEqual(mock_resolve_file.call_count, 2) + assert mock_resolve_file.call_count == 2 # Assert load_tokenizer was called with correct path mock_load_tokenizer.assert_called_once_with( @@ -380,7 +364,7 @@ def test_load_with_truncated_dimension(self, mock_resolve_file): ) # Assert the returned model is an instance of WordLlamaInference - self.assertIsInstance(model, WordLlamaInference) + assert isinstance(model, WordLlamaInference) # Assert that the embedding was truncated mock_tensor.__getitem__.assert_called_with((slice(None), slice(None, 128))) @@ -400,15 +384,14 @@ def test_load_tokenizer_fallback( # Setup mocks # First call for weights, second call for tokenizer weights_path = Path("/dummy/cache/weights/l2_supercat_256.safetensors") - tokenizer_path = Path( - "/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json" - ) + tokenizer_path = Path("/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json") mock_resolve_file.side_effect = [weights_path, tokenizer_path] # Simulate tokenizer config does not exist by patching Path.exists - with patch( - "wordllama.wordllama.Path.exists", side_effect=[False, False] - ), patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open: + with ( + patch("wordllama.wordllama.Path.exists", side_effect=[False, False]), + patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open, + ): # Mock the tensor returned by safe_open mock_tensor = MagicMock() mock_tensor.__getitem__.return_value = np.random.rand(256, 4096) @@ -449,7 +432,7 @@ def test_load_tokenizer_fallback( ), ] mock_resolve_file.assert_has_calls(expected_calls, any_order=False) - self.assertEqual(mock_resolve_file.call_count, 2) + assert mock_resolve_file.call_count == 2 # Assert Tokenizer.from_pretrained was called since local config was not found mock_from_pretrained.assert_called_once_with("meta-llama/Llama-2-70b-hf") @@ -462,7 +445,7 @@ def test_load_tokenizer_fallback( ) # Assert the returned model is an instance of WordLlamaInference - self.assertIsInstance(model, WordLlamaInference) + assert isinstance(model, WordLlamaInference) if __name__ == "__main__": From 130912c00a6b0c4d371e54d0a1ab9fe43c3287aa Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 10:14:06 -0700 Subject: [PATCH 2/9] update build and release --- .github/workflows/ci.yml | 81 ++++++++++++++++++++++++++--------- .github/workflows/publish.yml | 8 ++-- Makefile | 16 ++++++- README.md | 21 +++++++-- 4 files changed, 97 insertions(+), 29 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 432bed4..a83d3dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,36 +2,75 @@ name: CI on: pull_request: - branches: - - main + branches: [main] + push: + branches: [main] jobs: - build: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: true + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install ruff and pre-commit + run: uv pip install --system ruff pre-commit + + - name: Run ruff + run: | + ruff check src/ tests/ + ruff format --check src/ tests/ + + - name: Run pre-commit + run: pre-commit run --all-files + + test: runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: - os: [ubuntu-latest] - python-version: ['3.9', '3.10', '3.11'] + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Needed for setuptools_scm + + - name: Install uv + uses: astral-sh/setup-uv@v2 + with: + enable-cache: true + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} + - name: Install build dependencies + run: uv pip install --system setuptools wheel setuptools_scm cython "numpy>=2" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install setuptools wheel setuptools_scm cython numpy + - name: Build Cython extensions + run: uv run python setup.py build_ext --inplace - - name: Build and install package - run: | - python setup.py build_ext --inplace - python -m pip install . + - name: Install package with dev dependencies + run: uv pip install --system -e ".[dev]" - - name: Test installation - run: | - python -m unittest discover -s tests + - name: Run pytest + run: uv run pytest tests/ -v --cov=wordllama --cov-report=xml --cov-report=term-missing + - name: Upload coverage to Codecov + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' + uses: codecov/codecov-action@v3 + with: + files: ./coverage.xml + fail_ci_if_error: false diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 97a47f0..811f902 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -222,10 +222,13 @@ jobs: with: python-version: '3.11' + - name: Install uv + uses: astral-sh/setup-uv@v2 + - name: Install required Python packages run: | - python -m pip install --upgrade pip setuptools setuptools_scm wheel cython numpy - python -m pip install build twine + pip install uv + uv pip install --system setuptools setuptools_scm wheel cython "numpy>=2" build twine - name: Print detected version run: | @@ -262,4 +265,3 @@ jobs: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} packages_dir: dist/ - diff --git a/Makefile b/Makefile index 74e56a1..f90927e 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help install install-dev build clean test test-cov lint format pre-commit-install pre-commit-run all +.PHONY: help install install-dev build clean test test-cov lint format pre-commit-install pre-commit-run tag all help: @echo "WordLlama Development Makefile" @@ -14,6 +14,7 @@ help: @echo " format - Format code with ruff" @echo " pre-commit-install - Install pre-commit hooks" @echo " pre-commit-run - Run pre-commit on all files" + @echo " tag VERSION=X.Y.Z - Create and push a new release tag" @echo " all - Clean, build, lint, format, and test" install: @@ -50,5 +51,16 @@ pre-commit-install: pre-commit-run: uv run pre-commit run --all-files +tag: + @if [ -z "$(VERSION)" ]; then \ + echo "Error: VERSION is required. Usage: make tag VERSION=0.4.0"; \ + exit 1; \ + fi + @echo "Creating and pushing tag v$(VERSION)..." + git tag -a v$(VERSION) -m "Release version $(VERSION)" + git push origin v$(VERSION) + @echo "- Tag v$(VERSION) created and pushed" + @echo "- GitHub Actions will build and publish the release" + all: clean build lint format test - @echo "✓ All tasks completed successfully" + @echo "- All tasks completed successfully" diff --git a/README.md b/README.md index c5ce668..491e2d8 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})") # 2. Foundations of neural science (Score: 0.2115) # 3. Introduction to philosophy: logic (Score: 0.1067) # 4. Cooking delicious pasta at home (Score: 0.0045) -# +# # Best Match: Introduction to neural networks (Score: 0.3414) ``` @@ -143,7 +143,7 @@ The following table presents the performance of WordLlama models compared to oth 8k documents from the `ag_news` dataset - Single core performance (CPU), i9 12th gen, DDR4 3200 -- NVIDIA A4500 (GPU) +- NVIDIA A4500 (GPU)

Word Llama @@ -200,7 +200,7 @@ print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})") # 2. Foundations of neural science (Score: 0.2115) # 3. Introduction to philosophy: logic (Score: 0.1067) # 4. Cooking delicious pasta at home (Score: 0.0045) -# +# # Best Match: Introduction to neural networks (Score: 0.3414) ``` @@ -341,6 +341,21 @@ The L2 Supercat model was trained using a batch size of 512 on a single A100 GPU - DSPy evaluators - Retrieval-Augmented Generation (RAG) pipelines +## Development + +For local development: + +```bash +git clone https://github.com/dleemiller/WordLlama.git +cd WordLlama +pip install uv +uv sync --all-extras +uv run python setup.py build_ext --inplace +uv run pytest +``` + +See the [Makefile](Makefile) for common development commands. + ## Extracting Token Embeddings To extract token embeddings from a model, ensure you have agreed to the user agreement and logged in using the Hugging Face CLI (for LLaMA models). You can then use the following snippet: From d4414a868163043d6066b4d8a33daf8760955189 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 10:25:25 -0700 Subject: [PATCH 3/9] linting --- build_tools/build_wheels.sh | 9 ++- classifiers.txt | 1 - dataset_loader.py | 49 ++++++++----- eval_mteb.py | 23 +++---- find_mteb.sh | 12 ++-- pyproject.toml | 4 ++ train.py | 27 +++----- .../blog/semantic_split/wl_semantic_blog.md | 68 +++++++++---------- tutorials/extract_token_embeddings.md | 2 +- 9 files changed, 100 insertions(+), 95 deletions(-) diff --git a/build_tools/build_wheels.sh b/build_tools/build_wheels.sh index eaca201..8c20b82 100755 --- a/build_tools/build_wheels.sh +++ b/build_tools/build_wheels.sh @@ -31,17 +31,17 @@ if [[ "$(uname)" == "Darwin" ]]; then else # For x86_64 builds, adjust deployment target and install llvm-openmp via Conda export MACOSX_DEPLOYMENT_TARGET=13.0 # Matches Homebrew's libomp minimum - + # Install llvm-openmp via Conda OPENMP_URL="https://anaconda.org/conda-forge/llvm-openmp/19.1.6/download/osx-64/llvm-openmp-19.1.6-ha54dae1_0.conda" echo "Installing llvm-openmp via Conda for x86_64..." sudo conda create -n build $OPENMP_URL PREFIX="$CONDA_HOME/envs/build" - + # Use system Clang and point it to Conda's OpenMP paths export CC="/usr/bin/clang" export CXX="/usr/bin/clang++" - + # Locate omp.h dynamically OMP_INCLUDE_DIR=$(find $PREFIX -type d -name "include" | head -n 1) if [[ -n "$OMP_INCLUDE_DIR" && -f "$OMP_INCLUDE_DIR/omp.h" ]]; then @@ -52,7 +52,7 @@ if [[ "$(uname)" == "Darwin" ]]; then ls -R $PREFIX # Debug: Show the structure of the Conda environment exit 1 fi - + # Set flags export CPPFLAGS="-Xpreprocessor -fopenmp -I$OMP_INCLUDE_DIR" export CFLAGS="-I$OMP_INCLUDE_DIR -ffp-contract=off" @@ -74,4 +74,3 @@ python -m pip install --upgrade pip # Install cibuildwheel and build wheels python -m pip install --upgrade cibuildwheel python -m cibuildwheel --output-dir wheelhouse - diff --git a/classifiers.txt b/classifiers.txt index 442896a..547e84a 100644 --- a/classifiers.txt +++ b/classifiers.txt @@ -1,4 +1,3 @@ Programming Language :: Python :: 3 License :: OSI Approved :: MIT License Operating System :: OS Independent - diff --git a/dataset_loader.py b/dataset_loader.py index d900224..eb8437e 100644 --- a/dataset_loader.py +++ b/dataset_loader.py @@ -1,5 +1,6 @@ from datasets import load_dataset + def load_datasets(seed=42): def shuffled_load(path, *args, **kwargs): return load_dataset(path, *args, **kwargs).shuffle(seed) @@ -8,42 +9,54 @@ def shuffled_load(path, *args, **kwargs): "train": { # NLI (Natural Language Inference) datasets "all-nli": shuffled_load("sentence-transformers/all-nli", "triplet", split="train"), - "nli-for-simcse": shuffled_load("sentence-transformers/nli-for-simcse", "triplet", split="train"), - + "nli-for-simcse": shuffled_load( + "sentence-transformers/nli-for-simcse", "triplet", split="train" + ), # Information Retrieval datasets - "msmarco": shuffled_load("sentence-transformers/msmarco-bm25", "triplet", split="train"), + "msmarco": shuffled_load( + "sentence-transformers/msmarco-bm25", "triplet", split="train" + ), "mr-tydi": shuffled_load("sentence-transformers/mr-tydi", "en-triplet", split="train"), - # Text Summarization / Compression datasets - "compression": shuffled_load("sentence-transformers/sentence-compression", split="train"), + "compression": shuffled_load( + "sentence-transformers/sentence-compression", split="train" + ), "simple-wiki": shuffled_load("sentence-transformers/simple-wiki", split="train"), - # News datasets "agnews": shuffled_load("sentence-transformers/agnews", split="train"), "ccnews": shuffled_load("sentence-transformers/ccnews", split="train"), "npr": shuffled_load("sentence-transformers/npr", split="train"), - # Question Answering (QA) datasets "gooaq": shuffled_load("sentence-transformers/gooaq", split="train"), - "yahoo-answers": shuffled_load("sentence-transformers/yahoo-answers", "title-question-answer-pair", split="train"), + "yahoo-answers": shuffled_load( + "sentence-transformers/yahoo-answers", "title-question-answer-pair", split="train" + ), "eli5": shuffled_load("sentence-transformers/eli5", split="train"), "amazon-qa": shuffled_load("sentence-transformers/amazon-qa", split="train[0:1000000]"), "squad": shuffled_load("sentence-transformers/squad", split="train"), - "natural_questions": shuffled_load("sentence-transformers/natural-questions", split="train"), + "natural_questions": shuffled_load( + "sentence-transformers/natural-questions", split="train" + ), "hotpotqa": shuffled_load("sentence-transformers/hotpotqa", "triplet", split="train"), - # Duplicate Detection datasets - "quora_duplicates": shuffled_load("sentence-transformers/quora-duplicates", "pair", split="train"), - "quora_triplets": shuffled_load("sentence-transformers/quora-duplicates", "triplet", split="train"), - + "quora_duplicates": shuffled_load( + "sentence-transformers/quora-duplicates", "pair", split="train" + ), + "quora_triplets": shuffled_load( + "sentence-transformers/quora-duplicates", "triplet", split="train" + ), # Scientific / Academic datasets "specter": shuffled_load("sentence-transformers/specter", "triplet", split="train"), - # Stack Exchange datasets - "stackexchange_bbp": shuffled_load("sentence-transformers/stackexchange-duplicates", "body-body-pair", split="train"), - "stackexchange_ttp": shuffled_load("sentence-transformers/stackexchange-duplicates", "title-title-pair", split="train"), - "stackexchange_ppp": shuffled_load("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train"), - + "stackexchange_bbp": shuffled_load( + "sentence-transformers/stackexchange-duplicates", "body-body-pair", split="train" + ), + "stackexchange_ttp": shuffled_load( + "sentence-transformers/stackexchange-duplicates", "title-title-pair", split="train" + ), + "stackexchange_ppp": shuffled_load( + "sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train" + ), # Lexical / Linguistic datasets "altlex": shuffled_load("sentence-transformers/altlex", split="train"), }, diff --git a/eval_mteb.py b/eval_mteb.py index 13907b6..1891e32 100644 --- a/eval_mteb.py +++ b/eval_mteb.py @@ -1,5 +1,6 @@ # ruff: noqa: E402 from __future__ import annotations + import os # Set environment variables @@ -8,17 +9,17 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" -import mteb import logging - from functools import partial -from typing import Any, List -from wordllama import load_training, Config +from typing import Any + +import mteb import numpy as np from more_itertools import chunked - from mteb.model_meta import ModelMeta +from wordllama import Config, load_training + logger = logging.getLogger(__name__) @@ -109,23 +110,18 @@ class WordLlamaWrapper: - def __init__( - self, model_name: str, config, embed_dim: int | None = None, **kwargs - ) -> None: + def __init__(self, model_name: str, config, embed_dim: int | None = None, **kwargs) -> None: self._model_name = model_name self._embed_dim = embed_dim print(model_name) self.model = load_training(model_name, config, dims=embed_dim).to("cuda") - def encode(self, sentences: List[str], batch_size=512, **kwargs: Any) -> np.ndarray: + def encode(self, sentences: list[str], batch_size=512, **kwargs: Any) -> np.ndarray: all_embeddings = [] for chunk in chunked(sentences, batch_size): embed_chunk = ( - self.model.embed(chunk, return_pt=True, norm=True) - .to("cpu") - .detach() - .numpy() + self.model.embed(chunk, return_pt=True, norm=True).to("cpu").detach().numpy() ) all_embeddings.append(embed_chunk) @@ -135,7 +131,6 @@ def encode(self, sentences: List[str], batch_size=512, **kwargs: Any) -> np.ndar if __name__ == "__main__": - TASK_LIST = ( TASK_LIST_CLASSIFICATION + TASK_LIST_CLUSTERING diff --git a/find_mteb.sh b/find_mteb.sh index 9224659..b4bc79a 100755 --- a/find_mteb.sh +++ b/find_mteb.sh @@ -43,12 +43,12 @@ while IFS= read -r line; do filename=$(echo "$line" | cut -d' ' -f1) score=$(echo "$line" | cut -d' ' -f2) subset=$(echo "$line" | cut -d' ' -f3) - + echo "$filename $score $subset" - + # Remove .json extension for matching task_name=${filename%.json} - + for category in "${!TASK_LISTS[@]}"; do if [[ "${TASK_LISTS[$category]}" == *"$task_name"* ]]; then eval "${category}_scores+=(${score})" @@ -63,8 +63,8 @@ while IFS= read -r line; do fi done done < <(find "$search_dir" -type f -name "*.json" -print0 | xargs -0 -I {} sh -c ' -jq -r ".. | objects | select((.hf_subset? == \"en\" or .hf_subset? == \"default\") and .main_score?) | \"\(.main_score) \(.hf_subset)\"" {} | -while read score subset; do +jq -r ".. | objects | select((.hf_subset? == \"en\" or .hf_subset? == \"default\") and .main_score?) | \"\(.main_score) \(.hf_subset)\"" {} | +while read score subset; do formatted_score=$(printf "%.2f" $(echo "$score * 100" | bc)) echo "$(basename {}) $formatted_score $subset" done @@ -86,7 +86,7 @@ print_results() { local scores=("${!1}") local files=("${!2}") local name=$3 - + if [ ${#scores[@]} -ne 0 ]; then average=$(calculate_average "$1") echo "" diff --git a/pyproject.toml b/pyproject.toml index d4bf64e..af632f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,10 @@ dependencies = [ [project.optional-dependencies] dev = [ + "setuptools", + "wheel", + "setuptools_scm[toml]", + "Cython", "pytest>=7.4", "pytest-cov>=4.1", "pytest-xdist>=3.3", diff --git a/train.py b/train.py index 9d4fcf8..731f1a1 100644 --- a/train.py +++ b/train.py @@ -7,23 +7,24 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" -import torch -import tqdm -import safetensors.torch from datetime import datetime from pathlib import Path + +import safetensors.torch +import torch +import tqdm +from dataset_loader import load_datasets from sentence_transformers import ( - SentenceTransformerTrainingArguments, SentenceTransformer, + SentenceTransformerTrainingArguments, ) from sentence_transformers.training_args import MultiDatasetBatchSamplers -from wordllama import load_training, Config +from wordllama import Config, load_training +from wordllama.adapters import AvgPool, Binarizer, WeightedProjector from wordllama.config import WordLlamaModel from wordllama.embedding.word_llama_embedding import WordLlamaEmbedding from wordllama.trainers.reduce_dimension import ReduceDimension -from wordllama.adapters import AvgPool, WeightedProjector, Binarizer -from dataset_loader import load_datasets class ReduceDimensionConfig: @@ -108,9 +109,7 @@ def save(self, checkpoint: Path, outdir: Path): # load the projector weights max_dim = max(self.matryoshka_dims) - proj_path = ( - checkpoint / "1_WeightedProjector" / "weighted_projector.safetensors" - ) + proj_path = checkpoint / "1_WeightedProjector" / "weighted_projector.safetensors" proj = WeightedProjector( self.config.model.dim, max_dim, @@ -183,9 +182,7 @@ class TmpConfig: required=True, help="Name of your configuration (eg. [your_config].toml)", ) - parser_save.add_argument( - "--checkpoint", type=str, required=True, help="Path to the checkpoint" - ) + parser_save.add_argument("--checkpoint", type=str, required=True, help="Path to the checkpoint") parser_save.add_argument( "--outdir", type=str, required=True, help="Directory to save the models" ) @@ -196,9 +193,7 @@ class TmpConfig: # Execute based on the command if args.command == "train": - config = ReduceDimensionConfig( - config_name, binarize=args.binarize, norm=args.norm - ) + config = ReduceDimensionConfig(config_name, binarize=args.binarize, norm=args.norm) trainer = ReduceDimension(config) trainer.train() diff --git a/tutorials/blog/semantic_split/wl_semantic_blog.md b/tutorials/blog/semantic_split/wl_semantic_blog.md index bc5e7de..543f739 100644 --- a/tutorials/blog/semantic_split/wl_semantic_blog.md +++ b/tutorials/blog/semantic_split/wl_semantic_blog.md @@ -16,7 +16,7 @@ When splitting/chunking text for Retrieval-Augmented Generation (RAG) applicatio These methods don't require language modeling, but they lack semantic awareness. More advanced techniques using language models can provide better semantic coherence but typically at a significant computational cost and latency. And, although semantic splitting is conceptually simple, it still involves multiple steps to refine a quality algorithm. -WordLlama is a good platform for accomplishing this, since it can incorporate basic semantic information into the chunking process, without adding significant computational requirements. Here, we develop a recipe for semantic splitting with WordLlama using an intuitive process. +WordLlama is a good platform for accomplishing this, since it can incorporate basic semantic information into the chunking process, without adding significant computational requirements. Here, we develop a recipe for semantic splitting with WordLlama using an intuitive process. ## Target texts @@ -83,21 +83,21 @@ print(text[0:300]) J. R. R. Tolkien — The Lord Of The Rings. (1/4) ----------------------------------------------- - - + + THE LORD OF THE RINGS - + by - + J. R. R. TOLKIEN - - - + + + Part 1: The Fellowship of the Ring Part 2: The Two Towers Part 3: The Return of the King - - + + _Complete with Index and Full Appendi @@ -115,21 +115,21 @@ import seaborn as sns def plot_chars(chars_per_line): sns.set(style="whitegrid") fig, axes = plt.subplots(2, 1, figsize=(10, 8)) - + # First plot: full range sns.histplot(chars_per_line, bins=200, ax=axes[0], kde=False) axes[0].set_title("Characters per Line - Full Range") axes[0].set_xlabel("# Characters") axes[0].set_ylabel("$log($Counts$)$") axes[0].semilogy(True) - + # Second plot: zoomed-in range sns.histplot(chars_per_line, bins=1000, ax=axes[1], kde=False) axes[1].set_title("Characters per Line - Zoomed In (0 to 100)") axes[1].set_xlabel("# Characters") axes[1].set_ylabel("Counts") axes[1].set_xlim((0, 200)) - + plt.tight_layout() plt.show() ``` @@ -146,9 +146,9 @@ plot_chars(chars_per_line) ``` - + ![png](output_6_0.png) - + Here we can see a bunch of small fragments with close to zero size. Additionally, there are some smaller segments below 50 characters. While most of the chunks are fewer than 1k characters, there are a few larger ones as well. The chunks that are a few characters or less are not likely to carry much semantic information and are disproportionate compared to most of the other segments. @@ -208,9 +208,9 @@ plot_chars(chars_per_line) ``` - + ![png](output_12_0.png) - + This is better. Let's take care of the larger segments. @@ -268,9 +268,9 @@ plot_chars(chars_per_line) ``` - + ![png](output_15_0.png) - + Now we have a more reasonable starting point for doing semantic splitting. Let's use wordllama to embed the segments into vectors, and compute similarity for all the segments. @@ -301,9 +301,9 @@ plt.title("Cross-similarity of lines") - + ![png](output_17_1.png) - + Here's where we can see how wordllama can help. As we traverse the diagonal, we can identify blocks of similar texts. The very small block in the upper left corner is the table of contents. @@ -334,12 +334,12 @@ plt.show() ``` - + ![png](output_19_0.png) - -With the size of our segments, even 10-20 segments is a decent chunk size. Here we can zoom in on the minimum around the **dashed red line index (308)**. + +With the size of our segments, even 10-20 segments is a decent chunk size. Here we can zoom in on the minimum around the **dashed red line index (308)**. ```python @@ -350,7 +350,7 @@ print("\n".join([lines[i] if i != 308 else f">>>>>>>>>>>>>{lines[i]}<<<<<<<<<<<< 'So do I,' said Gandalf. 'And I wonder many other things. Good-bye now! Take care of yourself! Look out for me, especially at unlikely times! Good-bye!' Frodo saw him to the door. He gave a final wave of his hand, and walked off at a surprising pace; but Frodo thought the old wizard looked unusually bent, almost as if he was carrying a great weight. The evening was closing in, and his cloaked figure quickly vanished into the twilight. Frodo did not see him again for a long time. >>>>>>>>>>>>> - + _Chapter 2_ The Shadow of the Past <<<<<<<<<<<<< @@ -404,9 +404,9 @@ plt.show() ``` - + ![png](output_23_0.png) - + Well that was fun. Now all that's left is to bring the sections back up to our target size. @@ -450,9 +450,9 @@ ax.semilogy(True) - + ![png](output_26_1.png) - + ### Visualize @@ -464,21 +464,21 @@ from IPython.display import Markdown, display def display_strings(string_list, offset=0): """ Convert a list of strings into a markdown table and display it in a Jupyter notebook. - + Parameters: - string_list (list): The list of strings to display - + Returns: - None (displays the table in the notebook) """ # Create the table header table = "| Index | Text |\n|-------|------|\n" - + # Add each string to the table for i, text in enumerate(string_list): row = f"| {i + offset} | {text[:600]}{'...' if len(text) > 600 else ''} |\n" table += row - + # Display the table display(Markdown(table)) @@ -520,7 +520,7 @@ print(f"Length of text: {len(text):.2e} chars\n# of chunks: {len(results)}\n\nPr Length of text: 1.02e+06 chars # of chunks: 784 - + Processing time: CPU times: user 1.31 s, sys: 111 ms, total: 1.43 s Wall time: 677 ms diff --git a/tutorials/extract_token_embeddings.md b/tutorials/extract_token_embeddings.md index f9f33d8..e88efa3 100644 --- a/tutorials/extract_token_embeddings.md +++ b/tutorials/extract_token_embeddings.md @@ -52,7 +52,7 @@ In [1]: from safetensors import safe_open In [2]: with safe_open("/home/lee/Downloads/model-00001-of-00002.safetensors", framework="pt") as f: ...: weights = f.get_tensor("model.embed_tokens.weight") - ...: + ...: In [3]: weights.shape Out[3]: torch.Size([256000, 2304]) From 281164b991793011993764ef16e694f28087271f Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 10:36:57 -0700 Subject: [PATCH 4/9] linting and updating ci for environment --- .github/workflows/ci.yml | 7 ++----- .pre-commit-config.yaml | 2 +- src/wordllama/algorithms/semantic_splitter.py | 6 +++--- src/wordllama/inference.py | 6 +++--- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a83d3dc..bbf43fd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,15 +56,12 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install build dependencies - run: uv pip install --system setuptools wheel setuptools_scm cython "numpy>=2" + - name: Install package with dev dependencies + run: uv sync --all-extras - name: Build Cython extensions run: uv run python setup.py build_ext --inplace - - name: Install package with dev dependencies - run: uv pip install --system -e ".[dev]" - - name: Run pytest run: uv run pytest tests/ -v --cov=wordllama --cov-report=xml --cov-report=term-missing diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 69d1c51..bff5b68 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: check-merge-conflict - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.4 + rev: v0.14.7 hooks: - id: ruff args: [--fix] diff --git a/src/wordllama/algorithms/semantic_splitter.py b/src/wordllama/algorithms/semantic_splitter.py index 69ba2a3..c548054 100644 --- a/src/wordllama/algorithms/semantic_splitter.py +++ b/src/wordllama/algorithms/semantic_splitter.py @@ -79,9 +79,9 @@ def split( List[str]: List of text chunks. """ assert target_size > intermediate_size, "Target size must be larger than intermediate size." - assert ( - intermediate_size > cleanup_size - ), "Intermediate size must be larger than cleanup size." + assert intermediate_size > cleanup_size, ( + "Intermediate size must be larger than cleanup size." + ) lines = text.splitlines() lines = constrained_coalesce(lines, intermediate_size, separator="\n") diff --git a/src/wordllama/inference.py b/src/wordllama/inference.py index 47264f4..8f51f2b 100644 --- a/src/wordllama/inference.py +++ b/src/wordllama/inference.py @@ -268,9 +268,9 @@ def topk(self, query: str, candidates: list[str], k: int = 3) -> list[str]: Returns: List[str]: The top-k documents most similar to the query. """ - assert ( - len(candidates) > k - ), f"Number of candidates ({len(candidates)}) must be greater than k ({k})" + assert len(candidates) > k, ( + f"Number of candidates ({len(candidates)}) must be greater than k ({k})" + ) ranked_docs = self.rank(query, candidates) return [doc for doc, score in ranked_docs[:k]] From 8295bf96bd8492fa3a90cedce4e64bba9a39c6b3 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 10:44:35 -0700 Subject: [PATCH 5/9] fixing installs for lint step --- .github/workflows/ci.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bbf43fd..1ea5d69 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,8 +22,11 @@ jobs: with: python-version: '3.11' - - name: Install ruff and pre-commit - run: uv pip install --system ruff pre-commit + - name: Install ruff, pre-commit, and pytest + run: uv pip install --system ruff pre-commit pytest + + - name: Install package for pytest + run: uv pip install --system -e . - name: Run ruff run: | From 6d90e67cacb3c651b1b3f545e910d41aaa48eef3 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 10:47:37 -0700 Subject: [PATCH 6/9] exception for generated version file --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index af632f4..a42b27d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ exclude = [ ".ruff_cache", "tutorials", "benchmark", + "_version.py", ] [tool.ruff.lint] From 56da8e5a163f4fe861641cc9694ce2f21ba56f85 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 11:23:39 -0700 Subject: [PATCH 7/9] remove pytest from lint step --- .github/workflows/ci.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1ea5d69..e821b8d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,11 +22,8 @@ jobs: with: python-version: '3.11' - - name: Install ruff, pre-commit, and pytest - run: uv pip install --system ruff pre-commit pytest - - - name: Install package for pytest - run: uv pip install --system -e . + - name: Install ruff and pre-commit + run: uv pip install --system ruff pre-commit - name: Run ruff run: | @@ -35,6 +32,8 @@ jobs: - name: Run pre-commit run: pre-commit run --all-files + env: + SKIP: pytest-fast test: runs-on: ${{ matrix.os }} From 9bc8cd930ea5257d02940342f0d1d770bb6ea800 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 11:24:52 -0700 Subject: [PATCH 8/9] update precommit --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bff5b68..01fec72 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer From 4fb7fa70b11f464be350d6eb3340a661b05fce69 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 1 Dec 2025 11:29:35 -0700 Subject: [PATCH 9/9] resolve paths to fix mac build --- tests/test_wordllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_wordllama.py b/tests/test_wordllama.py index 323f4f1..c84e7ce 100644 --- a/tests/test_wordllama.py +++ b/tests/test_wordllama.py @@ -172,7 +172,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file): "tilde": Path("~/tmp_cache").expanduser(), "relative": Path("tmp").resolve(), "relative_dot": Path("./tmp").resolve(), - "absolute": Path("/tmp/cache_dir"), + "absolute": Path("/tmp/cache_dir").resolve(), } for key, cache_dir_input in cache_dirs.items():