diff --git a/README.md b/README.md
index 491e2d8..b6ead41 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
**WordLlama** is a fast, lightweight NLP toolkit designed for tasks like fuzzy deduplication, similarity computation, ranking, clustering, and semantic text splitting. It operates with minimal inference-time dependencies and is optimized for CPU hardware, making it suitable for deployment in resource-constrained environments.
-
+
## News and Updates 🔥
diff --git a/tests/test_functional.py b/tests/test_functional.py
index bc29003..b377f60 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,9 +1,7 @@
-import unittest
-
from wordllama import WordLlama
-class TestFunctional(unittest.TestCase):
+class TestFunctional:
def test_function_clustering(self):
wl = WordLlama.load()
wl.cluster(["a", "b"], k=2)
diff --git a/tests/test_inference.py b/tests/test_inference.py
index 97c4e55..378f4b3 100644
--- a/tests/test_inference.py
+++ b/tests/test_inference.py
@@ -1,4 +1,3 @@
-import unittest
from unittest.mock import MagicMock, patch
import numpy as np
@@ -9,9 +8,10 @@
np.random.seed(42)
-class TestWordLlamaInference(unittest.TestCase):
+class TestWordLlamaInference:
+ @pytest.fixture(autouse=True)
@patch("wordllama.inference.Tokenizer.from_pretrained")
- def setUp(self, mock_tokenizer):
+ def setup(self, mock_tokenizer):
np.random.seed(42)
# Mock the tokenizer
@@ -196,7 +196,3 @@ def test_normalization_effect(self):
normalized_output = self.model.embed("test string", norm=True)
norm = np.linalg.norm(normalized_output)
assert norm == pytest.approx(1, abs=1e-5)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_kmeans.py b/tests/test_kmeans.py
index 59b2c33..479849e 100644
--- a/tests/test_kmeans.py
+++ b/tests/test_kmeans.py
@@ -1,15 +1,12 @@
-import unittest
-
import numpy as np
+import pytest
-from wordllama.algorithms.kmeans import (
- # kmeans_plusplus_initialization,
- kmeans_clustering,
-)
+from wordllama.algorithms.kmeans import kmeans_clustering
-class TestKMeansClustering(unittest.TestCase):
- def setUp(self):
+class TestKMeansClustering:
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.random_state = np.random.RandomState(42)
self.embeddings = np.array(
[
@@ -23,21 +20,6 @@ def setUp(self):
dtype=np.float32,
)
- # def test_kmeans_plusplus_initialization(self):
- # k = 2
- # centroids = kmeans_plusplus_initialization(
- # self.embeddings, k, self.random_state
- # )
-
- # self.assertEqual(centroids.shape[0], k)
- # self.assertEqual(centroids.shape[1], self.embeddings.shape[1])
-
- # # Check that centroids are among the original points
- # for centroid in centroids:
- # self.assertTrue(
- # any(np.allclose(centroid, point) for point in self.embeddings)
- # )
-
def test_kmeans_clustering_convergence(self):
k = 2
labels, inertia = kmeans_clustering(self.embeddings, k, random_state=self.random_state)
@@ -77,7 +59,3 @@ def test_kmeans_clustering_different_initializations(self):
labels2, inertia2 = kmeans_clustering(self.embeddings, k, random_state=42, n_init=10)
assert inertia1 > inertia2
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_minima_functions.py b/tests/test_minima_functions.py
index de355f1..887ada2 100644
--- a/tests/test_minima_functions.py
+++ b/tests/test_minima_functions.py
@@ -1,5 +1,3 @@
-import unittest
-
import numpy as np
import pytest
@@ -9,8 +7,9 @@
)
-class TestSavitzkyGolay(unittest.TestCase):
- def setUp(self):
+class TestSavitzkyGolay:
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.x1 = np.linspace(0, 2 * np.pi, 100, dtype=np.float32)
self.x = np.arange(100)
self.y = np.sin(self.x1)
@@ -38,8 +37,9 @@ def test_find_local_minima_invalid_polynomial_order(self):
find_local_minima(self.y, window_size=11, poly_order=11)
-class TestWindowedCrossSimilarity(unittest.TestCase):
- def setUp(self):
+class TestWindowedCrossSimilarity:
+ @pytest.fixture(autouse=True)
+ def setup(self):
# Example embedding matrix (5 vectors of 3 dimensions each)
self.embeddings = np.array(
[
@@ -75,7 +75,3 @@ def test_windowed_cross_similarity_small_window(self):
# Test windowed cross similarity with a small window (size 3)
result = windowed_cross_similarity(self.embeddings, window_size=3)
assert result.shape[0] == self.embeddings.shape[0]
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_semantic_splitter.py b/tests/test_semantic_splitter.py
index 5826029..a9a2b51 100644
--- a/tests/test_semantic_splitter.py
+++ b/tests/test_semantic_splitter.py
@@ -1,12 +1,12 @@
-import unittest
-
import numpy as np
+import pytest
from wordllama.algorithms.semantic_splitter import SemanticSplitter
-class TestSemanticSplitter(unittest.TestCase):
- def setUp(self):
+class TestSemanticSplitter:
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.splitter = SemanticSplitter()
def test_flatten(self):
@@ -62,7 +62,3 @@ def test_reconstruct_return_minima(self):
assert isinstance(roots, np.ndarray)
assert isinstance(y, np.ndarray)
assert isinstance(sim_avg, np.ndarray)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_splitting_functions.py b/tests/test_splitting_functions.py
index 8014793..198b9c0 100644
--- a/tests/test_splitting_functions.py
+++ b/tests/test_splitting_functions.py
@@ -1,5 +1,4 @@
import string
-import unittest
import pytest
@@ -11,7 +10,7 @@
)
-class TestSplitter(unittest.TestCase):
+class TestSplitter:
def test_constrained_batches(self):
# Basic batching
data = ["a", "bb", "ccc", "dddd", "eeeee"]
@@ -123,7 +122,3 @@ def test_reverse_merge(self):
result = reverse_merge(data, n=5, separator=" ")
expected = ["a bb ccc"]
assert result == expected
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_vector_similarity.py b/tests/test_vector_similarity.py
index 017baaa..c7baa9d 100644
--- a/tests/test_vector_similarity.py
+++ b/tests/test_vector_similarity.py
@@ -1,11 +1,9 @@
-import unittest
-
import numpy as np
from wordllama.algorithms import binarize_and_packbits, vector_similarity
-class TestVectorSimilarity(unittest.TestCase):
+class TestVectorSimilarity:
def test_binarization_and_packing(self):
vec = np.zeros((1, 64))
vec[0][7] = 1
@@ -25,7 +23,3 @@ def test_hamming_similarity_direct(self):
vec2 = np.expand_dims(np.random.randint(2, size=64, dtype=np.uint64), axis=0)
result = vector_similarity(vec1, vec2, binary=True)
assert isinstance(result.item(), float)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_wordllama.py b/tests/test_wordllama.py
index c84e7ce..c9da66e 100644
--- a/tests/test_wordllama.py
+++ b/tests/test_wordllama.py
@@ -1,4 +1,3 @@
-import unittest
from pathlib import Path
from unittest.mock import MagicMock, call, mock_open, patch
@@ -10,8 +9,9 @@
from wordllama.wordllama import WordLlama, WordLlamaInference
-class TestWordLlama(unittest.TestCase):
- def setUp(self):
+class TestWordLlama:
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.config = "l2_supercat"
self.dim = 256
self.binary = False
@@ -176,86 +176,83 @@ def test_load_with_custom_cache_dir(self, mock_resolve_file):
}
for key, cache_dir_input in cache_dirs.items():
- with self.subTest(cache_dir=key):
- # Reset mocks
- mock_resolve_file.reset_mock()
+ # Reset mocks
+ mock_resolve_file.reset_mock()
- # Setup mock for resolve_file
- weights_path = (
- expected_resolved_dirs[key] / "weights" / "l2_supercat_256.safetensors"
+ # Setup mock for resolve_file
+ weights_path = expected_resolved_dirs[key] / "weights" / "l2_supercat_256.safetensors"
+ tokenizer_path = (
+ expected_resolved_dirs[key] / "tokenizers" / "l2_supercat_tokenizer_config.json"
+ )
+ mock_resolve_file.side_effect = [weights_path, tokenizer_path]
+
+ # Mock tokenizer and weights loading
+ with (
+ patch(
+ "wordllama.wordllama.WordLlama.load_tokenizer",
+ return_value=MagicMock(spec=Tokenizer),
+ ) as mock_load_tokenizer,
+ patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open,
+ ):
+ # Mock the tensor returned by safe_open
+ mock_tensor = MagicMock()
+ mock_tensor.__getitem__.return_value = np.random.rand(256, 4096)
+ mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = (
+ mock_tensor
)
- tokenizer_path = (
- expected_resolved_dirs[key] / "tokenizers" / "l2_supercat_tokenizer_config.json"
+
+ # Call load with custom cache_dir
+ model = WordLlama.load(
+ config=self.model_uri,
+ cache_dir=cache_dir_input,
+ binary=self.binary,
+ dim=self.dim,
+ trunc_dim=self.trunc_dim,
)
- mock_resolve_file.side_effect = [weights_path, tokenizer_path]
-
- # Mock tokenizer and weights loading
- with (
- patch(
- "wordllama.wordllama.WordLlama.load_tokenizer",
- return_value=MagicMock(spec=Tokenizer),
- ) as mock_load_tokenizer,
- patch("wordllama.wordllama.safe_open", autospec=True) as mock_safe_open,
- ):
- # Mock the tensor returned by safe_open
- mock_tensor = MagicMock()
- mock_tensor.__getitem__.return_value = np.random.rand(256, 4096)
- mock_safe_open.return_value.__enter__.return_value.get_tensor.return_value = (
- mock_tensor
- )
-
- # Call load with custom cache_dir
- model = WordLlama.load(
- config=self.model_uri,
- cache_dir=cache_dir_input,
+
+ # Assert resolve_file was called twice with the correct cache_dir
+ expected_calls = [
+ call(
+ # WordLlama,
+ config_name="custom",
+ model_uri=self.model_uri,
+ dim=self.dim,
binary=self.binary,
+ file_type="weights",
+ cache_dir=expected_resolved_dirs[key],
+ disable_download=True,
+ remote_filename=None,
+ ),
+ call(
+ # WordLlama,
+ config_name="custom",
+ model_uri=self.model_uri,
dim=self.dim,
- trunc_dim=self.trunc_dim,
- )
-
- # Assert resolve_file was called twice with the correct cache_dir
- expected_calls = [
- call(
- # WordLlama,
- config_name="custom",
- model_uri=self.model_uri,
- dim=self.dim,
- binary=self.binary,
- file_type="weights",
- cache_dir=expected_resolved_dirs[key],
- disable_download=True,
- remote_filename=None,
- ),
- call(
- # WordLlama,
- config_name="custom",
- model_uri=self.model_uri,
- dim=self.dim,
- binary=False,
- file_type="tokenizer",
- cache_dir=expected_resolved_dirs[key],
- disable_download=True,
- remote_filename=None,
- ),
- ]
- mock_resolve_file.assert_has_calls(expected_calls, any_order=False)
- assert mock_resolve_file.call_count == 2
-
- # Assert load_tokenizer was called with correct path
- mock_load_tokenizer.assert_called_once_with(
- tokenizer_path,
- hf_model_id=self.model_uri.tokenizer_fallback,
- )
-
- # Assert safe_open was called with the weights path
- mock_safe_open.assert_called_once_with(
- weights_path,
- framework="np",
- device="cpu",
- )
-
- # Assert the returned model is an instance of WordLlamaInference
- assert isinstance(model, WordLlamaInference)
+ binary=False,
+ file_type="tokenizer",
+ cache_dir=expected_resolved_dirs[key],
+ disable_download=True,
+ remote_filename=None,
+ ),
+ ]
+ mock_resolve_file.assert_has_calls(expected_calls, any_order=False)
+ assert mock_resolve_file.call_count == 2
+
+ # Assert load_tokenizer was called with correct path
+ mock_load_tokenizer.assert_called_once_with(
+ tokenizer_path,
+ hf_model_id=self.model_uri.tokenizer_fallback,
+ )
+
+ # Assert safe_open was called with the weights path
+ mock_safe_open.assert_called_once_with(
+ weights_path,
+ framework="np",
+ device="cpu",
+ )
+
+ # Assert the returned model is an instance of WordLlamaInference
+ assert isinstance(model, WordLlamaInference)
@patch.object(WordLlama, "resolve_file", autospec=True)
def test_load_with_disable_download(self, mock_resolve_file):
@@ -446,7 +443,3 @@ def test_load_tokenizer_fallback(
# Assert the returned model is an instance of WordLlamaInference
assert isinstance(model, WordLlamaInference)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/wordllama.png b/wordllama.png
index 6fa66e3..fad4b9b 100644
Binary files a/wordllama.png and b/wordllama.png differ