diff --git a/setup.py b/setup.py index 389e553..db19da0 100644 --- a/setup.py +++ b/setup.py @@ -12,16 +12,15 @@ description='Werkzeug encrypted cookie', packages=['werkzeug_encryptedcookie'], platforms='any', - install_requires=[ - 'pycryptodome>=3.11.0', - 'secure-cookie', - 'brotli>=1.0.1', - 'Werkzeug>=2.0.0,<2.1.0', - ], + install_requires=['pycryptodome>=3.11.0', 'brotli>=1.0.1'], classifiers=[ 'Programming Language :: Python', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', 'Topic :: Software Development :: Libraries :: Python Modules' ], diff --git a/tests.py b/tests.py index 997128a..2ec5b77 100644 --- a/tests.py +++ b/tests.py @@ -1,3 +1,4 @@ +from datetime import timedelta from time import time from werkzeug_encryptedcookie import EncryptedCookie, SecureEncryptedCookie @@ -25,82 +26,89 @@ def test_dumps_loads(self): assert r == case def test_encrypt_decrypt(self): - key = b'my little key' + cookie = self.Cookie(b'my little key') for case in [b'{"a": "b"}', b'{"a": "pr\xc3\xb3ba"}']: - r1 = self.Cookie.encrypt(case, key) - r2 = self.Cookie.encrypt(case, key) + r1 = cookie.encrypt(case) + r2 = cookie.encrypt(case) assert isinstance(r1, bytes) assert isinstance(r2, bytes) assert r1 != r2 - r1_broken = self.Cookie.decrypt(r1, b'another key') + r1_broken = self.Cookie(b'another key').decrypt(r1) assert r1_broken != case - r1 = self.Cookie.decrypt(r1, key) - r2 = self.Cookie.decrypt(r2, key) + r1 = cookie.decrypt(r1) + r2 = cookie.decrypt(r2) assert r1 == case assert r2 == case def test_serialize_unserialize(self): - key = b'my little key' + cookie = self.Cookie(b'my little key') for case in [{'a': 'b'}, {'a': 'próba'}, {'próba': '123'}]: - r = self.Cookie(case, key).serialize() + r = cookie.serialize(case) assert isinstance(r, bytes) # Check it is ascii r.decode('ascii') - r = self.Cookie.unserialize(r, key) + r = cookie.unserialize(r) assert r == case def test_unserialize_binary(self): """ Test unserialize compatibility with existing binary data. """ - key = b'my little key' + cookie = self.Cookie(b'my little key') for case in [ b'GXCS2JfvmfQJwuxYUITWTmnanyjkIP0IHKbZF2u7oz2qnuIRGuzJbF5JhZrp', b'bvK0dvBIBuPqIrG+o4Zmmu6ln7bLoR+xTz906R8GQAAAaM2rlncYNzsKIsmU', ]: - r = self.Cookie.unserialize(case, key) + r = cookie.unserialize(case) assert {'a': 'próba'} == dict(r) def test_expires(self): - key = b'my little key' - c = self.Cookie({'a': 'próba'}, key) + cookie = self.Cookie(b'my little key') + data = {'a': 'próba'} + c = cookie - r = self.Cookie.unserialize(c.serialize(time() - 1), key) + r = cookie.unserialize(c.serialize(data, time() - 1)) assert not r # Make sure previous expire not stored in cookie object. # (such bug present in original SecureCookie) - r = self.Cookie.unserialize(c.serialize(), key) - assert r + r = cookie.unserialize(c.serialize(data)) + assert r == data + + r = cookie.unserialize(c.serialize(data, time() + 1)) + assert r == data - r = self.Cookie.unserialize(c.serialize(time() + 1), key) - assert r + r = cookie.unserialize(c.serialize(data, timedelta(-1))) + assert not r + + r = cookie.unserialize(c.serialize(data, timedelta(1))) + assert r == data def test_fail_with_another_key(self): - c = self.Cookie({'a': 'próba'}, 'one key') - r = self.Cookie.unserialize(c.serialize(), b'another key') + r = self.Cookie(b'one key').serialize({'a': 'próba'}) + r = self.Cookie(b'another key').unserialize(r) assert not r def test_fail_when_not_json(self): - key = b'my little key' - r = self.RawCookie.encrypt(b'{"a", "pr\xc3\xb3ba"}', key) - r = self.RawCookie.unserialize(r, key) + cookie = self.RawCookie(b'my little key') + r = cookie.encrypt(b'{"a", "pr\xc3\xb3ba"}') + r = cookie.unserialize(r) assert not r def test_fail_when_corrupted(self): - key = b'my little key' - r = self.RawCookie({"a": "próba"}, key).serialize() - r = self.RawCookie.unserialize(r[:20] + r[21:], key) + cookie = self.RawCookie(b'my little key') + r = cookie.serialize({'a': 'próba'}) + r = cookie.unserialize(r[:20] + r[21:]) assert not r def test_compression_and_decompression(self): key = b'my little key' case = {'a': 'próba'} - no_compress = self.NoCompressCookie(case, key) - compress = self.CompressCookie(case, key) + no_compress = self.NoCompressCookie(key) + compress = self.CompressCookie(key) cases = ( # No-compressed instance unserialized by no-compressed instance (no_compress, no_compress), @@ -112,7 +120,7 @@ def test_compression_and_decompression(self): (compress, compress), ) for cookie1, cookie2 in cases: - result = cookie2.unserialize(cookie1.serialize(), key) + result = cookie2.unserialize(cookie1.serialize(case)) assert result == case @@ -131,27 +139,33 @@ class CompressCookie(Cookie): # pyright: ignore[reportIncompatibleVariableOverr compress_cookie = True def test_unsigned(self): - key, case = b'my little key', b'{"a": "pr\xc3\xb3ba"}' - r = self.Cookie.encrypt(case, key) - signed = EncryptedCookie.decrypt(r, key) - assert case in signed - - r = EncryptedCookie.encrypt(signed, key) - r = self.Cookie.decrypt(r, key) + secure_cookie = self.Cookie(b'my little key') + unsecure_cookie = EncryptedCookie(b'my little key') + case = b'{"a": "pr\xc3\xb3ba"}' + + # Check that encrypted data is the same as in original cookie + r = secure_cookie.encrypt(case) + signed = unsecure_cookie.decrypt(r) + assert case == signed[:-4] + + # Should be the same as secure_cookie.encrypt(case) + r = unsecure_cookie.encrypt(signed) + r = secure_cookie.decrypt(r) assert r == case - r = EncryptedCookie.encrypt(signed[:-1] + b'!', key) - r = self.Cookie.decrypt(r, key) + # Try to fake signature + r = unsecure_cookie.encrypt(case + b'xxxx') + r = secure_cookie.decrypt(r) assert r == b'' def test_unserialize_binary(self): """ Test unserialize compatibility with existing binary data. """ - key = b'my little key' + cookie = self.Cookie(b'my little key') for case in [ b'vGSOoyvh3KREQNzFhAbhl/oSugKPMJ8QDvp4VWRtSpgUA3670wlkbv1kzA15HQ9oBw==', b'78EM1wnaIkz6FP0EDxHPk6xeGFam2w6cSr6FWosRf6X3H7ILJvhA+gkuq+6AT9iD6g==' ]: - r = self.Cookie.unserialize(case, key) + r = cookie.unserialize(case) assert {'a': 'próba'} == dict(r) diff --git a/werkzeug_encryptedcookie/__init__.py b/werkzeug_encryptedcookie/__init__.py index 9494ead..acda84b 100644 --- a/werkzeug_encryptedcookie/__init__.py +++ b/werkzeug_encryptedcookie/__init__.py @@ -5,44 +5,52 @@ import secrets import struct import zlib +from datetime import timedelta from hashlib import sha1 from time import time import brotli from Crypto.Cipher import ARC4 -from secure_cookie.cookie import SecureCookie, _date_to_unix -class EncryptedCookie(SecureCookie): +def _date_to_unix(arg: float | int | timedelta): + """ + Converts int or timedelta object into the seconds from epoch in UTC. + """ + if isinstance(arg, timedelta): + arg = time() + arg.total_seconds() + return int(arg) + + +class EncryptedCookie: + quote_base64 = True compress_cookie = True compress_cookie_header = b'~!~brtl~!~' - # to avoid deprecation warnings - serialization_method = json - @classmethod - def _get_cipher(cls, key: bytes) -> ARC4.ARC4Cipher: - return ARC4.new(sha1(key).digest()) + def __init__(self, secret_key: bytes): + self.secret_key = secret_key + + def _get_cipher(self, nonce: bytes) -> ARC4.ARC4Cipher: + return ARC4.new(sha1(self.secret_key + nonce).digest()) @classmethod def dumps(cls, data: dict) -> bytes: return json.dumps(data, ensure_ascii=False).encode() - @classmethod - def encrypt(cls, data: bytes, secret_key: bytes) -> bytes: + def encrypt(self, data: bytes) -> bytes: nonce = secrets.token_bytes(16) - cipher = cls._get_cipher(secret_key + nonce) + cipher = self._get_cipher(nonce) return nonce + cipher.encrypt(data) @classmethod def compress(cls, data: bytes) -> bytes: return cls.compress_cookie_header + brotli.compress(data, quality=8) - def serialize(self, expires=None) -> bytes: - if self.secret_key is None: - raise RuntimeError('no secret key defined') - - data = dict(self) - if expires: + def serialize( + self, data: dict, expires: float | int | timedelta | None = None + ) -> bytes: + data = data.copy() + if expires is not None: data['_expires'] = _date_to_unix(expires) payload = self.dumps(data) @@ -50,7 +58,7 @@ def serialize(self, expires=None) -> bytes: if self.compress_cookie: payload = self.compress(payload) - string = self.encrypt(payload, self.secret_key) + string = self.encrypt(payload) if self.quote_base64: string = base64.b64encode(string) @@ -61,11 +69,10 @@ def serialize(self, expires=None) -> bytes: def loads(cls, data: bytes) -> dict: return json.loads(data.decode('utf-8')) - @classmethod - def decrypt(cls, string: bytes, secret_key: bytes) -> bytes: + def decrypt(self, string: bytes) -> bytes: nonce, payload = string[:16], string[16:] - cipher = cls._get_cipher(secret_key + nonce) + cipher = self._get_cipher(nonce) return cipher.decrypt(payload) @classmethod @@ -79,43 +86,40 @@ def decompress(cls, data: bytes) -> bytes: return data - @classmethod - def unserialize(cls, string: bytes, secret_key: bytes) -> EncryptedCookie: - if cls.quote_base64: + def unserialize(self, string: bytes) -> dict: + if self.quote_base64: try: string = base64.b64decode(string) except Exception: pass - payload = cls.decrypt(string, secret_key) - payload = cls.decompress(payload) + payload = self.decrypt(string) + payload = self.decompress(payload) try: - data = cls.loads(payload) + data = self.loads(payload) except ValueError: - data = None + data = {} if data and '_expires' in data: if time() > data['_expires']: - data = None + data = {} else: del data['_expires'] - return cls(data, secret_key, False) + return data class SecureEncryptedCookie(EncryptedCookie): - @classmethod - def encrypt(cls, data: bytes, secret_key: bytes) -> bytes: - crc = zlib.crc32(data, zlib.crc32(secret_key)) + def encrypt(self, data: bytes) -> bytes: + crc = zlib.crc32(data, zlib.crc32(self.secret_key)) data += struct.pack('>I', crc & 0xffffffff) - return super().encrypt(data, secret_key) + return super().encrypt(data) - @classmethod - def decrypt(cls, string: bytes, secret_key: bytes) -> bytes: - data = super().decrypt(string, secret_key) + def decrypt(self, string: bytes) -> bytes: + data = super().decrypt(string) data, crc1 = data[:-4], data[-4:] - crc2 = zlib.crc32(data, zlib.crc32(secret_key)) + crc2 = zlib.crc32(data, zlib.crc32(self.secret_key)) if crc1 != struct.pack('>I', crc2 & 0xffffffff): return b'' return data