Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
description='Werkzeug encrypted cookie',
packages=['werkzeug_encryptedcookie'],
platforms='any',
install_requires=[
'pycryptodome>=3.11.0',
'secure-cookie',
'brotli>=1.0.1',
'Werkzeug>=2.0.0,<2.1.0',
],
install_requires=['pycryptodome>=3.11.0', 'brotli>=1.0.1'],
classifiers=[
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
'Topic :: Software Development :: Libraries :: Python Modules'
],
Expand Down
94 changes: 54 additions & 40 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from time import time

from werkzeug_encryptedcookie import EncryptedCookie, SecureEncryptedCookie
Expand Down Expand Up @@ -25,82 +26,89 @@ def test_dumps_loads(self):
assert r == case

def test_encrypt_decrypt(self):
key = b'my little key'
cookie = self.Cookie(b'my little key')
for case in [b'{"a": "b"}', b'{"a": "pr\xc3\xb3ba"}']:
r1 = self.Cookie.encrypt(case, key)
r2 = self.Cookie.encrypt(case, key)
r1 = cookie.encrypt(case)
r2 = cookie.encrypt(case)
assert isinstance(r1, bytes)
assert isinstance(r2, bytes)
assert r1 != r2

r1_broken = self.Cookie.decrypt(r1, b'another key')
r1_broken = self.Cookie(b'another key').decrypt(r1)
assert r1_broken != case

r1 = self.Cookie.decrypt(r1, key)
r2 = self.Cookie.decrypt(r2, key)
r1 = cookie.decrypt(r1)
r2 = cookie.decrypt(r2)
assert r1 == case
assert r2 == case

def test_serialize_unserialize(self):
key = b'my little key'
cookie = self.Cookie(b'my little key')
for case in [{'a': 'b'}, {'a': 'próba'}, {'próba': '123'}]:
r = self.Cookie(case, key).serialize()
r = cookie.serialize(case)
assert isinstance(r, bytes)
# Check it is ascii
r.decode('ascii')

r = self.Cookie.unserialize(r, key)
r = cookie.unserialize(r)
assert r == case

def test_unserialize_binary(self):
"""
Test unserialize compatibility with existing binary data.
"""
key = b'my little key'
cookie = self.Cookie(b'my little key')
for case in [
b'GXCS2JfvmfQJwuxYUITWTmnanyjkIP0IHKbZF2u7oz2qnuIRGuzJbF5JhZrp',
b'bvK0dvBIBuPqIrG+o4Zmmu6ln7bLoR+xTz906R8GQAAAaM2rlncYNzsKIsmU',
]:
r = self.Cookie.unserialize(case, key)
r = cookie.unserialize(case)
assert {'a': 'próba'} == dict(r)

def test_expires(self):
key = b'my little key'
c = self.Cookie({'a': 'próba'}, key)
cookie = self.Cookie(b'my little key')
data = {'a': 'próba'}
c = cookie

r = self.Cookie.unserialize(c.serialize(time() - 1), key)
r = cookie.unserialize(c.serialize(data, time() - 1))
assert not r

# Make sure previous expire not stored in cookie object.
# (such bug present in original SecureCookie)
r = self.Cookie.unserialize(c.serialize(), key)
assert r
r = cookie.unserialize(c.serialize(data))
assert r == data

r = cookie.unserialize(c.serialize(data, time() + 1))
assert r == data

r = self.Cookie.unserialize(c.serialize(time() + 1), key)
assert r
r = cookie.unserialize(c.serialize(data, timedelta(-1)))
assert not r

r = cookie.unserialize(c.serialize(data, timedelta(1)))
assert r == data

def test_fail_with_another_key(self):
c = self.Cookie({'a': 'próba'}, 'one key')
r = self.Cookie.unserialize(c.serialize(), b'another key')
r = self.Cookie(b'one key').serialize({'a': 'próba'})
r = self.Cookie(b'another key').unserialize(r)
assert not r

