diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 432bed4..e821b8d 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -2,36 +2,74 @@ name: CI
on:
pull_request:
- branches:
- - main
+ branches: [main]
+ push:
+ branches: [main]
jobs:
- build:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v2
+ with:
+ enable-cache: true
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+
+ - name: Install ruff and pre-commit
+ run: uv pip install --system ruff pre-commit
+
+ - name: Run ruff
+ run: |
+ ruff check src/ tests/
+ ruff format --check src/ tests/
+
+ - name: Run pre-commit
+ run: pre-commit run --all-files
+ env:
+ SKIP: pytest-fast
+
+ test:
runs-on: ${{ matrix.os }}
strategy:
+ fail-fast: false
matrix:
- os: [ubuntu-latest]
- python-version: ['3.9', '3.10', '3.11']
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0 # Needed for setuptools_scm
+
+ - name: Install uv
+ uses: astral-sh/setup-uv@v2
+ with:
+ enable-cache: true
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install setuptools wheel setuptools_scm cython numpy
+ - name: Install package with dev dependencies
+ run: uv sync --all-extras
- - name: Build and install package
- run: |
- python setup.py build_ext --inplace
- python -m pip install .
+ - name: Build Cython extensions
+ run: uv run python setup.py build_ext --inplace
- - name: Test installation
- run: |
- python -m unittest discover -s tests
+ - name: Run pytest
+ run: uv run pytest tests/ -v --cov=wordllama --cov-report=xml --cov-report=term-missing
+ - name: Upload coverage to Codecov
+ if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11'
+ uses: codecov/codecov-action@v3
+ with:
+ files: ./coverage.xml
+ fail_ci_if_error: false
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 97a47f0..811f902 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -222,10 +222,13 @@ jobs:
with:
python-version: '3.11'
+ - name: Install uv
+ uses: astral-sh/setup-uv@v2
+
- name: Install required Python packages
run: |
- python -m pip install --upgrade pip setuptools setuptools_scm wheel cython numpy
- python -m pip install build twine
+ pip install uv
+ uv pip install --system setuptools setuptools_scm wheel cython "numpy>=2" build twine
- name: Print detected version
run: |
@@ -262,4 +265,3 @@ jobs:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
packages_dir: dist/
-
diff --git a/.gitignore b/.gitignore
index e26add1..9e0b9c3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -162,10 +162,15 @@ cython_debug/
#.idea/
# Ignore generated Cython files
-wordllama/algorithms/*.c
-wordllama/algorithms/*.cpp
-wordllama/algorithms/*.html
+src/wordllama/algorithms/*.c
+src/wordllama/algorithms/*.cpp
+src/wordllama/algorithms/*.html
# Ignore the generated version file
-wordllama/_version.py
+src/wordllama/_version.py
+# uv
+uv.lock
+
+# ruff
+.ruff_cache/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..01fec72
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,32 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v6.0.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ exclude: '\.pyx$|\.pxd$'
+ - id: check-yaml
+ - id: check-toml
+ - id: check-json
+ - id: check-added-large-files
+ args: ['--maxkb=1024']
+ exclude: 'weights/.*\.safetensors$'
+ - id: check-merge-conflict
+
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.14.7
+ hooks:
+ - id: ruff
+ args: [--fix]
+ - id: ruff-format
+
+ - repo: local
+ hooks:
+ - id: pytest-fast
+ name: pytest-fast
+ entry: uv run pytest
+ language: system
+ pass_filenames: false
+ args: [-m, "not slow", --tb=short, -q]
+ types: [python]
+ stages: [pre-commit]
diff --git a/MANIFEST.in b/MANIFEST.in
index 612420c..f799f3a 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,8 +1,8 @@
include LICENSE
include README.md
-recursive-include wordllama *.py *.toml *.json
-include wordllama/weights/*.safetensors
-include wordllama/algorithms/*.pyx
-include wordllama/algorithms/*.pxd
-include wordllama/algorithms/*.so
-include wordllama/algorithms/*.pyd
+recursive-include src/wordllama *.py *.toml *.json
+include src/wordllama/weights/*.safetensors
+include src/wordllama/algorithms/*.pyx
+include src/wordllama/algorithms/*.pxd
+include src/wordllama/algorithms/*.so
+include src/wordllama/algorithms/*.pyd
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..f90927e
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,66 @@
+.PHONY: help install install-dev build clean test test-cov lint format pre-commit-install pre-commit-run tag all
+
+help:
+ @echo "WordLlama Development Makefile"
+ @echo ""
+ @echo "Available targets:"
+ @echo " install - Install package and dependencies"
+ @echo " install-dev - Install package with dev dependencies"
+ @echo " build - Build Cython extensions"
+ @echo " clean - Clean build artifacts"
+ @echo " test - Run tests"
+ @echo " test-cov - Run tests with coverage"
+ @echo " lint - Run ruff linter"
+ @echo " format - Format code with ruff"
+ @echo " pre-commit-install - Install pre-commit hooks"
+ @echo " pre-commit-run - Run pre-commit on all files"
+ @echo " tag VERSION=X.Y.Z - Create and push a new release tag"
+ @echo " all - Clean, build, lint, format, and test"
+
+install:
+ uv sync
+
+install-dev:
+ uv sync --all-extras
+
+build:
+ uv run python setup.py build_ext --inplace
+
+clean:
+ rm -rf build/ dist/ *.egg-info
+ rm -rf src/wordllama/**/*.so src/wordllama/**/*.c src/wordllama/**/*.cpp
+ rm -rf .pytest_cache .ruff_cache htmlcov .coverage
+ find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
+
+test:
+ uv run pytest
+
+test-cov:
+ uv run pytest --cov=wordllama --cov-report=html --cov-report=term-missing
+
+lint:
+ uv run ruff check src/ tests/
+
+format:
+ uv run ruff format src/ tests/
+ uv run ruff check --fix src/ tests/
+
+pre-commit-install:
+ uv run pre-commit install
+
+pre-commit-run:
+ uv run pre-commit run --all-files
+
+tag:
+ @if [ -z "$(VERSION)" ]; then \
+ echo "Error: VERSION is required. Usage: make tag VERSION=0.4.0"; \
+ exit 1; \
+ fi
+ @echo "Creating and pushing tag v$(VERSION)..."
+ git tag -a v$(VERSION) -m "Release version $(VERSION)"
+ git push origin v$(VERSION)
+ @echo "- Tag v$(VERSION) created and pushed"
+ @echo "- GitHub Actions will build and publish the release"
+
+all: clean build lint format test
+ @echo "- All tasks completed successfully"
diff --git a/README.md b/README.md
index c5ce668..491e2d8 100644
--- a/README.md
+++ b/README.md
@@ -83,7 +83,7 @@ print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})")
# 2. Foundations of neural science (Score: 0.2115)
# 3. Introduction to philosophy: logic (Score: 0.1067)
# 4. Cooking delicious pasta at home (Score: 0.0045)
-#
+#
# Best Match: Introduction to neural networks (Score: 0.3414)
```
@@ -143,7 +143,7 @@ The following table presents the performance of WordLlama models compared to oth
8k documents from the `ag_news` dataset
- Single core performance (CPU), i9 12th gen, DDR4 3200
-- NVIDIA A4500 (GPU)
+- NVIDIA A4500 (GPU)
@@ -200,7 +200,7 @@ print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})")
# 2. Foundations of neural science (Score: 0.2115)
# 3. Introduction to philosophy: logic (Score: 0.1067)
# 4. Cooking delicious pasta at home (Score: 0.0045)
-#
+#
# Best Match: Introduction to neural networks (Score: 0.3414)
```
@@ -341,6 +341,21 @@ The L2 Supercat model was trained using a batch size of 512 on a single A100 GPU
- DSPy evaluators
- Retrieval-Augmented Generation (RAG) pipelines
+## Development
+
+For local development:
+
+```bash
+git clone https://github.com/dleemiller/WordLlama.git
+cd WordLlama
+pip install uv
+uv sync --all-extras
+uv run python setup.py build_ext --inplace
+uv run pytest
+```
+
+See the [Makefile](Makefile) for common development commands.
+
## Extracting Token Embeddings
To extract token embeddings from a model, ensure you have agreed to the user agreement and logged in using the Hugging Face CLI (for LLaMA models). You can then use the following snippet:
diff --git a/build_tools/build_wheels.sh b/build_tools/build_wheels.sh
index eaca201..8c20b82 100755
--- a/build_tools/build_wheels.sh
+++ b/build_tools/build_wheels.sh
@@ -31,17 +31,17 @@ if [[ "$(uname)" == "Darwin" ]]; then
else
# For x86_64 builds, adjust deployment target and install llvm-openmp via Conda
export MACOSX_DEPLOYMENT_TARGET=13.0 # Matches Homebrew's libomp minimum
-
+
# Install llvm-openmp via Conda
OPENMP_URL="https://anaconda.org/conda-forge/llvm-openmp/19.1.6/download/osx-64/llvm-openmp-19.1.6-ha54dae1_0.conda"
echo "Installing llvm-openmp via Conda for x86_64..."
sudo conda create -n build $OPENMP_URL
PREFIX="$CONDA_HOME/envs/build"
-
+
# Use system Clang and point it to Conda's OpenMP paths
export CC="/usr/bin/clang"
export CXX="/usr/bin/clang++"
-
+
# Locate omp.h dynamically
OMP_INCLUDE_DIR=$(find $PREFIX -type d -name "include" | head -n 1)
if [[ -n "$OMP_INCLUDE_DIR" && -f "$OMP_INCLUDE_DIR/omp.h" ]]; then
@@ -52,7 +52,7 @@ if [[ "$(uname)" == "Darwin" ]]; then
ls -R $PREFIX # Debug: Show the structure of the Conda environment
exit 1
fi
-
+
# Set flags
export CPPFLAGS="-Xpreprocessor -fopenmp -I$OMP_INCLUDE_DIR"
export CFLAGS="-I$OMP_INCLUDE_DIR -ffp-contract=off"
@@ -74,4 +74,3 @@ python -m pip install --upgrade pip
# Install cibuildwheel and build wheels
python -m pip install --upgrade cibuildwheel
python -m cibuildwheel --output-dir wheelhouse
-
diff --git a/classifiers.txt b/classifiers.txt
index 442896a..547e84a 100644
--- a/classifiers.txt
+++ b/classifiers.txt
@@ -1,4 +1,3 @@
Programming Language :: Python :: 3
License :: OSI Approved :: MIT License
Operating System :: OS Independent
-
diff --git a/dataset_loader.py b/dataset_loader.py
index d900224..eb8437e 100644
--- a/dataset_loader.py
+++ b/dataset_loader.py
@@ -1,5 +1,6 @@
from datasets import load_dataset
+
def load_datasets(seed=42):
def shuffled_load(path, *args, **kwargs):
return load_dataset(path, *args, **kwargs).shuffle(seed)
@@ -8,42 +9,54 @@ def shuffled_load(path, *args, **kwargs):
"train": {
# NLI (Natural Language Inference) datasets
"all-nli": shuffled_load("sentence-transformers/all-nli", "triplet", split="train"),
- "nli-for-simcse": shuffled_load("sentence-transformers/nli-for-simcse", "triplet", split="train"),
-
+ "nli-for-simcse": shuffled_load(
+ "sentence-transformers/nli-for-simcse", "triplet", split="train"
+ ),
# Information Retrieval datasets
- "msmarco": shuffled_load("sentence-transformers/msmarco-bm25", "triplet", split="train"),
+ "msmarco": shuffled_load(
+ "sentence-transformers/msmarco-bm25", "triplet", split="train"
+ ),
"mr-tydi": shuffled_load("sentence-transformers/mr-tydi", "en-triplet", split="train"),
-
# Text Summarization / Compression datasets
- "compression": shuffled_load("sentence-transformers/sentence-compression", split="train"),
+ "compression": shuffled_load(
+ "sentence-transformers/sentence-compression", split="train"
+ ),
"simple-wiki": shuffled_load("sentence-transformers/simple-wiki", split="train"),
-
# News datasets
"agnews": shuffled_load("sentence-transformers/agnews", split="train"),
"ccnews": shuffled_load("sentence-transformers/ccnews", split="train"),
"npr": shuffled_load("sentence-transformers/npr", split="train"),
-
# Question Answering (QA) datasets
"gooaq": shuffled_load("sentence-transformers/gooaq", split="train"),
- "yahoo-answers": shuffled_load("sentence-transformers/yahoo-answers", "title-question-answer-pair", split="train"),
+ "yahoo-answers": shuffled_load(
+ "sentence-transformers/yahoo-answers", "title-question-answer-pair", split="train"
+ ),
"eli5": shuffled_load("sentence-transformers/eli5", split="train"),
"amazon-qa": shuffled_load("sentence-transformers/amazon-qa", split="train[0:1000000]"),
"squad": shuffled_load("sentence-transformers/squad", split="train"),
- "natural_questions": shuffled_load("sentence-transformers/natural-questions", split="train"),
+ "natural_questions": shuffled_load(
+ "sentence-transformers/natural-questions", split="train"
+ ),
"hotpotqa": shuffled_load("sentence-transformers/hotpotqa", "triplet", split="train"),
-
# Duplicate Detection datasets
- "quora_duplicates": shuffled_load("sentence-transformers/quora-duplicates", "pair", split="train"),
- "quora_triplets": shuffled_load("sentence-transformers/quora-duplicates", "triplet", split="train"),
-
+ "quora_duplicates": shuffled_load(
+ "sentence-transformers/quora-duplicates", "pair", split="train"
+ ),
+ "quora_triplets": shuffled_load(
+ "sentence-transformers/quora-duplicates", "triplet", split="train"
+ ),
# Scientific / Academic datasets
"specter": shuffled_load("sentence-transformers/specter", "triplet", split="train"),
-
# Stack Exchange datasets
- "stackexchange_bbp": shuffled_load("sentence-transformers/stackexchange-duplicates", "body-body-pair", split="train"),
- "stackexchange_ttp": shuffled_load("sentence-transformers/stackexchange-duplicates", "title-title-pair", split="train"),
- "stackexchange_ppp": shuffled_load("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train"),
-
+ "stackexchange_bbp": shuffled_load(
+ "sentence-transformers/stackexchange-duplicates", "body-body-pair", split="train"
+ ),
+ "stackexchange_ttp": shuffled_load(
+ "sentence-transformers/stackexchange-duplicates", "title-title-pair", split="train"
+ ),
+ "stackexchange_ppp": shuffled_load(
+ "sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train"
+ ),
# Lexical / Linguistic datasets
"altlex": shuffled_load("sentence-transformers/altlex", split="train"),
},
diff --git a/eval_mteb.py b/eval_mteb.py
index 13907b6..1891e32 100644
--- a/eval_mteb.py
+++ b/eval_mteb.py
@@ -1,5 +1,6 @@
# ruff: noqa: E402
from __future__ import annotations
+
import os
# Set environment variables
@@ -8,17 +9,17 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
-import mteb
import logging
-
from functools import partial
-from typing import Any, List
-from wordllama import load_training, Config
+from typing import Any
+
+import mteb
import numpy as np
from more_itertools import chunked
-
from mteb.model_meta import ModelMeta
+from wordllama import Config, load_training
+
logger = logging.getLogger(__name__)
@@ -109,23 +110,18 @@
class WordLlamaWrapper:
- def __init__(
- self, model_name: str, config, embed_dim: int | None = None, **kwargs
- ) -> None:
+ def __init__(self, model_name: str, config, embed_dim: int | None = None, **kwargs) -> None:
self._model_name = model_name
self._embed_dim = embed_dim
print(model_name)
self.model = load_training(model_name, config, dims=embed_dim).to("cuda")
- def encode(self, sentences: List[str], batch_size=512, **kwargs: Any) -> np.ndarray:
+ def encode(self, sentences: list[str], batch_size=512, **kwargs: Any) -> np.ndarray:
all_embeddings = []
for chunk in chunked(sentences, batch_size):
embed_chunk = (
- self.model.embed(chunk, return_pt=True, norm=True)
- .to("cpu")
- .detach()
- .numpy()
+ self.model.embed(chunk, return_pt=True, norm=True).to("cpu").detach().numpy()
)
all_embeddings.append(embed_chunk)
@@ -135,7 +131,6 @@ def encode(self, sentences: List[str], batch_size=512, **kwargs: Any) -> np.ndar
if __name__ == "__main__":
-
TASK_LIST = (
TASK_LIST_CLASSIFICATION
+ TASK_LIST_CLUSTERING
diff --git a/find_mteb.sh b/find_mteb.sh
index 9224659..b4bc79a 100755
--- a/find_mteb.sh
+++ b/find_mteb.sh
@@ -43,12 +43,12 @@ while IFS= read -r line; do
filename=$(echo "$line" | cut -d' ' -f1)
score=$(echo "$line" | cut -d' ' -f2)
subset=$(echo "$line" | cut -d' ' -f3)
-
+
echo "$filename $score $subset"
-
+
# Remove .json extension for matching
task_name=${filename%.json}
-
+
for category in "${!TASK_LISTS[@]}"; do
if [[ "${TASK_LISTS[$category]}" == *"$task_name"* ]]; then
eval "${category}_scores+=(${score})"
@@ -63,8 +63,8 @@ while IFS= read -r line; do
fi
done
done < <(find "$search_dir" -type f -name "*.json" -print0 | xargs -0 -I {} sh -c '
-jq -r ".. | objects | select((.hf_subset? == \"en\" or .hf_subset? == \"default\") and .main_score?) | \"\(.main_score) \(.hf_subset)\"" {} |
-while read score subset; do
+jq -r ".. | objects | select((.hf_subset? == \"en\" or .hf_subset? == \"default\") and .main_score?) | \"\(.main_score) \(.hf_subset)\"" {} |
+while read score subset; do
formatted_score=$(printf "%.2f" $(echo "$score * 100" | bc))
echo "$(basename {}) $formatted_score $subset"
done
@@ -86,7 +86,7 @@ print_results() {
local scores=("${!1}")
local files=("${!2}")
local name=$3
-
+
if [ ${#scores[@]} -ne 0 ]; then
average=$(calculate_average "$1")
echo ""
diff --git a/pyproject.toml b/pyproject.toml
index 0a7ca4e..a42b27d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ dynamic = ["version"]
description = "WordLlama NLP Utility"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
-requires-python = ">=3.8"
+requires-python = ">=3.9"
authors = [{ name = "Lee Miller", email = "dleemiller@gmail.com" }]
dependencies = [
"numpy>=2",
@@ -20,6 +20,17 @@ dependencies = [
]
[project.optional-dependencies]
+dev = [
+ "setuptools",
+ "wheel",
+ "setuptools_scm[toml]",
+ "Cython",
+ "pytest>=7.4",
+ "pytest-cov>=4.1",
+ "pytest-xdist>=3.3",
+ "ruff>=0.1.0",
+ "pre-commit>=3.5",
+]
train = [
"accelerate",
"torch>=2",
@@ -37,7 +48,7 @@ Repository = "https://github.com/dleemiller/WordLlama"
include-package-data = true # Ensures non-code files are included in the package
[tool.setuptools.packages.find]
-where = ["."]
+where = ["src"]
include = ["wordllama*"]
[tool.setuptools.package-data]
@@ -57,7 +68,7 @@ wordllama = [
classifiers = { file = "classifiers.txt" }
[tool.setuptools_scm]
-write_to = "wordllama/_version.py"
+write_to = "src/wordllama/_version.py"
version_scheme = "post-release"
local_scheme = "no-local-version"
@@ -74,7 +85,7 @@ print(similarity_score);"
[tool.cibuildwheel.macos]
# before-all = """
# brew install openblas libomp llvm
-#
+#
# # Create a NumPy site.cfg so that it detects your Homebrew-installed OpenBLAS
# cat > ~/.numpy-site.cfg <= -t) & (x < -interval)
tmp[mask2] = (1 / o) * torch.pow(-x[mask2], (1 - o) / o)
tmp[(x <= interval) & (x >= 0)] = approximate_function(interval, o) / interval
- tmp[(x <= 0) & (x >= -interval)] = (
- -approximate_function(-interval, o) / interval
- )
+ tmp[(x <= 0) & (x >= -interval)] = -approximate_function(-interval, o) / interval
# calculate the final gradient
grad_x = tmp * grad_output.clone()
diff --git a/wordllama/adapters/mlp.py b/src/wordllama/adapters/mlp.py
similarity index 100%
rename from wordllama/adapters/mlp.py
rename to src/wordllama/adapters/mlp.py
index bb9794e..94497e2 100644
--- a/wordllama/adapters/mlp.py
+++ b/src/wordllama/adapters/mlp.py
@@ -1,7 +1,7 @@
import os
-from torch import nn
import safetensors.torch as st
+from torch import nn
class MLP(nn.Module):
diff --git a/wordllama/adapters/projector.py b/src/wordllama/adapters/projector.py
similarity index 99%
rename from wordllama/adapters/projector.py
rename to src/wordllama/adapters/projector.py
index 502e701..821e550 100644
--- a/wordllama/adapters/projector.py
+++ b/src/wordllama/adapters/projector.py
@@ -1,6 +1,7 @@
import os
-from torch import nn
+
import safetensors.torch as st
+from torch import nn
class Projector(nn.Module):
diff --git a/wordllama/adapters/weighted_mlp.py b/src/wordllama/adapters/weighted_mlp.py
similarity index 91%
rename from wordllama/adapters/weighted_mlp.py
rename to src/wordllama/adapters/weighted_mlp.py
index 5fcbee3..4c3e621 100644
--- a/wordllama/adapters/weighted_mlp.py
+++ b/src/wordllama/adapters/weighted_mlp.py
@@ -1,10 +1,10 @@
import os
-from nltk.corpus import stopwords
+import safetensors.torch as st
import torch
-from torch import nn
import torch.nn.functional as F
-import safetensors.torch as st
+from nltk.corpus import stopwords
+from torch import nn
class WeightedMLP(nn.Module):
@@ -22,9 +22,7 @@ def __init__(self, in_dim, out_dim, hidden_dim, tokenizer, dropout=0.1):
n_vocab = len(tokenizer.vocab)
self.weights = nn.Parameter(torch.ones(n_vocab))
stopword_list = set(stopwords.words("english"))
- stopword_ids = [
- tokenizer.vocab[word] for word in stopword_list if word in tokenizer.vocab
- ]
+ stopword_ids = [tokenizer.vocab[word] for word in stopword_list if word in tokenizer.vocab]
with torch.no_grad():
for stopword_id in stopword_ids:
self.weights[stopword_id] = 0.1
@@ -34,9 +32,7 @@ def forward(self, tensors) -> dict:
weights = F.gelu(
self.weights[token_ids]
) # use gelu to limit negative contribution of weights
- weighted_embeddings = self.mlp(tensors["token_embeddings"]) * weights.unsqueeze(
- -1
- )
+ weighted_embeddings = self.mlp(tensors["token_embeddings"]) * weights.unsqueeze(-1)
tensors.update({"x": weighted_embeddings})
return tensors
diff --git a/wordllama/adapters/weighted_projector.py b/src/wordllama/adapters/weighted_projector.py
similarity index 87%
rename from wordllama/adapters/weighted_projector.py
rename to src/wordllama/adapters/weighted_projector.py
index 879e59e..812700e 100644
--- a/wordllama/adapters/weighted_projector.py
+++ b/src/wordllama/adapters/weighted_projector.py
@@ -1,10 +1,10 @@
import os
-from nltk.corpus import stopwords
+import safetensors.torch as st
import torch
-from torch import nn
import torch.nn.functional as F
-import safetensors.torch as st
+from nltk.corpus import stopwords
+from torch import nn
class WeightedProjector(nn.Module):
@@ -17,9 +17,7 @@ def __init__(self, in_dim, out_dim, tokenizer, n_vocab, key="token_embeddings"):
# Initialize stopword weights to a lower value
self.weights = nn.Parameter(torch.ones(n_vocab))
stopword_list = set(filter(lambda x: len(x) > 1, stopwords.words("english")))
- stopword_ids = [
- tokenizer.vocab[word] for word in stopword_list if word in tokenizer.vocab
- ]
+ stopword_ids = [tokenizer.vocab[word] for word in stopword_list if word in tokenizer.vocab]
print(f"Num stopword ids: {len(stopword_ids)}")
with torch.no_grad():
for stopword_id in stopword_ids:
@@ -30,9 +28,7 @@ def forward(self, tensors) -> dict:
weights = F.gelu(
self.weights[token_ids]
) # use gelu to limit negative contribution of weights
- weighted_embeddings = self.proj(
- tensors["token_embeddings"] * weights.unsqueeze(-1)
- )
+ weighted_embeddings = self.proj(tensors["token_embeddings"] * weights.unsqueeze(-1))
tensors.update({"x": weighted_embeddings})
return tensors
diff --git a/wordllama/algorithms/__init__.py b/src/wordllama/algorithms/__init__.py
similarity index 59%
rename from wordllama/algorithms/__init__.py
rename to src/wordllama/algorithms/__init__.py
index dd772eb..7090bce 100644
--- a/wordllama/algorithms/__init__.py
+++ b/src/wordllama/algorithms/__init__.py
@@ -1,7 +1,7 @@
-from .kmeans import kmeans_clustering
-from .vector_similarity import vector_similarity, binarize_and_packbits
from .deduplicate_helpers import deduplicate_embeddings
-from .splitter import split_sentences, constrained_batches, constrained_coalesce
+from .kmeans import kmeans_clustering
+from .splitter import constrained_batches, constrained_coalesce, split_sentences
+from .vector_similarity import binarize_and_packbits, vector_similarity
__all__ = [
"kmeans_clustering",
@@ -10,5 +10,5 @@
"deduplicate_embeddings",
"split_sentences",
"constrained_batches",
- "constrained_coalesce"
+ "constrained_coalesce",
]
diff --git a/wordllama/algorithms/deduplicate_helpers.pyx b/src/wordllama/algorithms/deduplicate_helpers.pyx
similarity index 100%
rename from wordllama/algorithms/deduplicate_helpers.pyx
rename to src/wordllama/algorithms/deduplicate_helpers.pyx
diff --git a/wordllama/algorithms/find_local_minima.pyx b/src/wordllama/algorithms/find_local_minima.pyx
similarity index 99%
rename from wordllama/algorithms/find_local_minima.pyx
rename to src/wordllama/algorithms/find_local_minima.pyx
index 113b639..676c7ac 100644
--- a/wordllama/algorithms/find_local_minima.pyx
+++ b/src/wordllama/algorithms/find_local_minima.pyx
@@ -87,7 +87,7 @@ cpdef tuple find_local_minima(np.ndarray[DTYPE_t, ndim=1] y, int window_size=11,
# Ensure that y is a float32 array
y_float = np.asarray(y, dtype=np.float32)
-
+
return _find_local_minima_impl(y_float, window_size, poly_order)
cdef tuple _find_local_minima_impl(DTYPE_t[:] y, int window_size, int poly_order):
@@ -110,7 +110,7 @@ cdef tuple _find_local_minima_impl(DTYPE_t[:] y, int window_size, int poly_order
# Precompute Savitzky-Golay coefficients
cdef np.ndarray[DTYPE_t, ndim=2] coeffs = compute_savitzky_golay_coeffs(window_size, poly_order)
-
+
# Apply the filter for the first and second derivatives
cdef np.ndarray[DTYPE_t, ndim=1] dy = apply_savitzky_golay_filter(coeffs, y, deriv=1)
cdef np.ndarray[DTYPE_t, ndim=1] ddy = apply_savitzky_golay_filter(coeffs, y, deriv=2)
@@ -154,11 +154,11 @@ cpdef np.ndarray[DTYPE_t, ndim=1] windowed_cross_similarity(np.ndarray[DTYPE_t,
"""
cdef int n = embeddings.shape[0]
-
+
# Ensure the window size is odd and >= 3
if window_size < 3 or window_size % 2 == 0:
raise ValueError("Window size must be odd and >= 3")
-
+
cdef int half_window = window_size // 2
# Pre-allocate the output array for storing windowed averages
diff --git a/wordllama/algorithms/kmeans.pyx b/src/wordllama/algorithms/kmeans.pyx
similarity index 99%
rename from wordllama/algorithms/kmeans.pyx
rename to src/wordllama/algorithms/kmeans.pyx
index 3a48a2e..9c99630 100644
--- a/wordllama/algorithms/kmeans.pyx
+++ b/src/wordllama/algorithms/kmeans.pyx
@@ -83,7 +83,7 @@ cdef kmeans_plusplus_initialization(float[:, :] embeddings, int k, object random
cumulative_probabilities = np.cumsum(probabilities)
r = random_state.rand()
index = np.searchsorted(cumulative_probabilities, r)
-
+
for j in range(n_features):
centroids_view[i, j] = embeddings[index, j]
@@ -115,10 +115,10 @@ cdef _kmeans_single(np.ndarray[FLOAT_t, ndim=2] X, int k, int min_iterations, in
prev_inertia = inertia
inertia = np.sum(np.min(distances, axis=1) ** 2)
-
+
if iteration >= min_iterations - 1 and abs(prev_inertia - inertia) < tolerance:
break
-
+
centers = new_centers
end = clock()
@@ -149,7 +149,7 @@ def kmeans_clustering(np.ndarray[FLOAT_t, ndim=2] X, int k, int n_init=10, int m
for i in range(n_init):
labels, centers, inertia, n_iterations, time_taken = _kmeans_single(X, k, min_iterations, max_iterations, tolerance, random_state)
logger.info(f"Initialization {i + 1}/{n_init}: Inertia = {inertia:.2f}, Iterations = {n_iterations}, Time = {time_taken:.2f} seconds")
-
+
if inertia < best_inertia:
best_labels = labels
best_centers = centers
diff --git a/wordllama/algorithms/semantic_splitter.py b/src/wordllama/algorithms/semantic_splitter.py
similarity index 84%
rename from wordllama/algorithms/semantic_splitter.py
rename to src/wordllama/algorithms/semantic_splitter.py
index 3490416..c548054 100644
--- a/wordllama/algorithms/semantic_splitter.py
+++ b/src/wordllama/algorithms/semantic_splitter.py
@@ -1,12 +1,13 @@
-import numpy as np
-from typing import List, Tuple, Union
from itertools import chain
+from typing import Union
+
+import numpy as np
+
from .find_local_minima import find_local_minima, windowed_cross_similarity
from .splitter import (
- constrained_batches,
constrained_coalesce,
- split_sentences,
reverse_merge,
+ split_sentences,
)
@@ -19,7 +20,7 @@ class SemanticSplitter:
"""
@staticmethod
- def flatten(nested_list: List[List[any]]) -> List[any]:
+ def flatten(nested_list: list[list[any]]) -> list[any]:
"""
Flatten a list of lists into a single list.
@@ -37,7 +38,7 @@ def constrained_split(
target_size: int,
separator: str = " ",
min_size: int = 24,
- ) -> List[str]:
+ ) -> list[str]:
"""
Split text into chunks of approximately target_size.
@@ -64,7 +65,7 @@ def split(
target_size: int,
cleanup_size: int = 24,
intermediate_size: int = 96,
- ) -> List[str]:
+ ) -> list[str]:
"""
Split the input text into chunks based on semantic coherence.
@@ -77,21 +78,17 @@ def split(
Returns:
List[str]: List of text chunks.
"""
- assert (
- target_size > intermediate_size
- ), "Target size must be larger than intermediate size."
- assert (
- intermediate_size > cleanup_size
- ), "Intermediate size must be larger than cleanup size."
+ assert target_size > intermediate_size, "Target size must be larger than intermediate size."
+ assert intermediate_size > cleanup_size, (
+ "Intermediate size must be larger than cleanup size."
+ )
lines = text.splitlines()
lines = constrained_coalesce(lines, intermediate_size, separator="\n")
lines = reverse_merge(lines, n=cleanup_size, separator="\n")
chunks = [
- cls.constrained_split(
- line, target_size, min_size=cleanup_size, separator=" "
- )
+ cls.constrained_split(line, target_size, min_size=cleanup_size, separator=" ")
if len(line) > target_size
else [line]
for line in lines
@@ -103,7 +100,7 @@ def split(
@classmethod
def reconstruct(
cls,
- lines: List[str],
+ lines: list[str],
norm_embed: np.ndarray,
target_size: int,
window_size: int,
@@ -111,7 +108,7 @@ def reconstruct(
savgol_window: int,
max_score_pct: float = 0.4,
return_minima: bool = False,
- ) -> Union[List[str], Tuple[np.ndarray, np.ndarray, np.ndarray]]:
+ ) -> Union[list[str], tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
Reconstruct text chunks based on semantic similarity.
@@ -133,14 +130,10 @@ def reconstruct(
Raises:
AssertionError: If the number of texts doesn't equal the number of embeddings.
"""
- assert (
- len(lines) == norm_embed.shape[0]
- ), "Number of texts must equal number of embeddings"
+ assert len(lines) == norm_embed.shape[0], "Number of texts must equal number of embeddings"
sim_avg = windowed_cross_similarity(norm_embed, window_size)
- roots, y = find_local_minima(
- sim_avg, poly_order=poly_order, window_size=savgol_window
- )
+ roots, y = find_local_minima(sim_avg, poly_order=poly_order, window_size=savgol_window)
if return_minima:
return roots, y, sim_avg
diff --git a/wordllama/algorithms/splitter.pyx b/src/wordllama/algorithms/splitter.pyx
similarity index 99%
rename from wordllama/algorithms/splitter.pyx
rename to src/wordllama/algorithms/splitter.pyx
index 7d33109..5316da4 100644
--- a/wordllama/algorithms/splitter.pyx
+++ b/src/wordllama/algorithms/splitter.pyx
@@ -69,7 +69,7 @@ def split_sentences(str text not None, set punct_chars=None):
cdef cset[Py_UCS4] punct_chars_c
if punct_chars is None:
- punct_chars = {'.', '!', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹', '।', '॥', '၊', '။', '።', '፧', '፨',
+ punct_chars = {'.', '!', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹', '।', '॥', '၊', '။', '።', '፧', '፨',
'᙮', '᜵', '᜶', '᠃', '᠉', '᥄', '᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟',
'᰻', '᰼', '᱾', '᱿', '‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳',
'꛷', '꡶', '꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒',
@@ -131,7 +131,7 @@ cdef tuple _combine_pass(list iterable, Py_ssize_t max_size, str separator):
"""
A single pass to combine successive pairs of items if their combined size doesn't
exceed max_size. The separator used for joining pairs can be configured.
-
+
Returns:
tuple: A new list of coalesced items and the number of changes made.
"""
diff --git a/wordllama/algorithms/vector_similarity.pxd b/src/wordllama/algorithms/vector_similarity.pxd
similarity index 100%
rename from wordllama/algorithms/vector_similarity.pxd
rename to src/wordllama/algorithms/vector_similarity.pxd
diff --git a/wordllama/algorithms/vector_similarity.pyx b/src/wordllama/algorithms/vector_similarity.pyx
similarity index 99%
rename from wordllama/algorithms/vector_similarity.pyx
rename to src/wordllama/algorithms/vector_similarity.pyx
index 232dfea..1572135 100644
--- a/wordllama/algorithms/vector_similarity.pyx
+++ b/src/wordllama/algorithms/vector_similarity.pyx
@@ -95,7 +95,7 @@ cpdef object vector_similarity(
max_distance = a_binary.shape[1] * 64
if max_distance == 0:
raise ValueError("Binary vectors must have at least one bit")
-
+
# convert to similarity
similarity = 1.0 - 2.0 * (dist / max_distance).astype(np.float32)
diff --git a/wordllama/config/__init__.py b/src/wordllama/config/__init__.py
similarity index 89%
rename from wordllama/config/__init__.py
rename to src/wordllama/config/__init__.py
index a2200ec..cae909b 100644
--- a/wordllama/config/__init__.py
+++ b/src/wordllama/config/__init__.py
@@ -1,9 +1,10 @@
-import toml
from pathlib import Path
+from typing import Optional
+
+import toml
from pydantic import BaseModel
-from typing import List, Dict, Optional
-from .models import ModelURI, WordLlamaModels, Model2VecModels
+from .models import Model2VecModels, ModelURI, WordLlamaModels
class TokenizerInferenceConfig(BaseModel):
@@ -37,7 +38,7 @@ class TrainingConfig(BaseModel):
class MatryoshkaConfig(BaseModel):
- dims: List[int]
+ dims: list[int]
class WordLlamaModel(BaseModel):
@@ -57,7 +58,7 @@ class WordLlamaConfig(BaseModel):
class Config:
- _configurations: Dict[str, WordLlamaConfig] = {}
+ _configurations: dict[str, WordLlamaConfig] = {}
@classmethod
def setup(cls):
@@ -67,7 +68,7 @@ def setup(cls):
setattr(cls, config_name, config) # Set as class attributes for easy access
@staticmethod
- def load_configurations() -> Dict[str, WordLlamaConfig]:
+ def load_configurations() -> dict[str, WordLlamaConfig]:
"""Load configurations from TOML files within the same directory as this script."""
config_dir = Path(__file__).parent / "train"
configs = {}
diff --git a/wordllama/config/models.py b/src/wordllama/config/models.py
similarity index 90%
rename from wordllama/config/models.py
rename to src/wordllama/config/models.py
index d08d066..1d9551e 100644
--- a/wordllama/config/models.py
+++ b/src/wordllama/config/models.py
@@ -1,12 +1,12 @@
from dataclasses import dataclass
-from typing import List, Optional
+from typing import Optional
@dataclass
class ModelURI:
repo_id: str
- available_dims: List[int]
- binary_dims: List[int]
+ available_dims: list[int]
+ binary_dims: list[int]
tokenizer_config: Optional[str]
remote_filename: Optional[str] = None
remote_tokenizer_filename: Optional[str] = None
@@ -15,18 +15,15 @@ class ModelURI:
class WordLlamaModels:
-
@classmethod
- def list_configs(cls) -> List[str]:
+ def list_configs(cls) -> list[str]:
"""
Return a list of configuration names defined as `ModelURI` instances in the class.
Returns:
List[str]: A list of configuration attribute names.
"""
- return [
- name for name, value in cls.__dict__.items() if isinstance(value, ModelURI)
- ]
+ return [name for name, value in cls.__dict__.items() if isinstance(value, ModelURI)]
l2_supercat = ModelURI(
repo_id="dleemiller/word-llama-l2-supercat",
@@ -46,18 +43,15 @@ def list_configs(cls) -> List[str]:
class Model2VecModels:
-
@classmethod
- def list_configs(cls) -> List[str]:
+ def list_configs(cls) -> list[str]:
"""
Return a list of configuration names defined as `ModelURI` instances in the class.
Returns:
List[str]: A list of configuration attribute names.
"""
- return [
- name for name, value in cls.__dict__.items() if isinstance(value, ModelURI)
- ]
+ return [name for name, value in cls.__dict__.items() if isinstance(value, ModelURI)]
potion_base_8m = ModelURI(
repo_id="minishlab/potion-base-8M",
diff --git a/wordllama/config/train/command_rplus.toml b/src/wordllama/config/train/command_rplus.toml
similarity index 99%
rename from wordllama/config/train/command_rplus.toml
rename to src/wordllama/config/train/command_rplus.toml
index 9f21018..42399cf 100644
--- a/wordllama/config/train/command_rplus.toml
+++ b/src/wordllama/config/train/command_rplus.toml
@@ -28,4 +28,3 @@ binarizer_ste = "tanh"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/dbrx.toml b/src/wordllama/config/train/dbrx.toml
similarity index 99%
rename from wordllama/config/train/dbrx.toml
rename to src/wordllama/config/train/dbrx.toml
index f5d9e16..322322c 100644
--- a/wordllama/config/train/dbrx.toml
+++ b/src/wordllama/config/train/dbrx.toml
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/deberta_v3_large.toml b/src/wordllama/config/train/deberta_v3_large.toml
similarity index 95%
rename from wordllama/config/train/deberta_v3_large.toml
rename to src/wordllama/config/train/deberta_v3_large.toml
index bbd9e1d..970d946 100644
--- a/wordllama/config/train/deberta_v3_large.toml
+++ b/src/wordllama/config/train/deberta_v3_large.toml
@@ -3,7 +3,7 @@ dim = 1024
n_vocab = 128100
hf_model_id = "microsoft/deberta-v3-large"
is_encoder = true
-pad_token = "<|end_of_text|>"
+pad_token = "<|end_of_text|>"
[tokenizer]
return_tensors = "pt"
@@ -29,4 +29,3 @@ binarizer_ste = "tanh"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/deepseekv2.toml b/src/wordllama/config/train/deepseekv2.toml
similarity index 99%
rename from wordllama/config/train/deepseekv2.toml
rename to src/wordllama/config/train/deepseekv2.toml
index 07116b9..b3036b1 100644
--- a/wordllama/config/train/deepseekv2.toml
+++ b/src/wordllama/config/train/deepseekv2.toml
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/gemma2_27B.toml b/src/wordllama/config/train/gemma2_27B.toml
similarity index 99%
rename from wordllama/config/train/gemma2_27B.toml
rename to src/wordllama/config/train/gemma2_27B.toml
index 69aa380..66dbcb9 100644
--- a/wordllama/config/train/gemma2_27B.toml
+++ b/src/wordllama/config/train/gemma2_27B.toml
@@ -28,4 +28,3 @@ binarizer_ste = "tanh"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/l2_supercat.toml b/src/wordllama/config/train/l2_supercat.toml
similarity index 99%
rename from wordllama/config/train/l2_supercat.toml
rename to src/wordllama/config/train/l2_supercat.toml
index 0ac4f88..bd2b14d 100644
--- a/wordllama/config/train/l2_supercat.toml
+++ b/src/wordllama/config/train/l2_supercat.toml
@@ -32,4 +32,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/l2p3.toml b/src/wordllama/config/train/l2p3.toml
similarity index 96%
rename from wordllama/config/train/l2p3.toml
rename to src/wordllama/config/train/l2p3.toml
index 28a8eab..dd1adea 100644
--- a/wordllama/config/train/l2p3.toml
+++ b/src/wordllama/config/train/l2p3.toml
@@ -2,7 +2,7 @@
dim = 13312
n_vocab = 32000
hf_model_id = "meta-llama/Llama-2-70b-hf"
-pad_token = ""
+pad_token = ""
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/l2p3_lg.toml b/src/wordllama/config/train/l2p3_lg.toml
similarity index 96%
rename from wordllama/config/train/l2p3_lg.toml
rename to src/wordllama/config/train/l2p3_lg.toml
index f6c2f32..a814346 100644
--- a/wordllama/config/train/l2p3_lg.toml
+++ b/src/wordllama/config/train/l2p3_lg.toml
@@ -2,7 +2,7 @@
dim = 17408
n_vocab = 32000
hf_model_id = "meta-llama/Llama-2-70b-hf"
-pad_token = ""
+pad_token = ""
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/l3_supercat.toml b/src/wordllama/config/train/l3_supercat.toml
similarity index 95%
rename from wordllama/config/train/l3_supercat.toml
rename to src/wordllama/config/train/l3_supercat.toml
index 0fe7c82..8f2df63 100644
--- a/wordllama/config/train/l3_supercat.toml
+++ b/src/wordllama/config/train/l3_supercat.toml
@@ -2,7 +2,7 @@
dim = 28672
n_vocab = 128256
hf_model_id = "meta-llama/Meta-Llama-3.1-405B"
-pad_token = "<|end_of_text|>"
+pad_token = "<|end_of_text|>"
[tokenizer]
return_tensors = "pt"
@@ -32,4 +32,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/llama2_70B.toml b/src/wordllama/config/train/llama2_70B.toml
similarity index 96%
rename from wordllama/config/train/llama2_70B.toml
rename to src/wordllama/config/train/llama2_70B.toml
index 0f9fc46..af8e9f6 100644
--- a/wordllama/config/train/llama2_70B.toml
+++ b/src/wordllama/config/train/llama2_70B.toml
@@ -2,7 +2,7 @@
dim = 8192
n_vocab = 32000
hf_model_id = "meta-llama/Llama-2-70b-hf"
-pad_token = ""
+pad_token = ""
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/llama3_70B.toml b/src/wordllama/config/train/llama3_70B.toml
similarity index 94%
rename from wordllama/config/train/llama3_70B.toml
rename to src/wordllama/config/train/llama3_70B.toml
index 0b31c7b..fe67067 100644
--- a/wordllama/config/train/llama3_70B.toml
+++ b/src/wordllama/config/train/llama3_70B.toml
@@ -2,7 +2,7 @@
dim = 8192
n_vocab = 128256
hf_model_id = "meta-llama/Meta-Llama-3-70B"
-pad_token = "<|end_of_text|>"
+pad_token = "<|end_of_text|>"
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/llama3_8B.toml b/src/wordllama/config/train/llama3_8B.toml
similarity index 99%
rename from wordllama/config/train/llama3_8B.toml
rename to src/wordllama/config/train/llama3_8B.toml
index 0ce7baa..88f42f0 100644
--- a/wordllama/config/train/llama3_8B.toml
+++ b/src/wordllama/config/train/llama3_8B.toml
@@ -28,4 +28,3 @@ binarizer_ste = "tanh"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/llamaguard.toml b/src/wordllama/config/train/llamaguard.toml
similarity index 96%
rename from wordllama/config/train/llamaguard.toml
rename to src/wordllama/config/train/llamaguard.toml
index 04dabba..dcd89a8 100644
--- a/wordllama/config/train/llamaguard.toml
+++ b/src/wordllama/config/train/llamaguard.toml
@@ -2,7 +2,7 @@
dim = 4096
n_vocab = 32000
hf_model_id = "meta-llama/LlamaGuard-7b"
-pad_token = ""
+pad_token = ""
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/miqu.toml b/src/wordllama/config/train/miqu.toml
similarity index 96%
rename from wordllama/config/train/miqu.toml
rename to src/wordllama/config/train/miqu.toml
index 3858e06..46ae966 100644
--- a/wordllama/config/train/miqu.toml
+++ b/src/wordllama/config/train/miqu.toml
@@ -2,7 +2,7 @@
dim = 8192
n_vocab = 32000
hf_model_id = "152334H/miqu-1-70b-sf"
-pad_token = ""
+pad_token = ""
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/mixtral.toml b/src/wordllama/config/train/mixtral.toml
similarity index 99%
rename from wordllama/config/train/mixtral.toml
rename to src/wordllama/config/train/mixtral.toml
index 52c1084..1591485 100644
--- a/wordllama/config/train/mixtral.toml
+++ b/src/wordllama/config/train/mixtral.toml
@@ -28,4 +28,3 @@ binarizer_ste = "tanh"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/mixtral_8x22B.toml b/src/wordllama/config/train/mixtral_8x22B.toml
similarity index 99%
rename from wordllama/config/train/mixtral_8x22B.toml
rename to src/wordllama/config/train/mixtral_8x22B.toml
index f55f278..61368e0 100644
--- a/wordllama/config/train/mixtral_8x22B.toml
+++ b/src/wordllama/config/train/mixtral_8x22B.toml
@@ -28,4 +28,3 @@ binarizer_ste = "tanh"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/openelm_3B.toml b/src/wordllama/config/train/openelm_3B.toml
similarity index 96%
rename from wordllama/config/train/openelm_3B.toml
rename to src/wordllama/config/train/openelm_3B.toml
index eef1f75..632972f 100644
--- a/wordllama/config/train/openelm_3B.toml
+++ b/src/wordllama/config/train/openelm_3B.toml
@@ -2,7 +2,7 @@
dim = 3072
n_vocab = 32000
hf_model_id = "apple/OpenELM-3B"
-pad_token = ""
+pad_token = ""
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/phi3_medium.toml b/src/wordllama/config/train/phi3_medium.toml
similarity index 95%
rename from wordllama/config/train/phi3_medium.toml
rename to src/wordllama/config/train/phi3_medium.toml
index d98b3cd..81987af 100644
--- a/wordllama/config/train/phi3_medium.toml
+++ b/src/wordllama/config/train/phi3_medium.toml
@@ -2,7 +2,7 @@
dim = 5120
n_vocab = 32064
hf_model_id = "microsoft/Phi-3-medium-4k-instruct"
-pad_token = "<|endoftext|>"
+pad_token = "<|endoftext|>"
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/qwen2_72B.toml b/src/wordllama/config/train/qwen2_72B.toml
similarity index 95%
rename from wordllama/config/train/qwen2_72B.toml
rename to src/wordllama/config/train/qwen2_72B.toml
index f61f246..db7904d 100644
--- a/wordllama/config/train/qwen2_72B.toml
+++ b/src/wordllama/config/train/qwen2_72B.toml
@@ -2,7 +2,7 @@
dim = 8192
n_vocab = 152064
hf_model_id = "Qwen/Qwen2-72B"
-pad_token = "<|endoftext|>"
+pad_token = "<|endoftext|>"
[tokenizer]
return_tensors = "pt"
@@ -28,4 +28,3 @@ binarizer_ste = "ste"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/config/train/yi_1v5_34B.toml b/src/wordllama/config/train/yi_1v5_34B.toml
similarity index 99%
rename from wordllama/config/train/yi_1v5_34B.toml
rename to src/wordllama/config/train/yi_1v5_34B.toml
index cfe18c7..d4ee9fe 100644
--- a/wordllama/config/train/yi_1v5_34B.toml
+++ b/src/wordllama/config/train/yi_1v5_34B.toml
@@ -28,4 +28,3 @@ binarizer_ste = "tanh"
[matryoshka]
dims = [1024, 512, 256, 128, 64]
-
diff --git a/wordllama/embedding/__init__.py b/src/wordllama/embedding/__init__.py
similarity index 100%
rename from wordllama/embedding/__init__.py
rename to src/wordllama/embedding/__init__.py
diff --git a/wordllama/embedding/word_llama_embedding.py b/src/wordllama/embedding/word_llama_embedding.py
similarity index 87%
rename from wordllama/embedding/word_llama_embedding.py
rename to src/wordllama/embedding/word_llama_embedding.py
index 9873bb4..781358a 100644
--- a/wordllama/embedding/word_llama_embedding.py
+++ b/src/wordllama/embedding/word_llama_embedding.py
@@ -1,13 +1,13 @@
+import warnings
+from typing import Union
+
import numpy as np
+import safetensors.torch
import torch
from torch import nn
-import safetensors.torch
-
from transformers import AutoTokenizer
-from typing import Union, List, Dict
from ..adapters import AvgPool
-import warnings
class WordLlamaEmbedding(nn.Module):
@@ -24,9 +24,7 @@ def __init__(self, config, tokenizer_kwargs=None, dims=None):
super().__init__()
self.config = config
model = config.model
- self.embedding = nn.Embedding(
- model.n_vocab, model.dim if dims is None else dims
- )
+ self.embedding = nn.Embedding(model.n_vocab, model.dim if dims is None else dims)
if tokenizer_kwargs:
self.tokenizer_kwargs = tokenizer_kwargs
@@ -44,7 +42,7 @@ def build(cls, filepath, config, dims=None):
safetensors.torch.load_model(word_llama, filepath)
return word_llama
- def forward(self, tensors: Dict[str, torch.Tensor]):
+ def forward(self, tensors: dict[str, torch.Tensor]):
return {
"token_ids": tensors["input_ids"],
"token_embeddings": self.embedding(tensors["input_ids"]),
@@ -61,7 +59,7 @@ def tokenize(self, *args, **kwargs):
@torch.inference_mode()
def embed(
self,
- texts: Union[str, List[str]],
+ texts: Union[str, list[str]],
norm: bool = False,
binarize: bool = False,
pack: bool = True,
@@ -82,14 +80,14 @@ def embed(
# Clamp out-of-bounds input_ids
if tensors["input_ids"].max() >= self.embedding.num_embeddings:
- warnings.warn("Some input_ids are out of bounds. Clamping to valid range.")
- tensors["input_ids"] = tensors["input_ids"].clamp(
- 0, self.embedding.num_embeddings - 1
+ warnings.warn(
+ "Some input_ids are out of bounds. Clamping to valid range.", stacklevel=2
)
+ tensors["input_ids"] = tensors["input_ids"].clamp(0, self.embedding.num_embeddings - 1)
# Check for NaNs in input_ids and replace with 0
if torch.isnan(tensors["input_ids"]).any():
- warnings.warn("NaN values found in input_ids. Replacing NaNs with 0.")
+ warnings.warn("NaN values found in input_ids. Replacing NaNs with 0.", stacklevel=2)
tensors["input_ids"][torch.isnan(tensors["input_ids"])] = 0
# Ensure at least one non-zero value in the attention mask
@@ -97,7 +95,8 @@ def embed(
attention_mask = tensors["attention_mask"]
if (attention_mask.sum(dim=1) == 0).any():
warnings.warn(
- "Some attention masks are all zeros. Setting the first token to 1 for these cases."
+ "Some attention masks are all zeros. Setting the first token to 1 for these cases.",
+ stacklevel=2,
)
attention_mask[attention_mask.sum(dim=1) == 0, 0] = 1
diff --git a/wordllama/extract/__init__.py b/src/wordllama/extract/__init__.py
similarity index 100%
rename from wordllama/extract/__init__.py
rename to src/wordllama/extract/__init__.py
index e3f80eb..00db1b7 100644
--- a/wordllama/extract/__init__.py
+++ b/src/wordllama/extract/__init__.py
@@ -1,4 +1,4 @@
-from .extract_llama_70B import extract_llama_70B
from .extract_hf import extract_from_hf
+from .extract_llama_70B import extract_llama_70B
__all__ = ["extract_from_hf", "extract_llama_70B"]
diff --git a/wordllama/extract/extract_hf.py b/src/wordllama/extract/extract_hf.py
similarity index 93%
rename from wordllama/extract/extract_hf.py
rename to src/wordllama/extract/extract_hf.py
index d882263..2848157 100644
--- a/wordllama/extract/extract_hf.py
+++ b/src/wordllama/extract/extract_hf.py
@@ -1,5 +1,5 @@
import safetensors.torch
-from transformers import AutoModelForCausalLM, AutoModel
+from transformers import AutoModel, AutoModelForCausalLM
from ..config import WordLlamaConfig
from ..embedding.word_llama_embedding import WordLlamaEmbedding
diff --git a/wordllama/extract/extract_llama_70B.py b/src/wordllama/extract/extract_llama_70B.py
similarity index 99%
rename from wordllama/extract/extract_llama_70B.py
rename to src/wordllama/extract/extract_llama_70B.py
index 76f4597..329d38c 100644
--- a/wordllama/extract/extract_llama_70B.py
+++ b/src/wordllama/extract/extract_llama_70B.py
@@ -1,4 +1,5 @@
import os
+
import safetensors.torch
from ..config import Config
diff --git a/wordllama/extract/extract_safetensors.py b/src/wordllama/extract/extract_safetensors.py
similarity index 100%
rename from wordllama/extract/extract_safetensors.py
rename to src/wordllama/extract/extract_safetensors.py
diff --git a/wordllama/inference.py b/src/wordllama/inference.py
similarity index 89%
rename from wordllama/inference.py
rename to src/wordllama/inference.py
index f5159ac..8f51f2b 100644
--- a/wordllama/inference.py
+++ b/src/wordllama/inference.py
@@ -1,16 +1,16 @@
+import logging
+from typing import Callable, Optional, Union
+
import numpy as np
from tokenizers import Tokenizer
-from typing import Callable, List, Optional, Tuple, Union
-import logging
from .algorithms import (
- kmeans_clustering,
- vector_similarity,
binarize_and_packbits,
deduplicate_embeddings,
+ kmeans_clustering,
+ vector_similarity,
)
from .algorithms.semantic_splitter import SemanticSplitter
-from .config import WordLlamaConfig
from .mode_decorators import dense_only
# Set up logging
@@ -41,7 +41,7 @@ def __init__(
self.tokenizer.enable_padding()
self.tokenizer.no_truncation()
- def tokenize(self, texts: Union[str, List[str]]) -> List:
+ def tokenize(self, texts: Union[str, list[str]]) -> list:
"""Tokenize input texts using the configured tokenizer.
Args:
@@ -55,18 +55,16 @@ def tokenize(self, texts: Union[str, List[str]]) -> List:
else:
assert isinstance(texts, list), "Input texts must be str or List[str]"
- return self.tokenizer.encode_batch(
- texts, is_pretokenized=False, add_special_tokens=False
- )
+ return self.tokenizer.encode_batch(texts, is_pretokenized=False, add_special_tokens=False)
def embed(
self,
- texts: Union[str, List[str]],
+ texts: Union[str, list[str]],
norm: bool = False,
return_np: bool = True,
pool_embeddings: bool = True,
batch_size: int = 64,
- ) -> Union[np.ndarray, List]:
+ ) -> Union[np.ndarray, list]:
"""Generate embeddings for input texts with optional normalization and binarization.
Args:
@@ -118,9 +116,7 @@ def embed(
# Normalize embeddings after pooling
if norm:
- batch_embeddings /= np.linalg.norm(
- batch_embeddings, axis=1, keepdims=True
- )
+ batch_embeddings /= np.linalg.norm(batch_embeddings, axis=1, keepdims=True)
# Binarize embeddings
if self.binary:
@@ -197,15 +193,13 @@ def key(self, query: str, norm: bool = True) -> Callable[[str], float]:
def similarity_key(candidate: str) -> float:
candidate_embedding = self.embed(candidate, norm=norm)
- return self.vector_similarity(
- query_embedding[0], candidate_embedding[0]
- ).item()
+ return self.vector_similarity(query_embedding[0], candidate_embedding[0]).item()
return similarity_key
def rank(
- self, query: str, docs: List[str], sort: bool = True, batch_size: int = 64
- ) -> List[Tuple[str, float]]:
+ self, query: str, docs: list[str], sort: bool = True, batch_size: int = 64
+ ) -> list[tuple[str, float]]:
"""Rank documents based on their similarity to a query.
Result may be sorted by similarity score in descending order, or not (see `sort` parameter)
@@ -220,9 +214,8 @@ def rank(
List[Tuple[str, float]]: A list of tuples `(doc, score)`.
"""
assert isinstance(query, str), "Query must be a string"
- assert (
- isinstance(docs, list) and len(docs) > 1
- ), "Docs must be a list of 2 more more strings."
+ assert isinstance(docs, list), "Docs must be a list"
+ assert len(docs) > 1, "Docs must contain 2 or more strings"
query_embedding = self.embed(query)
doc_embeddings = self.embed(docs, batch_size=batch_size)
scores = self.vector_similarity(query_embedding[0], doc_embeddings)
@@ -235,11 +228,11 @@ def rank(
def deduplicate(
self,
- docs: List[str],
+ docs: list[str],
threshold: float = 0.9,
return_indices: bool = False,
batch_size: Optional[int] = None,
- ) -> List[Union[str, int]]:
+ ) -> list[Union[str, int]]:
"""Deduplicate documents based on a similarity threshold.
Args:
@@ -255,20 +248,16 @@ def deduplicate(
if batch_size is None:
batch_size = 500 if self.binary else 5000
- duplicate_indices = deduplicate_embeddings(
- doc_embeddings, threshold, batch_size
- )
+ duplicate_indices = deduplicate_embeddings(doc_embeddings, threshold, batch_size)
if return_indices:
# turn set of numpy int into sorted list of python int
- duplicate_indices = list(map(lambda x: x.item(), duplicate_indices))
+ duplicate_indices = [x.item() for x in duplicate_indices]
return sorted(duplicate_indices)
- unique_docs = [
- doc for idx, doc in enumerate(docs) if idx not in duplicate_indices
- ]
+ unique_docs = [doc for idx, doc in enumerate(docs) if idx not in duplicate_indices]
return unique_docs
- def topk(self, query: str, candidates: List[str], k: int = 3) -> List[str]:
+ def topk(self, query: str, candidates: list[str], k: int = 3) -> list[str]:
"""Retrieve the top-k most similar documents to a query.
Args:
@@ -279,15 +268,13 @@ def topk(self, query: str, candidates: List[str], k: int = 3) -> List[str]:
Returns:
List[str]: The top-k documents most similar to the query.
"""
- assert (
- len(candidates) > k
- ), f"Number of candidates ({len(candidates)}) must be greater than k ({k})"
+ assert len(candidates) > k, (
+ f"Number of candidates ({len(candidates)}) must be greater than k ({k})"
+ )
ranked_docs = self.rank(query, candidates)
return [doc for doc, score in ranked_docs[:k]]
- def filter(
- self, query: str, candidates: List[str], threshold: float = 0.3
- ) -> List[str]:
+ def filter(self, query: str, candidates: list[str], threshold: float = 0.3) -> list[str]:
"""Filter documents to include only those similar to the query above a threshold.
Args:
@@ -305,23 +292,21 @@ def filter(
).squeeze()
filtered_docs = [
- doc
- for doc, score in zip(candidates, similarity_scores)
- if score > threshold
+ doc for doc, score in zip(candidates, similarity_scores) if score > threshold
]
return filtered_docs
@dense_only
def cluster(
self,
- docs: List[str],
+ docs: list[str],
k: int,
max_iterations: int = 100,
tolerance: float = 1e-4,
n_init: int = 10,
min_iterations: int = 5,
random_state=None,
- ) -> Tuple[List[int], float]:
+ ) -> tuple[list[int], float]:
"""Cluster documents into `k` clusters using KMeans clustering.
Args:
@@ -363,7 +348,7 @@ def split(
cleanup_size: int = 24,
intermediate_size: int = 96,
return_minima: bool = False,
- ) -> List[str]:
+ ) -> list[str]:
"""Split text into semantically coherent chunks.
Args:
diff --git a/wordllama/mode_decorators.py b/src/wordllama/mode_decorators.py
similarity index 100%
rename from wordllama/mode_decorators.py
rename to src/wordllama/mode_decorators.py
diff --git a/wordllama/tokenizers/__init__.py b/src/wordllama/tokenizers/__init__.py
similarity index 85%
rename from wordllama/tokenizers/__init__.py
rename to src/wordllama/tokenizers/__init__.py
index 31140ae..3efc0dc 100644
--- a/wordllama/tokenizers/__init__.py
+++ b/src/wordllama/tokenizers/__init__.py
@@ -1,4 +1,5 @@
from pathlib import Path
+
from tokenizers import Tokenizer
@@ -17,9 +18,7 @@ def tokenizer_from_file(file_name: str) -> Tokenizer:
tokenizer_path = current_dir / file_name
# Assert that the tokenizer configuration file exists
- assert (
- tokenizer_path.exists()
- ), f"Tokenizer configuration file {tokenizer_path} does not exist."
+ assert tokenizer_path.exists(), f"Tokenizer configuration file {tokenizer_path} does not exist."
# Load the tokenizer from the specified file
tokenizer = Tokenizer.from_file(str(tokenizer_path))
diff --git a/wordllama/tokenizers/l2_supercat_tokenizer_config.json b/src/wordllama/tokenizers/l2_supercat_tokenizer_config.json
similarity index 99%
rename from wordllama/tokenizers/l2_supercat_tokenizer_config.json
rename to src/wordllama/tokenizers/l2_supercat_tokenizer_config.json
index 599eb71..16ca4e4 100644
--- a/wordllama/tokenizers/l2_supercat_tokenizer_config.json
+++ b/src/wordllama/tokenizers/l2_supercat_tokenizer_config.json
@@ -93389,4 +93389,4 @@
"▁ ▁▁▁▁▁▁▁▁▁▁▁▁▁▁"
]
}
-}
\ No newline at end of file
+}
diff --git a/wordllama/trainers/__init__.py b/src/wordllama/trainers/__init__.py
similarity index 100%
rename from wordllama/trainers/__init__.py
rename to src/wordllama/trainers/__init__.py
diff --git a/wordllama/trainers/reduce_dimension.py b/src/wordllama/trainers/reduce_dimension.py
similarity index 95%
rename from wordllama/trainers/reduce_dimension.py
rename to src/wordllama/trainers/reduce_dimension.py
index 130d8d2..6decaa7 100644
--- a/wordllama/trainers/reduce_dimension.py
+++ b/src/wordllama/trainers/reduce_dimension.py
@@ -11,7 +11,6 @@
SimilarityFunction,
)
-
logger = logging.getLogger(__name__)
@@ -58,9 +57,7 @@ def setup_evaluator(self) -> SequentialEvaluator:
)
for dim in self.config.matryoshka_dims
]
- return SequentialEvaluator(
- evaluators, main_score_function=lambda scores: scores[0]
- )
+ return SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[0])
def initialize_trainer(self) -> SentenceTransformerTrainer:
return SentenceTransformerTrainer(
diff --git a/wordllama/weights/l2_supercat_256.safetensors b/src/wordllama/weights/l2_supercat_256.safetensors
similarity index 100%
rename from wordllama/weights/l2_supercat_256.safetensors
rename to src/wordllama/weights/l2_supercat_256.safetensors
diff --git a/wordllama/wordllama.py b/src/wordllama/wordllama.py
similarity index 94%
rename from wordllama/wordllama.py
rename to src/wordllama/wordllama.py
index 296d503..81062c2 100644
--- a/wordllama/wordllama.py
+++ b/src/wordllama/wordllama.py
@@ -1,15 +1,14 @@
import logging
import warnings
from pathlib import Path
-from typing import Dict, List, Optional, Union
+from typing import Optional, Union
import requests
from safetensors import safe_open
from tokenizers import Tokenizer
+from .config import Model2VecModels, ModelURI, WordLlamaModels
from .inference import WordLlamaInference
-from .config import ModelURI, WordLlamaModels, Model2VecModels
-
logger = logging.getLogger(__name__)
@@ -26,7 +25,7 @@ class WordLlama:
DEFAULT_CACHE_DIR = Path.home() / ".cache" / "wordllama"
@classmethod
- def list_configs(cls) -> Dict[str, List[str]]:
+ def list_configs(cls) -> dict[str, list[str]]:
"""
List the available configurations.
@@ -166,9 +165,7 @@ def resolve_file(
response = requests.get(url, stream=True)
response.raise_for_status()
except requests.RequestException as e:
- logger.error(
- f"Failed to download {file_type} file '{filename}' from '{url}': {e}"
- )
+ logger.error(f"Failed to download {file_type} file '{filename}' from '{url}': {e}")
raise FileNotFoundError(
f"Failed to download {file_type} file '{filename}' from '{url}'."
) from e
@@ -267,18 +264,12 @@ def load(
# Validate dimensions
if dim not in model_uri.available_dims:
- raise ValueError(
- f"Model dimension must be one of {model_uri.available_dims}"
- )
+ raise ValueError(f"Model dimension must be one of {model_uri.available_dims}")
if trunc_dim is not None:
if trunc_dim > dim:
- raise ValueError(
- f"Cannot truncate to a higher dimension ({trunc_dim} > {dim})"
- )
+ raise ValueError(f"Cannot truncate to a higher dimension ({trunc_dim} > {dim})")
if trunc_dim not in model_uri.available_dims:
- raise ValueError(
- f"Truncated dimension must be one of {model_uri.available_dims}"
- )
+ raise ValueError(f"Truncated dimension must be one of {model_uri.available_dims}")
return cls._load(
config_name=config_name,
@@ -384,14 +375,13 @@ def load_tokenizer(
if use_local_if_exists:
# Check in the default path
if tokenizer_file_path.exists():
- logger.debug(
- f"Loading tokenizer from local config_name: {tokenizer_file_path}"
- )
+ logger.debug(f"Loading tokenizer from local config_name: {tokenizer_file_path}")
return Tokenizer.from_file(str(tokenizer_file_path))
else:
warnings.warn(
f"Local tokenizer config not found at {tokenizer_file_path}. "
- "Falling back to Hugging Face."
+ "Falling back to Hugging Face.",
+ stacklevel=2,
)
if hf_model_id:
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 6876a46..bc29003 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,4 +1,5 @@
import unittest
+
from wordllama import WordLlama
@@ -24,8 +25,8 @@ def test_function_sorted(self):
sim_key = wl.key(query)
sorted_candidates = sorted(candidates, key=sim_key, reverse=True)
- self.assertIsInstance(sorted_candidates, list)
- self.assertEqual(len(sorted_candidates), len(candidates))
+ assert isinstance(sorted_candidates, list)
+ assert len(sorted_candidates) == len(candidates)
def test_function_max(self):
wl = WordLlama.load()
@@ -35,7 +36,5 @@ def test_function_max(self):
sim_key = wl.key(query)
best_candidate = max(candidates, key=sim_key)
- self.assertIn(best_candidate, candidates)
- self.assertEqual(
- best_candidate, max(candidates, key=lambda x: wl.similarity(query, x))
- )
+ assert best_candidate in candidates
+ assert best_candidate == max(candidates, key=lambda x: wl.similarity(query, x))
diff --git a/tests/test_inference.py b/tests/test_inference.py
index 7101421..97c4e55 100644
--- a/tests/test_inference.py
+++ b/tests/test_inference.py
@@ -1,15 +1,10 @@
import unittest
-from unittest.mock import patch, MagicMock
+from unittest.mock import MagicMock, patch
+
import numpy as np
+import pytest
+
from wordllama.inference import WordLlamaInference
-from wordllama.config import (
- WordLlamaConfig,
- WordLlamaModel,
- TokenizerConfig,
- TrainingConfig,
- MatryoshkaConfig,
- TokenizerInferenceConfig,
-)
np.random.seed(42)
@@ -43,9 +38,9 @@ def mock_encode_batch(texts, *args, **kwargs):
def test_deduplicate_cosine(self, mock_embed):
docs = ["doc1", "doc1_dup", "a second document that is different", "doc1_dup2"]
deduplicated_docs = self.model.deduplicate(docs, threshold=0.9)
- self.assertEqual(len(deduplicated_docs), 2)
- self.assertIn("doc1", deduplicated_docs)
- self.assertIn("a second document that is different", deduplicated_docs)
+ assert len(deduplicated_docs) == 2
+ assert "doc1" in deduplicated_docs
+ assert "a second document that is different" in deduplicated_docs
@patch.object(
WordLlamaInference,
@@ -62,10 +57,10 @@ def test_deduplicate_cosine(self, mock_embed):
def test_deduplicate_no_duplicates(self, mock_embed):
docs = ["doc1", "doc2", "doc3"]
deduplicated_docs = self.model.deduplicate(docs, threshold=0.9)
- self.assertEqual(len(deduplicated_docs), 3)
- self.assertIn("doc1", deduplicated_docs)
- self.assertIn("doc2", deduplicated_docs)
- self.assertIn("doc3", deduplicated_docs)
+ assert len(deduplicated_docs) == 3
+ assert "doc1" in deduplicated_docs
+ assert "doc2" in deduplicated_docs
+ assert "doc3" in deduplicated_docs
@patch.object(
WordLlamaInference,
@@ -75,8 +70,8 @@ def test_deduplicate_no_duplicates(self, mock_embed):
def test_deduplicate_all_duplicates(self, mock_embed):
docs = ["doc1", "doc1_dup", "doc1_dup2"]
deduplicated_docs = self.model.deduplicate(docs, threshold=0.9)
- self.assertEqual(len(deduplicated_docs), 1)
- self.assertIn("doc1", deduplicated_docs)
+ assert len(deduplicated_docs) == 1
+ assert "doc1" in deduplicated_docs
@patch.object(
WordLlamaInference,
@@ -85,32 +80,30 @@ def test_deduplicate_all_duplicates(self, mock_embed):
)
def test_deduplicate_return_indices(self, mock_embed):
docs = ["doc1", "doc1_dup", "doc1_dup2"]
- duplicated_idx = self.model.deduplicate(
- docs, return_indices=True, threshold=0.9
- )
- self.assertEqual(len(duplicated_idx), 2)
- self.assertIn(1, duplicated_idx)
- self.assertIn(2, duplicated_idx)
+ duplicated_idx = self.model.deduplicate(docs, return_indices=True, threshold=0.9)
+ assert len(duplicated_idx) == 2
+ assert 1 in duplicated_idx
+ assert 2 in duplicated_idx
def test_tokenize(self):
tokens = self.model.tokenize("test string")
self.mock_tokenizer.encode_batch.assert_called_with(
["test string"], is_pretokenized=False, add_special_tokens=False
)
- self.assertEqual(len(tokens), 1)
+ assert len(tokens) == 1
def test_embed(self):
embeddings = self.model.embed("test string", return_np=True)
- self.assertEqual(embeddings.shape, (1, 64))
+ assert embeddings.shape == (1, 64)
def test_cluster_fails_binary(self):
self.model.binary = True
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.model.cluster(["a", "b", "c"])
def test_split_fails_binary(self):
self.model.binary = True
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.model.split("a" * 1000)
def test_similarity_cosine(self):
@@ -119,7 +112,7 @@ def mock_encode_batch(texts, *args, **kwargs):
self.mock_tokenizer.encode_batch.side_effect = mock_encode_batch
sim_score = self.model.similarity("test string 1", "test string 2")
- self.assertTrue(isinstance(sim_score, float))
+ assert isinstance(sim_score, float)
def test_similarity_hamming(self):
def mock_encode_batch(texts, *args, **kwargs):
@@ -129,7 +122,7 @@ def mock_encode_batch(texts, *args, **kwargs):
self.model.binary = True
sim_score = self.model.similarity("test string 1", "test string 2")
- self.assertTrue(isinstance(sim_score, float))
+ assert isinstance(sim_score, float)
def test_rank_cosine(self):
def mock_encode_batch(texts, *args, **kwargs):
@@ -145,7 +138,7 @@ def mock_embed(texts, *args, **kwargs):
if isinstance(texts, str):
texts = [texts]
embeddings = []
- for i, text in enumerate(texts):
+ for i, _text in enumerate(texts):
embedding = np.zeros(64, dtype=np.float32)
embedding[1 if len(texts) == 1 else i] = 1
embeddings.append(embedding)
@@ -155,15 +148,15 @@ def mock_embed(texts, *args, **kwargs):
docs = ["doc1", "doc2", "doc3"]
ranked_docs = self.model.rank("test query", docs)
- self.assertEqual(len(ranked_docs), len(docs))
- self.assertTrue(all(isinstance(score, float) for doc, score in ranked_docs))
- self.assertEqual(ranked_docs[0], ("doc2", 1.0))
+ assert len(ranked_docs) == len(docs)
+ assert all(isinstance(score, float) for doc, score in ranked_docs)
+ assert ranked_docs[0] == ("doc2", 1.0)
# test turning off sorting
unsorted_docs = self.model.rank("test query", docs, sort=False)
- self.assertEqual(len(unsorted_docs), len(docs))
- self.assertTrue(all(isinstance(score, float) for doc, score in unsorted_docs))
- self.assertEqual(unsorted_docs[1], ("doc2", 1.0))
+ assert len(unsorted_docs) == len(docs)
+ assert all(isinstance(score, float) for doc, score in unsorted_docs)
+ assert unsorted_docs[1] == ("doc2", 1.0)
def test_rank_hamming(self):
def mock_encode_batch(texts, *args, **kwargs):
@@ -178,8 +171,8 @@ def mock_encode_batch(texts, *args, **kwargs):
ranked_docs = self.model.rank("test query", docs)
mock_hamming.assert_called_once() # check setting binary calls hamming
- self.assertEqual(len(ranked_docs), len(docs))
- self.assertTrue(all(isinstance(score, float) for doc, score in ranked_docs))
+ assert len(ranked_docs) == len(docs)
+ assert all(isinstance(score, float) for doc, score in ranked_docs)
def test_instantiate_with_truncation(self):
truncated_embedding = np.random.rand(128256, 32)
@@ -187,22 +180,22 @@ def test_instantiate_with_truncation(self):
embedding=truncated_embedding,
tokenizer=self.mock_tokenizer,
)
- self.assertEqual(truncated_model.embedding.shape[1], 32)
+ assert truncated_model.embedding.shape[1] == 32
def test_error_on_wrong_embedding_type(self):
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
self.model.embed(np.array([1, 2]))
def test_binarization_and_packing(self):
self.model.binary = True
binary_output = self.model.embed("test string")
- self.assertIsInstance(binary_output, np.ndarray)
- self.assertEqual(binary_output.dtype, np.uint64)
+ assert isinstance(binary_output, np.ndarray)
+ assert binary_output.dtype == np.uint64
def test_normalization_effect(self):
normalized_output = self.model.embed("test string", norm=True)
norm = np.linalg.norm(normalized_output)
- self.assertAlmostEqual(norm, 1, places=5)
+ assert norm == pytest.approx(1, abs=1e-5)
if __name__ == "__main__":
diff --git a/tests/test_kmeans.py b/tests/test_kmeans.py
index 9a82111..59b2c33 100644
--- a/tests/test_kmeans.py
+++ b/tests/test_kmeans.py
@@ -1,4 +1,5 @@
import unittest
+
import numpy as np
from wordllama.algorithms.kmeans import (
@@ -39,53 +40,43 @@ def setUp(self):
def test_kmeans_clustering_convergence(self):
k = 2
- labels, inertia = kmeans_clustering(
- self.embeddings, k, random_state=self.random_state
- )
+ labels, inertia = kmeans_clustering(self.embeddings, k, random_state=self.random_state)
- self.assertEqual(len(labels), self.embeddings.shape[0])
- self.assertGreater(inertia, 0)
+ assert len(labels) == self.embeddings.shape[0]
+ assert inertia > 0
def test_kmeans_clustering_labels(self):
k = 2
- labels, _ = kmeans_clustering(
- self.embeddings, k, random_state=self.random_state
- )
+ labels, _ = kmeans_clustering(self.embeddings, k, random_state=self.random_state)
# Check that labels are within the valid range
for label in labels:
- self.assertIn(label, range(k))
+ assert label in range(k)
def test_kmeans_clustering_different_k(self):
k = 3
- labels, _ = kmeans_clustering(
- self.embeddings, k, random_state=self.random_state
- )
+ labels, _ = kmeans_clustering(self.embeddings, k, random_state=self.random_state)
- self.assertEqual(len(labels), self.embeddings.shape[0])
+ assert len(labels) == self.embeddings.shape[0]
# Check that labels are within the valid range
for label in labels:
- self.assertIn(label, range(k))
+ assert label in range(k)
def test_kmeans_clustering_random_state(self):
k = 2
labels1, losses1 = kmeans_clustering(self.embeddings, k, random_state=42)
labels2, losses2 = kmeans_clustering(self.embeddings, k, random_state=42)
- self.assertEqual(labels1, labels2)
- self.assertEqual(losses1, losses2)
+ assert labels1 == labels2
+ assert losses1 == losses2
def test_kmeans_clustering_different_initializations(self):
k = 2
- labels1, inertia1 = kmeans_clustering(
- self.embeddings, k, random_state=42, n_init=1
- )
- labels2, inertia2 = kmeans_clustering(
- self.embeddings, k, random_state=42, n_init=10
- )
+ labels1, inertia1 = kmeans_clustering(self.embeddings, k, random_state=42, n_init=1)
+ labels2, inertia2 = kmeans_clustering(self.embeddings, k, random_state=42, n_init=10)
- self.assertGreater(inertia1, inertia2)
+ assert inertia1 > inertia2
if __name__ == "__main__":
diff --git a/tests/test_minima_functions.py b/tests/test_minima_functions.py
index b74cb4d..de355f1 100644
--- a/tests/test_minima_functions.py
+++ b/tests/test_minima_functions.py
@@ -1,5 +1,8 @@
import unittest
+
import numpy as np
+import pytest
+
from wordllama.algorithms.find_local_minima import (
find_local_minima,
windowed_cross_similarity,
@@ -17,23 +20,21 @@ def test_find_local_minima(self):
x_minima, y_minima = find_local_minima(self.y, window_size=3, poly_order=2)
# Known minima for sin(x) in the given range [0, 2*pi]
- expected_x_minima = np.array([3 * np.pi / 2], dtype=np.float32)
+ np.array([3 * np.pi / 2], dtype=np.float32)
expected_y_minima = np.array([-1.0], dtype=np.float32)
# Check if the found minima are correct (allow small numerical tolerance)
- np.testing.assert_array_almost_equal(
- self.y[x_minima], expected_y_minima, decimal=2
- )
+ np.testing.assert_array_almost_equal(self.y[x_minima], expected_y_minima, decimal=2)
np.testing.assert_array_almost_equal(y_minima, expected_y_minima, decimal=2)
def test_find_local_minima_invalid_window_size(self):
# Test that the function raises a ValueError for an invalid window size
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
find_local_minima(self.y, window_size=2, poly_order=2)
def test_find_local_minima_invalid_polynomial_order(self):
# Test that the function raises a ValueError for an invalid polynomial order
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
find_local_minima(self.y, window_size=11, poly_order=11)
@@ -63,17 +64,17 @@ def test_windowed_cross_similarity(self):
def test_windowed_cross_similarity_invalid_window(self):
# Test invalid window size (even window size should raise ValueError)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
windowed_cross_similarity(self.embeddings, window_size=4)
# Test invalid window size (window size < 3)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
windowed_cross_similarity(self.embeddings, window_size=2)
def test_windowed_cross_similarity_small_window(self):
# Test windowed cross similarity with a small window (size 3)
result = windowed_cross_similarity(self.embeddings, window_size=3)
- self.assertEqual(result.shape[0], self.embeddings.shape[0])
+ assert result.shape[0] == self.embeddings.shape[0]
if __name__ == "__main__":
diff --git a/tests/test_semantic_splitter.py b/tests/test_semantic_splitter.py
index 2cf989b..5826029 100644
--- a/tests/test_semantic_splitter.py
+++ b/tests/test_semantic_splitter.py
@@ -1,5 +1,7 @@
import unittest
+
import numpy as np
+
from wordllama.algorithms.semantic_splitter import SemanticSplitter
@@ -10,21 +12,19 @@ def setUp(self):
def test_flatten(self):
nested_list = [[1, 2], [3, 4], [5, 6]]
flattened = self.splitter.flatten(nested_list)
- self.assertEqual(flattened, [1, 2, 3, 4, 5, 6])
+ assert flattened == [1, 2, 3, 4, 5, 6]
def test_constrained_split(self):
text = "This is a test sentence. Another sentence here. And one more. " * 10
chunks = self.splitter.constrained_split(text, target_size=50)
- self.assertTrue(all(len(chunk) <= 50 for chunk in chunks))
- self.assertEqual(" ".join(chunks), text.strip())
+ assert all(len(chunk) <= 50 for chunk in chunks)
+ assert " ".join(chunks) == text.strip()
def test_split(self):
text = "Short sentence.\n\nTwo sentences. Without a line break.\n\nAnother short one."
- chunks = self.splitter.split(
- text, target_size=30, cleanup_size=10, intermediate_size=20
- )
- self.assertTrue(all(len(chunk) <= 30 for chunk in chunks))
- self.assertTrue(all(len(chunk) >= 10 for chunk in chunks))
+ chunks = self.splitter.split(text, target_size=30, cleanup_size=10, intermediate_size=20)
+ assert all(len(chunk) <= 30 for chunk in chunks)
+ assert all(len(chunk) >= 10 for chunk in chunks)
def test_reconstruct(self):
lines = ["Short text.", "Another short text.", "A bit longer text here."]
@@ -39,14 +39,12 @@ def test_reconstruct(self):
savgol_window=3,
)
- self.assertIsInstance(reconstructed, list)
- self.assertTrue(all(isinstance(chunk, str) for chunk in reconstructed))
+ assert isinstance(reconstructed, list)
+ assert all(isinstance(chunk, str) for chunk in reconstructed)
def test_reconstruct_return_minima(self):
lines = ["Short text.", "Another short text.", "A bit longer text here."]
- embeddings = np.random.rand(3, 16).astype(
- np.float32
- ) # 3 texts, 16-dimensional embeddings
+ embeddings = np.random.rand(3, 16).astype(np.float32) # 3 texts, 16-dimensional embeddings
result = self.splitter.reconstruct(
lines,
@@ -58,12 +56,12 @@ def test_reconstruct_return_minima(self):
return_minima=True,
)
- self.assertIsInstance(result, tuple)
- self.assertEqual(len(result), 3)
+ assert isinstance(result, tuple)
+ assert len(result) == 3
roots, y, sim_avg = result
- self.assertIsInstance(roots, np.ndarray)
- self.assertIsInstance(y, np.ndarray)
- self.assertIsInstance(sim_avg, np.ndarray)
+ assert isinstance(roots, np.ndarray)
+ assert isinstance(y, np.ndarray)
+ assert isinstance(sim_avg, np.ndarray)
if __name__ == "__main__":
diff --git a/tests/test_splitting_functions.py b/tests/test_splitting_functions.py
index 471a691..8014793 100644
--- a/tests/test_splitting_functions.py
+++ b/tests/test_splitting_functions.py
@@ -1,11 +1,14 @@
+import string
import unittest
+
+import pytest
+
from wordllama.algorithms.splitter import (
constrained_batches,
- split_sentences,
constrained_coalesce,
reverse_merge,
+ split_sentences,
)
-import string
class TestSplitter(unittest.TestCase):
@@ -14,26 +17,26 @@ def test_constrained_batches(self):
data = ["a", "bb", "ccc", "dddd", "eeeee"]
batches = list(constrained_batches(data, max_size=5))
expected = [("a", "bb"), ("ccc",), ("dddd",), ("eeeee",)]
- self.assertEqual(batches, expected)
+ assert batches == expected
# Batching with max_count
data = ["a", "bb", "ccc", "dddd", "eeeee"]
batches = list(constrained_batches(data, max_size=10, max_count=2))
- self.assertEqual(batches, [("a", "bb"), ("ccc", "dddd"), ("eeeee",)])
+ assert batches == [("a", "bb"), ("ccc", "dddd"), ("eeeee",)]
# Batching with get_len
data = ["a", "bb", "ccc", "dddd", "eeeee"]
batches = list(constrained_batches(data, max_size=5, get_len=lambda x: 1))
- self.assertEqual(batches, [("a", "bb", "ccc", "dddd", "eeeee")])
+ assert batches == [("a", "bb", "ccc", "dddd", "eeeee")]
# Non-strict mode
data = ["aaaaaa", "b", "c"]
batches = list(constrained_batches(data, max_size=5, strict=False))
- self.assertEqual(batches, [("aaaaaa",), ("b", "c")])
+ assert batches == [("aaaaaa",), ("b", "c")]
# Strict mode with item exceeding max_size
data = ["aaaaaa", "b", "c"]
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
batches = list(constrained_batches(data, max_size=5))
def test_split_sentences(self):
@@ -45,37 +48,35 @@ def test_split_sentences(self):
"This is another sentence!",
"And another one?",
]
- self.assertEqual(sentences, expected)
+ assert sentences == expected
# Test with no punctuation
text = "This is a text without punctuation"
sentences = split_sentences(text)
expected = ["This is a text without punctuation"]
- self.assertEqual(sentences, expected)
+ assert sentences == expected
# Test with custom punctuation
text = "Sentence one# Sentence two# Sentence three"
sentences = split_sentences(text, punct_chars={"#"})
expected = ["Sentence one#", "Sentence two#", "Sentence three"]
- self.assertEqual(sentences, expected)
+ assert sentences == expected
# Test with text ending without punctuation
text = "This is a sentence. This is another sentence"
sentences = split_sentences(text)
expected = ["This is a sentence.", "This is another sentence"]
- self.assertEqual(sentences, expected)
+ assert sentences == expected
def test_constrained_coalesce(self):
letters = list(string.ascii_lowercase)
# Using the example from the documentation
result = constrained_coalesce(letters, max_size=5, separator="")
expected = ["abcd", "efgh", "ijkl", "mnop", "qrst", "uvwx", "yz"]
- self.assertEqual(result, expected)
+ assert result == expected
# Test with max_iterations=1
- result = constrained_coalesce(
- letters, max_size=5, separator="", max_iterations=1
- )
+ result = constrained_coalesce(letters, max_size=5, separator="", max_iterations=1)
expected_one_pass = [
"ab",
"cd",
@@ -91,37 +92,37 @@ def test_constrained_coalesce(self):
"wx",
"yz",
]
- self.assertEqual(result, expected_one_pass)
+ assert result == expected_one_pass
# Test with data that cannot be combined
data = ["a"] * 100
result = constrained_coalesce(data, max_size=1, max_iterations=5)
- self.assertEqual(result, ["a"] * 100)
+ assert result == ["a"] * 100
def test_reverse_merge(self):
# Basic merging test
data = ["long enough", "short", "tiny", "adequate", "s"]
result = reverse_merge(data, n=6, separator=" ")
expected = ["long enough short tiny", "adequate s"]
- self.assertEqual(result, expected)
+ assert result == expected
# Test with empty list
data = []
result = reverse_merge(data, n=5)
expected = []
- self.assertEqual(result, expected)
+ assert result == expected
# All strings longer than n
data = ["string1", "string2", "string3"]
result = reverse_merge(data, n=5)
expected = ["string1", "string2", "string3"]
- self.assertEqual(result, expected)
+ assert result == expected
# All strings shorter than n
data = ["a", "bb", "ccc"]
result = reverse_merge(data, n=5, separator=" ")
expected = ["a bb ccc"]
- self.assertEqual(result, expected)
+ assert result == expected
if __name__ == "__main__":
diff --git a/tests/test_vector_similarity.py b/tests/test_vector_similarity.py
index e066d96..017baaa 100644
--- a/tests/test_vector_similarity.py
+++ b/tests/test_vector_similarity.py
@@ -1,7 +1,8 @@
import unittest
+
import numpy as np
-from wordllama.algorithms import vector_similarity, binarize_and_packbits
+from wordllama.algorithms import binarize_and_packbits, vector_similarity
class TestVectorSimilarity(unittest.TestCase):
@@ -9,21 +10,21 @@ def test_binarization_and_packing(self):
vec = np.zeros((1, 64))
vec[0][7] = 1
binary_output = binarize_and_packbits(vec)
- self.assertIsInstance(binary_output, np.ndarray)
- self.assertEqual(binary_output.dtype, np.uint64)
- self.assertEqual(binary_output, 1)
+ assert isinstance(binary_output, np.ndarray)
+ assert binary_output.dtype == np.uint64
+ assert binary_output == 1
def test_cosine_similarity_direct(self):
vec1 = np.random.rand(1, 64)
vec2 = np.random.rand(1, 64)
result = vector_similarity(vec1, vec2, binary=False)
- self.assertIsInstance(result.item(), float)
+ assert isinstance(result.item(), float)
def test_hamming_similarity_direct(self):
vec1 = np.expand_dims(np.random.randint(2, size=64, dtype=np.uint64), axis=0)
vec2 = np.expand_dims(np.random.randint(2, size=64, dtype=np.uint64), axis=0)
result = vector_similarity(vec1, vec2, binary=True)
- self.assertIsInstance(result.item(), float)
+ assert isinstance(result.item(), float)
if __name__ == "__main__":
diff --git a/tests/test_wordllama.py b/tests/test_wordllama.py
index b5a8c42..c84e7ce 100644
--- a/tests/test_wordllama.py
+++ b/tests/test_wordllama.py
@@ -1,18 +1,13 @@
import unittest
-from unittest.mock import patch, MagicMock, mock_open, call
from pathlib import Path
+from unittest.mock import MagicMock, call, mock_open, patch
+
import numpy as np
+import pytest
from tokenizers import Tokenizer
+
+from wordllama.config.models import WordLlamaModels
from wordllama.wordllama import WordLlama, WordLlamaInference
-from wordllama.config import (
- WordLlamaConfig,
- TokenizerConfig,
- MatryoshkaConfig,
- WordLlamaModel,
- TrainingConfig,
- TokenizerInferenceConfig,
-)
-from wordllama.config.models import ModelURI, WordLlamaModels
class TestWordLlama(unittest.TestCase):
@@ -24,7 +19,7 @@ def setUp(self):
# Assemble WordLlamaConfig
self.config = "l2_supercat"
- self.model_uri = getattr(WordLlamaModels, "l2_supercat")
+ self.model_uri = WordLlamaModels.l2_supercat
def test_list_configs(self):
output = WordLlama.list_configs()
@@ -82,9 +77,7 @@ def test_resolve_file_downloads_if_not_found(
handle.write.assert_has_calls([call(b"chunk1"), call(b"chunk2")])
# Assert the returned path is correct
- self.assertEqual(
- weights_path, Path("/dummy/cache/weights/l2_supercat_256.safetensors")
- )
+ assert weights_path == Path("/dummy/cache/weights/l2_supercat_256.safetensors")
@patch.object(WordLlama, "resolve_file", autospec=True)
def test_load_with_default_cache_dir(self, mock_resolve_file):
@@ -94,24 +87,21 @@ def test_load_with_default_cache_dir(self, mock_resolve_file):
# Setup mock for resolve_file
default_cache_dir = WordLlama.DEFAULT_CACHE_DIR
weights_path = default_cache_dir / "weights" / "l2_supercat_256.safetensors"
- tokenizer_path = (
- default_cache_dir / "tokenizers" / "l2_supercat_tokenizer_config.json"
- )
+ tokenizer_path = default_cache_dir / "tokenizers" / "l2_supercat_tokenizer_config.json"
mock_resolve_file.side_effect = [weights_path, tokenizer_path]
# Mock tokenizer and weights loading
- with patch(
- "wordllama.wordllama.WordLlama.load_tokenizer",
- return_value=MagicMock(spec=Tokenizer),
- ) as mock_load_tokenizer, patch(
- "wordllama.wordllama.safe_open", autospec=True
- ) as mock_safe_open:
+ with (
+ patch(
+ "wordllama.wordllama.WordLlama.load_tokenizer",
+ return_value=MagicMock(spec=Tokenizer),
+ ) as mock_load_tokenizer,
+ patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open,
+ ):
# Mock the tensor returned by safe_open
mock_tensor = MagicMock()
mock_tensor.__getitem__.return_value = np.random.rand(256, 4096)
- mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = (
- mock_tensor
- )
+ mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = mock_tensor
# Call load without specifying cache_dir
model = WordLlama.load(
@@ -146,7 +136,7 @@ def test_load_with_default_cache_dir(self, mock_resolve_file):
),
]
mock_resolve_file.assert_has_calls(expected_calls, any_order=False)
- self.assertEqual(mock_resolve_file.call_count, 2)
+ assert mock_resolve_file.call_count == 2
# Assert load_tokenizer was called with correct path
mock_load_tokenizer.assert_called_once_with(
@@ -162,7 +152,7 @@ def test_load_with_default_cache_dir(self, mock_resolve_file):
)
# Assert the returned model is an instance of WordLlamaInference
- self.assertIsInstance(model, WordLlamaInference)
+ assert isinstance(model, WordLlamaInference)
@patch.object(WordLlama, "resolve_file", autospec=True)
def test_load_with_custom_cache_dir(self, mock_resolve_file):
@@ -182,7 +172,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file):
"tilde": Path("~/tmp_cache").expanduser(),
"relative": Path("tmp").resolve(),
"relative_dot": Path("./tmp").resolve(),
- "absolute": Path("/tmp/cache_dir"),
+ "absolute": Path("/tmp/cache_dir").resolve(),
}
for key, cache_dir_input in cache_dirs.items():
@@ -192,24 +182,21 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file):
# Setup mock for resolve_file
weights_path = (
- expected_resolved_dirs[key]
- / "weights"
- / "l2_supercat_256.safetensors"
+ expected_resolved_dirs[key] / "weights" / "l2_supercat_256.safetensors"
)
tokenizer_path = (
- expected_resolved_dirs[key]
- / "tokenizers"
- / "l2_supercat_tokenizer_config.json"
+ expected_resolved_dirs[key] / "tokenizers" / "l2_supercat_tokenizer_config.json"
)
mock_resolve_file.side_effect = [weights_path, tokenizer_path]
# Mock tokenizer and weights loading
- with patch(
- "wordllama.wordllama.WordLlama.load_tokenizer",
- return_value=MagicMock(spec=Tokenizer),
- ) as mock_load_tokenizer, patch(
- "wordllama.wordllama.safe_open", autospec=True
- ) as mock_safe_open:
+ with (
+ patch(
+ "wordllama.wordllama.WordLlama.load_tokenizer",
+ return_value=MagicMock(spec=Tokenizer),
+ ) as mock_load_tokenizer,
+ patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open,
+ ):
# Mock the tensor returned by safe_open
mock_tensor = MagicMock()
mock_tensor.__getitem__.return_value = np.random.rand(256, 4096)
@@ -252,7 +239,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file):
),
]
mock_resolve_file.assert_has_calls(expected_calls, any_order=False)
- self.assertEqual(mock_resolve_file.call_count, 2)
+ assert mock_resolve_file.call_count == 2
# Assert load_tokenizer was called with correct path
mock_load_tokenizer.assert_called_once_with(
@@ -268,7 +255,7 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file):
)
# Assert the returned model is an instance of WordLlamaInference
- self.assertIsInstance(model, WordLlamaInference)
+ assert isinstance(model, WordLlamaInference)
@patch.object(WordLlama, "resolve_file", autospec=True)
def test_load_with_disable_download(self, mock_resolve_file):
@@ -279,7 +266,7 @@ def test_load_with_disable_download(self, mock_resolve_file):
mock_resolve_file.side_effect = FileNotFoundError("File not found")
# Call load with disable_download=True and expect FileNotFoundError
- with self.assertRaises(FileNotFoundError):
+ with pytest.raises(FileNotFoundError):
WordLlama.load(
config=self.config,
cache_dir=Path("/dummy/cache"),
@@ -303,7 +290,7 @@ def test_load_with_disable_download(self, mock_resolve_file):
)
]
mock_resolve_file.assert_has_calls(expected_calls, any_order=False)
- self.assertEqual(mock_resolve_file.call_count, 1)
+ assert mock_resolve_file.call_count == 1
@patch.object(WordLlama, "resolve_file", autospec=True)
def test_load_with_truncated_dimension(self, mock_resolve_file):
@@ -312,24 +299,21 @@ def test_load_with_truncated_dimension(self, mock_resolve_file):
"""
# Setup mock for resolve_file
weights_path = Path("/dummy/cache/weights/l2_supercat_256.safetensors")
- tokenizer_path = Path(
- "/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json"
- )
+ tokenizer_path = Path("/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json")
mock_resolve_file.side_effect = [weights_path, tokenizer_path]
# Mock tokenizer and weights loading
- with patch(
- "wordllama.wordllama.WordLlama.load_tokenizer",
- return_value=MagicMock(spec=Tokenizer),
- ) as mock_load_tokenizer, patch(
- "wordllama.wordllama.safe_open", autospec=True
- ) as mock_safe_open:
+ with (
+ patch(
+ "wordllama.wordllama.WordLlama.load_tokenizer",
+ return_value=MagicMock(spec=Tokenizer),
+ ) as mock_load_tokenizer,
+ patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open,
+ ):
# Mock the tensor returned by safe_open
mock_tensor = MagicMock()
mock_tensor.__getitem__.return_value = np.random.rand(256, 4096)
- mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = (
- mock_tensor
- )
+ mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = mock_tensor
# Call load with trunc_dim
model = WordLlama.load(
@@ -364,7 +348,7 @@ def test_load_with_truncated_dimension(self, mock_resolve_file):
),
]
mock_resolve_file.assert_has_calls(expected_calls, any_order=False)
- self.assertEqual(mock_resolve_file.call_count, 2)
+ assert mock_resolve_file.call_count == 2
# Assert load_tokenizer was called with correct path
mock_load_tokenizer.assert_called_once_with(
@@ -380,7 +364,7 @@ def test_load_with_truncated_dimension(self, mock_resolve_file):
)
# Assert the returned model is an instance of WordLlamaInference
- self.assertIsInstance(model, WordLlamaInference)
+ assert isinstance(model, WordLlamaInference)
# Assert that the embedding was truncated
mock_tensor.__getitem__.assert_called_with((slice(None), slice(None, 128)))
@@ -400,15 +384,14 @@ def test_load_tokenizer_fallback(
# Setup mocks
# First call for weights, second call for tokenizer
weights_path = Path("/dummy/cache/weights/l2_supercat_256.safetensors")
- tokenizer_path = Path(
- "/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json"
- )
+ tokenizer_path = Path("/dummy/cache/tokenizers/l2_supercat_tokenizer_config.json")
mock_resolve_file.side_effect = [weights_path, tokenizer_path]
# Simulate tokenizer config does not exist by patching Path.exists
- with patch(
- "wordllama.wordllama.Path.exists", side_effect=[False, False]
- ), patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open:
+ with (
+ patch("wordllama.wordllama.Path.exists", side_effect=[False, False]),
+ patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open,
+ ):
# Mock the tensor returned by safe_open
mock_tensor = MagicMock()
mock_tensor.__getitem__.return_value = np.random.rand(256, 4096)
@@ -449,7 +432,7 @@ def test_load_tokenizer_fallback(
),
]
mock_resolve_file.assert_has_calls(expected_calls, any_order=False)
- self.assertEqual(mock_resolve_file.call_count, 2)
+ assert mock_resolve_file.call_count == 2
# Assert Tokenizer.from_pretrained was called since local config was not found
mock_from_pretrained.assert_called_once_with("meta-llama/Llama-2-70b-hf")
@@ -462,7 +445,7 @@ def test_load_tokenizer_fallback(
)
# Assert the returned model is an instance of WordLlamaInference
- self.assertIsInstance(model, WordLlamaInference)
+ assert isinstance(model, WordLlamaInference)
if __name__ == "__main__":
diff --git a/train.py b/train.py
index 9d4fcf8..731f1a1 100644
--- a/train.py
+++ b/train.py
@@ -7,23 +7,24 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
-import torch
-import tqdm
-import safetensors.torch
from datetime import datetime
from pathlib import Path
+
+import safetensors.torch
+import torch
+import tqdm
+from dataset_loader import load_datasets
from sentence_transformers import (
- SentenceTransformerTrainingArguments,
SentenceTransformer,
+ SentenceTransformerTrainingArguments,
)
from sentence_transformers.training_args import MultiDatasetBatchSamplers
-from wordllama import load_training, Config
+from wordllama import Config, load_training
+from wordllama.adapters import AvgPool, Binarizer, WeightedProjector
from wordllama.config import WordLlamaModel
from wordllama.embedding.word_llama_embedding import WordLlamaEmbedding
from wordllama.trainers.reduce_dimension import ReduceDimension
-from wordllama.adapters import AvgPool, WeightedProjector, Binarizer
-from dataset_loader import load_datasets
class ReduceDimensionConfig:
@@ -108,9 +109,7 @@ def save(self, checkpoint: Path, outdir: Path):
# load the projector weights
max_dim = max(self.matryoshka_dims)
- proj_path = (
- checkpoint / "1_WeightedProjector" / "weighted_projector.safetensors"
- )
+ proj_path = checkpoint / "1_WeightedProjector" / "weighted_projector.safetensors"
proj = WeightedProjector(
self.config.model.dim,
max_dim,
@@ -183,9 +182,7 @@ class TmpConfig:
required=True,
help="Name of your configuration (eg. [your_config].toml)",
)
- parser_save.add_argument(
- "--checkpoint", type=str, required=True, help="Path to the checkpoint"
- )
+ parser_save.add_argument("--checkpoint", type=str, required=True, help="Path to the checkpoint")
parser_save.add_argument(
"--outdir", type=str, required=True, help="Directory to save the models"
)
@@ -196,9 +193,7 @@ class TmpConfig:
# Execute based on the command
if args.command == "train":
- config = ReduceDimensionConfig(
- config_name, binarize=args.binarize, norm=args.norm
- )
+ config = ReduceDimensionConfig(config_name, binarize=args.binarize, norm=args.norm)
trainer = ReduceDimension(config)
trainer.train()
diff --git a/tutorials/blog/semantic_split/wl_semantic_blog.md b/tutorials/blog/semantic_split/wl_semantic_blog.md
index bc5e7de..543f739 100644
--- a/tutorials/blog/semantic_split/wl_semantic_blog.md
+++ b/tutorials/blog/semantic_split/wl_semantic_blog.md
@@ -16,7 +16,7 @@ When splitting/chunking text for Retrieval-Augmented Generation (RAG) applicatio
These methods don't require language modeling, but they lack semantic awareness. More advanced techniques using language models can provide better semantic coherence but typically at a significant computational cost and latency. And, although semantic splitting is conceptually simple, it still involves multiple steps to refine a quality algorithm.
-WordLlama is a good platform for accomplishing this, since it can incorporate basic semantic information into the chunking process, without adding significant computational requirements. Here, we develop a recipe for semantic splitting with WordLlama using an intuitive process.
+WordLlama is a good platform for accomplishing this, since it can incorporate basic semantic information into the chunking process, without adding significant computational requirements. Here, we develop a recipe for semantic splitting with WordLlama using an intuitive process.
## Target texts
@@ -83,21 +83,21 @@ print(text[0:300])
J. R. R. Tolkien — The Lord Of The Rings. (1/4)
-----------------------------------------------
-
-
+
+
THE LORD OF THE RINGS
-
+
by
-
+
J. R. R. TOLKIEN
-
-
-
+
+
+
Part 1: The Fellowship of the Ring
Part 2: The Two Towers
Part 3: The Return of the King
-
-
+
+
_Complete with Index and Full Appendi
@@ -115,21 +115,21 @@ import seaborn as sns
def plot_chars(chars_per_line):
sns.set(style="whitegrid")
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
-
+
# First plot: full range
sns.histplot(chars_per_line, bins=200, ax=axes[0], kde=False)
axes[0].set_title("Characters per Line - Full Range")
axes[0].set_xlabel("# Characters")
axes[0].set_ylabel("$log($Counts$)$")
axes[0].semilogy(True)
-
+
# Second plot: zoomed-in range
sns.histplot(chars_per_line, bins=1000, ax=axes[1], kde=False)
axes[1].set_title("Characters per Line - Zoomed In (0 to 100)")
axes[1].set_xlabel("# Characters")
axes[1].set_ylabel("Counts")
axes[1].set_xlim((0, 200))
-
+
plt.tight_layout()
plt.show()
```
@@ -146,9 +146,9 @@ plot_chars(chars_per_line)
```
-
+

-
+
Here we can see a bunch of small fragments with close to zero size. Additionally, there are some smaller segments below 50 characters. While most of the chunks are fewer than 1k characters, there are a few larger ones as well. The chunks that are a few characters or less are not likely to carry much semantic information and are disproportionate compared to most of the other segments.
@@ -208,9 +208,9 @@ plot_chars(chars_per_line)
```
-
+

-
+
This is better. Let's take care of the larger segments.
@@ -268,9 +268,9 @@ plot_chars(chars_per_line)
```
-
+

-
+
Now we have a more reasonable starting point for doing semantic splitting. Let's use wordllama to embed the segments into vectors, and compute similarity for all the segments.
@@ -301,9 +301,9 @@ plt.title("Cross-similarity of lines")
-
+

-
+
Here's where we can see how wordllama can help. As we traverse the diagonal, we can identify blocks of similar texts. The very small block in the upper left corner is the table of contents.
@@ -334,12 +334,12 @@ plt.show()
```
-
+

-
-With the size of our segments, even 10-20 segments is a decent chunk size. Here we can zoom in on the minimum around the **dashed red line index (308)**.
+
+With the size of our segments, even 10-20 segments is a decent chunk size. Here we can zoom in on the minimum around the **dashed red line index (308)**.
```python
@@ -350,7 +350,7 @@ print("\n".join([lines[i] if i != 308 else f">>>>>>>>>>>>>{lines[i]}<<<<<<<<<<<<
'So do I,' said Gandalf. 'And I wonder many other things. Good-bye now! Take care of yourself! Look out for me, especially at unlikely times! Good-bye!'
Frodo saw him to the door. He gave a final wave of his hand, and walked off at a surprising pace; but Frodo thought the old wizard looked unusually bent, almost as if he was carrying a great weight. The evening was closing in, and his cloaked figure quickly vanished into the twilight. Frodo did not see him again for a long time.
>>>>>>>>>>>>>
-
+
_Chapter 2_
The Shadow of the Past
<<<<<<<<<<<<<
@@ -404,9 +404,9 @@ plt.show()
```
-
+

-
+
Well that was fun. Now all that's left is to bring the sections back up to our target size.
@@ -450,9 +450,9 @@ ax.semilogy(True)
-
+

-
+
### Visualize
@@ -464,21 +464,21 @@ from IPython.display import Markdown, display
def display_strings(string_list, offset=0):
"""
Convert a list of strings into a markdown table and display it in a Jupyter notebook.
-
+
Parameters:
- string_list (list): The list of strings to display
-
+
Returns:
- None (displays the table in the notebook)
"""
# Create the table header
table = "| Index | Text |\n|-------|------|\n"
-
+
# Add each string to the table
for i, text in enumerate(string_list):
row = f"| {i + offset} | {text[:600]}{'...' if len(text) > 600 else ''} |\n"
table += row
-
+
# Display the table
display(Markdown(table))
@@ -520,7 +520,7 @@ print(f"Length of text: {len(text):.2e} chars\n# of chunks: {len(results)}\n\nPr
Length of text: 1.02e+06 chars
# of chunks: 784
-
+
Processing time:
CPU times: user 1.31 s, sys: 111 ms, total: 1.43 s
Wall time: 677 ms
diff --git a/tutorials/extract_token_embeddings.md b/tutorials/extract_token_embeddings.md
index f9f33d8..e88efa3 100644
--- a/tutorials/extract_token_embeddings.md
+++ b/tutorials/extract_token_embeddings.md
@@ -52,7 +52,7 @@ In [1]: from safetensors import safe_open
In [2]: with safe_open("/home/lee/Downloads/model-00001-of-00002.safetensors", framework="pt") as f:
...: weights = f.get_tensor("model.embed_tokens.weight")
- ...:
+ ...:
In [3]: weights.shape
Out[3]: torch.Size([256000, 2304])