def test_fail_when_not_json(self):
key = b'my little key'
r = self.RawCookie.encrypt(b'{"a", "pr\xc3\xb3ba"}', key)
r = self.RawCookie.unserialize(r, key)
cookie = self.RawCookie(b'my little key')
r = cookie.encrypt(b'{"a", "pr\xc3\xb3ba"}')
r = cookie.unserialize(r)
assert not r

def test_fail_when_corrupted(self):
key = b'my little key'
r = self.RawCookie({"a": "próba"}, key).serialize()
r = self.RawCookie.unserialize(r[:20] + r[21:], key)
cookie = self.RawCookie(b'my little key')
r = cookie.serialize({'a': 'próba'})
r = cookie.unserialize(r[:20] + r[21:])
assert not r

def test_compression_and_decompression(self):
key = b'my little key'
case = {'a': 'próba'}
no_compress = self.NoCompressCookie(case, key)
compress = self.CompressCookie(case, key)
no_compress = self.NoCompressCookie(key)
compress = self.CompressCookie(key)
cases = (
# No-compressed instance unserialized by no-compressed instance
(no_compress, no_compress),
Expand All @@ -112,7 +120,7 @@ def test_compression_and_decompression(self):
(compress, compress),
)
for cookie1, cookie2 in cases:
result = cookie2.unserialize(cookie1.serialize(), key)
result = cookie2.unserialize(cookie1.serialize(case))
assert result == case


Expand All @@ -131,27 +139,33 @@ class CompressCookie(Cookie): # pyright: ignore[reportIncompatibleVariableOverr
compress_cookie = True

def test_unsigned(self):
key, case = b'my little key', b'{"a": "pr\xc3\xb3ba"}'
r = self.Cookie.encrypt(case, key)
signed = EncryptedCookie.decrypt(r, key)
assert case in signed

r = EncryptedCookie.encrypt(signed, key)
r = self.Cookie.decrypt(r, key)
secure_cookie = self.Cookie(b'my little key')
unsecure_cookie = EncryptedCookie(b'my little key')
case = b'{"a": "pr\xc3\xb3ba"}'

# Check that encrypted data is the same as in original cookie
r = secure_cookie.encrypt(case)
signed = unsecure_cookie.decrypt(r)
assert case == signed[:-4]

# Should be the same as secure_cookie.encrypt(case)
r = unsecure_cookie.encrypt(signed)
r = secure_cookie.decrypt(r)
assert r == case

r = EncryptedCookie.encrypt(signed[:-1] + b'!', key)
r = self.Cookie.decrypt(r, key)
# Try to fake signature
r = unsecure_cookie.encrypt(case + b'xxxx')
r = secure_cookie.decrypt(r)
assert r == b''

def test_unserialize_binary(self):
"""
Test unserialize compatibility with existing binary data.
"""
key = b'my little key'
cookie = self.Cookie(b'my little key')
for case in [
b'vGSOoyvh3KREQNzFhAbhl/oSugKPMJ8QDvp4VWRtSpgUA3670wlkbv1kzA15HQ9oBw==',
b'78EM1wnaIkz6FP0EDxHPk6xeGFam2w6cSr6FWosRf6X3H7ILJvhA+gkuq+6AT9iD6g=='
]:
r = self.Cookie.unserialize(case, key)
r = cookie.unserialize(case)
assert {'a': 'próba'} == dict(r)
78 changes: 41 additions & 37 deletions werkzeug_encryptedcookie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,60 @@
import secrets
import struct
import zlib
from datetime import timedelta
from hashlib import sha1
from time import time

import brotli
from Crypto.Cipher import ARC4
from secure_cookie.cookie import SecureCookie, _date_to_unix


class EncryptedCookie(SecureCookie):
def _date_to_unix(arg: float | int | timedelta):
"""
Converts int or timedelta object into the seconds from epoch in UTC.
"""
if isinstance(arg, timedelta):
arg = time() + arg.total_seconds()
return int(arg)


class EncryptedCookie:
quote_base64 = True
compress_cookie = True
compress_cookie_header = b'~!~brtl~!~'
# to avoid deprecation warnings
serialization_method = json

@classmethod
def _get_cipher(cls, key: bytes) -> ARC4.ARC4Cipher:
return ARC4.new(sha1(key).digest())
def __init__(self, secret_key: bytes):
self.secret_key = secret_key

def _get_cipher(self, nonce: bytes) -> ARC4.ARC4Cipher:
return ARC4.new(sha1(self.secret_key + nonce).digest())

@classmethod
def dumps(cls, data: dict) -> bytes:
return json.dumps(data, ensure_ascii=False).encode()

@classmethod
def encrypt(cls, data: bytes, secret_key: bytes) -> bytes:
def encrypt(self, data: bytes) -> bytes:
nonce = secrets.token_bytes(16)
cipher = cls._get_cipher(secret_key + nonce)
cipher = self._get_cipher(nonce)
return nonce + cipher.encrypt(data)

@classmethod
def compress(cls, data: bytes) -> bytes:
return cls.compress_cookie_header + brotli.compress(data, quality=8)

def serialize(self, expires=None) -> bytes:
if self.secret_key is None:
raise RuntimeError('no secret key defined')

data = dict(self)
if expires:
def serialize(
self, data: dict, expires: float | int | timedelta | None = None
) -> bytes:
data = data.copy()
if expires is not None:
data['_expires'] = _date_to_unix(expires)

payload = self.dumps(data)

if self.compress_cookie:
payload = self.compress(payload)

string = self.encrypt(payload, self.secret_key)
string = self.encrypt(payload)

if self.quote_base64:
string = base64.b64encode(string)
Expand All @@ -61,11 +69,10 @@ def serialize(self, expires=None) -> bytes:
def loads(cls, data: bytes) -> dict:
return json.loads(data.decode('utf-8'))

@classmethod
def decrypt(cls, string: bytes, secret_key: bytes) -> bytes:
def decrypt(self, string: bytes) -> bytes:
nonce, payload = string[:16], string[16:]

cipher = cls._get_cipher(secret_key + nonce)
cipher = self._get_cipher(nonce)
return cipher.decrypt(payload)

@classmethod
Expand All @@ -79,43 +86,40 @@ def decompress(cls, data: bytes) -> bytes:

return data

@classmethod
def unserialize(cls, string: bytes, secret_key: bytes) -> EncryptedCookie:
if cls.quote_base64:
def unserialize(self, string: bytes) -> dict:
if self.quote_base64:
try:
string = base64.b64decode(string)
except Exception:
pass

payload = cls.decrypt(string, secret_key)
payload = cls.decompress(payload)
payload = self.decrypt(string)
payload = self.decompress(payload)

try:
data = cls.loads(payload)
data = self.loads(payload)
except ValueError:
data = None
data = {}

if data and '_expires' in data:
if time() > data['_expires']:
data = None
data = {}
else:
del data['_expires']

return cls(data, secret_key, False)
return data


class SecureEncryptedCookie(EncryptedCookie):
@classmethod
def encrypt(cls, data: bytes, secret_key: bytes) -> bytes:
crc = zlib.crc32(data, zlib.crc32(secret_key))
def encrypt(self, data: bytes) -> bytes:
crc = zlib.crc32(data, zlib.crc32(self.secret_key))
data += struct.pack('>I', crc & 0xffffffff)
return super().encrypt(data, secret_key)
return super().encrypt(data)

@classmethod
def decrypt(cls, string: bytes, secret_key: bytes) -> bytes:
data = super().decrypt(string, secret_key)
def decrypt(self, string: bytes) -> bytes:
data = super().decrypt(string)
data, crc1 = data[:-4], data[-4:]
crc2 = zlib.crc32(data, zlib.crc32(secret_key))
crc2 = zlib.crc32(data, zlib.crc32(self.secret_key))
if crc1 != struct.pack('>I', crc2 & 0xffffffff):
return b''
return data