From 286c2f9543628993dc204fa220192a713557c9e7 Mon Sep 17 00:00:00 2001 From: keremeyuboglu <32223948+keremeyuboglu@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:11:08 +0300 Subject: [PATCH 1/8] Added composite keys for hash and range keys --- docs/indexes.rst | 34 +++++++ examples/indexes.py | 20 +++- pynamodb/connection/base.py | 93 +++++++++++++++--- pynamodb/connection/table.py | 4 +- pynamodb/indexes.py | 124 ++++++++++++++++++++---- pynamodb/models.py | 11 ++- setup.py | 59 +++++++----- tests/test_base_connection.py | 166 +++++++++++++++++++++++++++++++ tests/test_model.py | 177 +++++++++++++++++++++++++++++++++- 9 files changed, 610 insertions(+), 78 deletions(-) diff --git a/docs/indexes.rst b/docs/indexes.rst index c475a0461..e5cf47043 100644 --- a/docs/indexes.rst +++ b/docs/indexes.rst @@ -124,6 +124,40 @@ range key of the index. Here is an example that queries the index for values of print("Item queried from index: {0}".format(item.view)) +Composite Global Secondary Keys +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +DynamoDB global secondary indexes support multi-attribute partition and sort keys. +To define a composite key in PynamoDB, declare multiple ``hash_key=True`` and/or +``range_key=True`` attributes on a ``GlobalSecondaryIndex`` in the order you want +them used. + +.. code-block:: python + + class TournamentRegionIndex(GlobalSecondaryIndex): + class Meta: + projection = AllProjection() + read_capacity_units = 2 + write_capacity_units = 1 + + tournament_id = UnicodeAttribute(hash_key=True) + region = UnicodeAttribute(hash_key=True) + round = UnicodeAttribute(range_key=True) + bracket = UnicodeAttribute(range_key=True) + +When querying a composite GSI, pass all partition-key values as a tuple (or list) +in the same declaration order: + +.. code-block:: python + + for item in MatchModel.tournament_region_index.query(('WINTER2024', 'NA-EAST')): + print(item) + +For sort-key conditions, DynamoDB enforces left-to-right semantics across declared +sort-key attributes. See the DynamoDB documentation for details: +https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/GSI.DesignPattern.MultiAttributeKeys.html + + Pagination and last evaluated key ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/examples/indexes.py b/examples/indexes.py index d1f15ad7b..58525e770 100644 --- a/examples/indexes.py +++ b/examples/indexes.py @@ -1,6 +1,7 @@ """ Examples using DynamoDB indexes """ + import datetime from pynamodb.models import Model from pynamodb.indexes import GlobalSecondaryIndex, AllProjection, LocalSecondaryIndex @@ -11,6 +12,7 @@ class ViewIndex(GlobalSecondaryIndex): """ This class represents a global secondary index """ + class Meta: # You can override the index name by setting it below index_name = "viewIdx" @@ -18,6 +20,7 @@ class Meta: write_capacity_units = 1 # All attributes are projected projection = AllProjection() + # This attribute is the hash key for the index # Note that this attribute must also exist # in the model @@ -28,20 +31,23 @@ class TestModel(Model): """ A test model that uses a global secondary index """ + class Meta: table_name = "TestModel" # Set host for using DynamoDB Local host = "http://localhost:8000" + forum = UnicodeAttribute(hash_key=True) thread = UnicodeAttribute(range_key=True) view_index = ViewIndex() view = NumberAttribute(default=0) + if not TestModel.exists(): TestModel.create_table(read_capacity_units=1, write_capacity_units=1, wait=True) # Create an item -test_item = TestModel('forum-example', 'thread-example') +test_item = TestModel("forum-example", "thread-example") test_item.view = 1 test_item.save() @@ -57,6 +63,7 @@ class Meta: table_name = "GamePlayerOpponentIndex" host = "http://localhost:8000" projection = AllProjection() + player_id = UnicodeAttribute(hash_key=True) winner_id = UnicodeAttribute(range_key=True) @@ -68,6 +75,7 @@ class Meta: table_name = "GameOpponentTimeIndex" host = "http://localhost:8000" projection = AllProjection() + winner_id = UnicodeAttribute(hash_key=True) created_time = UnicodeAttribute(range_key=True) @@ -78,6 +86,7 @@ class Meta: write_capacity_units = 1 table_name = "GameModel" host = "http://localhost:8000" + player_id = UnicodeAttribute(hash_key=True) created_time = UTCDateTimeAttribute(range_key=True) winner_id = UnicodeAttribute() @@ -86,17 +95,18 @@ class Meta: player_opponent_index = GamePlayerOpponentIndex() opponent_time_index = GameOpponentTimeIndex() + if not GameModel.exists(): GameModel.create_table(wait=True) # Create an item -item = GameModel('1234', datetime.datetime.utcnow()) -item.winner_id = '5678' +item = GameModel("1234", datetime.datetime.now(datetime.UTC)) +item.winner_id = "5678" item.save() # Indexes can be queried easily using the index's hash key -for item in GameModel.player_opponent_index.query('1234'): +for item in GameModel.player_opponent_index.query("1234"): print("Item queried from index: {0}".format(item)) # Count on an index -print(GameModel.player_opponent_index.count('1234')) +print(GameModel.player_opponent_index.count("1234")) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 301ee1715..302a95ea1 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -118,12 +118,12 @@ def get_key_names(self, index_name=None): if self.range_keyname: key_names.append(self.range_keyname) if index_name is not None: - index_hash_keyname = self.get_index_hash_keyname(index_name) - if index_hash_keyname not in key_names: - key_names.append(index_hash_keyname) - index_range_keyname = self.get_index_range_keyname(index_name) - if index_range_keyname is not None and index_range_keyname not in key_names: - key_names.append(index_range_keyname) + for index_hash_keyname in self.get_index_hash_keynames(index_name): + if index_hash_keyname not in key_names: + key_names.append(index_hash_keyname) + for index_range_keyname in self.get_index_range_keynames(index_name): + if index_range_keyname not in key_names: + key_names.append(index_range_keyname) return key_names def has_index_name(self, index_name): @@ -139,6 +139,15 @@ def get_index_hash_keyname(self, index_name: str) -> str: """ Returns the name of the hash key for a given index """ + hash_keynames = self.get_index_hash_keynames(index_name) + if hash_keynames: + return hash_keynames[0] + raise ValueError("No hash key attribute for index: {}".format(index_name)) + + def get_index_hash_keynames(self, index_name: str) -> List[str]: + """ + Returns the names of the hash keys for a given index + """ global_indexes = self.data.get(GLOBAL_SECONDARY_INDEXES) local_indexes = self.data.get(LOCAL_SECONDARY_INDEXES) indexes = [] @@ -148,14 +157,26 @@ def get_index_hash_keyname(self, index_name: str) -> str: indexes += global_indexes for index in indexes: if index.get(INDEX_NAME) == index_name: + hash_keynames = [] for schema_key in index.get(KEY_SCHEMA): if schema_key.get(KEY_TYPE) == HASH: - return schema_key.get(ATTR_NAME) + hash_keynames.append(schema_key.get(ATTR_NAME)) + if hash_keynames: + return hash_keynames raise ValueError("No hash key attribute for index: {}".format(index_name)) def get_index_range_keyname(self, index_name): """ - Returns the name of the hash key for a given index + Returns the name of the range key for a given index + """ + range_keynames = self.get_index_range_keynames(index_name) + if range_keynames: + return range_keynames[0] + return None + + def get_index_range_keynames(self, index_name) -> List[str]: + """ + Returns the names of the range keys for a given index """ global_indexes = self.data.get(GLOBAL_SECONDARY_INDEXES) local_indexes = self.data.get(LOCAL_SECONDARY_INDEXES) @@ -166,10 +187,12 @@ def get_index_range_keyname(self, index_name): indexes += global_indexes for index in indexes: if index.get(INDEX_NAME) == index_name: + range_keynames = [] for schema_key in index.get(KEY_SCHEMA): if schema_key.get(KEY_TYPE) == RANGE: - return schema_key.get(ATTR_NAME) - return None + range_keynames.append(schema_key.get(ATTR_NAME)) + return range_keynames + return [] def get_item_attribute_map(self, attributes: Dict, item_key=ITEM, pythonic_key: bool = True): """ @@ -1173,7 +1196,7 @@ def scan( def query( self, table_name: str, - hash_key: str, + hash_key: Union[object, Sequence[object], Mapping[str, object]], range_key_condition: Optional[Condition] = None, filter_condition: Optional[Any] = None, attributes_to_get: Optional[Any] = None, @@ -1201,12 +1224,22 @@ def query( if index_name: if not tbl.has_index_name(index_name): raise ValueError("Table {} has no index: {}".format(table_name, index_name)) - hash_keyname = tbl.get_index_hash_keyname(index_name) + hash_keynames = tbl.get_index_hash_keynames(index_name) else: - hash_keyname = tbl.hash_keyname + hash_keynames = [tbl.hash_keyname] - hash_condition_value = {self.get_attribute_type(table_name, hash_keyname, hash_key): self.parse_attribute(hash_key)} - key_condition = Path([hash_keyname]) == hash_condition_value + hash_key_values = self._get_query_hash_key_values( + hash_key, + hash_keynames, + index_name=index_name, + ) + key_condition = None + for hash_keyname, hash_keyvalue in zip(hash_keynames, hash_key_values): + hash_condition_value = { + self.get_attribute_type(table_name, hash_keyname, hash_keyvalue): self.parse_attribute(hash_keyvalue) + } + hash_condition = Path([hash_keyname]) == hash_condition_value + key_condition = hash_condition if key_condition is None else key_condition & hash_condition if range_key_condition is not None: key_condition &= range_key_condition @@ -1252,3 +1285,33 @@ def _check_condition(self, name, condition): @staticmethod def _reverse_dict(d): return {v: k for k, v in d.items()} + + @staticmethod + def _get_query_hash_key_values( + hash_key: Union[object, Sequence[object], Mapping[str, object]], + hash_keynames: Sequence[str], + index_name: Optional[str] = None, + ) -> List[object]: + if len(hash_keynames) == 1: + return [hash_key] + if isinstance(hash_key, (tuple, list)): + if len(hash_key) != len(hash_keynames): + raise ValueError( + f"Index {index_name} expects {len(hash_keynames)} hash key values, got {len(hash_key)}" + ) + return list(hash_key) + if isinstance(hash_key, Mapping): + missing_keys = [keyname for keyname in hash_keynames if keyname not in hash_key] + if missing_keys: + raise ValueError( + f"Index {index_name} requires values for hash keys: {', '.join(missing_keys)}" + ) + extra_keys = [keyname for keyname in hash_key if keyname not in hash_keynames] + if extra_keys: + raise ValueError( + f"Index {index_name} received unknown hash keys: {', '.join(extra_keys)}" + ) + return [hash_key[keyname] for keyname in hash_keynames] + raise ValueError( + f"Index {index_name} expects {len(hash_keynames)} hash key values as tuple/list" + ) diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 5e70ba5cc..430958413 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -3,7 +3,7 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ -from typing import Any, Dict, Mapping, Optional, Sequence +from typing import Any, Dict, Mapping, Optional, Sequence, Union from pynamodb.connection.base import Connection, MetaTable from pynamodb.constants import DEFAULT_BILLING_MODE, KEY @@ -238,7 +238,7 @@ def scan( def query( self, - hash_key: str, + hash_key: Union[object, Sequence[object], Mapping[str, object]], range_key_condition: Optional[Condition] = None, filter_condition: Optional[Any] = None, attributes_to_get: Optional[Any] = None, diff --git a/pynamodb/indexes.py b/pynamodb/indexes.py index e282b8372..c49da3c28 100644 --- a/pynamodb/indexes.py +++ b/pynamodb/indexes.py @@ -1,8 +1,7 @@ """ PynamoDB Indexes """ -from inspect import getmembers -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union from typing import TYPE_CHECKING from pynamodb._schema import IndexSchema, GlobalSecondaryIndexSchema @@ -19,7 +18,9 @@ if TYPE_CHECKING: from pynamodb.models import Model -_KeyType = Any +_KeyType = object +_HashKeyInputType = Union[_KeyType, Tuple[_KeyType, ...], List[_KeyType]] +_SerializedHashKeyType = Union[_KeyType, Tuple[_KeyType, ...]] _M = TypeVar('_M', bound='Model') @@ -30,13 +31,26 @@ class Index(Generic[_M]): Meta: Any = None _model: _M + @staticmethod + def _get_attributes_in_declaration_order(index_cls: Type['Index']) -> Dict[str, Attribute]: + """ + Returns attributes in declaration order, respecting overrides. + """ + attributes: Dict[str, Attribute] = {} + for base in reversed(index_cls.__mro__): + for name, attribute in getattr(base, "__dict__", {}).items(): + if isinstance(attribute, Attribute): + # If a subclass overrides an attribute, preserve the subclass declaration order. + if name in attributes: + del attributes[name] + attributes[name] = attribute + return attributes + @classmethod def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls.Meta is not None: - cls.Meta.attributes = {} - for name, attribute in getmembers(cls, lambda o: isinstance(o, Attribute)): - cls.Meta.attributes[name] = attribute + cls.Meta.attributes = cls._get_attributes_in_declaration_order(cls) def __init__(self) -> None: if self.Meta is None: @@ -50,7 +64,7 @@ def __set_name__(self, owner: Type[_M], name: str): def count( self, - hash_key: _KeyType, + hash_key: _HashKeyInputType, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -72,7 +86,7 @@ def count( def query( self, - hash_key: _KeyType, + hash_key: _HashKeyInputType, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -133,9 +147,45 @@ def _hash_key_attribute(cls): """ Returns the attribute class for the hash key """ - for attr_cls in cls.Meta.attributes.values(): - if attr_cls.is_hash_key: - return attr_cls + hash_key_attributes = cls._hash_key_attributes() + if hash_key_attributes: + return hash_key_attributes[0] + + @classmethod + def _hash_key_attributes(cls) -> List[Attribute]: + return [attr for attr in cls.Meta.attributes.values() if attr.is_hash_key] + + @classmethod + def _range_key_attributes(cls) -> List[Attribute]: + return [attr for attr in cls.Meta.attributes.values() if attr.is_range_key] + + @classmethod + def _serialize_hash_key_values(cls, hash_key: _HashKeyInputType) -> _SerializedHashKeyType: + hash_key_attributes = cls._hash_key_attributes() + if len(hash_key_attributes) <= 1: + if len(hash_key_attributes) == 0: + raise ValueError(f"{cls.__name__} has no hash key attributes") + return hash_key_attributes[0].serialize(hash_key) + + if not isinstance(hash_key, (tuple, list)): + raise ValueError( + f"{cls.__name__} expects {len(hash_key_attributes)} hash key values as a tuple/list" + ) + if len(hash_key) != len(hash_key_attributes): + raise ValueError( + f"{cls.__name__} expects {len(hash_key_attributes)} hash key values, got {len(hash_key)}" + ) + return tuple( + attr.serialize(value) + for attr, value in zip(hash_key_attributes, hash_key) + ) + + @classmethod + def _validate_key_attributes(cls) -> None: + """ + Hook for subclasses to validate key constraints. + """ + return None def _update_model_schema(self, schema: ModelSchema) -> None: raise NotImplementedError @@ -154,16 +204,29 @@ def _get_schema(cls) -> IndexSchema: 'attribute_definitions': [], } - for attr_cls in cls.Meta.attributes.values(): - if attr_cls.is_hash_key or attr_cls.is_range_key: - schema['attribute_definitions'].append({ - ATTR_NAME: attr_cls.attr_name, - ATTR_TYPE: attr_cls.attr_type, - }) - schema['key_schema'].append({ - ATTR_NAME: attr_cls.attr_name, - KEY_TYPE: HASH if attr_cls.is_hash_key else RANGE, - }) + cls._validate_key_attributes() + + hash_key_attributes = cls._hash_key_attributes() + range_key_attributes = cls._range_key_attributes() + + for attr_cls in hash_key_attributes: + schema['attribute_definitions'].append({ + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type, + }) + schema['key_schema'].append({ + ATTR_NAME: attr_cls.attr_name, + KEY_TYPE: HASH, + }) + for attr_cls in range_key_attributes: + schema['attribute_definitions'].append({ + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type, + }) + schema['key_schema'].append({ + ATTR_NAME: attr_cls.attr_name, + KEY_TYPE: RANGE, + }) if cls.Meta.projection.non_key_attributes: schema['projection'][NON_KEY_ATTRIBUTES] = cls.Meta.projection.non_key_attributes return schema @@ -173,6 +236,17 @@ class GlobalSecondaryIndex(Index[_M]): """ A global secondary index """ + @classmethod + def _validate_key_attributes(cls) -> None: + hash_keys = cls._hash_key_attributes() + range_keys = cls._range_key_attributes() + if len(hash_keys) == 0: + raise ValueError(f"{cls.__name__} must have at least one hash key attribute") + if len(hash_keys) > 4: + raise ValueError(f"{cls.__name__} supports at most 4 hash key attributes") + if len(range_keys) > 4: + raise ValueError(f"{cls.__name__} supports at most 4 range key attributes") + @classmethod def _update_model_schema(cls, schema: ModelSchema) -> None: index_schema: GlobalSecondaryIndexSchema = { @@ -197,6 +271,14 @@ class LocalSecondaryIndex(Index[_M]): """ A local secondary index """ + @classmethod + def _validate_key_attributes(cls) -> None: + hash_keys = cls._hash_key_attributes() + range_keys = cls._range_key_attributes() + if len(hash_keys) > 1: + raise ValueError(f"{cls.__name__} supports at most one hash key attribute") + if len(range_keys) > 1: + raise ValueError(f"{cls.__name__} supports at most one range key attribute") @classmethod def _update_model_schema(cls, schema: ModelSchema) -> None: diff --git a/pynamodb/models.py b/pynamodb/models.py index 8e14918ed..136a66a4c 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -59,7 +59,8 @@ ) _T = TypeVar('_T', bound='Model') -_KeyType = Any +_KeyType = object +_HashKeyQueryType = Union[_KeyType, Tuple[_KeyType, ...], List[_KeyType]] log = logging.getLogger(__name__) @@ -569,7 +570,7 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T: @classmethod def count( cls: Type[_T], - hash_key: Optional[_KeyType] = None, + hash_key: Optional[_HashKeyQueryType] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -593,7 +594,7 @@ def count( return cls.describe_table().get(ITEM_COUNT) if index_name: - hash_key = cls._indexes[index_name]._hash_key_attribute().serialize(hash_key) + hash_key = cls._indexes[index_name]._serialize_hash_key_values(hash_key) else: hash_key = cls._serialize_keys(hash_key)[0] @@ -628,7 +629,7 @@ def count( @classmethod def query( cls: Type[_T], - hash_key: _KeyType, + hash_key: _HashKeyQueryType, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -657,7 +658,7 @@ def query( :param rate_limit: If set then consumed capacity will be limited to this amount per second """ if index_name: - hash_key = cls._indexes[index_name]._hash_key_attribute().serialize(hash_key) + hash_key = cls._indexes[index_name]._serialize_hash_key_values(hash_key) else: hash_key = cls._serialize_keys(hash_key)[0] diff --git a/setup.py b/setup.py index fd0f72ef5..e2d04d39c 100644 --- a/setup.py +++ b/setup.py @@ -2,43 +2,50 @@ install_requires = [ - 'botocore>=1.12.54', + "botocore>=1.12.54", 'typing-extensions>=4; python_version<"3.11"', ] setup( - name='pynamodb', - version=__import__('pynamodb').__version__, - packages=find_packages(exclude=('examples', 'tests', 'typing_tests', 'tests.integration',)), - url='http://jlafon.io/pynamodb.html', + name="pynamodb", + version=__import__("pynamodb").__version__, + packages=find_packages( + exclude=( + "examples", + "tests", + "typing_tests", + "tests.integration", + ) + ), + url="http://jlafon.io/pynamodb.html", project_urls={ - 'Source': 'https://github.com/pynamodb/PynamoDB', + "Source": "https://github.com/pynamodb/PynamoDB", }, - author='Jharrod LaFon', - author_email='jlafon@eyesopen.com', - description='A Pythonic Interface to DynamoDB', - long_description=open('README.rst').read(), - long_description_content_type='text/x-rst', + author="Jharrod LaFon", + author_email="jlafon@eyesopen.com", + description="A Pythonic Interface to DynamoDB", + long_description=open("README.rst").read(), + long_description_content_type="text/x-rst", zip_safe=False, - license='MIT', - keywords='python dynamodb amazon', + license="MIT", + keywords="python dynamodb amazon", python_requires=">=3.7", install_requires=install_requires, classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Programming Language :: Python', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'License :: OSI Approved :: MIT License', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Programming Language :: Python", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: MIT License", ], extras_require={ - 'signals': ['blinker>=1.3,<2.0'], + "signals": ["blinker>=1.3,<2.0"], }, - package_data={'pynamodb': ['py.typed']}, + package_data={"pynamodb": ["py.typed"]}, ) diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index c90acd27d..3e6634de3 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -49,6 +49,50 @@ def test_meta_table_get_key_names__index(meta_table): assert key_names == ["ForumName", "Subject", "LastPostDateTime"] +def test_meta_table_get_key_names__composite_index(): + composite_table_data = { + "TableName": "Thread", + "AttributeDefinitions": [ + {"AttributeName": "ForumName", "AttributeType": "S"}, + {"AttributeName": "Subject", "AttributeType": "S"}, + {"AttributeName": "a_partition", "AttributeType": "S"}, + {"AttributeName": "z_partition", "AttributeType": "S"}, + {"AttributeName": "b_sort", "AttributeType": "S"}, + {"AttributeName": "c_sort", "AttributeType": "S"}, + ], + "KeySchema": [ + {"AttributeName": "ForumName", "KeyType": "HASH"}, + {"AttributeName": "Subject", "KeyType": "RANGE"}, + ], + "GlobalSecondaryIndexes": [ + { + "IndexName": "CompositeIndex", + "KeySchema": [ + {"AttributeName": "z_partition", "KeyType": "HASH"}, + {"AttributeName": "a_partition", "KeyType": "HASH"}, + {"AttributeName": "c_sort", "KeyType": "RANGE"}, + {"AttributeName": "b_sort", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "KEYS_ONLY"}, + } + ], + } + meta_table = MetaTable(composite_table_data) + + assert meta_table.get_index_hash_keynames("CompositeIndex") == ["z_partition", "a_partition"] + assert meta_table.get_index_hash_keyname("CompositeIndex") == "z_partition" + assert meta_table.get_index_range_keynames("CompositeIndex") == ["c_sort", "b_sort"] + assert meta_table.get_index_range_keyname("CompositeIndex") == "c_sort" + assert meta_table.get_key_names("CompositeIndex") == [ + "ForumName", + "Subject", + "z_partition", + "a_partition", + "c_sort", + "b_sort", + ] + + def test_meta_table_get_attribute_type(meta_table): assert meta_table.get_attribute_type('ForumName') == 'S' with pytest.raises(ValueError): @@ -219,6 +263,35 @@ def test_connection_create_table(): # Ensure that the hash key is first when creating indexes assert req.call_args[0][1]['GlobalSecondaryIndexes'][0]['KeySchema'][0]['KeyType'] == 'HASH' assert req.call_args[0][1] == params + + kwargs['global_secondary_indexes'] = [ + { + 'index_name': 'composite-index', + 'key_schema': [ + {'AttributeName': 'z_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'a_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'c_sort', 'KeyType': 'RANGE'}, + {'AttributeName': 'b_sort', 'KeyType': 'RANGE'}, + ], + 'projection': { + 'ProjectionType': 'KEYS_ONLY' + }, + 'provisioned_throughput': { + 'ReadCapacityUnits': 1, + 'WriteCapacityUnits': 1, + }, + } + ] + with patch(PATCH_METHOD) as req: + req.return_value = None + conn.create_table(TEST_TABLE_NAME, **kwargs) + assert req.call_args[0][1]['GlobalSecondaryIndexes'][0]['KeySchema'] == [ + {'AttributeName': 'z_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'a_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'c_sort', 'KeyType': 'RANGE'}, + {'AttributeName': 'b_sort', 'KeyType': 'RANGE'}, + ] + del(kwargs['global_secondary_indexes']) del(params['GlobalSecondaryIndexes']) @@ -1142,6 +1215,99 @@ def test_connection_query(): } assert req.call_args[0][1] == params + composite_table_name = "ThreadComposite" + composite_table_data = { + "TableName": composite_table_name, + "AttributeDefinitions": [ + {"AttributeName": "ForumName", "AttributeType": "S"}, + {"AttributeName": "Subject", "AttributeType": "S"}, + {"AttributeName": "z_partition", "AttributeType": "S"}, + {"AttributeName": "a_partition", "AttributeType": "S"}, + {"AttributeName": "c_sort", "AttributeType": "S"}, + {"AttributeName": "b_sort", "AttributeType": "S"}, + ], + "KeySchema": [ + {"AttributeName": "ForumName", "KeyType": "HASH"}, + {"AttributeName": "Subject", "KeyType": "RANGE"}, + ], + "GlobalSecondaryIndexes": [ + { + "IndexName": "CompositeIndex", + "KeySchema": [ + {"AttributeName": "z_partition", "KeyType": "HASH"}, + {"AttributeName": "a_partition", "KeyType": "HASH"}, + {"AttributeName": "c_sort", "KeyType": "RANGE"}, + {"AttributeName": "b_sort", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "KEYS_ONLY"}, + } + ], + } + conn.add_meta_table(MetaTable(composite_table_data)) + + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + composite_table_name, + ("z1", "a1"), + index_name='CompositeIndex' + ) + params = { + 'ReturnConsumedCapacity': 'TOTAL', + 'IndexName': 'CompositeIndex', + 'KeyConditionExpression': '(#0 = :0 AND #1 = :1)', + 'ExpressionAttributeNames': { + '#0': 'z_partition', + '#1': 'a_partition' + }, + 'ExpressionAttributeValues': { + ':0': {'S': 'z1'}, + ':1': {'S': 'a1'} + }, + 'TableName': composite_table_name + } + assert req.call_args[0][1] == params + + with pytest.raises(ValueError, match="expects 2 hash key values"): + conn.query(composite_table_name, "z1", index_name='CompositeIndex') + + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + composite_table_name, + {'a_partition': 'a1', 'z_partition': 'z1'}, + index_name='CompositeIndex' + ) + params = { + 'ReturnConsumedCapacity': 'TOTAL', + 'IndexName': 'CompositeIndex', + 'KeyConditionExpression': '(#0 = :0 AND #1 = :1)', + 'ExpressionAttributeNames': { + '#0': 'z_partition', + '#1': 'a_partition' + }, + 'ExpressionAttributeValues': { + ':0': {'S': 'z1'}, + ':1': {'S': 'a1'} + }, + 'TableName': composite_table_name + } + assert req.call_args[0][1] == params + + with pytest.raises(ValueError, match="requires values for hash keys: a_partition"): + conn.query( + composite_table_name, + {'z_partition': 'z1'}, + index_name='CompositeIndex' + ) + + with pytest.raises(ValueError, match="received unknown hash keys: unknown"): + conn.query( + composite_table_name, + {'z_partition': 'z1', 'a_partition': 'a1', 'unknown': 'u1'}, + index_name='CompositeIndex' + ) + with patch(PATCH_METHOD) as req: req.return_value = {} conn.query( diff --git a/tests/test_model.py b/tests/test_model.py index da54303bb..966206eb0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -169,6 +169,31 @@ class Meta: icons = BinarySetAttribute(legacy_encoding=False) +class CompositeOrderIndex(GlobalSecondaryIndex): + class Meta: + index_name = 'composite_order_idx' + read_capacity_units = 2 + write_capacity_units = 1 + projection = AllProjection() + + z_partition = UnicodeAttribute(hash_key=True) + a_partition = UnicodeAttribute(hash_key=True) + c_sort = UnicodeAttribute(range_key=True) + b_sort = UnicodeAttribute(range_key=True) + + +class CompositeIndexedModel(Model): + class Meta: + table_name = 'CompositeIndexedModel' + + item_id = UnicodeAttribute(hash_key=True) + z_partition = UnicodeAttribute() + a_partition = UnicodeAttribute() + c_sort = UnicodeAttribute() + b_sort = UnicodeAttribute() + composite_index = CompositeOrderIndex() + + class SimpleUserModel(Model): """ A hash key only model @@ -1083,6 +1108,22 @@ def test_index_count(self): } deep_eq(args, params, _assert=True) + def test_index_count_composite_hash_key(self): + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 1, 'ScannedCount': 1} + res = CompositeIndexedModel.composite_index.count(('p1', 'p2')) + self.assertEqual(res, 1) + params = req.call_args[0][1] + self.assertEqual(params['KeyConditionExpression'], '(#0 = :0 AND #1 = :1)') + self.assertEqual(params['ExpressionAttributeNames'], { + '#0': 'z_partition', + '#1': 'a_partition', + }) + self.assertEqual(params['ExpressionAttributeValues'], { + ':0': {'S': 'p1'}, + ':1': {'S': 'p2'}, + }) + def test_index_multipage_count(self): with patch(PATCH_METHOD) as req: last_evaluated_key = { @@ -2301,6 +2342,111 @@ def fake_dynamodb(*args, **kwargs): } ) + def test_global_index_composite_keys(self): + scope_args = {'count': 0} + + def fake_dynamodb(*args, **kwargs): + if scope_args['count'] == 0: + scope_args['count'] += 1 + raise ClientError({'Error': {'Code': 'ResourceNotFoundException', 'Message': 'Not Found'}}, + "DescribeTable") + return {} + + fake_db = MagicMock() + fake_db.side_effect = fake_dynamodb + + with patch(PATCH_METHOD, new=fake_db) as req: + CompositeIndexedModel.create_table(read_capacity_units=2, write_capacity_units=2) + args = req.call_args[0][1] + self.assertEqual( + args['GlobalSecondaryIndexes'][0]['KeySchema'], + [ + {'AttributeName': 'z_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'a_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'c_sort', 'KeyType': 'RANGE'}, + {'AttributeName': 'b_sort', 'KeyType': 'RANGE'}, + ] + ) + + def test_global_index_composite_query(self): + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(CompositeIndexedModel.composite_index.query(('p1', 'p2'), CompositeIndexedModel.c_sort == 's1')) + + params = req.call_args[0][1] + self.assertEqual(params['IndexName'], 'composite_order_idx') + self.assertEqual(params['TableName'], 'CompositeIndexedModel') + self.assertEqual(params['ExpressionAttributeValues'], { + ':0': {'S': 'p1'}, + ':1': {'S': 'p2'}, + ':2': {'S': 's1'}, + }) + self.assertEqual(params['ExpressionAttributeNames'], { + '#0': 'z_partition', + '#1': 'a_partition', + '#2': 'c_sort', + }) + self.assertEqual(params['KeyConditionExpression'], '((#0 = :0 AND #1 = :1) AND #2 = :2)') + + def test_global_index_composite_query_list_hash_key(self): + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(CompositeIndexedModel.composite_index.query(['p1', 'p2'])) + + params = req.call_args[0][1] + self.assertEqual(params['ExpressionAttributeValues'], { + ':0': {'S': 'p1'}, + ':1': {'S': 'p2'}, + }) + self.assertEqual(params['KeyConditionExpression'], '(#0 = :0 AND #1 = :1)') + + def test_global_index_composite_query_hash_key_validation(self): + with pytest.raises(ValueError, match='expects 2 hash key values'): + CompositeIndexedModel.composite_index.query('p1') + + with pytest.raises(ValueError, match='expects 2 hash key values, got 1'): + CompositeIndexedModel.composite_index.query(('p1',)) + + def test_global_index_query_hash_only_does_not_error(self): + """ + Non-composite GSI queries should work with hash key only. + """ + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(IndexedModel.email_index.query('foo')) + + params = req.call_args[0][1] + self.assertEqual(params['IndexName'], 'custom_idx_name') + self.assertEqual(params['TableName'], 'IndexedModel') + self.assertEqual(params['KeyConditionExpression'], '#0 = :0') + self.assertEqual(params['ExpressionAttributeNames'], {'#0': 'email'}) + self.assertEqual(params['ExpressionAttributeValues'], {':0': {'S': 'foo'}}) + + def test_global_index_composite_query_last_evaluated_key(self): + with patch(PATCH_METHOD) as req: + items = [] + for idx in range(30): + items.append({ + 'item_id': {STRING: f'id-{idx}'}, + 'z_partition': {STRING: 'z1'}, + 'a_partition': {STRING: 'a1'}, + 'c_sort': {STRING: f'c-{idx}'}, + 'b_sort': {STRING: f'b-{idx}'}, + }) + + req.return_value = {'Count': len(items), 'ScannedCount': len(items), 'Items': items} + results_iter = CompositeIndexedModel.composite_index.query(('z1', 'a1'), limit=25) + results = list(results_iter) + + self.assertEqual(len(results), 25) + self.assertEqual(results_iter.last_evaluated_key, { + 'item_id': items[24]['item_id'], + 'z_partition': items[24]['z_partition'], + 'a_partition': items[24]['a_partition'], + 'c_sort': items[24]['c_sort'], + 'b_sort': items[24]['b_sort'], + }) + def test_local_index(self): """ Models.LocalSecondaryIndex @@ -2314,14 +2460,14 @@ def test_local_index(self): 'AttributeType': 'S', 'AttributeName': 'user_name' }, - { - 'AttributeType': 'NS', - 'AttributeName': 'numbers' - }, { 'AttributeType': 'S', 'AttributeName': 'email' }, + { + 'AttributeType': 'NS', + 'AttributeName': 'numbers' + }, ] ) self.assertEqual(schema['local_secondary_indexes'][0]['projection']['ProjectionType'], 'INCLUDE') @@ -2357,6 +2503,29 @@ def fake_dynamodb(*args, **kwargs): ) self.assertTrue('ProvisionedThroughput' not in args['LocalSecondaryIndexes'][0]) + def test_local_index_reject_multiple_hash_or_range_keys(self): + class BadLocalHashIndex(LocalSecondaryIndex): + class Meta: + index_name = 'bad_local_hash_idx' + projection = AllProjection() + + h1 = UnicodeAttribute(hash_key=True) + h2 = UnicodeAttribute(hash_key=True) + + class BadLocalRangeIndex(LocalSecondaryIndex): + class Meta: + index_name = 'bad_local_range_idx' + projection = AllProjection() + + h = UnicodeAttribute(hash_key=True) + r1 = UnicodeAttribute(range_key=True) + r2 = UnicodeAttribute(range_key=True) + + with pytest.raises(ValueError, match='at most one hash key'): + BadLocalHashIndex._get_schema() + with pytest.raises(ValueError, match='at most one range key'): + BadLocalRangeIndex._get_schema() + def test_projections(self): """ Models.Projection From 91ab1f4ad1f341ef1da5a096a84c7efcc940efea Mon Sep 17 00:00:00 2001 From: keremeyuboglu <32223948+keremeyuboglu@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:03:42 +0300 Subject: [PATCH 2/8] Updated hash key and replaced hash keys dictionary for composite indexes instead --- docs/indexes.rst | 15 +- examples/indexes.py | 20 +- pynamodb/connection/base.py | 709 ++++++++++++++++++++++------------ pynamodb/connection/table.py | 11 +- pynamodb/indexes.py | 314 ++++++++++++--- pynamodb/models.py | 535 +++++++++++++++---------- setup.py | 59 ++- tests/test_base_connection.py | 18 +- tests/test_model.py | 172 ++++++++- 9 files changed, 1274 insertions(+), 579 deletions(-) diff --git a/docs/indexes.rst b/docs/indexes.rst index e5cf47043..b539b3ffc 100644 --- a/docs/indexes.rst +++ b/docs/indexes.rst @@ -145,16 +145,23 @@ them used. round = UnicodeAttribute(range_key=True) bracket = UnicodeAttribute(range_key=True) -When querying a composite GSI, pass all partition-key values as a tuple (or list) -in the same declaration order: +When querying a composite GSI, pass all partition-key values with ``hash_keys``. +PynamoDB validates the supplied names and sends them in the index declaration +order: .. code-block:: python - for item in MatchModel.tournament_region_index.query(('WINTER2024', 'NA-EAST')): + for item in MatchModel.tournament_region_index.query( + hash_keys={ + 'tournament_id': 'WINTER2024', + 'region': 'NA-EAST', + } + ): print(item) For sort-key conditions, DynamoDB enforces left-to-right semantics across declared -sort-key attributes. See the DynamoDB documentation for details: +sort-key attributes. If a later sort-key attribute is used, all preceding sort-key +attributes must have equality conditions. See the DynamoDB documentation for details: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/GSI.DesignPattern.MultiAttributeKeys.html diff --git a/examples/indexes.py b/examples/indexes.py index 58525e770..d1f15ad7b 100644 --- a/examples/indexes.py +++ b/examples/indexes.py @@ -1,7 +1,6 @@ """ Examples using DynamoDB indexes """ - import datetime from pynamodb.models import Model from pynamodb.indexes import GlobalSecondaryIndex, AllProjection, LocalSecondaryIndex @@ -12,7 +11,6 @@ class ViewIndex(GlobalSecondaryIndex): """ This class represents a global secondary index """ - class Meta: # You can override the index name by setting it below index_name = "viewIdx" @@ -20,7 +18,6 @@ class Meta: write_capacity_units = 1 # All attributes are projected projection = AllProjection() - # This attribute is the hash key for the index # Note that this attribute must also exist # in the model @@ -31,23 +28,20 @@ class TestModel(Model): """ A test model that uses a global secondary index """ - class Meta: table_name = "TestModel" # Set host for using DynamoDB Local host = "http://localhost:8000" - forum = UnicodeAttribute(hash_key=True) thread = UnicodeAttribute(range_key=True) view_index = ViewIndex() view = NumberAttribute(default=0) - if not TestModel.exists(): TestModel.create_table(read_capacity_units=1, write_capacity_units=1, wait=True) # Create an item -test_item = TestModel("forum-example", "thread-example") +test_item = TestModel('forum-example', 'thread-example') test_item.view = 1 test_item.save() @@ -63,7 +57,6 @@ class Meta: table_name = "GamePlayerOpponentIndex" host = "http://localhost:8000" projection = AllProjection() - player_id = UnicodeAttribute(hash_key=True) winner_id = UnicodeAttribute(range_key=True) @@ -75,7 +68,6 @@ class Meta: table_name = "GameOpponentTimeIndex" host = "http://localhost:8000" projection = AllProjection() - winner_id = UnicodeAttribute(hash_key=True) created_time = UnicodeAttribute(range_key=True) @@ -86,7 +78,6 @@ class Meta: write_capacity_units = 1 table_name = "GameModel" host = "http://localhost:8000" - player_id = UnicodeAttribute(hash_key=True) created_time = UTCDateTimeAttribute(range_key=True) winner_id = UnicodeAttribute() @@ -95,18 +86,17 @@ class Meta: player_opponent_index = GamePlayerOpponentIndex() opponent_time_index = GameOpponentTimeIndex() - if not GameModel.exists(): GameModel.create_table(wait=True) # Create an item -item = GameModel("1234", datetime.datetime.now(datetime.UTC)) -item.winner_id = "5678" +item = GameModel('1234', datetime.datetime.utcnow()) +item.winner_id = '5678' item.save() # Indexes can be queried easily using the index's hash key -for item in GameModel.player_opponent_index.query("1234"): +for item in GameModel.player_opponent_index.query('1234'): print("Item queried from index: {0}".format(item)) # Count on an index -print(GameModel.player_opponent_index.count("1234")) +print(GameModel.player_opponent_index.count('1234')) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 302a95ea1..dbcd6a0d1 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -6,6 +6,7 @@ import uuid from threading import local from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast + if sys.version_info >= (3, 8): from typing import Literal else: @@ -19,7 +20,6 @@ from botocore.session import get_session from pynamodb.connection._botocore_private import BotocoreBaseClientPrivate -from pynamodb._util import bin_decode_attr from pynamodb.constants import ( RETURN_CONSUMED_CAPACITY_VALUES, RETURN_ITEM_COLL_METRICS_VALUES, RETURN_ITEM_COLL_METRICS, RETURN_CONSUMED_CAPACITY, RETURN_VALUES_VALUES, @@ -194,27 +194,27 @@ def get_index_range_keynames(self, index_name) -> List[str]: return range_keynames return [] - def get_item_attribute_map(self, attributes: Dict, item_key=ITEM, pythonic_key: bool = True): + def get_item_attribute_map( + self, attributes: Dict, item_key=ITEM, pythonic_key: bool = True + ): """ Builds up a dynamodb compatible AttributeValue map """ if pythonic_key: item_key = item_key - attr_map: Dict[str, Dict] = { - item_key: {} - } + attr_map: Dict[str, Dict] = {item_key: {}} for key, value in attributes.items(): # In this case, the user provided a mapping # {'key': {'S': 'value'}} if isinstance(value, dict): attr_map[item_key][key] = value else: - attr_map[item_key][key] = { - self.get_attribute_type(key): value - } + attr_map[item_key][key] = {self.get_attribute_type(key): value} return attr_map - def get_attribute_type(self, attribute_name: str, value: Optional[Any] = None) -> str: + def get_attribute_type( + self, attribute_name: str, value: Optional[Any] = None + ) -> str: """ Returns the proper attribute type for a given attribute name """ @@ -225,10 +225,14 @@ def get_attribute_type(self, attribute_name: str, value: Optional[Any] = None) - for key in ATTRIBUTE_TYPES: if key in value: return key - attr_names = [attr.get(ATTR_NAME) for attr in self.data.get(ATTR_DEFINITIONS, [])] + attr_names = [ + attr.get(ATTR_NAME) for attr in self.data.get(ATTR_DEFINITIONS, []) + ] raise ValueError("No attribute {} in {}".format(attribute_name, attr_names)) - def get_identifier_map(self, hash_key: str, range_key: Optional[str] = None, key: str = KEY): + def get_identifier_map( + self, hash_key: str, range_key: Optional[str] = None, key: str = KEY + ): """ Builds the identifier map that is common to several operations """ @@ -249,12 +253,13 @@ def get_exclusive_start_key_map(self, exclusive_start_key): """ Builds the exclusive start key attribute map """ - if isinstance(exclusive_start_key, dict) and self.hash_keyname in exclusive_start_key: + if ( + isinstance(exclusive_start_key, dict) + and self.hash_keyname in exclusive_start_key + ): # This is useful when paginating results, as the LastEvaluatedKey returned is already # structured properly - return { - EXCLUSIVE_START_KEY: exclusive_start_key - } + return {EXCLUSIVE_START_KEY: exclusive_start_key} else: return { EXCLUSIVE_START_KEY: { @@ -270,24 +275,26 @@ class Connection(object): A higher level abstraction over botocore """ - def __init__(self, - region: Optional[str] = None, - host: Optional[str] = None, - read_timeout_seconds: Optional[float] = None, - connect_timeout_seconds: Optional[float] = None, - max_retry_attempts: Optional[int] = None, - retry_configuration: Optional[ - Union[ - Literal["LEGACY"], - Literal["UNSET"], - "botocore.config._RetryDict", - ] - ] = None, - max_pool_connections: Optional[int] = None, - extra_headers: Optional[Mapping[str, str]] = None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None): + def __init__( + self, + region: Optional[str] = None, + host: Optional[str] = None, + read_timeout_seconds: Optional[float] = None, + connect_timeout_seconds: Optional[float] = None, + max_retry_attempts: Optional[int] = None, + retry_configuration: Optional[ + Union[ + Literal["LEGACY"], + Literal["UNSET"], + "botocore.config._RetryDict", + ] + ] = None, + max_pool_connections: Optional[int] = None, + extra_headers: Optional[Mapping[str, str]] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + ): self._tables: Dict[str, MetaTable] = {} self.host = host self._local = local() @@ -296,22 +303,26 @@ def __init__(self, if region: self.region = region else: - self.region = get_settings_value('region') + self.region = get_settings_value("region") if connect_timeout_seconds is not None: self._connect_timeout_seconds = connect_timeout_seconds else: - self._connect_timeout_seconds = get_settings_value('connect_timeout_seconds') + self._connect_timeout_seconds = get_settings_value( + "connect_timeout_seconds" + ) if read_timeout_seconds is not None: self._read_timeout_seconds = read_timeout_seconds else: - self._read_timeout_seconds = get_settings_value('read_timeout_seconds') + self._read_timeout_seconds = get_settings_value("read_timeout_seconds") if max_retry_attempts is not None: self._max_retry_attempts_exception = max_retry_attempts else: - self._max_retry_attempts_exception = get_settings_value('max_retry_attempts') + self._max_retry_attempts_exception = get_settings_value( + "max_retry_attempts" + ) # Since we have the pattern of using `None` to indicate "read from the # settings", we use a literal of "UNSET" to indicate we want the @@ -323,17 +334,17 @@ def __init__(self, elif retry_configuration is not None: self._retry_configuration = retry_configuration else: - self._retry_configuration = get_settings_value('retry_configuration') + self._retry_configuration = get_settings_value("retry_configuration") if max_pool_connections is not None: self._max_pool_connections = max_pool_connections else: - self._max_pool_connections = get_settings_value('max_pool_connections') + self._max_pool_connections = get_settings_value("max_pool_connections") if extra_headers is not None: self._extra_headers = extra_headers else: - self._extra_headers = get_settings_value('extra_headers') + self._extra_headers = get_settings_value("extra_headers") self._aws_access_key_id = aws_access_key_id self._aws_secret_access_key = aws_secret_access_key @@ -348,7 +359,14 @@ def dispatch(self, operation_name: str, operation_kwargs: Dict) -> Dict: Raises TableDoesNotExist if the specified table does not exist """ - if operation_name not in [DESCRIBE_TABLE, LIST_TABLES, UPDATE_TABLE, UPDATE_TIME_TO_LIVE, DELETE_TABLE, CREATE_TABLE]: + if operation_name not in [ + DESCRIBE_TABLE, + LIST_TABLES, + UPDATE_TABLE, + UPDATE_TIME_TO_LIVE, + DELETE_TABLE, + CREATE_TABLE, + ]: if RETURN_CONSUMED_CAPACITY not in operation_kwargs: operation_kwargs.update(self.get_consumed_capacity_map(TOTAL)) log.debug("Calling %s with arguments %s", operation_name, operation_kwargs) @@ -364,18 +382,33 @@ def dispatch(self, operation_name: str, operation_kwargs: Dict) -> Dict: capacity = data.get(CONSUMED_CAPACITY) if isinstance(capacity, dict) and CAPACITY_UNITS in capacity: capacity = capacity.get(CAPACITY_UNITS) - log.debug("%s %s consumed %s units", data.get(TABLE_NAME, ''), operation_name, capacity) + log.debug( + "%s %s consumed %s units", + data.get(TABLE_NAME, ""), + operation_name, + capacity, + ) return data def send_post_boto_callback(self, operation_name, req_uuid, table_name): try: - post_dynamodb_send.send(self, operation_name=operation_name, table_name=table_name, req_uuid=req_uuid) + post_dynamodb_send.send( + self, + operation_name=operation_name, + table_name=table_name, + req_uuid=req_uuid, + ) except Exception: log.exception("post_boto callback threw an exception.") def send_pre_boto_callback(self, operation_name, req_uuid, table_name): try: - pre_dynamodb_send.send(self, operation_name=operation_name, table_name=table_name, req_uuid=req_uuid) + pre_dynamodb_send.send( + self, + operation_name=operation_name, + table_name=table_name, + req_uuid=req_uuid, + ) except Exception: log.exception("pre_boto callback threw an exception.") @@ -387,13 +420,15 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict: try: return self.client._make_api_call(operation_name, operation_kwargs) except ClientError as e: - resp_metadata = e.response.get('ResponseMetadata', {}).get('HTTPHeaders', {}) - cancellation_reasons = e.response.get('CancellationReasons', []) + resp_metadata = e.response.get("ResponseMetadata", {}).get( + "HTTPHeaders", {} + ) + cancellation_reasons = e.response.get("CancellationReasons", []) - botocore_props = {'Error': e.response.get('Error', {})} + botocore_props = {"Error": e.response.get("Error", {})} verbose_props = { - 'request_id': resp_metadata.get('x-amzn-requestid', ''), - 'table_name': self._get_table_name_for_error_context(operation_kwargs), + "request_id": resp_metadata.get("x-amzn-requestid", ""), + "table_name": self._get_table_name_for_error_context(operation_kwargs), } raise VerboseClientError( botocore_props, @@ -402,10 +437,14 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict: cancellation_reasons=( ( CancellationReason( - code=d['Code'], - message=d.get('Message'), - raw_item=cast(Optional[Dict[str, Dict[str, Any]]], d.get('Item')), - ) if d['Code'] != 'None' else None + code=d["Code"], + message=d.get("Message"), + raw_item=cast( + Optional[Dict[str, Dict[str, Any]]], d.get("Item") + ), + ) + if d["Code"] != "None" + else None ) for d in cancellation_reasons ), @@ -414,7 +453,7 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict: def _get_table_name_for_error_context(self, operation_kwargs) -> str: # First handle the two multi-table cases: batch and transaction operations if REQUEST_ITEMS in operation_kwargs: - return ','.join(operation_kwargs[REQUEST_ITEMS]) + return ",".join(operation_kwargs[REQUEST_ITEMS]) elif TRANSACT_ITEMS in operation_kwargs: table_names = [] for item in operation_kwargs[TRANSACT_ITEMS]: @@ -429,12 +468,14 @@ def session(self) -> botocore.session.Session: Returns a valid botocore session """ # botocore client creation is not thread safe as of v1.2.5+ (see issue #153) - if getattr(self._local, 'session', None) is None: + if getattr(self._local, "session", None) is None: self._local.session = get_session() if self._aws_access_key_id and self._aws_secret_access_key: - self._local.session.set_credentials(self._aws_access_key_id, - self._aws_secret_access_key, - self._aws_session_token) + self._local.session.set_credentials( + self._aws_access_key_id, + self._aws_secret_access_key, + self._aws_session_token, + ) return self._local.session @property @@ -446,15 +487,18 @@ def client(self) -> BotocoreBaseClientPrivate: # https://github.com/boto/botocore/blob/4d55c9b4142/botocore/credentials.py#L1016-L1021 # if the client does not have credentials, we create a new client # otherwise the client is permanently poisoned in the case of metadata service flakiness when using IAM roles - if not self._client or (self._client._request_signer and not self._client._request_signer._credentials): + if not self._client or ( + self._client._request_signer + and not self._client._request_signer._credentials + ): # Check if we are using the "LEGACY" retry mode to keep previous PynamoDB # retry behavior, or if we are using the new retry configuration settings. if self._retry_configuration != "LEGACY": retries = self._retry_configuration else: retries = { - 'total_max_attempts': 1 + self._max_retry_attempts_exception, - 'mode': 'standard', + "total_max_attempts": 1 + self._max_retry_attempts_exception, + "mode": "standard", } config = botocore.client.Config( @@ -464,9 +508,16 @@ def client(self) -> BotocoreBaseClientPrivate: max_pool_connections=self._max_pool_connections, retries=retries, ) - self._client = cast(BotocoreBaseClientPrivate, self.session.create_client(SERVICE_NAME, self.region, endpoint_url=self.host, config=config)) + self._client = cast( + BotocoreBaseClientPrivate, + self.session.create_client( + SERVICE_NAME, self.region, endpoint_url=self.host, config=config + ), + ) - self._client.meta.events.register_first('before-send.*.*', self._before_send) + self._client.meta.events.register_first( + "before-send.*.*", self._before_send + ) return self._client def add_meta_table(self, meta_table: MetaTable) -> None: @@ -508,20 +559,26 @@ def create_table( PROVISIONED_THROUGHPUT: { READ_CAPACITY_UNITS: read_capacity_units, WRITE_CAPACITY_UNITS: write_capacity_units, - } + }, } attrs_list = [] if attribute_definitions is None: raise ValueError("attribute_definitions argument is required") for attr in attribute_definitions: - attrs_list.append({ - ATTR_NAME: attr.get(ATTR_NAME) or attr['attribute_name'], - ATTR_TYPE: attr.get(ATTR_TYPE) or attr['attribute_type'] - }) + attrs_list.append( + { + ATTR_NAME: attr.get(ATTR_NAME) or attr["attribute_name"], + ATTR_TYPE: attr.get(ATTR_TYPE) or attr["attribute_type"], + } + ) operation_kwargs[ATTR_DEFINITIONS] = attrs_list if billing_mode not in AVAILABLE_BILLING_MODES: - raise ValueError("incorrect value for billing_mode, available modes: {}".format(AVAILABLE_BILLING_MODES)) + raise ValueError( + "incorrect value for billing_mode, available modes: {}".format( + AVAILABLE_BILLING_MODES + ) + ) if billing_mode == PAY_PER_REQUEST_BILLING_MODE: del operation_kwargs[PROVISIONED_THROUGHPUT] elif billing_mode == PROVISIONED_BILLING_MODE: @@ -531,10 +588,12 @@ def create_table( global_secondary_indexes_list = [] for index in global_secondary_indexes: index_kwargs = { - INDEX_NAME: index.get('index_name'), - KEY_SCHEMA: sorted(index.get('key_schema'), key=lambda x: x.get(KEY_TYPE)), - PROJECTION: index.get('projection'), - PROVISIONED_THROUGHPUT: index.get('provisioned_throughput') + INDEX_NAME: index.get("index_name"), + KEY_SCHEMA: sorted( + index.get("key_schema"), key=lambda x: x.get(KEY_TYPE) + ), + PROJECTION: index.get("projection"), + PROVISIONED_THROUGHPUT: index.get("provisioned_throughput"), } if billing_mode == PAY_PER_REQUEST_BILLING_MODE: del index_kwargs[PROVISIONED_THROUGHPUT] @@ -545,35 +604,38 @@ def create_table( raise ValueError("key_schema is required") key_schema_list = [] for item in key_schema: - key_schema_list.append({ - ATTR_NAME: item.get(ATTR_NAME) or item['attribute_name'], - KEY_TYPE: str(item.get(KEY_TYPE) or item['key_type']).upper() - }) - operation_kwargs[KEY_SCHEMA] = sorted(key_schema_list, key=lambda x: x.get(KEY_TYPE)) + key_schema_list.append( + { + ATTR_NAME: item.get(ATTR_NAME) or item["attribute_name"], + KEY_TYPE: str(item.get(KEY_TYPE) or item["key_type"]).upper(), + } + ) + operation_kwargs[KEY_SCHEMA] = sorted( + key_schema_list, key=lambda x: x.get(KEY_TYPE) + ) local_secondary_indexes_list = [] if local_secondary_indexes: for index in local_secondary_indexes: - local_secondary_indexes_list.append({ - INDEX_NAME: index.get('index_name'), - KEY_SCHEMA: sorted(index.get('key_schema'), key=lambda x: x.get(KEY_TYPE)), - PROJECTION: index.get('projection'), - }) + local_secondary_indexes_list.append( + { + INDEX_NAME: index.get("index_name"), + KEY_SCHEMA: sorted( + index.get("key_schema"), key=lambda x: x.get(KEY_TYPE) + ), + PROJECTION: index.get("projection"), + } + ) operation_kwargs[LOCAL_SECONDARY_INDEXES] = local_secondary_indexes_list if stream_specification: operation_kwargs[STREAM_SPECIFICATION] = { - STREAM_ENABLED: stream_specification['stream_enabled'], - STREAM_VIEW_TYPE: stream_specification['stream_view_type'] + STREAM_ENABLED: stream_specification["stream_enabled"], + STREAM_VIEW_TYPE: stream_specification["stream_view_type"], } if tags: - operation_kwargs[TAGS] = [ - { - KEY: k, - VALUE: v - } for k, v in tags.items() - ] + operation_kwargs[TAGS] = [{KEY: k, VALUE: v} for k, v in tags.items()] try: data = self.dispatch(CREATE_TABLE, operation_kwargs) @@ -590,7 +652,7 @@ def update_time_to_live(self, table_name: str, ttl_attribute_name: str) -> Dict: TIME_TO_LIVE_SPECIFICATION: { ATTR_NAME: ttl_attribute_name, ENABLED: True, - } + }, } try: return self.dispatch(UPDATE_TIME_TO_LIVE, operation_kwargs) @@ -601,9 +663,7 @@ def delete_table(self, table_name: str) -> Dict: """ Performs the DeleteTable operation """ - operation_kwargs = { - TABLE_NAME: table_name - } + operation_kwargs = {TABLE_NAME: table_name} try: data = self.dispatch(DELETE_TABLE, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: @@ -620,29 +680,38 @@ def update_table( """ Performs the UpdateTable operation """ - operation_kwargs: Dict[str, Any] = { - TABLE_NAME: table_name - } - if read_capacity_units and not write_capacity_units or write_capacity_units and not read_capacity_units: - raise ValueError("read_capacity_units and write_capacity_units are required together") + operation_kwargs: Dict[str, Any] = {TABLE_NAME: table_name} + if ( + read_capacity_units + and not write_capacity_units + or write_capacity_units + and not read_capacity_units + ): + raise ValueError( + "read_capacity_units and write_capacity_units are required together" + ) if read_capacity_units and write_capacity_units: operation_kwargs[PROVISIONED_THROUGHPUT] = { READ_CAPACITY_UNITS: read_capacity_units, - WRITE_CAPACITY_UNITS: write_capacity_units + WRITE_CAPACITY_UNITS: write_capacity_units, } if global_secondary_index_updates: global_secondary_indexes_list = [] for index in global_secondary_index_updates: - global_secondary_indexes_list.append({ - UPDATE: { - INDEX_NAME: index.get('index_name'), - PROVISIONED_THROUGHPUT: { - READ_CAPACITY_UNITS: index.get('read_capacity_units'), - WRITE_CAPACITY_UNITS: index.get('write_capacity_units') + global_secondary_indexes_list.append( + { + UPDATE: { + INDEX_NAME: index.get("index_name"), + PROVISIONED_THROUGHPUT: { + READ_CAPACITY_UNITS: index.get("read_capacity_units"), + WRITE_CAPACITY_UNITS: index.get("write_capacity_units"), + }, } } - }) - operation_kwargs[GLOBAL_SECONDARY_INDEX_UPDATES] = global_secondary_indexes_list + ) + operation_kwargs[GLOBAL_SECONDARY_INDEX_UPDATES] = ( + global_secondary_indexes_list + ) try: return self.dispatch(UPDATE_TABLE, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: @@ -658,13 +727,11 @@ def list_tables( """ operation_kwargs: Dict[str, Any] = {} if exclusive_start_table_name: - operation_kwargs.update({ - EXCLUSIVE_START_TABLE_NAME: exclusive_start_table_name - }) + operation_kwargs.update( + {EXCLUSIVE_START_TABLE_NAME: exclusive_start_table_name} + ) if limit is not None: - operation_kwargs.update({ - LIMIT: limit - }) + operation_kwargs.update({LIMIT: limit}) try: return self.dispatch(LIST_TABLES, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: @@ -674,9 +741,7 @@ def describe_table(self, table_name: str) -> Dict: """ Performs the DescribeTable operation """ - operation_kwargs = { - TABLE_NAME: table_name - } + operation_kwargs = {TABLE_NAME: table_name} try: data = self.dispatch(DESCRIBE_TABLE, operation_kwargs) table_data = data.get(TABLE_KEY) @@ -690,8 +755,8 @@ def describe_table(self, table_name: str) -> Dict: except BotoCoreError as e: raise TableError("Unable to describe table: {}".format(e), e) except ClientError as e: - if 'ResourceNotFound' in e.response['Error']['Code']: - raise TableDoesNotExist(e.response['Error']['Message']) + if "ResourceNotFound" in e.response["Error"]["Code"]: + raise TableDoesNotExist(e.response["Error"]["Message"]) else: raise @@ -709,15 +774,10 @@ def get_item_attribute_map( if tbl is None: raise TableError("No such table {}".format(table_name)) return tbl.get_item_attribute_map( - attributes, - item_key=item_key, - pythonic_key=pythonic_key) + attributes, item_key=item_key, pythonic_key=pythonic_key + ) - def parse_attribute( - self, - attribute: Any, - return_type: bool = False - ) -> Any: + def parse_attribute(self, attribute: Any, return_type: bool = False) -> Any: """ Returns the attribute value, where the attribute can be a raw attribute value, or a dictionary containing the type: @@ -736,10 +796,7 @@ def parse_attribute( return attribute def get_attribute_type( - self, - table_name: str, - attribute_name: str, - value: Optional[Any] = None + self, table_name: str, attribute_name: str, value: Optional[Any] = None ) -> str: """ Returns the proper attribute type for a given attribute name @@ -755,7 +812,7 @@ def get_identifier_map( table_name: str, hash_key: str, range_key: Optional[str] = None, - key: str = KEY + key: str = KEY, ) -> Dict: """ Builds the identifier map that is common to several operations @@ -770,48 +827,60 @@ def get_consumed_capacity_map(self, return_consumed_capacity: str) -> Dict: Builds the consumed capacity map that is common to several operations """ if return_consumed_capacity.upper() not in RETURN_CONSUMED_CAPACITY_VALUES: - raise ValueError("{} must be one of {}".format(RETURN_ITEM_COLL_METRICS, RETURN_CONSUMED_CAPACITY_VALUES)) - return { - RETURN_CONSUMED_CAPACITY: str(return_consumed_capacity).upper() - } + raise ValueError( + "{} must be one of {}".format( + RETURN_ITEM_COLL_METRICS, RETURN_CONSUMED_CAPACITY_VALUES + ) + ) + return {RETURN_CONSUMED_CAPACITY: str(return_consumed_capacity).upper()} def get_return_values_map(self, return_values: str) -> Dict: """ Builds the return values map that is common to several operations """ if return_values.upper() not in RETURN_VALUES_VALUES: - raise ValueError("{} must be one of {}".format(RETURN_VALUES, RETURN_VALUES_VALUES)) - return { - RETURN_VALUES: str(return_values).upper() - } + raise ValueError( + "{} must be one of {}".format(RETURN_VALUES, RETURN_VALUES_VALUES) + ) + return {RETURN_VALUES: str(return_values).upper()} def get_return_values_on_condition_failure_map( - self, - return_values_on_condition_failure: str + self, return_values_on_condition_failure: str ) -> Dict: """ Builds the return values map that is common to several operations """ if return_values_on_condition_failure.upper() not in RETURN_VALUES_VALUES: - raise ValueError("{} must be one of {}".format( - RETURN_VALUES_ON_CONDITION_FAILURE, - RETURN_VALUES_ON_CONDITION_FAILURE_VALUES - )) + raise ValueError( + "{} must be one of {}".format( + RETURN_VALUES_ON_CONDITION_FAILURE, + RETURN_VALUES_ON_CONDITION_FAILURE_VALUES, + ) + ) return { - RETURN_VALUES_ON_CONDITION_FAILURE: str(return_values_on_condition_failure).upper() + RETURN_VALUES_ON_CONDITION_FAILURE: str( + return_values_on_condition_failure + ).upper() } def get_item_collection_map(self, return_item_collection_metrics: str) -> Dict: """ Builds the item collection map """ - if return_item_collection_metrics.upper() not in RETURN_ITEM_COLL_METRICS_VALUES: - raise ValueError("{} must be one of {}".format(RETURN_ITEM_COLL_METRICS, RETURN_ITEM_COLL_METRICS_VALUES)) - return { - RETURN_ITEM_COLL_METRICS: str(return_item_collection_metrics).upper() - } + if ( + return_item_collection_metrics.upper() + not in RETURN_ITEM_COLL_METRICS_VALUES + ): + raise ValueError( + "{} must be one of {}".format( + RETURN_ITEM_COLL_METRICS, RETURN_ITEM_COLL_METRICS_VALUES + ) + ) + return {RETURN_ITEM_COLL_METRICS: str(return_item_collection_metrics).upper()} - def get_exclusive_start_key_map(self, table_name: str, exclusive_start_key: str) -> Dict: + def get_exclusive_start_key_map( + self, table_name: str, exclusive_start_key: str + ) -> Dict: """ Builds the exclusive start key attribute map """ @@ -834,43 +903,58 @@ def get_operation_kwargs( return_values: Optional[str] = None, return_consumed_capacity: Optional[str] = None, return_item_collection_metrics: Optional[str] = None, - return_values_on_condition_failure: Optional[str] = None + return_values_on_condition_failure: Optional[str] = None, ) -> Dict: - self._check_condition('condition', condition) + self._check_condition("condition", condition) operation_kwargs: Dict[str, Any] = {} - name_placeholders: Dict[str, str] = {} + name_placeholders: Dict[str, str] = {} expression_attribute_values: Dict[str, Any] = {} operation_kwargs[TABLE_NAME] = table_name - operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key, key=key)) + operation_kwargs.update( + self.get_identifier_map(table_name, hash_key, range_key, key=key) + ) if attributes and operation_kwargs.get(ITEM) is not None: attrs = self.get_item_attribute_map(table_name, attributes) operation_kwargs[ITEM].update(attrs[ITEM]) if attributes_to_get is not None: - projection_expression = create_projection_expression(attributes_to_get, name_placeholders) + projection_expression = create_projection_expression( + attributes_to_get, name_placeholders + ) operation_kwargs[PROJECTION_EXPRESSION] = projection_expression if condition is not None: - condition_expression = condition.serialize(name_placeholders, expression_attribute_values) + condition_expression = condition.serialize( + name_placeholders, expression_attribute_values + ) operation_kwargs[CONDITION_EXPRESSION] = condition_expression if consistent_read is not None: operation_kwargs[CONSISTENT_READ] = consistent_read if return_values is not None: operation_kwargs.update(self.get_return_values_map(return_values)) if return_values_on_condition_failure is not None: - operation_kwargs.update(self.get_return_values_on_condition_failure_map(return_values_on_condition_failure)) + operation_kwargs.update( + self.get_return_values_on_condition_failure_map( + return_values_on_condition_failure + ) + ) if return_consumed_capacity is not None: - operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) + operation_kwargs.update( + self.get_consumed_capacity_map(return_consumed_capacity) + ) if return_item_collection_metrics is not None: - operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) + operation_kwargs.update( + self.get_item_collection_map(return_item_collection_metrics) + ) if actions is not None: update_expression = Update(*actions) operation_kwargs[UPDATE_EXPRESSION] = update_expression.serialize( - name_placeholders, - expression_attribute_values + name_placeholders, expression_attribute_values ) if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) + operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict( + name_placeholders + ) if expression_attribute_values: operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values return operation_kwargs @@ -895,7 +979,7 @@ def delete_item( condition=condition, return_values=return_values, return_consumed_capacity=return_consumed_capacity, - return_item_collection_metrics=return_item_collection_metrics + return_item_collection_metrics=return_item_collection_metrics, ) try: return self.dispatch(DELETE_ITEM, operation_kwargs) @@ -957,7 +1041,7 @@ def put_item( condition=condition, return_values=return_values, return_consumed_capacity=return_consumed_capacity, - return_item_collection_metrics=return_item_collection_metrics + return_item_collection_metrics=return_item_collection_metrics, ) try: return self.dispatch(PUT_ITEM, operation_kwargs) @@ -968,15 +1052,19 @@ def _get_transact_operation_kwargs( self, client_request_token: Optional[str] = None, return_consumed_capacity: Optional[str] = None, - return_item_collection_metrics: Optional[str] = None + return_item_collection_metrics: Optional[str] = None, ) -> Dict: operation_kwargs = {} if client_request_token is not None: operation_kwargs[CLIENT_REQUEST_TOKEN] = client_request_token if return_consumed_capacity is not None: - operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) + operation_kwargs.update( + self.get_consumed_capacity_map(return_consumed_capacity) + ) if return_item_collection_metrics is not None: - operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) + operation_kwargs.update( + self.get_item_collection_map(return_item_collection_metrics) + ) return operation_kwargs @@ -997,20 +1085,14 @@ def transact_write_items( transact_items.extend( {TRANSACT_CONDITION_CHECK: item} for item in condition_check_items ) - transact_items.extend( - {TRANSACT_DELETE: item} for item in delete_items - ) - transact_items.extend( - {TRANSACT_PUT: item} for item in put_items - ) - transact_items.extend( - {TRANSACT_UPDATE: item} for item in update_items - ) + transact_items.extend({TRANSACT_DELETE: item} for item in delete_items) + transact_items.extend({TRANSACT_PUT: item} for item in put_items) + transact_items.extend({TRANSACT_UPDATE: item} for item in update_items) operation_kwargs = self._get_transact_operation_kwargs( client_request_token=client_request_token, return_consumed_capacity=return_consumed_capacity, - return_item_collection_metrics=return_item_collection_metrics + return_item_collection_metrics=return_item_collection_metrics, ) operation_kwargs[TRANSACT_ITEMS] = transact_items @@ -1027,10 +1109,10 @@ def transact_get_items( """ Performs the TransactGet operation and returns the result """ - operation_kwargs = self._get_transact_operation_kwargs(return_consumed_capacity=return_consumed_capacity) - operation_kwargs[TRANSACT_ITEMS] = [ - {TRANSACT_GET: item} for item in get_items - ] + operation_kwargs = self._get_transact_operation_kwargs( + return_consumed_capacity=return_consumed_capacity + ) + operation_kwargs[TRANSACT_ITEMS] = [{TRANSACT_GET: item} for item in get_items] try: return self.dispatch(TRANSACT_GET_ITEMS, operation_kwargs) @@ -1050,27 +1132,35 @@ def batch_write_item( """ if put_items is None and delete_items is None: raise ValueError("Either put_items or delete_items must be specified") - operation_kwargs: Dict[str, Any] = { - REQUEST_ITEMS: { - table_name: [] - } - } + operation_kwargs: Dict[str, Any] = {REQUEST_ITEMS: {table_name: []}} if return_consumed_capacity: - operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) + operation_kwargs.update( + self.get_consumed_capacity_map(return_consumed_capacity) + ) if return_item_collection_metrics: - operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) + operation_kwargs.update( + self.get_item_collection_map(return_item_collection_metrics) + ) put_items_list = [] if put_items: for item in put_items: - put_items_list.append({ - PUT_REQUEST: self.get_item_attribute_map(table_name, item, pythonic_key=False) - }) + put_items_list.append( + { + PUT_REQUEST: self.get_item_attribute_map( + table_name, item, pythonic_key=False + ) + } + ) delete_items_list = [] if delete_items: for item in delete_items: - delete_items_list.append({ - DELETE_REQUEST: self.get_item_attribute_map(table_name, item, item_key=KEY, pythonic_key=False) - }) + delete_items_list.append( + { + DELETE_REQUEST: self.get_item_attribute_map( + table_name, item, item_key=KEY, pythonic_key=False + ) + } + ) operation_kwargs[REQUEST_ITEMS][table_name] = delete_items_list + put_items_list try: return self.dispatch(BATCH_WRITE_ITEM, operation_kwargs) @@ -1088,20 +1178,20 @@ def batch_get_item( """ Performs the batch get item operation """ - operation_kwargs: Dict[str, Any] = { - REQUEST_ITEMS: { - table_name: {} - } - } + operation_kwargs: Dict[str, Any] = {REQUEST_ITEMS: {table_name: {}}} args_map: Dict[str, Any] = {} name_placeholders: Dict[str, str] = {} if consistent_read: args_map[CONSISTENT_READ] = consistent_read if return_consumed_capacity: - operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) + operation_kwargs.update( + self.get_consumed_capacity_map(return_consumed_capacity) + ) if attributes_to_get is not None: - projection_expression = create_projection_expression(attributes_to_get, name_placeholders) + projection_expression = create_projection_expression( + attributes_to_get, name_placeholders + ) args_map[PROJECTION_EXPRESSION] = projection_expression if name_placeholders: args_map[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) @@ -1109,9 +1199,7 @@ def batch_get_item( keys_map: Dict[str, List] = {KEYS: []} for key in keys: - keys_map[KEYS].append( - self.get_item_attribute_map(table_name, key)[ITEM] - ) + keys_map[KEYS].append(self.get_item_attribute_map(table_name, key)[ITEM]) operation_kwargs[REQUEST_ITEMS][table_name].update(keys_map) try: return self.dispatch(BATCH_GET_ITEM, operation_kwargs) @@ -1134,7 +1222,7 @@ def get_item( hash_key=hash_key, range_key=range_key, consistent_read=consistent_read, - attributes_to_get=attributes_to_get + attributes_to_get=attributes_to_get, ) try: return self.dispatch(GET_ITEM, operation_kwargs) @@ -1157,26 +1245,34 @@ def scan( """ Performs the scan operation """ - self._check_condition('filter_condition', filter_condition) + self._check_condition("filter_condition", filter_condition) operation_kwargs: Dict[str, Any] = {TABLE_NAME: table_name} name_placeholders: Dict[str, str] = {} expression_attribute_values: Dict[str, Any] = {} if filter_condition is not None: - filter_expression = filter_condition.serialize(name_placeholders, expression_attribute_values) + filter_expression = filter_condition.serialize( + name_placeholders, expression_attribute_values + ) operation_kwargs[FILTER_EXPRESSION] = filter_expression if attributes_to_get is not None: - projection_expression = create_projection_expression(attributes_to_get, name_placeholders) + projection_expression = create_projection_expression( + attributes_to_get, name_placeholders + ) operation_kwargs[PROJECTION_EXPRESSION] = projection_expression if index_name: operation_kwargs[INDEX_NAME] = index_name if limit is not None: operation_kwargs[LIMIT] = limit if return_consumed_capacity: - operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) + operation_kwargs.update( + self.get_consumed_capacity_map(return_consumed_capacity) + ) if exclusive_start_key: - operation_kwargs.update(self.get_exclusive_start_key_map(table_name, exclusive_start_key)) + operation_kwargs.update( + self.get_exclusive_start_key_map(table_name, exclusive_start_key) + ) if segment is not None: operation_kwargs[SEGMENT] = segment if total_segments: @@ -1184,7 +1280,9 @@ def scan( if consistent_read: operation_kwargs[CONSISTENT_READ] = consistent_read if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) + operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict( + name_placeholders + ) if expression_attribute_values: operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values @@ -1196,7 +1294,7 @@ def scan( def query( self, table_name: str, - hash_key: Union[object, Sequence[object], Mapping[str, object]], + hash_key: Optional[Any] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Any] = None, attributes_to_get: Optional[Any] = None, @@ -1207,12 +1305,13 @@ def query( return_consumed_capacity: Optional[str] = None, scan_index_forward: Optional[bool] = None, select: Optional[str] = None, + hash_keys: Optional[Mapping[str, Any]] = None, ) -> Dict: """ Performs the Query operation and returns the result """ - self._check_condition('range_key_condition', range_key_condition) - self._check_condition('filter_condition', filter_condition) + self._check_condition("range_key_condition", range_key_condition) + self._check_condition("filter_condition", filter_condition) operation_kwargs: Dict[str, Any] = {TABLE_NAME: table_name} name_placeholders: Dict[str, str] = {} @@ -1223,44 +1322,69 @@ def query( raise TableError("No such table: {}".format(table_name)) if index_name: if not tbl.has_index_name(index_name): - raise ValueError("Table {} has no index: {}".format(table_name, index_name)) + raise ValueError( + "Table {} has no index: {}".format(table_name, index_name) + ) hash_keynames = tbl.get_index_hash_keynames(index_name) + range_keynames = tbl.get_index_range_keynames(index_name) else: hash_keynames = [tbl.hash_keyname] + range_keynames = [tbl.range_keyname] if tbl.range_keyname else [] hash_key_values = self._get_query_hash_key_values( hash_key, + hash_keys, hash_keynames, index_name=index_name, ) + self._validate_multi_range_key_condition( + range_key_condition, + range_keynames, + index_name=index_name, + ) key_condition = None for hash_keyname, hash_keyvalue in zip(hash_keynames, hash_key_values): hash_condition_value = { - self.get_attribute_type(table_name, hash_keyname, hash_keyvalue): self.parse_attribute(hash_keyvalue) + self.get_attribute_type( + table_name, hash_keyname, hash_keyvalue + ): self.parse_attribute(hash_keyvalue) } hash_condition = Path([hash_keyname]) == hash_condition_value - key_condition = hash_condition if key_condition is None else key_condition & hash_condition + key_condition = ( + hash_condition + if key_condition is None + else key_condition & hash_condition + ) if range_key_condition is not None: key_condition &= range_key_condition operation_kwargs[KEY_CONDITION_EXPRESSION] = key_condition.serialize( - name_placeholders, expression_attribute_values) + name_placeholders, expression_attribute_values + ) if filter_condition is not None: - filter_expression = filter_condition.serialize(name_placeholders, expression_attribute_values) + filter_expression = filter_condition.serialize( + name_placeholders, expression_attribute_values + ) operation_kwargs[FILTER_EXPRESSION] = filter_expression if attributes_to_get: - projection_expression = create_projection_expression(attributes_to_get, name_placeholders) + projection_expression = create_projection_expression( + attributes_to_get, name_placeholders + ) operation_kwargs[PROJECTION_EXPRESSION] = projection_expression if consistent_read: operation_kwargs[CONSISTENT_READ] = True if exclusive_start_key: - operation_kwargs.update(self.get_exclusive_start_key_map(table_name, exclusive_start_key)) + operation_kwargs.update( + self.get_exclusive_start_key_map(table_name, exclusive_start_key) + ) if index_name: operation_kwargs[INDEX_NAME] = index_name if limit is not None: operation_kwargs[LIMIT] = limit if return_consumed_capacity: - operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) + operation_kwargs.update( + self.get_consumed_capacity_map(return_consumed_capacity) + ) if select: if select.upper() not in SELECT_VALUES: raise ValueError("{} must be one of {}".format(SELECT, SELECT_VALUES)) @@ -1268,7 +1392,9 @@ def query( if scan_index_forward is not None: operation_kwargs[SCAN_INDEX_FORWARD] = scan_index_forward if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) + operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict( + name_placeholders + ) if expression_attribute_values: operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values @@ -1288,30 +1414,125 @@ def _reverse_dict(d): @staticmethod def _get_query_hash_key_values( - hash_key: Union[object, Sequence[object], Mapping[str, object]], + hash_key: Optional[Any], + hash_keys: Optional[Mapping[str, Any]], hash_keynames: Sequence[str], index_name: Optional[str] = None, - ) -> List[object]: + ) -> List[Any]: + if hash_key is not None and hash_keys is not None: + raise ValueError(f"Index {index_name} received both hash_key and hash_keys") if len(hash_keynames) == 1: - return [hash_key] - if isinstance(hash_key, (tuple, list)): - if len(hash_key) != len(hash_keynames): + if hash_keys is None: + if hash_key is None: + raise ValueError(f"Index {index_name} requires a hash_key") + if isinstance(hash_key, (tuple, list, Mapping)): + raise ValueError( + f"Index {index_name} expects a single hash_key value" + ) + return [hash_key] + return Connection._get_ordered_query_hash_key_values( + hash_keys, hash_keynames, index_name=index_name + ) + if hash_key is not None: + raise ValueError( + f"Index {index_name} has multiple hash key attributes; use hash_keys=..." + ) + if hash_keys is None: + raise ValueError(f"Index {index_name} requires hash_keys") + return Connection._get_ordered_query_hash_key_values( + hash_keys, hash_keynames, index_name=index_name + ) + + @staticmethod + def _get_ordered_query_hash_key_values( + hash_keys: Mapping[str, Any], + hash_keynames: Sequence[str], + index_name: Optional[str] = None, + ) -> List[Any]: + if not isinstance(hash_keys, Mapping): + raise ValueError(f"Index {index_name} expects hash_keys to be a mapping") + missing_keys = [ + keyname for keyname in hash_keynames if keyname not in hash_keys + ] + if missing_keys: + raise ValueError( + f"Index {index_name} requires values for hash keys: {', '.join(missing_keys)}" + ) + extra_keys = [keyname for keyname in hash_keys if keyname not in hash_keynames] + if extra_keys: + raise ValueError( + f"Index {index_name} received unknown hash keys: {', '.join(extra_keys)}" + ) + return [hash_keys[keyname] for keyname in hash_keynames] + + @staticmethod + def _flatten_and_conditions(condition: Condition) -> List[Condition]: + if condition.operator == "AND": + conditions = [] + for value in condition.values: + conditions.extend(Connection._flatten_and_conditions(value)) + return conditions + return [condition] + + @staticmethod + def _condition_key_name(condition: Condition) -> Optional[str]: + path = getattr(condition.values[0], "path", None) if condition.values else None + if not isinstance(path, list) or len(path) != 1: + return None + return path[0] + + @staticmethod + def _validate_multi_range_key_condition( + range_key_condition: Optional[Condition], + range_keynames: Sequence[str], + index_name: Optional[str] = None, + ) -> None: + if range_key_condition is None or len(range_keynames) <= 1: + return + + valid_operators = {"=", "<", "<=", ">", ">=", "BETWEEN", "begins_with"} + conditions_by_key: Dict[str, Condition] = {} + context = f"Index {index_name}" + for condition in Connection._flatten_and_conditions(range_key_condition): + if condition.operator not in valid_operators: raise ValueError( - f"Index {index_name} expects {len(hash_keynames)} hash key values, got {len(hash_key)}" + f"{context} range_key_condition uses unsupported range key operator: {condition.operator}" ) - return list(hash_key) - if isinstance(hash_key, Mapping): - missing_keys = [keyname for keyname in hash_keynames if keyname not in hash_key] - if missing_keys: + key_name = Connection._condition_key_name(condition) + if key_name not in range_keynames: raise ValueError( - f"Index {index_name} requires values for hash keys: {', '.join(missing_keys)}" + f"{context} range_key_condition must only use range keys: {', '.join(range_keynames)}" ) - extra_keys = [keyname for keyname in hash_key if keyname not in hash_keynames] - if extra_keys: + if key_name in conditions_by_key: raise ValueError( - f"Index {index_name} received unknown hash keys: {', '.join(extra_keys)}" + f"{context} range_key_condition has multiple conditions for range key: {key_name}" ) - return [hash_key[keyname] for keyname in hash_keynames] - raise ValueError( - f"Index {index_name} expects {len(hash_keynames)} hash key values as tuple/list" + conditions_by_key[key_name] = condition + + if not conditions_by_key: + return + + highest_position = max( + range_keynames.index(key_name) for key_name in conditions_by_key ) + missing_prefix_keys = [ + key_name + for key_name in range_keynames[:highest_position] + if key_name not in conditions_by_key + ] + if missing_prefix_keys: + raise ValueError( + f"{context} range_key_condition must include equality conditions for preceding range keys: " + f"{', '.join(missing_prefix_keys)}" + ) + + non_equal_prefix_keys = [ + key_name + for key_name in range_keynames[:highest_position] + if conditions_by_key[key_name].operator != "=" + ] + if non_equal_prefix_keys: + raise ValueError( + f"{context} range_key_condition must use equality for preceding range keys: " + f"{', '.join(non_equal_prefix_keys)}" + ) diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 430958413..7c87c3aeb 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -3,7 +3,7 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ -from typing import Any, Dict, Mapping, Optional, Sequence, Union +from typing import Any, Dict, Mapping, Optional, Sequence from pynamodb.connection.base import Connection, MetaTable from pynamodb.constants import DEFAULT_BILLING_MODE, KEY @@ -238,7 +238,8 @@ def scan( def query( self, - hash_key: Union[object, Sequence[object], Mapping[str, object]], + hash_key: Optional[Any] = None, + hash_keys: Optional[Mapping[str, Any]] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Any] = None, attributes_to_get: Optional[Any] = None, @@ -266,6 +267,7 @@ def query( return_consumed_capacity=return_consumed_capacity, scan_index_forward=scan_index_forward, select=select, + hash_keys=hash_keys, ) def describe_table(self) -> Dict: @@ -299,7 +301,8 @@ def update_table( self.table_name, read_capacity_units=read_capacity_units, write_capacity_units=write_capacity_units, - global_secondary_index_updates=global_secondary_index_updates) + global_secondary_index_updates=global_secondary_index_updates, + ) def create_table( self, @@ -326,5 +329,5 @@ def create_table( local_secondary_indexes=local_secondary_indexes, stream_specification=stream_specification, billing_mode=billing_mode, - tags=tags + tags=tags, ) diff --git a/pynamodb/indexes.py b/pynamodb/indexes.py index c49da3c28..ea6029e53 100644 --- a/pynamodb/indexes.py +++ b/pynamodb/indexes.py @@ -1,26 +1,27 @@ """ PynamoDB Indexes """ -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union -from typing import TYPE_CHECKING +from inspect import getmembers +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Mapping, Optional, Type, TypeVar, Union from pynamodb._schema import IndexSchema, GlobalSecondaryIndexSchema from pynamodb._schema import ModelSchema +from pynamodb.attributes import Attribute from pynamodb.constants import ( INCLUDE, ALL, KEYS_ONLY, ATTR_NAME, ATTR_TYPE, KEY_TYPE, PROJECTION_TYPE, NON_KEY_ATTRIBUTES, READ_CAPACITY_UNITS, WRITE_CAPACITY_UNITS, ) -from pynamodb.attributes import Attribute from pynamodb.expressions.condition import Condition from pynamodb.pagination import ResultIterator from pynamodb.types import HASH, RANGE + if TYPE_CHECKING: from pynamodb.models import Model -_KeyType = object -_HashKeyInputType = Union[_KeyType, Tuple[_KeyType, ...], List[_KeyType]] -_SerializedHashKeyType = Union[_KeyType, Tuple[_KeyType, ...]] +_KeyType = Any +_HashKeysInputType = Mapping[str, _KeyType] +_SerializedHashKeyType = Union[_KeyType, Dict[str, _KeyType]] _M = TypeVar('_M', bound='Model') @@ -28,17 +29,20 @@ class Index(Generic[_M]): """ Base class for secondary indexes """ + Meta: Any = None _model: _M @staticmethod - def _get_attributes_in_declaration_order(index_cls: Type['Index']) -> Dict[str, Attribute]: + def _get_attributes_in_declaration_order( + index_cls: Type['Index'], + ) -> Dict[str, Attribute]: """ Returns attributes in declaration order, respecting overrides. """ attributes: Dict[str, Attribute] = {} for base in reversed(index_cls.__mro__): - for name, attribute in getattr(base, "__dict__", {}).items(): + for name, attribute in getattr(base, '__dict__', {}).items(): if isinstance(attribute, Attribute): # If a subclass overrides an attribute, preserve the subclass declaration order. if name in attributes: @@ -54,25 +58,29 @@ def __init_subclass__(cls, **kwargs): def __init__(self) -> None: if self.Meta is None: - raise ValueError("Indexes require a Meta class for settings") - if not hasattr(self.Meta, "projection"): - raise ValueError("No projection defined, define a projection for this class") + raise ValueError('Indexes require a Meta class for settings') + if not hasattr(self.Meta, 'projection'): + raise ValueError('No projection defined, define a projection for this class') def __set_name__(self, owner: Type[_M], name: str): - if not hasattr(self.Meta, "index_name"): + if not hasattr(self.Meta, 'index_name'): self.Meta.index_name = name def count( self, - hash_key: _HashKeyInputType, + hash_key: Optional[_KeyType] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, limit: Optional[int] = None, rate_limit: Optional[float] = None, + hash_keys: Optional[_HashKeysInputType] = None, ) -> int: """ Count on an index + + :param hash_key: The hash key to query. Can be None when ``hash_keys`` is provided. + :param hash_keys: Named hash key values for indexes with multiple hash key attributes. """ return self._model.count( hash_key, @@ -82,11 +90,12 @@ def count( consistent_read=consistent_read, limit=limit, rate_limit=rate_limit, + hash_keys=hash_keys, ) def query( self, - hash_key: _HashKeyInputType, + hash_key: Optional[_KeyType] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -96,9 +105,13 @@ def query( attributes_to_get: Optional[List[str]] = None, page_size: Optional[int] = None, rate_limit: Optional[float] = None, + hash_keys: Optional[_HashKeysInputType] = None, ) -> ResultIterator[_M]: """ Queries an index + + :param hash_key: The hash key to query. Can be None when ``hash_keys`` is provided. + :param hash_keys: Named hash key values for indexes with multiple hash key attributes. """ return self._model.query( hash_key, @@ -112,6 +125,7 @@ def query( attributes_to_get=attributes_to_get, page_size=page_size, rate_limit=rate_limit, + hash_keys=hash_keys, ) def scan( @@ -160,26 +174,199 @@ def _range_key_attributes(cls) -> List[Attribute]: return [attr for attr in cls.Meta.attributes.values() if attr.is_range_key] @classmethod - def _serialize_hash_key_values(cls, hash_key: _HashKeyInputType) -> _SerializedHashKeyType: + def _hash_key_aliases( + cls, hash_key_attributes: List[Attribute] + ) -> Dict[str, Attribute]: + aliases: Dict[str, Attribute] = {} + hash_key_attribute_ids = {id(attr) for attr in hash_key_attributes} + for attr_name, attr in cls.Meta.attributes.items(): + if id(attr) in hash_key_attribute_ids: + aliases[attr_name] = attr + aliases[attr.attr_name] = attr + return aliases + + @staticmethod + def _flatten_and_conditions(condition: Condition) -> List[Condition]: + if condition.operator == 'AND': + conditions: List[Condition] = [] + for value in condition.values: + conditions.extend(Index._flatten_and_conditions(value)) + return conditions + return [condition] + + @staticmethod + def _condition_key_name(condition: Condition) -> Optional[str]: + path = getattr(condition.values[0], 'path', None) if condition.values else None + if not isinstance(path, list) or len(path) != 1: + return None + return path[0] + + @staticmethod + def _validate_multi_key_condition( + range_key_condition: Condition, + range_keynames: List[str], + context: str, + ) -> None: + valid_operators = {'=', '<', '<=', '>', '>=', 'BETWEEN', 'begins_with'} + conditions_by_key: Dict[str, Condition] = {} + for condition in Index._flatten_and_conditions(range_key_condition): + if condition.operator not in valid_operators: + raise ValueError( + f'{context} range_key_condition uses unsupported range key operator: {condition.operator}' + ) + key_name = Index._condition_key_name(condition) + if key_name not in range_keynames: + raise ValueError( + f'{context} range_key_condition must only use range keys: ' + ', '.join(range_keynames) + ) + if key_name in conditions_by_key: + raise ValueError( + f'{context} range_key_condition has multiple conditions for range key: {key_name}' + ) + conditions_by_key[key_name] = condition + + if not conditions_by_key: + return + + highest_position = max( + range_keynames.index(key_name) for key_name in conditions_by_key + ) + missing_prefix_keys = [ + key_name + for key_name in range_keynames[:highest_position] + if key_name not in conditions_by_key + ] + if missing_prefix_keys: + raise ValueError( + f'{context} range_key_condition must include equality conditions for preceding range keys: ' + + ', '.join(missing_prefix_keys) + ) + + non_equal_prefix_keys = [ + key_name + for key_name in range_keynames[:highest_position] + if conditions_by_key[key_name].operator != '=' + ] + if non_equal_prefix_keys: + raise ValueError( + f'{context} range_key_condition must use equality for preceding range keys: ' + + ', '.join(non_equal_prefix_keys) + ) + + @classmethod + def _serialize_hash_key_values( + cls, + hash_key: Optional[_KeyType] = None, + hash_keys: Optional[_HashKeysInputType] = None, + ) -> _SerializedHashKeyType: hash_key_attributes = cls._hash_key_attributes() - if len(hash_key_attributes) <= 1: - if len(hash_key_attributes) == 0: - raise ValueError(f"{cls.__name__} has no hash key attributes") - return hash_key_attributes[0].serialize(hash_key) + if not hash_key_attributes: + raise ValueError(f'{cls.__name__} has no hash key attributes') + + if hash_key is not None and hash_keys is not None: + raise ValueError(f'{cls.__name__} received both hash_key and hash_keys') + + if len(hash_key_attributes) == 1: + if hash_keys is None: + if hash_key is None: + raise ValueError(f'{cls.__name__} requires a hash_key') + if isinstance(hash_key, (tuple, list)): + raise ValueError(f'{cls.__name__} expects a single hash_key value') + if isinstance(hash_key, Mapping): + raise ValueError( + f'{cls.__name__} expects hash_keys=... for named hash key values' + ) + return hash_key_attributes[0].serialize(hash_key) + + hash_key_values = cls._get_ordered_hash_key_values( + hash_keys, hash_key_attributes + ) + return hash_key_attributes[0].serialize(hash_key_values[0]) - if not isinstance(hash_key, (tuple, list)): + if hash_keys is None: + if hash_key is None: + raise ValueError(f'{cls.__name__} requires hash_keys') raise ValueError( - f"{cls.__name__} expects {len(hash_key_attributes)} hash key values as a tuple/list" + f'{cls.__name__} has multiple hash key attributes; use hash_keys=...' ) - if len(hash_key) != len(hash_key_attributes): + + hash_key_values = cls._get_ordered_hash_key_values( + hash_keys, hash_key_attributes + ) + return { + attr.attr_name: attr.serialize(value) + for attr, value in zip(hash_key_attributes, hash_key_values) + } + + @classmethod + def serialize_hash_key_values( + cls, + hash_key: Optional[_KeyType] = None, + hash_keys: Optional[_HashKeysInputType] = None, + ) -> _SerializedHashKeyType: + return cls._serialize_hash_key_values(hash_key, hash_keys=hash_keys) + + @classmethod + def _get_ordered_hash_key_values( + cls, + hash_keys: _HashKeysInputType, + hash_key_attributes: List[Attribute], + ) -> List[_KeyType]: + if not isinstance(hash_keys, Mapping): + raise ValueError(f'{cls.__name__} expects hash_keys to be a mapping') + + expected_aliases = cls._hash_key_aliases(hash_key_attributes) + + values_by_attr_name: Dict[str, _KeyType] = {} + unknown_keys = [] + for key, value in hash_keys.items(): + key_name = key + attr = expected_aliases.get(key_name) + if attr is None: + unknown_keys.append(str(key_name)) + continue + if attr.attr_name in values_by_attr_name: + raise ValueError( + f'{cls.__name__} received duplicate value for hash key: {attr.attr_name}' + ) + values_by_attr_name[attr.attr_name] = value + + if unknown_keys: raise ValueError( - f"{cls.__name__} expects {len(hash_key_attributes)} hash key values, got {len(hash_key)}" + f'{cls.__name__} received unknown hash keys: ' + ', '.join(unknown_keys) ) - return tuple( - attr.serialize(value) - for attr, value in zip(hash_key_attributes, hash_key) + + missing_keys = [ + attr.attr_name + for attr in hash_key_attributes + if attr.attr_name not in values_by_attr_name + ] + if missing_keys: + raise ValueError( + f'{cls.__name__} requires values for hash keys: ' + ', '.join(missing_keys) + ) + + return [values_by_attr_name[attr.attr_name] for attr in hash_key_attributes] + + @classmethod + def _validate_range_key_condition( + cls, range_key_condition: Optional[Condition] + ) -> None: + range_key_attributes = cls._range_key_attributes() + if range_key_condition is None or len(range_key_attributes) <= 1: + return + cls._validate_multi_key_condition( + range_key_condition, + [attr.attr_name for attr in range_key_attributes], + cls.__name__, ) + @classmethod + def validate_range_key_condition( + cls, range_key_condition: Optional[Condition] + ) -> None: + cls._validate_range_key_condition(range_key_condition) + @classmethod def _validate_key_attributes(cls) -> None: """ @@ -209,26 +396,38 @@ def _get_schema(cls) -> IndexSchema: hash_key_attributes = cls._hash_key_attributes() range_key_attributes = cls._range_key_attributes() + for attr_cls in range_key_attributes: + schema['attribute_definitions'].append( + { + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type, + } + ) for attr_cls in hash_key_attributes: - schema['attribute_definitions'].append({ - ATTR_NAME: attr_cls.attr_name, - ATTR_TYPE: attr_cls.attr_type, - }) - schema['key_schema'].append({ - ATTR_NAME: attr_cls.attr_name, - KEY_TYPE: HASH, - }) + schema['attribute_definitions'].append( + { + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type, + } + ) + for attr_cls in hash_key_attributes: + schema['key_schema'].append( + { + ATTR_NAME: attr_cls.attr_name, + KEY_TYPE: HASH, + } + ) for attr_cls in range_key_attributes: - schema['attribute_definitions'].append({ - ATTR_NAME: attr_cls.attr_name, - ATTR_TYPE: attr_cls.attr_type, - }) - schema['key_schema'].append({ - ATTR_NAME: attr_cls.attr_name, - KEY_TYPE: RANGE, - }) + schema['key_schema'].append( + { + ATTR_NAME: attr_cls.attr_name, + KEY_TYPE: RANGE, + } + ) if cls.Meta.projection.non_key_attributes: - schema['projection'][NON_KEY_ATTRIBUTES] = cls.Meta.projection.non_key_attributes + schema['projection'][NON_KEY_ATTRIBUTES] = ( + cls.Meta.projection.non_key_attributes + ) return schema @@ -236,16 +435,15 @@ class GlobalSecondaryIndex(Index[_M]): """ A global secondary index """ + @classmethod def _validate_key_attributes(cls) -> None: hash_keys = cls._hash_key_attributes() range_keys = cls._range_key_attributes() - if len(hash_keys) == 0: - raise ValueError(f"{cls.__name__} must have at least one hash key attribute") if len(hash_keys) > 4: - raise ValueError(f"{cls.__name__} supports at most 4 hash key attributes") + raise ValueError(f'{cls.__name__} supports at most 4 hash key attributes') if len(range_keys) > 4: - raise ValueError(f"{cls.__name__} supports at most 4 range key attributes") + raise ValueError(f'{cls.__name__} supports at most 4 range key attributes') @classmethod def _update_model_schema(cls, schema: ModelSchema) -> None: @@ -255,9 +453,13 @@ def _update_model_schema(cls, schema: ModelSchema) -> None: } if hasattr(cls.Meta, 'read_capacity_units'): - index_schema['provisioned_throughput'][READ_CAPACITY_UNITS] = cls.Meta.read_capacity_units + index_schema['provisioned_throughput'][READ_CAPACITY_UNITS] = ( + cls.Meta.read_capacity_units + ) if hasattr(cls.Meta, 'write_capacity_units'): - index_schema['provisioned_throughput'][WRITE_CAPACITY_UNITS] = cls.Meta.write_capacity_units + index_schema['provisioned_throughput'][WRITE_CAPACITY_UNITS] = ( + cls.Meta.write_capacity_units + ) schema['global_secondary_indexes'].append(index_schema) # With polymorphism, indexes can use the same attribute, e.g. index1 on (thread_id, created_at) @@ -271,14 +473,15 @@ class LocalSecondaryIndex(Index[_M]): """ A local secondary index """ + @classmethod def _validate_key_attributes(cls) -> None: hash_keys = cls._hash_key_attributes() range_keys = cls._range_key_attributes() if len(hash_keys) > 1: - raise ValueError(f"{cls.__name__} supports at most one hash key attribute") + raise ValueError(f'{cls.__name__} supports at most one hash key attribute') if len(range_keys) > 1: - raise ValueError(f"{cls.__name__} supports at most one range key attribute") + raise ValueError(f'{cls.__name__} supports at most one range key attribute') @classmethod def _update_model_schema(cls, schema: ModelSchema) -> None: @@ -291,11 +494,11 @@ def _update_model_schema(cls, schema: ModelSchema) -> None: schema['attribute_definitions'].append(attr_def) - class Projection: """ A class for presenting projections """ + projection_type: Any = None non_key_attributes: Any = None @@ -304,6 +507,7 @@ class KeysOnlyProjection(Projection): """ Keys only projection """ + projection_type = KEYS_ONLY @@ -311,11 +515,14 @@ class IncludeProjection(Projection): """ An INCLUDE projection """ + projection_type = INCLUDE def __init__(self, non_attr_keys: Optional[List[str]] = None) -> None: if not non_attr_keys: - raise ValueError("The INCLUDE type projection requires a list of string attribute names") + raise ValueError( + 'The INCLUDE type projection requires a list of string attribute names' + ) self.non_key_attributes = non_attr_keys @@ -323,4 +530,5 @@ class AllProjection(Projection): """ An ALL projection """ + projection_type = ALL diff --git a/pynamodb/models.py b/pynamodb/models.py index 136a66a4c..7ae4b4bdf 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -1,28 +1,14 @@ """ DynamoDB Models for PynamoDB """ -import random -import time import logging -import warnings import sys +import time +import warnings from copy import deepcopy from inspect import getmembers -from typing import Any -from typing import Dict -from typing import Generic -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Mapping -from typing import Optional -from typing import Sequence -from typing import Text -from typing import Tuple -from typing import Type -from typing import TypeVar -from typing import Union -from typing import cast +from typing import Any, Dict, Generic, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Tuple, Type, \ + TypeVar, Union, cast from pynamodb._schema import ModelSchema from pynamodb.connection.base import MetaTable @@ -62,7 +48,6 @@ _KeyType = object _HashKeyQueryType = Union[_KeyType, Tuple[_KeyType, ...], List[_KeyType]] - log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) @@ -71,6 +56,7 @@ class BatchWrite(Generic[_T]): """ A class for batch writes """ + def __init__(self, model: Type[_T], auto_commit: bool = True): self.model = model self.auto_commit = auto_commit @@ -93,10 +79,10 @@ def save(self, put_item: _T) -> None: """ if len(self.pending_operations) == self.max_operations: if not self.auto_commit: - raise ValueError("DynamoDB allows a maximum of 25 batch operations") + raise ValueError('DynamoDB allows a maximum of 25 batch operations') else: self.commit() - self.pending_operations.append({"action": PUT, "item": put_item}) + self.pending_operations.append({'action': PUT, 'item': put_item}) def delete(self, del_item: _T) -> None: """ @@ -113,10 +99,10 @@ def delete(self, del_item: _T) -> None: """ if len(self.pending_operations) == self.max_operations: if not self.auto_commit: - raise ValueError("DynamoDB allows a maximum of 25 batch operations") + raise ValueError('DynamoDB allows a maximum of 25 batch operations') else: self.commit() - self.pending_operations.append({"action": DELETE, "item": del_item}) + self.pending_operations.append({'action': DELETE, 'item': del_item}) def __enter__(self): return self @@ -132,7 +118,7 @@ def commit(self) -> None: """ Writes all of the changes that are pending """ - log.debug("%s committing batch operation", self.model) + log.debug('%s committing batch operation', self.model) put_items = [] delete_items = [] for item in self.pending_operations: @@ -158,7 +144,7 @@ def commit(self) -> None: retries += 1 if retries >= self.model.Meta.max_retry_attempts: self.failed_operations = unprocessed_items - raise PutError("Failed to batch write items: max_retry_attempts exceeded") + raise PutError('Failed to batch write items: max_retry_attempts exceeded') put_items = [] delete_items = [] for item in unprocessed_items: @@ -166,7 +152,11 @@ def commit(self) -> None: put_items.append(item.get(PUT_REQUEST).get(ITEM)) # type: ignore elif DELETE_REQUEST in item: delete_items.append(item.get(DELETE_REQUEST).get(KEY)) # type: ignore - log.info("Resending %d unprocessed keys for batch operation (retry %d)", len(unprocessed_items), retries) + log.info( + 'Resending %d unprocessed keys for batch operation (retry %d)', + len(unprocessed_items), + retries, + ) data = self.model._get_connection().batch_write_item( put_items=put_items, delete_items=delete_items, @@ -197,6 +187,7 @@ class MetaModel(AttributeContainerMeta): """ Model meta class """ + def __new__(cls, name, bases, namespace, discriminator=None): # Defined so that the discriminator can be set in the class definition. return super().__new__(cls, name, bases, namespace) @@ -208,24 +199,24 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: for attr_name, attribute in cls.get_attributes().items(): if attribute.is_hash_key: if cls._hash_keyname and cls._hash_keyname != attr_name: - raise ValueError(f"{cls.__name__} has more than one hash key: {cls._hash_keyname}, {attr_name}") + raise ValueError(f'{cls.__name__} has more than one hash key: {cls._hash_keyname}, {attr_name}') cls._hash_keyname = attr_name if attribute.is_range_key: if cls._range_keyname and cls._range_keyname != attr_name: - raise ValueError(f"{cls.__name__} has more than one range key: {cls._range_keyname}, {attr_name}") + raise ValueError(f'{cls.__name__} has more than one range key: {cls._range_keyname}, {attr_name}') cls._range_keyname = attr_name if isinstance(attribute, VersionAttribute): if cls._version_attribute_name and cls._version_attribute_name != attr_name: raise ValueError( - "The model has more than one Version attribute: {}, {}" - .format(cls._version_attribute_name, attr_name) + 'The model has more than one Version attribute: {}, {}'.format( + cls._version_attribute_name, attr_name + ) ) cls._version_attribute_name = attr_name ttl_attr_names = [name for name, attr in cls.get_attributes().items() if isinstance(attr, TTLAttribute)] if len(ttl_attr_names) > 1: - raise ValueError("{} has more than one TTL attribute: {}".format( - cls.__name__, ", ".join(ttl_attr_names))) + raise ValueError('{} has more than one TTL attribute: {}'.format(cls.__name__, ', '.join(ttl_attr_names))) if isinstance(namespace, dict): for attr_name, attr_obj in namespace.items(): @@ -235,7 +226,7 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: if not hasattr(attr_obj, HOST): setattr(attr_obj, HOST, get_settings_value('host')) if hasattr(attr_obj, 'session_cls') or hasattr(attr_obj, 'request_timeout_seconds'): - warnings.warn("The `session_cls` and `request_timeout_second` options are no longer supported") + warnings.warn('The `session_cls` and `request_timeout_second` options are no longer supported') if not hasattr(attr_obj, 'connect_timeout_seconds'): setattr(attr_obj, 'connect_timeout_seconds', get_settings_value('connect_timeout_seconds')) if not hasattr(attr_obj, 'read_timeout_seconds'): @@ -254,13 +245,13 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: setattr(attr_obj, 'aws_session_token', None) # create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist, - # so that "except Model.DoesNotExist:" would not catch other models' exceptions + # so that 'except Model.DoesNotExist:' would not catch other models' exceptions if 'DoesNotExist' not in namespace: exception_attrs = { '__module__': namespace.get('__module__'), - '__qualname__': f'{cls.__qualname__}.{"DoesNotExist"}', + '__qualname__': f'{cls.__qualname__}.DoesNotExist', } - cls.DoesNotExist = type('DoesNotExist', (DoesNotExist, ), exception_attrs) + cls.DoesNotExist = type('DoesNotExist', (DoesNotExist,), exception_attrs) @staticmethod def _initialize_indexes(cls): @@ -296,11 +287,11 @@ class Model(AttributeContainer, metaclass=MetaModel): _indexes: Dict[str, Index] def __init__( - self, - hash_key: Optional[_KeyType] = None, - range_key: Optional[_KeyType] = None, - _user_instantiated: bool = True, - **attributes: Any, + self, + hash_key: Optional[_KeyType] = None, + range_key: Optional[_KeyType] = None, + _user_instantiated: bool = True, + **attributes: Any, ) -> None: """ :param hash_key: Required. The hash key for this object. @@ -309,20 +300,24 @@ def __init__( """ if hash_key is not None: if self._hash_keyname is None: - raise ValueError(f"This model has no hash key, but a hash key value was provided: {hash_key}") + raise ValueError( + f'This model has no hash key, but a hash key value was provided: {hash_key}' + ) attributes[self._hash_keyname] = hash_key if range_key is not None: if self._range_keyname is None: - raise ValueError(f"This model has no range key, but a range key value was provided: {range_key}") + raise ValueError( + f'This model has no range key, but a range key value was provided: {range_key}' + ) attributes[self._range_keyname] = range_key super(Model, self).__init__(_user_instantiated=_user_instantiated, **attributes) @classmethod def batch_get( - cls: Type[_T], - items: Iterable[Union[_KeyType, Iterable[_KeyType]]], - consistent_read: Optional[bool] = None, - attributes_to_get: Optional[Sequence[str]] = None, + cls: Type[_T], + items: Iterable[Union[_KeyType, Iterable[_KeyType]]], + consistent_read: Optional[bool] = None, + attributes_to_get: Optional[Sequence[str]] = None, ) -> Iterator[_T]: """ BatchGetItem for this model @@ -351,23 +346,27 @@ def batch_get( item = items.pop() if range_key_attribute: if isinstance(item, str): - raise ValueError(f'Invalid key value {item!r}: ' - 'expected non-str iterable with exactly 2 elements (hash key, range key)') + raise ValueError( + f'Invalid key value {item!r}: ' + 'expected non-str iterable with exactly 2 elements (hash key, range key)' + ) try: hash_key, range_key = item except (TypeError, ValueError): - raise ValueError(f'Invalid key value {item!r}: ' - 'expected iterable with exactly 2 elements (hash key, range key)') + raise ValueError( + f'Invalid key value {item!r}: ' + 'expected iterable with exactly 2 elements (hash key, range key)' + ) hash_key_ser, range_key_ser = cls._serialize_keys(hash_key, range_key) - keys_to_get.append({ - hash_key_attribute.attr_name: hash_key_ser, - range_key_attribute.attr_name: range_key_ser, - }) + keys_to_get.append( + { + hash_key_attribute.attr_name: hash_key_ser, + range_key_attribute.attr_name: range_key_ser, + } + ) else: hash_key_ser, _ = cls._serialize_keys(item) - keys_to_get.append({ - hash_key_attribute.attr_name: hash_key_ser - }) + keys_to_get.append({hash_key_attribute.attr_name: hash_key_ser}) while keys_to_get: page, unprocessed_keys = cls._batch_get_page( @@ -395,7 +394,12 @@ def batch_write(cls: Type[_T], auto_commit: bool = True) -> BatchWrite[_T]: """ return BatchWrite(cls, auto_commit=auto_commit) - def delete(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Any: + def delete( + self, + condition: Optional[Condition] = None, + *, + add_version_condition: bool = True, + ) -> Any: """ Deletes this object from DynamoDB. @@ -410,9 +414,17 @@ def delete(self, condition: Optional[Condition] = None, *, add_version_condition if add_version_condition and version_condition is not None: condition &= version_condition - return self._get_connection().delete_item(hk_value, range_key=rk_value, condition=condition) + return self._get_connection().delete_item( + hk_value, range_key=rk_value, condition=condition + ) - def update(self, actions: List[Action], condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Any: + def update( + self, + actions: List[Action], + condition: Optional[Condition] = None, + *, + add_version_condition: bool = True, + ) -> Any: """ Updates an item using the UpdateItem operation. @@ -422,30 +434,47 @@ def update(self, actions: List[Action], condition: Optional[Condition] = None, * :param add_version_condition: For models which have a :class:`~pynamodb.attributes.VersionAttribute`, specifies whether only to update if the version matches the model that is currently loaded. Set to `False` for a 'last write wins' strategy. - Regardless, the version will always be incremented to prevent "rollbacks" by concurrent :meth:`save` calls. + Regardless, the version will always be incremented to prevent 'rollbacks' by concurrent :meth:`save` calls. :raises pynamodb.exceptions.UpdateError: if the `condition` is not met """ if not isinstance(actions, list) or len(actions) == 0: - raise TypeError("the value of `actions` is expected to be a non-empty list") + raise TypeError('the value of `actions` is expected to be a non-empty list') hk_value, rk_value = self._get_hash_range_key_serialized_values() version_condition = self._handle_version_attribute(actions=actions) if add_version_condition and version_condition is not None: condition &= version_condition - data = self._get_connection().update_item(hk_value, range_key=rk_value, return_values=ALL_NEW, condition=condition, actions=actions) + data = self._get_connection().update_item( + hk_value, + range_key=rk_value, + return_values=ALL_NEW, + condition=condition, + actions=actions, + ) item_data = data[ATTRIBUTES] stored_cls = self._get_discriminator_class(item_data) if stored_cls and stored_cls != type(self): - raise ValueError("Cannot update this item from the returned class: {}".format(stored_cls.__name__)) + raise ValueError( + 'Cannot update this item from the returned class: {}'.format( + stored_cls.__name__ + ) + ) self.deserialize(item_data) return data - def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Dict[str, Any]: + def save( + self, + condition: Optional[Condition] = None, + *, + add_version_condition: bool = True, + ) -> Dict[str, Any]: """ Save this object to dynamodb """ - args, kwargs = self._get_save_args(condition=condition, add_version_condition=add_version_condition) + args, kwargs = self._get_save_args( + condition=condition, add_version_condition=add_version_condition + ) data = self._get_connection().put_item(*args, **kwargs) self.update_local_version_attribute() return data @@ -459,22 +488,28 @@ def refresh(self, consistent_read: bool = False) -> None: :raises ModelInstance.DoesNotExist: if the object to be updated does not exist """ hk_value, rk_value = self._get_hash_range_key_serialized_values() - attrs = self._get_connection().get_item(hk_value, range_key=rk_value, consistent_read=consistent_read) + attrs = self._get_connection().get_item( + hk_value, range_key=rk_value, consistent_read=consistent_read + ) item_data = attrs.get(ITEM, None) if item_data is None: - raise self.DoesNotExist("This item does not exist in the table.") + raise self.DoesNotExist('This item does not exist in the table.') stored_cls = self._get_discriminator_class(item_data) if stored_cls and stored_cls != type(self): - raise ValueError("Cannot refresh this item from the returned class: {}".format(stored_cls.__name__)) + raise ValueError( + 'Cannot refresh this item from the returned class: {}'.format( + stored_cls.__name__ + ) + ) self.deserialize(item_data) def get_update_kwargs_from_instance( - self, - actions: List[Action], - condition: Optional[Condition] = None, - return_values_on_condition_failure: Optional[str] = None, - *, - add_version_condition: bool = True, + self, + actions: List[Action], + condition: Optional[Condition] = None, + return_values_on_condition_failure: Optional[str] = None, + *, + add_version_condition: bool = True, ) -> Dict[str, Any]: hk_value, rk_value = self._get_hash_range_key_serialized_values() @@ -482,14 +517,21 @@ def get_update_kwargs_from_instance( if add_version_condition and version_condition is not None: condition &= version_condition - return self._get_connection().get_operation_kwargs(hk_value, range_key=rk_value, key=KEY, actions=actions, condition=condition, return_values_on_condition_failure=return_values_on_condition_failure) + return self._get_connection().get_operation_kwargs( + hk_value, + range_key=rk_value, + key=KEY, + actions=actions, + condition=condition, + return_values_on_condition_failure=return_values_on_condition_failure, + ) def get_delete_kwargs_from_instance( - self, - condition: Optional[Condition] = None, - return_values_on_condition_failure: Optional[str] = None, - *, - add_version_condition: bool = True, + self, + condition: Optional[Condition] = None, + return_values_on_condition_failure: Optional[str] = None, + *, + add_version_condition: bool = True, ) -> Dict[str, Any]: hk_value, rk_value = self._get_hash_range_key_serialized_values() @@ -497,39 +539,45 @@ def get_delete_kwargs_from_instance( if add_version_condition and version_condition is not None: condition &= version_condition - return self._get_connection().get_operation_kwargs(hk_value, range_key=rk_value, key=KEY, condition=condition, return_values_on_condition_failure=return_values_on_condition_failure) + return self._get_connection().get_operation_kwargs( + hk_value, + range_key=rk_value, + key=KEY, + condition=condition, + return_values_on_condition_failure=return_values_on_condition_failure, + ) def get_save_kwargs_from_instance( - self, - condition: Optional[Condition] = None, - return_values_on_condition_failure: Optional[str] = None, + self, + condition: Optional[Condition] = None, + return_values_on_condition_failure: Optional[str] = None, ) -> Dict[str, Any]: args, save_kwargs = self._get_save_args(condition=condition) save_kwargs['key'] = ITEM - save_kwargs['return_values_on_condition_failure'] = return_values_on_condition_failure + save_kwargs['return_values_on_condition_failure'] = ( + return_values_on_condition_failure + ) return self._get_connection().get_operation_kwargs(*args, **save_kwargs) @classmethod def get_operation_kwargs_from_class( - cls, - hash_key: str, - range_key: Optional[_KeyType] = None, - condition: Optional[Condition] = None, + cls, + hash_key: str, + range_key: Optional[_KeyType] = None, + condition: Optional[Condition] = None, ) -> Dict[str, Any]: hash_key, range_key = cls._serialize_keys(hash_key, range_key) return cls._get_connection().get_operation_kwargs( - hash_key=hash_key, - range_key=range_key, - condition=condition + hash_key=hash_key, range_key=range_key, condition=condition ) @classmethod def get( - cls: Type[_T], - hash_key: _KeyType, - range_key: Optional[_KeyType] = None, - consistent_read: bool = False, - attributes_to_get: Optional[Sequence[Text]] = None, + cls: Type[_T], + hash_key: _KeyType, + range_key: Optional[_KeyType] = None, + consistent_read: bool = False, + attributes_to_get: Optional[Sequence[Text]] = None, ) -> _T: """ Returns a single object using the provided keys @@ -563,45 +611,66 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T: :param data: A serialized DynamoDB object """ if data is None: - raise ValueError("Received no data to construct object") + raise ValueError('Received no data to construct object') return cls._instantiate(data) @classmethod def count( - cls: Type[_T], - hash_key: Optional[_HashKeyQueryType] = None, - range_key_condition: Optional[Condition] = None, - filter_condition: Optional[Condition] = None, - consistent_read: bool = False, - index_name: Optional[str] = None, - limit: Optional[int] = None, - rate_limit: Optional[float] = None, + cls: Type[_T], + hash_key: Optional[_KeyType] = None, + range_key_condition: Optional[Condition] = None, + filter_condition: Optional[Condition] = None, + consistent_read: bool = False, + index_name: Optional[str] = None, + limit: Optional[int] = None, + rate_limit: Optional[float] = None, + *, + hash_keys: Optional[_HashKeyQueryType] = None, ) -> int: """ Provides a filtered count :param hash_key: The hash key to query. Can be None. + :param hash_keys: Named hash key values for indexes with multiple hash key attributes. :param range_key_condition: Condition for range key :param filter_condition: Condition used to restrict the query results :param consistent_read: If True, a consistent read is performed :param index_name: If set, then this index is used :param rate_limit: If set then consumed capacity will be limited to this amount per second """ - if hash_key is None: + if hash_key is None and hash_keys is None: + if index_name: + raise ValueError( + 'A hash_key or hash_keys must be given to query an index' + ) if filter_condition is not None: raise ValueError('A hash_key must be given to use filters') return cls.describe_table().get(ITEM_COUNT) + serialized_hash_keys = None if index_name: - hash_key = cls._indexes[index_name]._serialize_hash_key_values(hash_key) + index = cls._indexes[index_name] + index._validate_range_key_condition(range_key_condition) + serialized_hash_key = index._serialize_hash_key_values( + hash_key, hash_keys=hash_keys + ) + if isinstance(serialized_hash_key, dict): + serialized_hash_keys = serialized_hash_key + hash_key = None + else: + hash_key = serialized_hash_key else: + if hash_keys is not None: + raise ValueError('hash_keys can only be used with an index') hash_key = cls._serialize_keys(hash_key)[0] # If this class has a discriminator attribute, filter the query to only return instances of this class. discriminator_attr = cls._get_discriminator_attribute() if discriminator_attr: - filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls)) + filter_condition &= discriminator_attr.is_in( + *discriminator_attr.get_registered_subclasses(cls) + ) query_args = (hash_key,) query_kwargs = dict( @@ -610,7 +679,8 @@ def count( index_name=index_name, consistent_read=consistent_read, limit=limit, - select=COUNT + select=COUNT, + hash_keys=serialized_hash_keys, ) result_iterator: ResultIterator[_T] = ResultIterator( @@ -628,23 +698,26 @@ def count( @classmethod def query( - cls: Type[_T], - hash_key: _HashKeyQueryType, - range_key_condition: Optional[Condition] = None, - filter_condition: Optional[Condition] = None, - consistent_read: bool = False, - index_name: Optional[str] = None, - scan_index_forward: Optional[bool] = None, - limit: Optional[int] = None, - last_evaluated_key: Optional[Dict[str, Dict[str, Any]]] = None, - attributes_to_get: Optional[Iterable[str]] = None, - page_size: Optional[int] = None, - rate_limit: Optional[float] = None, + cls: Type[_T], + hash_key: Optional[_KeyType] = None, + range_key_condition: Optional[Condition] = None, + filter_condition: Optional[Condition] = None, + consistent_read: bool = False, + index_name: Optional[str] = None, + scan_index_forward: Optional[bool] = None, + limit: Optional[int] = None, + last_evaluated_key: Optional[Dict[str, Dict[str, Any]]] = None, + attributes_to_get: Optional[Iterable[str]] = None, + page_size: Optional[int] = None, + rate_limit: Optional[float] = None, + *, + hash_keys: Optional[_HashKeyQueryType] = None, ) -> ResultIterator[_T]: """ Provides a high level query API - :param hash_key: The hash key to query + :param hash_key: The hash key to query. + :param hash_keys: Named hash key values for indexes with multiple hash key attributes. :param range_key_condition: Condition for range key :param filter_condition: Condition used to restrict the query results :param consistent_read: If True, a consistent read is performed @@ -657,15 +730,32 @@ def query( :param page_size: Page size of the query to DynamoDB :param rate_limit: If set then consumed capacity will be limited to this amount per second """ + if hash_key is None and hash_keys is None: + raise ValueError('A hash_key or hash_keys must be given to query') + + serialized_hash_keys = None if index_name: - hash_key = cls._indexes[index_name]._serialize_hash_key_values(hash_key) + index = cls._indexes[index_name] + index._validate_range_key_condition(range_key_condition) + serialized_hash_key = index._serialize_hash_key_values( + hash_key, hash_keys=hash_keys + ) + if isinstance(serialized_hash_key, dict): + serialized_hash_keys = serialized_hash_key + hash_key = None + else: + hash_key = serialized_hash_key else: + if hash_keys is not None: + raise ValueError('hash_keys can only be used with an index') hash_key = cls._serialize_keys(hash_key)[0] # If this class has a discriminator attribute, filter the query to only return instances of this class. discriminator_attr = cls._get_discriminator_attribute() if discriminator_attr: - filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls)) + filter_condition &= discriminator_attr.is_in( + *discriminator_attr.get_registered_subclasses(cls) + ) if page_size is None: page_size = limit @@ -680,6 +770,7 @@ def query( scan_index_forward=scan_index_forward, limit=page_size, attributes_to_get=attributes_to_get, + hash_keys=serialized_hash_keys, ) return ResultIterator( @@ -693,17 +784,17 @@ def query( @classmethod def scan( - cls: Type[_T], - filter_condition: Optional[Condition] = None, - segment: Optional[int] = None, - total_segments: Optional[int] = None, - limit: Optional[int] = None, - last_evaluated_key: Optional[Dict[str, Dict[str, Any]]] = None, - page_size: Optional[int] = None, - consistent_read: Optional[bool] = None, - index_name: Optional[str] = None, - rate_limit: Optional[float] = None, - attributes_to_get: Optional[Sequence[str]] = None, + cls: Type[_T], + filter_condition: Optional[Condition] = None, + segment: Optional[int] = None, + total_segments: Optional[int] = None, + limit: Optional[int] = None, + last_evaluated_key: Optional[Dict[str, Dict[str, Any]]] = None, + page_size: Optional[int] = None, + consistent_read: Optional[bool] = None, + index_name: Optional[str] = None, + rate_limit: Optional[float] = None, + attributes_to_get: Optional[Sequence[str]] = None, ) -> ResultIterator[_T]: """ Iterates through all items in the table @@ -722,7 +813,9 @@ def scan( # If this class has a discriminator attribute, filter the scan to only return instances of this class. discriminator_attr = cls._get_discriminator_attribute() if discriminator_attr: - filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls)) + filter_condition &= discriminator_attr.is_in( + *discriminator_attr.get_registered_subclasses(cls) + ) if page_size is None: page_size = limit @@ -736,7 +829,7 @@ def scan( total_segments=total_segments, consistent_read=consistent_read, index_name=index_name, - attributes_to_get=attributes_to_get + attributes_to_get=attributes_to_get, ) return ResultIterator( @@ -782,12 +875,12 @@ def describe_table(cls) -> Any: @classmethod def create_table( - cls, - wait: bool = False, - read_capacity_units: Optional[int] = None, - write_capacity_units: Optional[int] = None, - billing_mode: Optional[str] = None, - ignore_update_ttl_errors: bool = False, + cls, + wait: bool = False, + read_capacity_units: Optional[int] = None, + write_capacity_units: Optional[int] = None, + billing_mode: Optional[str] = None, + ignore_update_ttl_errors: bool = False, ) -> Any: """ Create the table for this model @@ -812,7 +905,7 @@ def create_table( if hasattr(cls.Meta, 'stream_view_type'): operation_kwargs['stream_specification'] = { 'stream_enabled': True, - 'stream_view_type': cls.Meta.stream_view_type + 'stream_view_type': cls.Meta.stream_view_type, } if hasattr(cls.Meta, 'billing_mode'): operation_kwargs['billing_mode'] = cls.Meta.billing_mode @@ -824,9 +917,7 @@ def create_table( operation_kwargs['write_capacity_units'] = write_capacity_units if billing_mode is not None: operation_kwargs['billing_mode'] = billing_mode - cls._get_connection().create_table( - **operation_kwargs - ) + cls._get_connection().create_table(**operation_kwargs) if wait: while True: status = cls._get_connection().describe_table() @@ -837,7 +928,7 @@ def create_table( else: time.sleep(2) else: - raise TableError("No TableStatus returned for table") + raise TableError('No TableStatus returned for table') cls.update_ttl(ignore_update_ttl_errors) @@ -855,7 +946,9 @@ def update_ttl(cls, ignore_update_ttl_errors: bool) -> None: cls._get_connection().update_time_to_live(ttl_attribute.attr_name) except Exception: if ignore_update_ttl_errors: - log.info("Unable to update the TTL for {}".format(cls.Meta.table_name)) + log.info( + 'Unable to update the TTL for {}'.format(cls.Meta.table_name) + ) else: raise @@ -874,14 +967,15 @@ def _get_schema(cls) -> ModelSchema: } for attr_name, attr_cls in cls.get_attributes().items(): if attr_cls.is_hash_key or attr_cls.is_range_key: - schema['attribute_definitions'].append({ - ATTR_NAME: attr_cls.attr_name, - ATTR_TYPE: attr_cls.attr_type - }) - schema['key_schema'].append({ - KEY_TYPE: HASH if attr_cls.is_hash_key else RANGE, - ATTR_NAME: attr_cls.attr_name - }) + schema['attribute_definitions'].append( + {ATTR_NAME: attr_cls.attr_name, ATTR_TYPE: attr_cls.attr_type} + ) + schema['key_schema'].append( + { + KEY_TYPE: HASH if attr_cls.is_hash_key else RANGE, + ATTR_NAME: attr_cls.attr_name, + } + ) indexes = cls._indexes.copy() # add indexes from derived classes that we might initialize @@ -895,7 +989,12 @@ def _get_schema(cls) -> ModelSchema: return schema - def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Tuple[Iterable[Any], Dict[str, Any]]: + def _get_save_args( + self, + condition: Optional[Condition] = None, + *, + add_version_condition: bool = True, + ) -> Tuple[Iterable[Any], Dict[str, Any]]: """ Gets the proper *args, **kwargs for saving and retrieving this object @@ -908,12 +1007,16 @@ def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_c """ attribute_values = self.serialize(null_check=True) hash_key_attribute = self._hash_key_attribute() - hash_key = attribute_values.pop(hash_key_attribute.attr_name, {}).get(hash_key_attribute.attr_type) + hash_key = attribute_values.pop(hash_key_attribute.attr_name, {}).get( + hash_key_attribute.attr_type + ) range_key = None range_key_attribute = self._range_key_attribute() if range_key_attribute: - range_key = attribute_values.pop(range_key_attribute.attr_name, {}).get(range_key_attribute.attr_type) - args = (hash_key, ) + range_key = attribute_values.pop(range_key_attribute.attr_name, {}).get( + range_key_attribute.attr_type + ) + args = (hash_key,) kwargs = {} if range_key is not None: kwargs['range_key'] = range_key @@ -926,7 +1029,7 @@ def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_c def _get_hash_range_key_serialized_values(self) -> Tuple[Any, Optional[Any]]: if self._hash_keyname is None: - raise Exception("The model has no hash key") + raise Exception('The model has no hash key') attrs = self.get_attributes() @@ -941,7 +1044,12 @@ def _get_hash_range_key_serialized_values(self) -> Tuple[Any, Optional[Any]]: return hk_serialized_value, rk_serialized_value - def _handle_version_attribute(self, *, attributes: Optional[Dict[str, Any]] = None, actions: Optional[List[Action]] = None) -> Optional[Condition]: + def _handle_version_attribute( + self, + *, + attributes: Optional[Dict[str, Any]] = None, + actions: Optional[List[Action]] = None, + ) -> Optional[Condition]: """ Handles modifying the request to set or increment the version attribute. """ @@ -954,13 +1062,17 @@ def _handle_version_attribute(self, *, attributes: Optional[Dict[str, Any]] = No if value is not None: condition = version_attribute == value if attributes is not None: - attributes[version_attribute.attr_name] = self._serialize_value(version_attribute, value + 1) + attributes[version_attribute.attr_name] = self._serialize_value( + version_attribute, value + 1 + ) if actions is not None: actions.append(version_attribute.add(1)) else: condition = version_attribute.does_not_exist() if attributes is not None: - attributes[version_attribute.attr_name] = self._serialize_value(version_attribute, 1) + attributes[version_attribute.attr_name] = self._serialize_value( + version_attribute, 1 + ) if actions is not None: actions.append(version_attribute.set(1)) @@ -1025,12 +1137,16 @@ def _batch_get_page(cls, keys_to_get, consistent_read, attributes_to_get): :param consistent_read: Whether or not this needs to be consistent :param attributes_to_get: A list of attributes to return """ - log.debug("Fetching a BatchGetItem page") + log.debug('Fetching a BatchGetItem page') data = cls._get_connection().batch_get_item( - keys_to_get, consistent_read=consistent_read, attributes_to_get=attributes_to_get, + keys_to_get, + consistent_read=consistent_read, + attributes_to_get=attributes_to_get, ) item_data = data.get(RESPONSES).get(cls.Meta.table_name) # type: ignore - unprocessed_items = data.get(UNPROCESSED_KEYS).get(cls.Meta.table_name, {}).get(KEYS, None) # type: ignore + unprocessed_items = ( + data.get(UNPROCESSED_KEYS).get(cls.Meta.table_name, {}).get(KEYS, None) + ) # type: ignore return item_data, unprocessed_items @classmethod @@ -1038,57 +1154,63 @@ def _get_connection(cls) -> TableConnection: """ Returns a (cached) connection """ - if not hasattr(cls, "Meta"): + if not hasattr(cls, 'Meta'): raise AttributeError( 'As of v1.0 PynamoDB Models require a `Meta` class.\n' 'Model: {}.{}\n' 'See https://pynamodb.readthedocs.io/en/latest/release_notes.html\n'.format( - cls.__module__, cls.__name__, + cls.__module__, + cls.__name__, ), ) - elif not hasattr(cls.Meta, "table_name") or cls.Meta.table_name is None: + elif not hasattr(cls.Meta, 'table_name') or cls.Meta.table_name is None: raise AttributeError( 'As of v1.0 PynamoDB Models must have a table_name\n' 'Model: {}.{}\n' 'See https://pynamodb.readthedocs.io/en/latest/release_notes.html'.format( - cls.__module__, cls.__name__, + cls.__module__, + cls.__name__, ), ) # For now we just check that the connection exists and (in the case of model inheritance) # points to the same table. In the future we should update the connection if any of the attributes differ. if cls._connection is None or cls._connection.table_name != cls.Meta.table_name: schema = cls._get_schema() - meta_table = MetaTable({ - constants.TABLE_NAME: cls.Meta.table_name, - constants.KEY_SCHEMA: schema['key_schema'], - constants.ATTR_DEFINITIONS: schema['attribute_definitions'], - constants.GLOBAL_SECONDARY_INDEXES: [ - { - constants.INDEX_NAME: index_schema['index_name'], - constants.KEY_SCHEMA: index_schema['key_schema'], - } - for index_schema in schema['global_secondary_indexes'] - ], - constants.LOCAL_SECONDARY_INDEXES: [ - { - constants.INDEX_NAME: index_schema['index_name'], - constants.KEY_SCHEMA: index_schema['key_schema'], - } - for index_schema in schema['local_secondary_indexes'] - ], - }) - cls._connection = TableConnection(cls.Meta.table_name, - meta_table=meta_table, - region=cls.Meta.region, - host=cls.Meta.host, - connect_timeout_seconds=cls.Meta.connect_timeout_seconds, - read_timeout_seconds=cls.Meta.read_timeout_seconds, - max_retry_attempts=cls.Meta.max_retry_attempts, - max_pool_connections=cls.Meta.max_pool_connections, - extra_headers=cls.Meta.extra_headers, - aws_access_key_id=cls.Meta.aws_access_key_id, - aws_secret_access_key=cls.Meta.aws_secret_access_key, - aws_session_token=cls.Meta.aws_session_token) + meta_table = MetaTable( + { + constants.TABLE_NAME: cls.Meta.table_name, + constants.KEY_SCHEMA: schema['key_schema'], + constants.ATTR_DEFINITIONS: schema['attribute_definitions'], + constants.GLOBAL_SECONDARY_INDEXES: [ + { + constants.INDEX_NAME: index_schema['index_name'], + constants.KEY_SCHEMA: index_schema['key_schema'], + } + for index_schema in schema['global_secondary_indexes'] + ], + constants.LOCAL_SECONDARY_INDEXES: [ + { + constants.INDEX_NAME: index_schema['index_name'], + constants.KEY_SCHEMA: index_schema['key_schema'], + } + for index_schema in schema['local_secondary_indexes'] + ], + } + ) + cls._connection = TableConnection( + cls.Meta.table_name, + meta_table=meta_table, + region=cls.Meta.region, + host=cls.Meta.host, + connect_timeout_seconds=cls.Meta.connect_timeout_seconds, + read_timeout_seconds=cls.Meta.read_timeout_seconds, + max_retry_attempts=cls.Meta.max_retry_attempts, + max_pool_connections=cls.Meta.max_pool_connections, + extra_headers=cls.Meta.extra_headers, + aws_access_key_id=cls.Meta.aws_access_key_id, + aws_secret_access_key=cls.Meta.aws_secret_access_key, + aws_session_token=cls.Meta.aws_session_token, + ) return cls._connection @classmethod @@ -1149,6 +1271,7 @@ class _ModelFuture(Generic[_T]): For example: when performing a TransactGet request, this is a stand-in for a model that will be returned when the operation is complete """ + def __init__(self, model_cls: Type[_T]) -> None: self._model_cls = model_cls self._model: Optional[_T] = None diff --git a/setup.py b/setup.py index e2d04d39c..fd0f72ef5 100644 --- a/setup.py +++ b/setup.py @@ -2,50 +2,43 @@ install_requires = [ - "botocore>=1.12.54", + 'botocore>=1.12.54', 'typing-extensions>=4; python_version<"3.11"', ] setup( - name="pynamodb", - version=__import__("pynamodb").__version__, - packages=find_packages( - exclude=( - "examples", - "tests", - "typing_tests", - "tests.integration", - ) - ), - url="http://jlafon.io/pynamodb.html", + name='pynamodb', + version=__import__('pynamodb').__version__, + packages=find_packages(exclude=('examples', 'tests', 'typing_tests', 'tests.integration',)), + url='http://jlafon.io/pynamodb.html', project_urls={ - "Source": "https://github.com/pynamodb/PynamoDB", + 'Source': 'https://github.com/pynamodb/PynamoDB', }, - author="Jharrod LaFon", - author_email="jlafon@eyesopen.com", - description="A Pythonic Interface to DynamoDB", - long_description=open("README.rst").read(), - long_description_content_type="text/x-rst", + author='Jharrod LaFon', + author_email='jlafon@eyesopen.com', + description='A Pythonic Interface to DynamoDB', + long_description=open('README.rst').read(), + long_description_content_type='text/x-rst', zip_safe=False, - license="MIT", - keywords="python dynamodb amazon", + license='MIT', + keywords='python dynamodb amazon', python_requires=">=3.7", install_requires=install_requires, classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Programming Language :: Python", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "License :: OSI Approved :: MIT License", + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'Programming Language :: Python', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'License :: OSI Approved :: MIT License', ], extras_require={ - "signals": ["blinker>=1.3,<2.0"], + 'signals': ['blinker>=1.3,<2.0'], }, - package_data={"pynamodb": ["py.typed"]}, + package_data={'pynamodb': ['py.typed']}, ) diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index 3e6634de3..7ce57b650 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -1249,8 +1249,8 @@ def test_connection_query(): req.return_value = {} conn.query( composite_table_name, - ("z1", "a1"), - index_name='CompositeIndex' + index_name='CompositeIndex', + hash_keys={'z_partition': 'z1', 'a_partition': 'a1'}, ) params = { 'ReturnConsumedCapacity': 'TOTAL', @@ -1268,15 +1268,15 @@ def test_connection_query(): } assert req.call_args[0][1] == params - with pytest.raises(ValueError, match="expects 2 hash key values"): + with pytest.raises(ValueError, match="multiple hash key attributes; use hash_keys"): conn.query(composite_table_name, "z1", index_name='CompositeIndex') with patch(PATCH_METHOD) as req: req.return_value = {} conn.query( composite_table_name, - {'a_partition': 'a1', 'z_partition': 'z1'}, - index_name='CompositeIndex' + index_name='CompositeIndex', + hash_keys={'a_partition': 'a1', 'z_partition': 'z1'}, ) params = { 'ReturnConsumedCapacity': 'TOTAL', @@ -1297,15 +1297,15 @@ def test_connection_query(): with pytest.raises(ValueError, match="requires values for hash keys: a_partition"): conn.query( composite_table_name, - {'z_partition': 'z1'}, - index_name='CompositeIndex' + index_name='CompositeIndex', + hash_keys={'z_partition': 'z1'}, ) with pytest.raises(ValueError, match="received unknown hash keys: unknown"): conn.query( composite_table_name, - {'z_partition': 'z1', 'a_partition': 'a1', 'unknown': 'u1'}, - index_name='CompositeIndex' + index_name='CompositeIndex', + hash_keys={'z_partition': 'z1', 'a_partition': 'a1', 'unknown': 'u1'}, ) with patch(PATCH_METHOD) as req: diff --git a/tests/test_model.py b/tests/test_model.py index 966206eb0..4c7534780 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -194,6 +194,47 @@ class Meta: composite_index = CompositeOrderIndex() +class KeylessScanIndex(GlobalSecondaryIndex): + class Meta: + index_name = 'keyless_scan_idx' + read_capacity_units = 2 + write_capacity_units = 1 + projection = AllProjection() + + +class KeylessScanIndexedModel(Model): + class Meta: + table_name = 'KeylessScanIndexedModel' + + item_id = UnicodeAttribute(hash_key=True) + keyless_scan_index = KeylessScanIndex() + + +class ThreeRangeKeyIndex(GlobalSecondaryIndex): + class Meta: + index_name = 'three_range_idx' + read_capacity_units = 2 + write_capacity_units = 1 + projection = AllProjection() + + h = UnicodeAttribute(hash_key=True) + r1 = UnicodeAttribute(range_key=True) + r2 = UnicodeAttribute(range_key=True) + r3 = UnicodeAttribute(range_key=True) + + +class ThreeRangeKeyModel(Model): + class Meta: + table_name = 'ThreeRangeKeyModel' + + item_id = UnicodeAttribute(hash_key=True) + h = UnicodeAttribute() + r1 = UnicodeAttribute() + r2 = UnicodeAttribute() + r3 = UnicodeAttribute() + three_range_index = ThreeRangeKeyIndex() + + class SimpleUserModel(Model): """ A hash key only model @@ -1111,7 +1152,9 @@ def test_index_count(self): def test_index_count_composite_hash_key(self): with patch(PATCH_METHOD) as req: req.return_value = {'Count': 1, 'ScannedCount': 1} - res = CompositeIndexedModel.composite_index.count(('p1', 'p2')) + res = CompositeIndexedModel.composite_index.count( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'} + ) self.assertEqual(res, 1) params = req.call_args[0][1] self.assertEqual(params['KeyConditionExpression'], '(#0 = :0 AND #1 = :1)') @@ -2371,7 +2414,10 @@ def fake_dynamodb(*args, **kwargs): def test_global_index_composite_query(self): with patch(PATCH_METHOD) as req: req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} - list(CompositeIndexedModel.composite_index.query(('p1', 'p2'), CompositeIndexedModel.c_sort == 's1')) + list(CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=CompositeIndexedModel.c_sort == 's1', + )) params = req.call_args[0][1] self.assertEqual(params['IndexName'], 'composite_order_idx') @@ -2388,10 +2434,12 @@ def test_global_index_composite_query(self): }) self.assertEqual(params['KeyConditionExpression'], '((#0 = :0 AND #1 = :1) AND #2 = :2)') - def test_global_index_composite_query_list_hash_key(self): + def test_global_index_composite_query_unordered_hash_keys(self): with patch(PATCH_METHOD) as req: req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} - list(CompositeIndexedModel.composite_index.query(['p1', 'p2'])) + list(CompositeIndexedModel.composite_index.query( + hash_keys={'a_partition': 'p2', 'z_partition': 'p1'} + )) params = req.call_args[0][1] self.assertEqual(params['ExpressionAttributeValues'], { @@ -2401,12 +2449,111 @@ def test_global_index_composite_query_list_hash_key(self): self.assertEqual(params['KeyConditionExpression'], '(#0 = :0 AND #1 = :1)') def test_global_index_composite_query_hash_key_validation(self): - with pytest.raises(ValueError, match='expects 2 hash key values'): + with pytest.raises(ValueError, match='multiple hash key attributes; use hash_keys'): CompositeIndexedModel.composite_index.query('p1') - with pytest.raises(ValueError, match='expects 2 hash key values, got 1'): + with pytest.raises(ValueError, match='multiple hash key attributes; use hash_keys'): CompositeIndexedModel.composite_index.query(('p1',)) + with pytest.raises(ValueError, match='requires values for hash keys: a_partition'): + CompositeIndexedModel.composite_index.query(hash_keys={'z_partition': 'p1'}) + + with pytest.raises(ValueError, match='received unknown hash keys: unknown'): + CompositeIndexedModel.composite_index.query(hash_keys={ + 'z_partition': 'p1', + 'a_partition': 'p2', + 'unknown': 'p3', + }) + + def test_global_index_public_key_validation_helpers(self): + CompositeIndexedModel.composite_index.validate_range_key_condition( + CompositeIndexedModel.c_sort == 's1' + ) + with pytest.raises(ValueError, match='preceding range keys: c_sort'): + CompositeIndexedModel.composite_index.validate_range_key_condition( + CompositeIndexedModel.b_sort == 's2' + ) + + self.assertEqual( + CompositeIndexedModel.composite_index.serialize_hash_key_values( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'} + ), + {'z_partition': 'p1', 'a_partition': 'p2'}, + ) + + def test_global_index_scan_allows_keyless_index_definition(self): + """ + Index scans do not need key schema metadata, so scan-only index definitions + should remain backwards-compatible. + """ + with patch(PATCH_METHOD) as req: + req.return_value = { + 'Count': 1, + 'ScannedCount': 1, + 'Items': [{'item_id': {STRING: 'item-1'}}], + } + + items = list(KeylessScanIndexedModel.keyless_scan_index.scan()) + + self.assertEqual([item.item_id for item in items], ['item-1']) + self.assertEqual(req.call_args[0][1]['IndexName'], 'keyless_scan_idx') + + def test_global_index_query_rejects_keyless_index_definition(self): + with pytest.raises(ValueError, match='has no hash key attributes'): + KeylessScanIndexedModel.keyless_scan_index.query('item-1') + + def test_global_index_composite_query_range_key_validation(self): + with pytest.raises(ValueError, match='preceding range keys: c_sort'): + CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=CompositeIndexedModel.b_sort == 's1', + ) + + with pytest.raises(ValueError, match='use equality for preceding range keys: c_sort'): + CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=( + (CompositeIndexedModel.c_sort > 's1') & + (CompositeIndexedModel.b_sort == 's2') + ), + ) + + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=( + (CompositeIndexedModel.c_sort == 's1') & + CompositeIndexedModel.b_sort.startswith('s2') + ), + )) + self.assertEqual( + req.call_args[0][1]['KeyConditionExpression'], + '((#0 = :0 AND #1 = :1) AND (#2 = :2 AND begins_with (#3, :3)))' + ) + + def test_global_index_three_range_key_validation(self): + with pytest.raises(ValueError, match='preceding range keys: r1, r2'): + ThreeRangeKeyModel.three_range_index.query( + 'h1', + range_key_condition=ThreeRangeKeyModel.r3 == 'r3', + ) + + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(ThreeRangeKeyModel.three_range_index.query( + 'h1', + range_key_condition=( + (ThreeRangeKeyModel.r1 == 'r1') & + (ThreeRangeKeyModel.r2 == 'r2') & + (ThreeRangeKeyModel.r3 >= 'r3') + ), + )) + self.assertEqual( + req.call_args[0][1]['KeyConditionExpression'], + '(#0 = :0 AND ((#1 = :1 AND #2 = :2) AND #3 >= :3))' + ) + def test_global_index_query_hash_only_does_not_error(self): """ Non-composite GSI queries should work with hash key only. @@ -2435,7 +2582,10 @@ def test_global_index_composite_query_last_evaluated_key(self): }) req.return_value = {'Count': len(items), 'ScannedCount': len(items), 'Items': items} - results_iter = CompositeIndexedModel.composite_index.query(('z1', 'a1'), limit=25) + results_iter = CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'z1', 'a_partition': 'a1'}, + limit=25, + ) results = list(results_iter) self.assertEqual(len(results), 25) @@ -2460,14 +2610,14 @@ def test_local_index(self): 'AttributeType': 'S', 'AttributeName': 'user_name' }, - { - 'AttributeType': 'S', - 'AttributeName': 'email' - }, { 'AttributeType': 'NS', 'AttributeName': 'numbers' }, + { + 'AttributeType': 'S', + 'AttributeName': 'email' + }, ] ) self.assertEqual(schema['local_secondary_indexes'][0]['projection']['ProjectionType'], 'INCLUDE') From 0ae4c49ae034f4323ad026195f5efa36eefa6683 Mon Sep 17 00:00:00 2001 From: keremeyuboglu <32223948+keremeyuboglu@users.noreply.github.com> Date: Thu, 30 Apr 2026 17:17:28 +0300 Subject: [PATCH 3/8] Added casting and assertions for mypy --- pynamodb/connection/base.py | 4 +++- pynamodb/models.py | 30 ++++++++++++++++-------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index dbcd6a0d1..7c5f50512 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -1358,6 +1358,7 @@ def query( if range_key_condition is not None: key_condition &= range_key_condition + assert key_condition is not None operation_kwargs[KEY_CONDITION_EXPRESSION] = key_condition.serialize( name_placeholders, expression_attribute_values ) @@ -1499,7 +1500,7 @@ def _validate_multi_range_key_condition( f"{context} range_key_condition uses unsupported range key operator: {condition.operator}" ) key_name = Connection._condition_key_name(condition) - if key_name not in range_keynames: + if key_name is None or key_name not in range_keynames: raise ValueError( f"{context} range_key_condition must only use range keys: {', '.join(range_keynames)}" ) @@ -1507,6 +1508,7 @@ def _validate_multi_range_key_condition( raise ValueError( f"{context} range_key_condition has multiple conditions for range key: {key_name}" ) + assert key_name is not None conditions_by_key[key_name] = condition if not conditions_by_key: diff --git a/pynamodb/models.py b/pynamodb/models.py index 7ae4b4bdf..10e8ab2f8 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -351,7 +351,7 @@ def batch_get( 'expected non-str iterable with exactly 2 elements (hash key, range key)' ) try: - hash_key, range_key = item + hash_key, range_key = cast(Tuple[_KeyType, _KeyType], item) except (TypeError, ValueError): raise ValueError( f'Invalid key value {item!r}: ' @@ -562,13 +562,13 @@ def get_save_kwargs_from_instance( @classmethod def get_operation_kwargs_from_class( cls, - hash_key: str, + hash_key: _KeyType, range_key: Optional[_KeyType] = None, condition: Optional[Condition] = None, ) -> Dict[str, Any]: hash_key, range_key = cls._serialize_keys(hash_key, range_key) return cls._get_connection().get_operation_kwargs( - hash_key=hash_key, range_key=range_key, condition=condition + hash_key=cast(str, hash_key), range_key=cast(Optional[str], range_key), condition=condition ) @classmethod @@ -591,8 +591,8 @@ def get( hash_key, range_key = cls._serialize_keys(hash_key, range_key) data = cls._get_connection().get_item( - hash_key, - range_key=range_key, + cast(str, hash_key), + range_key=cast(Optional[str], range_key), consistent_read=consistent_read, attributes_to_get=attributes_to_get, ) @@ -619,14 +619,13 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T: def count( cls: Type[_T], hash_key: Optional[_KeyType] = None, + hash_keys: Optional[Mapping[str, _KeyType]] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, index_name: Optional[str] = None, limit: Optional[int] = None, rate_limit: Optional[float] = None, - *, - hash_keys: Optional[_HashKeyQueryType] = None, ) -> int: """ Provides a filtered count @@ -700,6 +699,7 @@ def count( def query( cls: Type[_T], hash_key: Optional[_KeyType] = None, + hash_keys: Optional[Mapping[str, _KeyType]] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -710,8 +710,6 @@ def query( attributes_to_get: Optional[Iterable[str]] = None, page_size: Optional[int] = None, rate_limit: Optional[float] = None, - *, - hash_keys: Optional[_HashKeyQueryType] = None, ) -> ResultIterator[_T]: """ Provides a high level query API @@ -1143,10 +1141,10 @@ def _batch_get_page(cls, keys_to_get, consistent_read, attributes_to_get): consistent_read=consistent_read, attributes_to_get=attributes_to_get, ) - item_data = data.get(RESPONSES).get(cls.Meta.table_name) # type: ignore - unprocessed_items = ( - data.get(UNPROCESSED_KEYS).get(cls.Meta.table_name, {}).get(KEYS, None) - ) # type: ignore + responses = cast(Dict[str, Any], data.get(RESPONSES, {})) + item_data = responses.get(cls.Meta.table_name) + unprocessed_keys = cast(Dict[str, Any], data.get(UNPROCESSED_KEYS, {})) + unprocessed_items = unprocessed_keys.get(cls.Meta.table_name, {}).get(KEYS, None) return item_data, unprocessed_items @classmethod @@ -1231,7 +1229,11 @@ def _serialize_value(cls, attr, value): return {attr.attr_type: serialized} @classmethod - def _serialize_keys(cls, hash_key, range_key=None) -> Tuple[_KeyType, _KeyType]: + def _serialize_keys( + cls, + hash_key: _KeyType, + range_key: Optional[_KeyType] = None, + ) -> Tuple[Any, Any]: """ Serializes the hash and range keys From 54d443ef018a60e7d1504f45dcd0b999053459e7 Mon Sep 17 00:00:00 2001 From: keremeyuboglu <32223948+keremeyuboglu@users.noreply.github.com> Date: Thu, 30 Apr 2026 17:26:45 +0300 Subject: [PATCH 4/8] Moved hash keys to the latest parameter for backwards compat --- pynamodb/connection/table.py | 2 +- pynamodb/models.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 7c87c3aeb..02a4cbfdd 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -239,7 +239,6 @@ def scan( def query( self, hash_key: Optional[Any] = None, - hash_keys: Optional[Mapping[str, Any]] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Any] = None, attributes_to_get: Optional[Any] = None, @@ -250,6 +249,7 @@ def query( return_consumed_capacity: Optional[str] = None, scan_index_forward: Optional[bool] = None, select: Optional[str] = None, + hash_keys: Optional[Mapping[str, Any]] = None, ) -> Dict: """ Performs the Query operation and returns the result diff --git a/pynamodb/models.py b/pynamodb/models.py index 10e8ab2f8..b69edb7c7 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -619,13 +619,13 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T: def count( cls: Type[_T], hash_key: Optional[_KeyType] = None, - hash_keys: Optional[Mapping[str, _KeyType]] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, index_name: Optional[str] = None, limit: Optional[int] = None, rate_limit: Optional[float] = None, + hash_keys: Optional[Mapping[str, _KeyType]] = None, ) -> int: """ Provides a filtered count @@ -699,7 +699,6 @@ def count( def query( cls: Type[_T], hash_key: Optional[_KeyType] = None, - hash_keys: Optional[Mapping[str, _KeyType]] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -710,6 +709,7 @@ def query( attributes_to_get: Optional[Iterable[str]] = None, page_size: Optional[int] = None, rate_limit: Optional[float] = None, + hash_keys: Optional[Mapping[str, _KeyType]] = None, ) -> ResultIterator[_T]: """ Provides a high level query API From d6b98874d88fa6871922121956f8995641aa9eea Mon Sep 17 00:00:00 2001 From: keremeyuboglu <32223948+keremeyuboglu@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:16:19 +0300 Subject: [PATCH 5/8] Preserve typed query hash keys --- pynamodb/connection/base.py | 5 ++++- tests/test_base_connection.py | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 7c5f50512..c32dd025c 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -1426,7 +1426,10 @@ def _get_query_hash_key_values( if hash_keys is None: if hash_key is None: raise ValueError(f"Index {index_name} requires a hash_key") - if isinstance(hash_key, (tuple, list, Mapping)): + if isinstance(hash_key, (tuple, list)) or ( + isinstance(hash_key, Mapping) + and not any(key in ATTRIBUTE_TYPES for key in hash_key) + ): raise ValueError( f"Index {index_name} expects a single hash_key value" ) diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index 7ce57b650..614db8d9e 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -1215,6 +1215,32 @@ def test_connection_query(): } assert req.call_args[0][1] == params + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + table_name, + {'S': 'FooForum'}, + Path('Subject').startswith('thread'), + ) + params = { + 'ReturnConsumedCapacity': 'TOTAL', + 'KeyConditionExpression': '(#0 = :0 AND begins_with (#1, :1))', + 'ExpressionAttributeNames': { + '#0': 'ForumName', + '#1': 'Subject' + }, + 'ExpressionAttributeValues': { + ':0': { + 'S': 'FooForum' + }, + ':1': { + 'S': 'thread' + } + }, + 'TableName': 'Thread' + } + assert req.call_args[0][1] == params + composite_table_name = "ThreadComposite" composite_table_data = { "TableName": composite_table_name, From bc0a310aedb683fdfaf6f3807074679eea3602f0 Mon Sep 17 00:00:00 2001 From: keremeyuboglu <32223948+keremeyuboglu@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:17:07 +0300 Subject: [PATCH 6/8] Normalize composite range key conditions --- pynamodb/connection/base.py | 595 ++++++++++++++-------------------- pynamodb/indexes.py | 36 +- pynamodb/models.py | 4 +- tests/test_base_connection.py | 31 ++ tests/test_model.py | 27 ++ 5 files changed, 332 insertions(+), 361 deletions(-) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index c32dd025c..6a641fddb 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -6,7 +6,6 @@ import uuid from threading import local from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, cast - if sys.version_info >= (3, 8): from typing import Literal else: @@ -194,27 +193,27 @@ def get_index_range_keynames(self, index_name) -> List[str]: return range_keynames return [] - def get_item_attribute_map( - self, attributes: Dict, item_key=ITEM, pythonic_key: bool = True - ): + def get_item_attribute_map(self, attributes: Dict, item_key=ITEM, pythonic_key: bool = True): """ Builds up a dynamodb compatible AttributeValue map """ if pythonic_key: item_key = item_key - attr_map: Dict[str, Dict] = {item_key: {}} + attr_map: Dict[str, Dict] = { + item_key: {} + } for key, value in attributes.items(): # In this case, the user provided a mapping # {'key': {'S': 'value'}} if isinstance(value, dict): attr_map[item_key][key] = value else: - attr_map[item_key][key] = {self.get_attribute_type(key): value} + attr_map[item_key][key] = { + self.get_attribute_type(key): value + } return attr_map - def get_attribute_type( - self, attribute_name: str, value: Optional[Any] = None - ) -> str: + def get_attribute_type(self, attribute_name: str, value: Optional[Any] = None) -> str: """ Returns the proper attribute type for a given attribute name """ @@ -225,14 +224,10 @@ def get_attribute_type( for key in ATTRIBUTE_TYPES: if key in value: return key - attr_names = [ - attr.get(ATTR_NAME) for attr in self.data.get(ATTR_DEFINITIONS, []) - ] + attr_names = [attr.get(ATTR_NAME) for attr in self.data.get(ATTR_DEFINITIONS, [])] raise ValueError("No attribute {} in {}".format(attribute_name, attr_names)) - def get_identifier_map( - self, hash_key: str, range_key: Optional[str] = None, key: str = KEY - ): + def get_identifier_map(self, hash_key: str, range_key: Optional[str] = None, key: str = KEY): """ Builds the identifier map that is common to several operations """ @@ -253,13 +248,12 @@ def get_exclusive_start_key_map(self, exclusive_start_key): """ Builds the exclusive start key attribute map """ - if ( - isinstance(exclusive_start_key, dict) - and self.hash_keyname in exclusive_start_key - ): + if isinstance(exclusive_start_key, dict) and self.hash_keyname in exclusive_start_key: # This is useful when paginating results, as the LastEvaluatedKey returned is already # structured properly - return {EXCLUSIVE_START_KEY: exclusive_start_key} + return { + EXCLUSIVE_START_KEY: exclusive_start_key + } else: return { EXCLUSIVE_START_KEY: { @@ -275,26 +269,24 @@ class Connection(object): A higher level abstraction over botocore """ - def __init__( - self, - region: Optional[str] = None, - host: Optional[str] = None, - read_timeout_seconds: Optional[float] = None, - connect_timeout_seconds: Optional[float] = None, - max_retry_attempts: Optional[int] = None, - retry_configuration: Optional[ - Union[ - Literal["LEGACY"], - Literal["UNSET"], - "botocore.config._RetryDict", - ] - ] = None, - max_pool_connections: Optional[int] = None, - extra_headers: Optional[Mapping[str, str]] = None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - ): + def __init__(self, + region: Optional[str] = None, + host: Optional[str] = None, + read_timeout_seconds: Optional[float] = None, + connect_timeout_seconds: Optional[float] = None, + max_retry_attempts: Optional[int] = None, + retry_configuration: Optional[ + Union[ + Literal["LEGACY"], + Literal["UNSET"], + "botocore.config._RetryDict", + ] + ] = None, + max_pool_connections: Optional[int] = None, + extra_headers: Optional[Mapping[str, str]] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None): self._tables: Dict[str, MetaTable] = {} self.host = host self._local = local() @@ -303,26 +295,22 @@ def __init__( if region: self.region = region else: - self.region = get_settings_value("region") + self.region = get_settings_value('region') if connect_timeout_seconds is not None: self._connect_timeout_seconds = connect_timeout_seconds else: - self._connect_timeout_seconds = get_settings_value( - "connect_timeout_seconds" - ) + self._connect_timeout_seconds = get_settings_value('connect_timeout_seconds') if read_timeout_seconds is not None: self._read_timeout_seconds = read_timeout_seconds else: - self._read_timeout_seconds = get_settings_value("read_timeout_seconds") + self._read_timeout_seconds = get_settings_value('read_timeout_seconds') if max_retry_attempts is not None: self._max_retry_attempts_exception = max_retry_attempts else: - self._max_retry_attempts_exception = get_settings_value( - "max_retry_attempts" - ) + self._max_retry_attempts_exception = get_settings_value('max_retry_attempts') # Since we have the pattern of using `None` to indicate "read from the # settings", we use a literal of "UNSET" to indicate we want the @@ -334,17 +322,17 @@ def __init__( elif retry_configuration is not None: self._retry_configuration = retry_configuration else: - self._retry_configuration = get_settings_value("retry_configuration") + self._retry_configuration = get_settings_value('retry_configuration') if max_pool_connections is not None: self._max_pool_connections = max_pool_connections else: - self._max_pool_connections = get_settings_value("max_pool_connections") + self._max_pool_connections = get_settings_value('max_pool_connections') if extra_headers is not None: self._extra_headers = extra_headers else: - self._extra_headers = get_settings_value("extra_headers") + self._extra_headers = get_settings_value('extra_headers') self._aws_access_key_id = aws_access_key_id self._aws_secret_access_key = aws_secret_access_key @@ -359,14 +347,7 @@ def dispatch(self, operation_name: str, operation_kwargs: Dict) -> Dict: Raises TableDoesNotExist if the specified table does not exist """ - if operation_name not in [ - DESCRIBE_TABLE, - LIST_TABLES, - UPDATE_TABLE, - UPDATE_TIME_TO_LIVE, - DELETE_TABLE, - CREATE_TABLE, - ]: + if operation_name not in [DESCRIBE_TABLE, LIST_TABLES, UPDATE_TABLE, UPDATE_TIME_TO_LIVE, DELETE_TABLE, CREATE_TABLE]: if RETURN_CONSUMED_CAPACITY not in operation_kwargs: operation_kwargs.update(self.get_consumed_capacity_map(TOTAL)) log.debug("Calling %s with arguments %s", operation_name, operation_kwargs) @@ -382,33 +363,18 @@ def dispatch(self, operation_name: str, operation_kwargs: Dict) -> Dict: capacity = data.get(CONSUMED_CAPACITY) if isinstance(capacity, dict) and CAPACITY_UNITS in capacity: capacity = capacity.get(CAPACITY_UNITS) - log.debug( - "%s %s consumed %s units", - data.get(TABLE_NAME, ""), - operation_name, - capacity, - ) + log.debug("%s %s consumed %s units", data.get(TABLE_NAME, ''), operation_name, capacity) return data def send_post_boto_callback(self, operation_name, req_uuid, table_name): try: - post_dynamodb_send.send( - self, - operation_name=operation_name, - table_name=table_name, - req_uuid=req_uuid, - ) + post_dynamodb_send.send(self, operation_name=operation_name, table_name=table_name, req_uuid=req_uuid) except Exception: log.exception("post_boto callback threw an exception.") def send_pre_boto_callback(self, operation_name, req_uuid, table_name): try: - pre_dynamodb_send.send( - self, - operation_name=operation_name, - table_name=table_name, - req_uuid=req_uuid, - ) + pre_dynamodb_send.send(self, operation_name=operation_name, table_name=table_name, req_uuid=req_uuid) except Exception: log.exception("pre_boto callback threw an exception.") @@ -420,15 +386,13 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict: try: return self.client._make_api_call(operation_name, operation_kwargs) except ClientError as e: - resp_metadata = e.response.get("ResponseMetadata", {}).get( - "HTTPHeaders", {} - ) - cancellation_reasons = e.response.get("CancellationReasons", []) + resp_metadata = e.response.get('ResponseMetadata', {}).get('HTTPHeaders', {}) + cancellation_reasons = e.response.get('CancellationReasons', []) - botocore_props = {"Error": e.response.get("Error", {})} + botocore_props = {'Error': e.response.get('Error', {})} verbose_props = { - "request_id": resp_metadata.get("x-amzn-requestid", ""), - "table_name": self._get_table_name_for_error_context(operation_kwargs), + 'request_id': resp_metadata.get('x-amzn-requestid', ''), + 'table_name': self._get_table_name_for_error_context(operation_kwargs), } raise VerboseClientError( botocore_props, @@ -437,14 +401,10 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict: cancellation_reasons=( ( CancellationReason( - code=d["Code"], - message=d.get("Message"), - raw_item=cast( - Optional[Dict[str, Dict[str, Any]]], d.get("Item") - ), - ) - if d["Code"] != "None" - else None + code=d['Code'], + message=d.get('Message'), + raw_item=cast(Optional[Dict[str, Dict[str, Any]]], d.get('Item')), + ) if d['Code'] != 'None' else None ) for d in cancellation_reasons ), @@ -453,7 +413,7 @@ def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict: def _get_table_name_for_error_context(self, operation_kwargs) -> str: # First handle the two multi-table cases: batch and transaction operations if REQUEST_ITEMS in operation_kwargs: - return ",".join(operation_kwargs[REQUEST_ITEMS]) + return ','.join(operation_kwargs[REQUEST_ITEMS]) elif TRANSACT_ITEMS in operation_kwargs: table_names = [] for item in operation_kwargs[TRANSACT_ITEMS]: @@ -468,14 +428,12 @@ def session(self) -> botocore.session.Session: Returns a valid botocore session """ # botocore client creation is not thread safe as of v1.2.5+ (see issue #153) - if getattr(self._local, "session", None) is None: + if getattr(self._local, 'session', None) is None: self._local.session = get_session() if self._aws_access_key_id and self._aws_secret_access_key: - self._local.session.set_credentials( - self._aws_access_key_id, - self._aws_secret_access_key, - self._aws_session_token, - ) + self._local.session.set_credentials(self._aws_access_key_id, + self._aws_secret_access_key, + self._aws_session_token) return self._local.session @property @@ -487,18 +445,15 @@ def client(self) -> BotocoreBaseClientPrivate: # https://github.com/boto/botocore/blob/4d55c9b4142/botocore/credentials.py#L1016-L1021 # if the client does not have credentials, we create a new client # otherwise the client is permanently poisoned in the case of metadata service flakiness when using IAM roles - if not self._client or ( - self._client._request_signer - and not self._client._request_signer._credentials - ): + if not self._client or (self._client._request_signer and not self._client._request_signer._credentials): # Check if we are using the "LEGACY" retry mode to keep previous PynamoDB # retry behavior, or if we are using the new retry configuration settings. if self._retry_configuration != "LEGACY": retries = self._retry_configuration else: retries = { - "total_max_attempts": 1 + self._max_retry_attempts_exception, - "mode": "standard", + 'total_max_attempts': 1 + self._max_retry_attempts_exception, + 'mode': 'standard', } config = botocore.client.Config( @@ -508,16 +463,9 @@ def client(self) -> BotocoreBaseClientPrivate: max_pool_connections=self._max_pool_connections, retries=retries, ) - self._client = cast( - BotocoreBaseClientPrivate, - self.session.create_client( - SERVICE_NAME, self.region, endpoint_url=self.host, config=config - ), - ) + self._client = cast(BotocoreBaseClientPrivate, self.session.create_client(SERVICE_NAME, self.region, endpoint_url=self.host, config=config)) - self._client.meta.events.register_first( - "before-send.*.*", self._before_send - ) + self._client.meta.events.register_first('before-send.*.*', self._before_send) return self._client def add_meta_table(self, meta_table: MetaTable) -> None: @@ -559,26 +507,20 @@ def create_table( PROVISIONED_THROUGHPUT: { READ_CAPACITY_UNITS: read_capacity_units, WRITE_CAPACITY_UNITS: write_capacity_units, - }, + } } attrs_list = [] if attribute_definitions is None: raise ValueError("attribute_definitions argument is required") for attr in attribute_definitions: - attrs_list.append( - { - ATTR_NAME: attr.get(ATTR_NAME) or attr["attribute_name"], - ATTR_TYPE: attr.get(ATTR_TYPE) or attr["attribute_type"], - } - ) + attrs_list.append({ + ATTR_NAME: attr.get(ATTR_NAME) or attr['attribute_name'], + ATTR_TYPE: attr.get(ATTR_TYPE) or attr['attribute_type'] + }) operation_kwargs[ATTR_DEFINITIONS] = attrs_list if billing_mode not in AVAILABLE_BILLING_MODES: - raise ValueError( - "incorrect value for billing_mode, available modes: {}".format( - AVAILABLE_BILLING_MODES - ) - ) + raise ValueError("incorrect value for billing_mode, available modes: {}".format(AVAILABLE_BILLING_MODES)) if billing_mode == PAY_PER_REQUEST_BILLING_MODE: del operation_kwargs[PROVISIONED_THROUGHPUT] elif billing_mode == PROVISIONED_BILLING_MODE: @@ -588,12 +530,10 @@ def create_table( global_secondary_indexes_list = [] for index in global_secondary_indexes: index_kwargs = { - INDEX_NAME: index.get("index_name"), - KEY_SCHEMA: sorted( - index.get("key_schema"), key=lambda x: x.get(KEY_TYPE) - ), - PROJECTION: index.get("projection"), - PROVISIONED_THROUGHPUT: index.get("provisioned_throughput"), + INDEX_NAME: index.get('index_name'), + KEY_SCHEMA: sorted(index.get('key_schema'), key=lambda x: x.get(KEY_TYPE)), + PROJECTION: index.get('projection'), + PROVISIONED_THROUGHPUT: index.get('provisioned_throughput') } if billing_mode == PAY_PER_REQUEST_BILLING_MODE: del index_kwargs[PROVISIONED_THROUGHPUT] @@ -604,38 +544,35 @@ def create_table( raise ValueError("key_schema is required") key_schema_list = [] for item in key_schema: - key_schema_list.append( - { - ATTR_NAME: item.get(ATTR_NAME) or item["attribute_name"], - KEY_TYPE: str(item.get(KEY_TYPE) or item["key_type"]).upper(), - } - ) - operation_kwargs[KEY_SCHEMA] = sorted( - key_schema_list, key=lambda x: x.get(KEY_TYPE) - ) + key_schema_list.append({ + ATTR_NAME: item.get(ATTR_NAME) or item['attribute_name'], + KEY_TYPE: str(item.get(KEY_TYPE) or item['key_type']).upper() + }) + operation_kwargs[KEY_SCHEMA] = sorted(key_schema_list, key=lambda x: x.get(KEY_TYPE)) local_secondary_indexes_list = [] if local_secondary_indexes: for index in local_secondary_indexes: - local_secondary_indexes_list.append( - { - INDEX_NAME: index.get("index_name"), - KEY_SCHEMA: sorted( - index.get("key_schema"), key=lambda x: x.get(KEY_TYPE) - ), - PROJECTION: index.get("projection"), - } - ) + local_secondary_indexes_list.append({ + INDEX_NAME: index.get('index_name'), + KEY_SCHEMA: sorted(index.get('key_schema'), key=lambda x: x.get(KEY_TYPE)), + PROJECTION: index.get('projection'), + }) operation_kwargs[LOCAL_SECONDARY_INDEXES] = local_secondary_indexes_list if stream_specification: operation_kwargs[STREAM_SPECIFICATION] = { - STREAM_ENABLED: stream_specification["stream_enabled"], - STREAM_VIEW_TYPE: stream_specification["stream_view_type"], + STREAM_ENABLED: stream_specification['stream_enabled'], + STREAM_VIEW_TYPE: stream_specification['stream_view_type'] } if tags: - operation_kwargs[TAGS] = [{KEY: k, VALUE: v} for k, v in tags.items()] + operation_kwargs[TAGS] = [ + { + KEY: k, + VALUE: v + } for k, v in tags.items() + ] try: data = self.dispatch(CREATE_TABLE, operation_kwargs) @@ -652,7 +589,7 @@ def update_time_to_live(self, table_name: str, ttl_attribute_name: str) -> Dict: TIME_TO_LIVE_SPECIFICATION: { ATTR_NAME: ttl_attribute_name, ENABLED: True, - }, + } } try: return self.dispatch(UPDATE_TIME_TO_LIVE, operation_kwargs) @@ -663,7 +600,9 @@ def delete_table(self, table_name: str) -> Dict: """ Performs the DeleteTable operation """ - operation_kwargs = {TABLE_NAME: table_name} + operation_kwargs = { + TABLE_NAME: table_name + } try: data = self.dispatch(DELETE_TABLE, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: @@ -680,38 +619,29 @@ def update_table( """ Performs the UpdateTable operation """ - operation_kwargs: Dict[str, Any] = {TABLE_NAME: table_name} - if ( - read_capacity_units - and not write_capacity_units - or write_capacity_units - and not read_capacity_units - ): - raise ValueError( - "read_capacity_units and write_capacity_units are required together" - ) + operation_kwargs: Dict[str, Any] = { + TABLE_NAME: table_name + } + if read_capacity_units and not write_capacity_units or write_capacity_units and not read_capacity_units: + raise ValueError("read_capacity_units and write_capacity_units are required together") if read_capacity_units and write_capacity_units: operation_kwargs[PROVISIONED_THROUGHPUT] = { READ_CAPACITY_UNITS: read_capacity_units, - WRITE_CAPACITY_UNITS: write_capacity_units, + WRITE_CAPACITY_UNITS: write_capacity_units } if global_secondary_index_updates: global_secondary_indexes_list = [] for index in global_secondary_index_updates: - global_secondary_indexes_list.append( - { - UPDATE: { - INDEX_NAME: index.get("index_name"), - PROVISIONED_THROUGHPUT: { - READ_CAPACITY_UNITS: index.get("read_capacity_units"), - WRITE_CAPACITY_UNITS: index.get("write_capacity_units"), - }, + global_secondary_indexes_list.append({ + UPDATE: { + INDEX_NAME: index.get('index_name'), + PROVISIONED_THROUGHPUT: { + READ_CAPACITY_UNITS: index.get('read_capacity_units'), + WRITE_CAPACITY_UNITS: index.get('write_capacity_units') } } - ) - operation_kwargs[GLOBAL_SECONDARY_INDEX_UPDATES] = ( - global_secondary_indexes_list - ) + }) + operation_kwargs[GLOBAL_SECONDARY_INDEX_UPDATES] = global_secondary_indexes_list try: return self.dispatch(UPDATE_TABLE, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: @@ -727,11 +657,13 @@ def list_tables( """ operation_kwargs: Dict[str, Any] = {} if exclusive_start_table_name: - operation_kwargs.update( - {EXCLUSIVE_START_TABLE_NAME: exclusive_start_table_name} - ) + operation_kwargs.update({ + EXCLUSIVE_START_TABLE_NAME: exclusive_start_table_name + }) if limit is not None: - operation_kwargs.update({LIMIT: limit}) + operation_kwargs.update({ + LIMIT: limit + }) try: return self.dispatch(LIST_TABLES, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: @@ -741,7 +673,9 @@ def describe_table(self, table_name: str) -> Dict: """ Performs the DescribeTable operation """ - operation_kwargs = {TABLE_NAME: table_name} + operation_kwargs = { + TABLE_NAME: table_name + } try: data = self.dispatch(DESCRIBE_TABLE, operation_kwargs) table_data = data.get(TABLE_KEY) @@ -755,8 +689,8 @@ def describe_table(self, table_name: str) -> Dict: except BotoCoreError as e: raise TableError("Unable to describe table: {}".format(e), e) except ClientError as e: - if "ResourceNotFound" in e.response["Error"]["Code"]: - raise TableDoesNotExist(e.response["Error"]["Message"]) + if 'ResourceNotFound' in e.response['Error']['Code']: + raise TableDoesNotExist(e.response['Error']['Message']) else: raise @@ -774,10 +708,15 @@ def get_item_attribute_map( if tbl is None: raise TableError("No such table {}".format(table_name)) return tbl.get_item_attribute_map( - attributes, item_key=item_key, pythonic_key=pythonic_key - ) + attributes, + item_key=item_key, + pythonic_key=pythonic_key) - def parse_attribute(self, attribute: Any, return_type: bool = False) -> Any: + def parse_attribute( + self, + attribute: Any, + return_type: bool = False + ) -> Any: """ Returns the attribute value, where the attribute can be a raw attribute value, or a dictionary containing the type: @@ -796,7 +735,10 @@ def parse_attribute(self, attribute: Any, return_type: bool = False) -> Any: return attribute def get_attribute_type( - self, table_name: str, attribute_name: str, value: Optional[Any] = None + self, + table_name: str, + attribute_name: str, + value: Optional[Any] = None ) -> str: """ Returns the proper attribute type for a given attribute name @@ -812,7 +754,7 @@ def get_identifier_map( table_name: str, hash_key: str, range_key: Optional[str] = None, - key: str = KEY, + key: str = KEY ) -> Dict: """ Builds the identifier map that is common to several operations @@ -827,60 +769,48 @@ def get_consumed_capacity_map(self, return_consumed_capacity: str) -> Dict: Builds the consumed capacity map that is common to several operations """ if return_consumed_capacity.upper() not in RETURN_CONSUMED_CAPACITY_VALUES: - raise ValueError( - "{} must be one of {}".format( - RETURN_ITEM_COLL_METRICS, RETURN_CONSUMED_CAPACITY_VALUES - ) - ) - return {RETURN_CONSUMED_CAPACITY: str(return_consumed_capacity).upper()} + raise ValueError("{} must be one of {}".format(RETURN_ITEM_COLL_METRICS, RETURN_CONSUMED_CAPACITY_VALUES)) + return { + RETURN_CONSUMED_CAPACITY: str(return_consumed_capacity).upper() + } def get_return_values_map(self, return_values: str) -> Dict: """ Builds the return values map that is common to several operations """ if return_values.upper() not in RETURN_VALUES_VALUES: - raise ValueError( - "{} must be one of {}".format(RETURN_VALUES, RETURN_VALUES_VALUES) - ) - return {RETURN_VALUES: str(return_values).upper()} + raise ValueError("{} must be one of {}".format(RETURN_VALUES, RETURN_VALUES_VALUES)) + return { + RETURN_VALUES: str(return_values).upper() + } def get_return_values_on_condition_failure_map( - self, return_values_on_condition_failure: str + self, + return_values_on_condition_failure: str ) -> Dict: """ Builds the return values map that is common to several operations """ if return_values_on_condition_failure.upper() not in RETURN_VALUES_VALUES: - raise ValueError( - "{} must be one of {}".format( - RETURN_VALUES_ON_CONDITION_FAILURE, - RETURN_VALUES_ON_CONDITION_FAILURE_VALUES, - ) - ) + raise ValueError("{} must be one of {}".format( + RETURN_VALUES_ON_CONDITION_FAILURE, + RETURN_VALUES_ON_CONDITION_FAILURE_VALUES + )) return { - RETURN_VALUES_ON_CONDITION_FAILURE: str( - return_values_on_condition_failure - ).upper() + RETURN_VALUES_ON_CONDITION_FAILURE: str(return_values_on_condition_failure).upper() } def get_item_collection_map(self, return_item_collection_metrics: str) -> Dict: """ Builds the item collection map """ - if ( - return_item_collection_metrics.upper() - not in RETURN_ITEM_COLL_METRICS_VALUES - ): - raise ValueError( - "{} must be one of {}".format( - RETURN_ITEM_COLL_METRICS, RETURN_ITEM_COLL_METRICS_VALUES - ) - ) - return {RETURN_ITEM_COLL_METRICS: str(return_item_collection_metrics).upper()} + if return_item_collection_metrics.upper() not in RETURN_ITEM_COLL_METRICS_VALUES: + raise ValueError("{} must be one of {}".format(RETURN_ITEM_COLL_METRICS, RETURN_ITEM_COLL_METRICS_VALUES)) + return { + RETURN_ITEM_COLL_METRICS: str(return_item_collection_metrics).upper() + } - def get_exclusive_start_key_map( - self, table_name: str, exclusive_start_key: str - ) -> Dict: + def get_exclusive_start_key_map(self, table_name: str, exclusive_start_key: str) -> Dict: """ Builds the exclusive start key attribute map """ @@ -903,58 +833,43 @@ def get_operation_kwargs( return_values: Optional[str] = None, return_consumed_capacity: Optional[str] = None, return_item_collection_metrics: Optional[str] = None, - return_values_on_condition_failure: Optional[str] = None, + return_values_on_condition_failure: Optional[str] = None ) -> Dict: - self._check_condition("condition", condition) + self._check_condition('condition', condition) operation_kwargs: Dict[str, Any] = {} - name_placeholders: Dict[str, str] = {} + name_placeholders: Dict[str, str] = {} expression_attribute_values: Dict[str, Any] = {} operation_kwargs[TABLE_NAME] = table_name - operation_kwargs.update( - self.get_identifier_map(table_name, hash_key, range_key, key=key) - ) + operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key, key=key)) if attributes and operation_kwargs.get(ITEM) is not None: attrs = self.get_item_attribute_map(table_name, attributes) operation_kwargs[ITEM].update(attrs[ITEM]) if attributes_to_get is not None: - projection_expression = create_projection_expression( - attributes_to_get, name_placeholders - ) + projection_expression = create_projection_expression(attributes_to_get, name_placeholders) operation_kwargs[PROJECTION_EXPRESSION] = projection_expression if condition is not None: - condition_expression = condition.serialize( - name_placeholders, expression_attribute_values - ) + condition_expression = condition.serialize(name_placeholders, expression_attribute_values) operation_kwargs[CONDITION_EXPRESSION] = condition_expression if consistent_read is not None: operation_kwargs[CONSISTENT_READ] = consistent_read if return_values is not None: operation_kwargs.update(self.get_return_values_map(return_values)) if return_values_on_condition_failure is not None: - operation_kwargs.update( - self.get_return_values_on_condition_failure_map( - return_values_on_condition_failure - ) - ) + operation_kwargs.update(self.get_return_values_on_condition_failure_map(return_values_on_condition_failure)) if return_consumed_capacity is not None: - operation_kwargs.update( - self.get_consumed_capacity_map(return_consumed_capacity) - ) + operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) if return_item_collection_metrics is not None: - operation_kwargs.update( - self.get_item_collection_map(return_item_collection_metrics) - ) + operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) if actions is not None: update_expression = Update(*actions) operation_kwargs[UPDATE_EXPRESSION] = update_expression.serialize( - name_placeholders, expression_attribute_values + name_placeholders, + expression_attribute_values ) if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict( - name_placeholders - ) + operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) if expression_attribute_values: operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values return operation_kwargs @@ -979,7 +894,7 @@ def delete_item( condition=condition, return_values=return_values, return_consumed_capacity=return_consumed_capacity, - return_item_collection_metrics=return_item_collection_metrics, + return_item_collection_metrics=return_item_collection_metrics ) try: return self.dispatch(DELETE_ITEM, operation_kwargs) @@ -1041,7 +956,7 @@ def put_item( condition=condition, return_values=return_values, return_consumed_capacity=return_consumed_capacity, - return_item_collection_metrics=return_item_collection_metrics, + return_item_collection_metrics=return_item_collection_metrics ) try: return self.dispatch(PUT_ITEM, operation_kwargs) @@ -1052,19 +967,15 @@ def _get_transact_operation_kwargs( self, client_request_token: Optional[str] = None, return_consumed_capacity: Optional[str] = None, - return_item_collection_metrics: Optional[str] = None, + return_item_collection_metrics: Optional[str] = None ) -> Dict: operation_kwargs = {} if client_request_token is not None: operation_kwargs[CLIENT_REQUEST_TOKEN] = client_request_token if return_consumed_capacity is not None: - operation_kwargs.update( - self.get_consumed_capacity_map(return_consumed_capacity) - ) + operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) if return_item_collection_metrics is not None: - operation_kwargs.update( - self.get_item_collection_map(return_item_collection_metrics) - ) + operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) return operation_kwargs @@ -1085,14 +996,20 @@ def transact_write_items( transact_items.extend( {TRANSACT_CONDITION_CHECK: item} for item in condition_check_items ) - transact_items.extend({TRANSACT_DELETE: item} for item in delete_items) - transact_items.extend({TRANSACT_PUT: item} for item in put_items) - transact_items.extend({TRANSACT_UPDATE: item} for item in update_items) + transact_items.extend( + {TRANSACT_DELETE: item} for item in delete_items + ) + transact_items.extend( + {TRANSACT_PUT: item} for item in put_items + ) + transact_items.extend( + {TRANSACT_UPDATE: item} for item in update_items + ) operation_kwargs = self._get_transact_operation_kwargs( client_request_token=client_request_token, return_consumed_capacity=return_consumed_capacity, - return_item_collection_metrics=return_item_collection_metrics, + return_item_collection_metrics=return_item_collection_metrics ) operation_kwargs[TRANSACT_ITEMS] = transact_items @@ -1109,10 +1026,10 @@ def transact_get_items( """ Performs the TransactGet operation and returns the result """ - operation_kwargs = self._get_transact_operation_kwargs( - return_consumed_capacity=return_consumed_capacity - ) - operation_kwargs[TRANSACT_ITEMS] = [{TRANSACT_GET: item} for item in get_items] + operation_kwargs = self._get_transact_operation_kwargs(return_consumed_capacity=return_consumed_capacity) + operation_kwargs[TRANSACT_ITEMS] = [ + {TRANSACT_GET: item} for item in get_items + ] try: return self.dispatch(TRANSACT_GET_ITEMS, operation_kwargs) @@ -1132,35 +1049,27 @@ def batch_write_item( """ if put_items is None and delete_items is None: raise ValueError("Either put_items or delete_items must be specified") - operation_kwargs: Dict[str, Any] = {REQUEST_ITEMS: {table_name: []}} + operation_kwargs: Dict[str, Any] = { + REQUEST_ITEMS: { + table_name: [] + } + } if return_consumed_capacity: - operation_kwargs.update( - self.get_consumed_capacity_map(return_consumed_capacity) - ) + operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) if return_item_collection_metrics: - operation_kwargs.update( - self.get_item_collection_map(return_item_collection_metrics) - ) + operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) put_items_list = [] if put_items: for item in put_items: - put_items_list.append( - { - PUT_REQUEST: self.get_item_attribute_map( - table_name, item, pythonic_key=False - ) - } - ) + put_items_list.append({ + PUT_REQUEST: self.get_item_attribute_map(table_name, item, pythonic_key=False) + }) delete_items_list = [] if delete_items: for item in delete_items: - delete_items_list.append( - { - DELETE_REQUEST: self.get_item_attribute_map( - table_name, item, item_key=KEY, pythonic_key=False - ) - } - ) + delete_items_list.append({ + DELETE_REQUEST: self.get_item_attribute_map(table_name, item, item_key=KEY, pythonic_key=False) + }) operation_kwargs[REQUEST_ITEMS][table_name] = delete_items_list + put_items_list try: return self.dispatch(BATCH_WRITE_ITEM, operation_kwargs) @@ -1178,20 +1087,20 @@ def batch_get_item( """ Performs the batch get item operation """ - operation_kwargs: Dict[str, Any] = {REQUEST_ITEMS: {table_name: {}}} + operation_kwargs: Dict[str, Any] = { + REQUEST_ITEMS: { + table_name: {} + } + } args_map: Dict[str, Any] = {} name_placeholders: Dict[str, str] = {} if consistent_read: args_map[CONSISTENT_READ] = consistent_read if return_consumed_capacity: - operation_kwargs.update( - self.get_consumed_capacity_map(return_consumed_capacity) - ) + operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) if attributes_to_get is not None: - projection_expression = create_projection_expression( - attributes_to_get, name_placeholders - ) + projection_expression = create_projection_expression(attributes_to_get, name_placeholders) args_map[PROJECTION_EXPRESSION] = projection_expression if name_placeholders: args_map[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) @@ -1199,7 +1108,9 @@ def batch_get_item( keys_map: Dict[str, List] = {KEYS: []} for key in keys: - keys_map[KEYS].append(self.get_item_attribute_map(table_name, key)[ITEM]) + keys_map[KEYS].append( + self.get_item_attribute_map(table_name, key)[ITEM] + ) operation_kwargs[REQUEST_ITEMS][table_name].update(keys_map) try: return self.dispatch(BATCH_GET_ITEM, operation_kwargs) @@ -1222,7 +1133,7 @@ def get_item( hash_key=hash_key, range_key=range_key, consistent_read=consistent_read, - attributes_to_get=attributes_to_get, + attributes_to_get=attributes_to_get ) try: return self.dispatch(GET_ITEM, operation_kwargs) @@ -1245,34 +1156,26 @@ def scan( """ Performs the scan operation """ - self._check_condition("filter_condition", filter_condition) + self._check_condition('filter_condition', filter_condition) operation_kwargs: Dict[str, Any] = {TABLE_NAME: table_name} name_placeholders: Dict[str, str] = {} expression_attribute_values: Dict[str, Any] = {} if filter_condition is not None: - filter_expression = filter_condition.serialize( - name_placeholders, expression_attribute_values - ) + filter_expression = filter_condition.serialize(name_placeholders, expression_attribute_values) operation_kwargs[FILTER_EXPRESSION] = filter_expression if attributes_to_get is not None: - projection_expression = create_projection_expression( - attributes_to_get, name_placeholders - ) + projection_expression = create_projection_expression(attributes_to_get, name_placeholders) operation_kwargs[PROJECTION_EXPRESSION] = projection_expression if index_name: operation_kwargs[INDEX_NAME] = index_name if limit is not None: operation_kwargs[LIMIT] = limit if return_consumed_capacity: - operation_kwargs.update( - self.get_consumed_capacity_map(return_consumed_capacity) - ) + operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) if exclusive_start_key: - operation_kwargs.update( - self.get_exclusive_start_key_map(table_name, exclusive_start_key) - ) + operation_kwargs.update(self.get_exclusive_start_key_map(table_name, exclusive_start_key)) if segment is not None: operation_kwargs[SEGMENT] = segment if total_segments: @@ -1280,9 +1183,7 @@ def scan( if consistent_read: operation_kwargs[CONSISTENT_READ] = consistent_read if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict( - name_placeholders - ) + operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) if expression_attribute_values: operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values @@ -1310,8 +1211,8 @@ def query( """ Performs the Query operation and returns the result """ - self._check_condition("range_key_condition", range_key_condition) - self._check_condition("filter_condition", filter_condition) + self._check_condition('range_key_condition', range_key_condition) + self._check_condition('filter_condition', filter_condition) operation_kwargs: Dict[str, Any] = {TABLE_NAME: table_name} name_placeholders: Dict[str, str] = {} @@ -1322,9 +1223,7 @@ def query( raise TableError("No such table: {}".format(table_name)) if index_name: if not tbl.has_index_name(index_name): - raise ValueError( - "Table {} has no index: {}".format(table_name, index_name) - ) + raise ValueError("Table {} has no index: {}".format(table_name, index_name)) hash_keynames = tbl.get_index_hash_keynames(index_name) range_keynames = tbl.get_index_range_keynames(index_name) else: @@ -1337,55 +1236,38 @@ def query( hash_keynames, index_name=index_name, ) - self._validate_multi_range_key_condition( + range_key_condition = self._normalize_multi_range_key_condition( range_key_condition, range_keynames, index_name=index_name, ) key_condition = None for hash_keyname, hash_keyvalue in zip(hash_keynames, hash_key_values): - hash_condition_value = { - self.get_attribute_type( - table_name, hash_keyname, hash_keyvalue - ): self.parse_attribute(hash_keyvalue) - } + hash_condition_value = {self.get_attribute_type(table_name, hash_keyname, hash_keyvalue): self.parse_attribute(hash_keyvalue)} hash_condition = Path([hash_keyname]) == hash_condition_value - key_condition = ( - hash_condition - if key_condition is None - else key_condition & hash_condition - ) + key_condition = hash_condition if key_condition is None else key_condition & hash_condition if range_key_condition is not None: key_condition &= range_key_condition assert key_condition is not None operation_kwargs[KEY_CONDITION_EXPRESSION] = key_condition.serialize( - name_placeholders, expression_attribute_values - ) + name_placeholders, expression_attribute_values) if filter_condition is not None: - filter_expression = filter_condition.serialize( - name_placeholders, expression_attribute_values - ) + filter_expression = filter_condition.serialize(name_placeholders, expression_attribute_values) operation_kwargs[FILTER_EXPRESSION] = filter_expression if attributes_to_get: - projection_expression = create_projection_expression( - attributes_to_get, name_placeholders - ) + projection_expression = create_projection_expression(attributes_to_get, name_placeholders) operation_kwargs[PROJECTION_EXPRESSION] = projection_expression if consistent_read: operation_kwargs[CONSISTENT_READ] = True if exclusive_start_key: - operation_kwargs.update( - self.get_exclusive_start_key_map(table_name, exclusive_start_key) - ) + operation_kwargs.update(self.get_exclusive_start_key_map(table_name, exclusive_start_key)) if index_name: operation_kwargs[INDEX_NAME] = index_name if limit is not None: operation_kwargs[LIMIT] = limit if return_consumed_capacity: - operation_kwargs.update( - self.get_consumed_capacity_map(return_consumed_capacity) - ) + operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) if select: if select.upper() not in SELECT_VALUES: raise ValueError("{} must be one of {}".format(SELECT, SELECT_VALUES)) @@ -1393,9 +1275,7 @@ def query( if scan_index_forward is not None: operation_kwargs[SCAN_INDEX_FORWARD] = scan_index_forward if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict( - name_placeholders - ) + operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) if expression_attribute_values: operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values @@ -1486,13 +1366,20 @@ def _condition_key_name(condition: Condition) -> Optional[str]: return path[0] @staticmethod - def _validate_multi_range_key_condition( + def _combine_conditions(conditions: List[Condition]) -> Condition: + combined_condition = conditions[0] + for condition in conditions[1:]: + combined_condition &= condition + return combined_condition + + @staticmethod + def _normalize_multi_range_key_condition( range_key_condition: Optional[Condition], range_keynames: Sequence[str], index_name: Optional[str] = None, - ) -> None: + ) -> Optional[Condition]: if range_key_condition is None or len(range_keynames) <= 1: - return + return range_key_condition valid_operators = {"=", "<", "<=", ">", ">=", "BETWEEN", "begins_with"} conditions_by_key: Dict[str, Condition] = {} @@ -1511,11 +1398,10 @@ def _validate_multi_range_key_condition( raise ValueError( f"{context} range_key_condition has multiple conditions for range key: {key_name}" ) - assert key_name is not None conditions_by_key[key_name] = condition if not conditions_by_key: - return + return range_key_condition highest_position = max( range_keynames.index(key_name) for key_name in conditions_by_key @@ -1541,3 +1427,10 @@ def _validate_multi_range_key_condition( f"{context} range_key_condition must use equality for preceding range keys: " f"{', '.join(non_equal_prefix_keys)}" ) + + ordered_conditions = [ + conditions_by_key[key_name] + for key_name in range_keynames + if key_name in conditions_by_key + ] + return Connection._combine_conditions(ordered_conditions) diff --git a/pynamodb/indexes.py b/pynamodb/indexes.py index ea6029e53..862807a4f 100644 --- a/pynamodb/indexes.py +++ b/pynamodb/indexes.py @@ -202,11 +202,18 @@ def _condition_key_name(condition: Condition) -> Optional[str]: return path[0] @staticmethod - def _validate_multi_key_condition( + def _combine_conditions(conditions: List[Condition]) -> Condition: + combined_condition = conditions[0] + for condition in conditions[1:]: + combined_condition &= condition + return combined_condition + + @staticmethod + def _normalize_multi_key_condition( range_key_condition: Condition, range_keynames: List[str], context: str, - ) -> None: + ) -> Condition: valid_operators = {'=', '<', '<=', '>', '>=', 'BETWEEN', 'begins_with'} conditions_by_key: Dict[str, Condition] = {} for condition in Index._flatten_and_conditions(range_key_condition): @@ -215,7 +222,7 @@ def _validate_multi_key_condition( f'{context} range_key_condition uses unsupported range key operator: {condition.operator}' ) key_name = Index._condition_key_name(condition) - if key_name not in range_keynames: + if key_name is None or key_name not in range_keynames: raise ValueError( f'{context} range_key_condition must only use range keys: ' + ', '.join(range_keynames) ) @@ -226,7 +233,7 @@ def _validate_multi_key_condition( conditions_by_key[key_name] = condition if not conditions_by_key: - return + return range_key_condition highest_position = max( range_keynames.index(key_name) for key_name in conditions_by_key @@ -253,6 +260,13 @@ def _validate_multi_key_condition( + ', '.join(non_equal_prefix_keys) ) + ordered_conditions = [ + conditions_by_key[key_name] + for key_name in range_keynames + if key_name in conditions_by_key + ] + return Index._combine_conditions(ordered_conditions) + @classmethod def _serialize_hash_key_values( cls, @@ -349,18 +363,24 @@ def _get_ordered_hash_key_values( return [values_by_attr_name[attr.attr_name] for attr in hash_key_attributes] @classmethod - def _validate_range_key_condition( + def _normalize_range_key_condition( cls, range_key_condition: Optional[Condition] - ) -> None: + ) -> Optional[Condition]: range_key_attributes = cls._range_key_attributes() if range_key_condition is None or len(range_key_attributes) <= 1: - return - cls._validate_multi_key_condition( + return range_key_condition + return cls._normalize_multi_key_condition( range_key_condition, [attr.attr_name for attr in range_key_attributes], cls.__name__, ) + @classmethod + def _validate_range_key_condition( + cls, range_key_condition: Optional[Condition] + ) -> None: + cls._normalize_range_key_condition(range_key_condition) + @classmethod def validate_range_key_condition( cls, range_key_condition: Optional[Condition] diff --git a/pynamodb/models.py b/pynamodb/models.py index b69edb7c7..f9870c266 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -650,7 +650,7 @@ def count( serialized_hash_keys = None if index_name: index = cls._indexes[index_name] - index._validate_range_key_condition(range_key_condition) + range_key_condition = index._normalize_range_key_condition(range_key_condition) serialized_hash_key = index._serialize_hash_key_values( hash_key, hash_keys=hash_keys ) @@ -734,7 +734,7 @@ def query( serialized_hash_keys = None if index_name: index = cls._indexes[index_name] - index._validate_range_key_condition(range_key_condition) + range_key_condition = index._normalize_range_key_condition(range_key_condition) serialized_hash_key = index._serialize_hash_key_values( hash_key, hash_keys=hash_keys ) diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index 614db8d9e..87f244ce3 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -1334,6 +1334,37 @@ def test_connection_query(): hash_keys={'z_partition': 'z1', 'a_partition': 'a1', 'unknown': 'u1'}, ) + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + composite_table_name, + index_name='CompositeIndex', + hash_keys={'z_partition': 'z1', 'a_partition': 'a1'}, + range_key_condition=( + (Path('b_sort') == 'b1') & + (Path('c_sort') == 'c1') + ), + ) + params = { + 'ReturnConsumedCapacity': 'TOTAL', + 'IndexName': 'CompositeIndex', + 'KeyConditionExpression': '((#0 = :0 AND #1 = :1) AND (#2 = :2 AND #3 = :3))', + 'ExpressionAttributeNames': { + '#0': 'z_partition', + '#1': 'a_partition', + '#2': 'c_sort', + '#3': 'b_sort' + }, + 'ExpressionAttributeValues': { + ':0': {'S': 'z1'}, + ':1': {'S': 'a1'}, + ':2': {'S': 'c1'}, + ':3': {'S': 'b1'} + }, + 'TableName': composite_table_name + } + assert req.call_args[0][1] == params + with patch(PATCH_METHOD) as req: req.return_value = {} conn.query( diff --git a/tests/test_model.py b/tests/test_model.py index 4c7534780..73cdc3d78 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2532,6 +2532,33 @@ def test_global_index_composite_query_range_key_validation(self): '((#0 = :0 AND #1 = :1) AND (#2 = :2 AND begins_with (#3, :3)))' ) + def test_global_index_composite_query_normalizes_out_of_order_range_keys(self): + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=( + (CompositeIndexedModel.b_sort == 's2') & + (CompositeIndexedModel.c_sort == 's1') + ), + )) + self.assertEqual( + req.call_args[0][1]['KeyConditionExpression'], + '((#0 = :0 AND #1 = :1) AND (#2 = :2 AND #3 = :3))' + ) + self.assertEqual(req.call_args[0][1]['ExpressionAttributeNames'], { + '#0': 'z_partition', + '#1': 'a_partition', + '#2': 'c_sort', + '#3': 'b_sort', + }) + self.assertEqual(req.call_args[0][1]['ExpressionAttributeValues'], { + ':0': {'S': 'p1'}, + ':1': {'S': 'p2'}, + ':2': {'S': 's1'}, + ':3': {'S': 's2'}, + }) + def test_global_index_three_range_key_validation(self): with pytest.raises(ValueError, match='preceding range keys: r1, r2'): ThreeRangeKeyModel.three_range_index.query( From 07583fe14106640d3782729fbf27fbe52e49d822 Mon Sep 17 00:00:00 2001 From: keremeyuboglu <32223948+keremeyuboglu@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:17:30 +0300 Subject: [PATCH 7/8] Respect shadowed index attributes --- pynamodb/indexes.py | 4 ++-- tests/test_model.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/pynamodb/indexes.py b/pynamodb/indexes.py index 862807a4f..1248e5dd6 100644 --- a/pynamodb/indexes.py +++ b/pynamodb/indexes.py @@ -43,10 +43,10 @@ def _get_attributes_in_declaration_order( attributes: Dict[str, Attribute] = {} for base in reversed(index_cls.__mro__): for name, attribute in getattr(base, '__dict__', {}).items(): + if name in attributes: + del attributes[name] if isinstance(attribute, Attribute): # If a subclass overrides an attribute, preserve the subclass declaration order. - if name in attributes: - del attributes[name] attributes[name] = attribute return attributes diff --git a/tests/test_model.py b/tests/test_model.py index 73cdc3d78..25ecc7241 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2703,6 +2703,34 @@ class Meta: with pytest.raises(ValueError, match='at most one range key'): BadLocalRangeIndex._get_schema() + def test_index_attribute_shadowing_removes_inherited_key(self): + class BaseShadowIndex(GlobalSecondaryIndex): + class Meta: + index_name = 'base_shadow_idx' + projection = AllProjection() + + h = UnicodeAttribute(hash_key=True) + r = UnicodeAttribute(range_key=True) + + class ChildShadowIndex(BaseShadowIndex): + class Meta(BaseShadowIndex.Meta): + index_name = 'child_shadow_idx' + + h = None + h2 = UnicodeAttribute(hash_key=True) + + self.assertEqual( + [attr.attr_name for attr in ChildShadowIndex._hash_key_attributes()], + ['h2'], + ) + self.assertEqual( + ChildShadowIndex._get_schema()['key_schema'], + [ + {'AttributeName': 'h2', 'KeyType': 'HASH'}, + {'AttributeName': 'r', 'KeyType': 'RANGE'}, + ], + ) + def test_projections(self): """ Models.Projection From e996962f602e7b7bbaa3c18807a70861266dc7de Mon Sep 17 00:00:00 2001 From: keremeyuboglu <32223948+keremeyuboglu@users.noreply.github.com> Date: Fri, 1 May 2026 10:04:34 +0300 Subject: [PATCH 8/8] Moved hash keys to the latest parameter for backwards compat --- pynamodb/indexes.py | 149 +++++-------- pynamodb/models.py | 502 ++++++++++++++++++-------------------------- 2 files changed, 258 insertions(+), 393 deletions(-) diff --git a/pynamodb/indexes.py b/pynamodb/indexes.py index 1248e5dd6..04c664a0d 100644 --- a/pynamodb/indexes.py +++ b/pynamodb/indexes.py @@ -2,20 +2,20 @@ PynamoDB Indexes """ from inspect import getmembers -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Mapping, Optional, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Mapping, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING from pynamodb._schema import IndexSchema, GlobalSecondaryIndexSchema from pynamodb._schema import ModelSchema -from pynamodb.attributes import Attribute from pynamodb.constants import ( INCLUDE, ALL, KEYS_ONLY, ATTR_NAME, ATTR_TYPE, KEY_TYPE, PROJECTION_TYPE, NON_KEY_ATTRIBUTES, READ_CAPACITY_UNITS, WRITE_CAPACITY_UNITS, ) +from pynamodb.attributes import Attribute from pynamodb.expressions.condition import Condition from pynamodb.pagination import ResultIterator from pynamodb.types import HASH, RANGE - if TYPE_CHECKING: from pynamodb.models import Model @@ -29,7 +29,6 @@ class Index(Generic[_M]): """ Base class for secondary indexes """ - Meta: Any = None _model: _M @@ -58,12 +57,12 @@ def __init_subclass__(cls, **kwargs): def __init__(self) -> None: if self.Meta is None: - raise ValueError('Indexes require a Meta class for settings') - if not hasattr(self.Meta, 'projection'): - raise ValueError('No projection defined, define a projection for this class') + raise ValueError("Indexes require a Meta class for settings") + if not hasattr(self.Meta, "projection"): + raise ValueError("No projection defined, define a projection for this class") def __set_name__(self, owner: Type[_M], name: str): - if not hasattr(self.Meta, 'index_name'): + if not hasattr(self.Meta, "index_name"): self.Meta.index_name = name def count( @@ -78,9 +77,6 @@ def count( ) -> int: """ Count on an index - - :param hash_key: The hash key to query. Can be None when ``hash_keys`` is provided. - :param hash_keys: Named hash key values for indexes with multiple hash key attributes. """ return self._model.count( hash_key, @@ -109,9 +105,6 @@ def query( ) -> ResultIterator[_M]: """ Queries an index - - :param hash_key: The hash key to query. Can be None when ``hash_keys`` is provided. - :param hash_keys: Named hash key values for indexes with multiple hash key attributes. """ return self._model.query( hash_key, @@ -187,7 +180,7 @@ def _hash_key_aliases( @staticmethod def _flatten_and_conditions(condition: Condition) -> List[Condition]: - if condition.operator == 'AND': + if condition.operator == "AND": conditions: List[Condition] = [] for value in condition.values: conditions.extend(Index._flatten_and_conditions(value)) @@ -196,7 +189,7 @@ def _flatten_and_conditions(condition: Condition) -> List[Condition]: @staticmethod def _condition_key_name(condition: Condition) -> Optional[str]: - path = getattr(condition.values[0], 'path', None) if condition.values else None + path = getattr(condition.values[0], "path", None) if condition.values else None if not isinstance(path, list) or len(path) != 1: return None return path[0] @@ -214,22 +207,18 @@ def _normalize_multi_key_condition( range_keynames: List[str], context: str, ) -> Condition: - valid_operators = {'=', '<', '<=', '>', '>=', 'BETWEEN', 'begins_with'} + valid_operators = {"=", "<", "<=", ">", ">=", "BETWEEN", "begins_with"} conditions_by_key: Dict[str, Condition] = {} for condition in Index._flatten_and_conditions(range_key_condition): if condition.operator not in valid_operators: raise ValueError( - f'{context} range_key_condition uses unsupported range key operator: {condition.operator}' + f"{context} range_key_condition uses unsupported range key operator: {condition.operator}" ) key_name = Index._condition_key_name(condition) if key_name is None or key_name not in range_keynames: - raise ValueError( - f'{context} range_key_condition must only use range keys: ' + ', '.join(range_keynames) - ) + raise ValueError(f"{context} range_key_condition must only use range keys: " + ", ".join(range_keynames)) if key_name in conditions_by_key: - raise ValueError( - f'{context} range_key_condition has multiple conditions for range key: {key_name}' - ) + raise ValueError(f"{context} range_key_condition has multiple conditions for range key: {key_name}") conditions_by_key[key_name] = condition if not conditions_by_key: @@ -245,19 +234,19 @@ def _normalize_multi_key_condition( ] if missing_prefix_keys: raise ValueError( - f'{context} range_key_condition must include equality conditions for preceding range keys: ' - + ', '.join(missing_prefix_keys) + f"{context} range_key_condition must include equality conditions for preceding range keys: " + + ", ".join(missing_prefix_keys) ) non_equal_prefix_keys = [ key_name for key_name in range_keynames[:highest_position] - if conditions_by_key[key_name].operator != '=' + if conditions_by_key[key_name].operator != "=" ] if non_equal_prefix_keys: raise ValueError( - f'{context} range_key_condition must use equality for preceding range keys: ' - + ', '.join(non_equal_prefix_keys) + f"{context} range_key_condition must use equality for preceding range keys: " + + ", ".join(non_equal_prefix_keys) ) ordered_conditions = [ @@ -275,21 +264,19 @@ def _serialize_hash_key_values( ) -> _SerializedHashKeyType: hash_key_attributes = cls._hash_key_attributes() if not hash_key_attributes: - raise ValueError(f'{cls.__name__} has no hash key attributes') + raise ValueError(f"{cls.__name__} has no hash key attributes") if hash_key is not None and hash_keys is not None: - raise ValueError(f'{cls.__name__} received both hash_key and hash_keys') + raise ValueError(f"{cls.__name__} received both hash_key and hash_keys") if len(hash_key_attributes) == 1: if hash_keys is None: if hash_key is None: - raise ValueError(f'{cls.__name__} requires a hash_key') + raise ValueError(f"{cls.__name__} requires a hash_key") if isinstance(hash_key, (tuple, list)): - raise ValueError(f'{cls.__name__} expects a single hash_key value') + raise ValueError(f"{cls.__name__} expects a single hash_key value") if isinstance(hash_key, Mapping): - raise ValueError( - f'{cls.__name__} expects hash_keys=... for named hash key values' - ) + raise ValueError(f"{cls.__name__} expects hash_keys=... for named hash key values") return hash_key_attributes[0].serialize(hash_key) hash_key_values = cls._get_ordered_hash_key_values( @@ -299,10 +286,8 @@ def _serialize_hash_key_values( if hash_keys is None: if hash_key is None: - raise ValueError(f'{cls.__name__} requires hash_keys') - raise ValueError( - f'{cls.__name__} has multiple hash key attributes; use hash_keys=...' - ) + raise ValueError(f"{cls.__name__} requires hash_keys") + raise ValueError(f"{cls.__name__} has multiple hash key attributes; use hash_keys=...") hash_key_values = cls._get_ordered_hash_key_values( hash_keys, hash_key_attributes @@ -327,7 +312,7 @@ def _get_ordered_hash_key_values( hash_key_attributes: List[Attribute], ) -> List[_KeyType]: if not isinstance(hash_keys, Mapping): - raise ValueError(f'{cls.__name__} expects hash_keys to be a mapping') + raise ValueError(f"{cls.__name__} expects hash_keys to be a mapping") expected_aliases = cls._hash_key_aliases(hash_key_attributes) @@ -340,15 +325,11 @@ def _get_ordered_hash_key_values( unknown_keys.append(str(key_name)) continue if attr.attr_name in values_by_attr_name: - raise ValueError( - f'{cls.__name__} received duplicate value for hash key: {attr.attr_name}' - ) + raise ValueError(f"{cls.__name__} received duplicate value for hash key: {attr.attr_name}") values_by_attr_name[attr.attr_name] = value if unknown_keys: - raise ValueError( - f'{cls.__name__} received unknown hash keys: ' + ', '.join(unknown_keys) - ) + raise ValueError(f"{cls.__name__} received unknown hash keys: " + ", ".join(unknown_keys)) missing_keys = [ attr.attr_name @@ -356,9 +337,7 @@ def _get_ordered_hash_key_values( if attr.attr_name not in values_by_attr_name ] if missing_keys: - raise ValueError( - f'{cls.__name__} requires values for hash keys: ' + ', '.join(missing_keys) - ) + raise ValueError(f"{cls.__name__} requires values for hash keys: " + ", ".join(missing_keys)) return [values_by_attr_name[attr.attr_name] for attr in hash_key_attributes] @@ -417,37 +396,27 @@ def _get_schema(cls) -> IndexSchema: range_key_attributes = cls._range_key_attributes() for attr_cls in range_key_attributes: - schema['attribute_definitions'].append( - { - ATTR_NAME: attr_cls.attr_name, - ATTR_TYPE: attr_cls.attr_type, - } - ) + schema['attribute_definitions'].append({ + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type, + }) for attr_cls in hash_key_attributes: - schema['attribute_definitions'].append( - { - ATTR_NAME: attr_cls.attr_name, - ATTR_TYPE: attr_cls.attr_type, - } - ) + schema['attribute_definitions'].append({ + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type, + }) for attr_cls in hash_key_attributes: - schema['key_schema'].append( - { - ATTR_NAME: attr_cls.attr_name, - KEY_TYPE: HASH, - } - ) + schema['key_schema'].append({ + ATTR_NAME: attr_cls.attr_name, + KEY_TYPE: HASH, + }) for attr_cls in range_key_attributes: - schema['key_schema'].append( - { - ATTR_NAME: attr_cls.attr_name, - KEY_TYPE: RANGE, - } - ) + schema['key_schema'].append({ + ATTR_NAME: attr_cls.attr_name, + KEY_TYPE: RANGE, + }) if cls.Meta.projection.non_key_attributes: - schema['projection'][NON_KEY_ATTRIBUTES] = ( - cls.Meta.projection.non_key_attributes - ) + schema['projection'][NON_KEY_ATTRIBUTES] = cls.Meta.projection.non_key_attributes return schema @@ -455,15 +424,14 @@ class GlobalSecondaryIndex(Index[_M]): """ A global secondary index """ - @classmethod def _validate_key_attributes(cls) -> None: hash_keys = cls._hash_key_attributes() range_keys = cls._range_key_attributes() if len(hash_keys) > 4: - raise ValueError(f'{cls.__name__} supports at most 4 hash key attributes') + raise ValueError(f"{cls.__name__} supports at most 4 hash key attributes") if len(range_keys) > 4: - raise ValueError(f'{cls.__name__} supports at most 4 range key attributes') + raise ValueError(f"{cls.__name__} supports at most 4 range key attributes") @classmethod def _update_model_schema(cls, schema: ModelSchema) -> None: @@ -473,13 +441,9 @@ def _update_model_schema(cls, schema: ModelSchema) -> None: } if hasattr(cls.Meta, 'read_capacity_units'): - index_schema['provisioned_throughput'][READ_CAPACITY_UNITS] = ( - cls.Meta.read_capacity_units - ) + index_schema['provisioned_throughput'][READ_CAPACITY_UNITS] = cls.Meta.read_capacity_units if hasattr(cls.Meta, 'write_capacity_units'): - index_schema['provisioned_throughput'][WRITE_CAPACITY_UNITS] = ( - cls.Meta.write_capacity_units - ) + index_schema['provisioned_throughput'][WRITE_CAPACITY_UNITS] = cls.Meta.write_capacity_units schema['global_secondary_indexes'].append(index_schema) # With polymorphism, indexes can use the same attribute, e.g. index1 on (thread_id, created_at) @@ -499,9 +463,9 @@ def _validate_key_attributes(cls) -> None: hash_keys = cls._hash_key_attributes() range_keys = cls._range_key_attributes() if len(hash_keys) > 1: - raise ValueError(f'{cls.__name__} supports at most one hash key attribute') + raise ValueError(f"{cls.__name__} supports at most one hash key attribute") if len(range_keys) > 1: - raise ValueError(f'{cls.__name__} supports at most one range key attribute') + raise ValueError(f"{cls.__name__} supports at most one range key attribute") @classmethod def _update_model_schema(cls, schema: ModelSchema) -> None: @@ -514,11 +478,11 @@ def _update_model_schema(cls, schema: ModelSchema) -> None: schema['attribute_definitions'].append(attr_def) + class Projection: """ A class for presenting projections """ - projection_type: Any = None non_key_attributes: Any = None @@ -527,7 +491,6 @@ class KeysOnlyProjection(Projection): """ Keys only projection """ - projection_type = KEYS_ONLY @@ -535,14 +498,11 @@ class IncludeProjection(Projection): """ An INCLUDE projection """ - projection_type = INCLUDE def __init__(self, non_attr_keys: Optional[List[str]] = None) -> None: if not non_attr_keys: - raise ValueError( - 'The INCLUDE type projection requires a list of string attribute names' - ) + raise ValueError("The INCLUDE type projection requires a list of string attribute names") self.non_key_attributes = non_attr_keys @@ -550,5 +510,4 @@ class AllProjection(Projection): """ An ALL projection """ - projection_type = ALL diff --git a/pynamodb/models.py b/pynamodb/models.py index f9870c266..475616498 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -1,14 +1,27 @@ """ DynamoDB Models for PynamoDB """ -import logging -import sys import time +import logging import warnings +import sys from copy import deepcopy from inspect import getmembers -from typing import Any, Dict, Generic, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Tuple, Type, \ - TypeVar, Union, cast +from typing import Any +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Text +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import cast from pynamodb._schema import ModelSchema from pynamodb.connection.base import MetaTable @@ -56,7 +69,6 @@ class BatchWrite(Generic[_T]): """ A class for batch writes """ - def __init__(self, model: Type[_T], auto_commit: bool = True): self.model = model self.auto_commit = auto_commit @@ -79,10 +91,10 @@ def save(self, put_item: _T) -> None: """ if len(self.pending_operations) == self.max_operations: if not self.auto_commit: - raise ValueError('DynamoDB allows a maximum of 25 batch operations') + raise ValueError("DynamoDB allows a maximum of 25 batch operations") else: self.commit() - self.pending_operations.append({'action': PUT, 'item': put_item}) + self.pending_operations.append({"action": PUT, "item": put_item}) def delete(self, del_item: _T) -> None: """ @@ -99,10 +111,10 @@ def delete(self, del_item: _T) -> None: """ if len(self.pending_operations) == self.max_operations: if not self.auto_commit: - raise ValueError('DynamoDB allows a maximum of 25 batch operations') + raise ValueError("DynamoDB allows a maximum of 25 batch operations") else: self.commit() - self.pending_operations.append({'action': DELETE, 'item': del_item}) + self.pending_operations.append({"action": DELETE, "item": del_item}) def __enter__(self): return self @@ -118,7 +130,7 @@ def commit(self) -> None: """ Writes all of the changes that are pending """ - log.debug('%s committing batch operation', self.model) + log.debug("%s committing batch operation", self.model) put_items = [] delete_items = [] for item in self.pending_operations: @@ -144,7 +156,7 @@ def commit(self) -> None: retries += 1 if retries >= self.model.Meta.max_retry_attempts: self.failed_operations = unprocessed_items - raise PutError('Failed to batch write items: max_retry_attempts exceeded') + raise PutError("Failed to batch write items: max_retry_attempts exceeded") put_items = [] delete_items = [] for item in unprocessed_items: @@ -152,11 +164,7 @@ def commit(self) -> None: put_items.append(item.get(PUT_REQUEST).get(ITEM)) # type: ignore elif DELETE_REQUEST in item: delete_items.append(item.get(DELETE_REQUEST).get(KEY)) # type: ignore - log.info( - 'Resending %d unprocessed keys for batch operation (retry %d)', - len(unprocessed_items), - retries, - ) + log.info("Resending %d unprocessed keys for batch operation (retry %d)", len(unprocessed_items), retries) data = self.model._get_connection().batch_write_item( put_items=put_items, delete_items=delete_items, @@ -187,7 +195,6 @@ class MetaModel(AttributeContainerMeta): """ Model meta class """ - def __new__(cls, name, bases, namespace, discriminator=None): # Defined so that the discriminator can be set in the class definition. return super().__new__(cls, name, bases, namespace) @@ -199,24 +206,24 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: for attr_name, attribute in cls.get_attributes().items(): if attribute.is_hash_key: if cls._hash_keyname and cls._hash_keyname != attr_name: - raise ValueError(f'{cls.__name__} has more than one hash key: {cls._hash_keyname}, {attr_name}') + raise ValueError(f"{cls.__name__} has more than one hash key: {cls._hash_keyname}, {attr_name}") cls._hash_keyname = attr_name if attribute.is_range_key: if cls._range_keyname and cls._range_keyname != attr_name: - raise ValueError(f'{cls.__name__} has more than one range key: {cls._range_keyname}, {attr_name}') + raise ValueError(f"{cls.__name__} has more than one range key: {cls._range_keyname}, {attr_name}") cls._range_keyname = attr_name if isinstance(attribute, VersionAttribute): if cls._version_attribute_name and cls._version_attribute_name != attr_name: raise ValueError( - 'The model has more than one Version attribute: {}, {}'.format( - cls._version_attribute_name, attr_name - ) + "The model has more than one Version attribute: {}, {}" + .format(cls._version_attribute_name, attr_name) ) cls._version_attribute_name = attr_name ttl_attr_names = [name for name, attr in cls.get_attributes().items() if isinstance(attr, TTLAttribute)] if len(ttl_attr_names) > 1: - raise ValueError('{} has more than one TTL attribute: {}'.format(cls.__name__, ', '.join(ttl_attr_names))) + raise ValueError("{} has more than one TTL attribute: {}".format( + cls.__name__, ", ".join(ttl_attr_names))) if isinstance(namespace, dict): for attr_name, attr_obj in namespace.items(): @@ -226,7 +233,7 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: if not hasattr(attr_obj, HOST): setattr(attr_obj, HOST, get_settings_value('host')) if hasattr(attr_obj, 'session_cls') or hasattr(attr_obj, 'request_timeout_seconds'): - warnings.warn('The `session_cls` and `request_timeout_second` options are no longer supported') + warnings.warn("The `session_cls` and `request_timeout_second` options are no longer supported") if not hasattr(attr_obj, 'connect_timeout_seconds'): setattr(attr_obj, 'connect_timeout_seconds', get_settings_value('connect_timeout_seconds')) if not hasattr(attr_obj, 'read_timeout_seconds'): @@ -245,13 +252,13 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: setattr(attr_obj, 'aws_session_token', None) # create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist, - # so that 'except Model.DoesNotExist:' would not catch other models' exceptions + # so that "except Model.DoesNotExist:" would not catch other models' exceptions if 'DoesNotExist' not in namespace: exception_attrs = { '__module__': namespace.get('__module__'), - '__qualname__': f'{cls.__qualname__}.DoesNotExist', + '__qualname__': f'{cls.__qualname__}.{"DoesNotExist"}', } - cls.DoesNotExist = type('DoesNotExist', (DoesNotExist,), exception_attrs) + cls.DoesNotExist = type('DoesNotExist', (DoesNotExist, ), exception_attrs) @staticmethod def _initialize_indexes(cls): @@ -287,11 +294,11 @@ class Model(AttributeContainer, metaclass=MetaModel): _indexes: Dict[str, Index] def __init__( - self, - hash_key: Optional[_KeyType] = None, - range_key: Optional[_KeyType] = None, - _user_instantiated: bool = True, - **attributes: Any, + self, + hash_key: Optional[_KeyType] = None, + range_key: Optional[_KeyType] = None, + _user_instantiated: bool = True, + **attributes: Any, ) -> None: """ :param hash_key: Required. The hash key for this object. @@ -300,24 +307,20 @@ def __init__( """ if hash_key is not None: if self._hash_keyname is None: - raise ValueError( - f'This model has no hash key, but a hash key value was provided: {hash_key}' - ) + raise ValueError(f"This model has no hash key, but a hash key value was provided: {hash_key}") attributes[self._hash_keyname] = hash_key if range_key is not None: if self._range_keyname is None: - raise ValueError( - f'This model has no range key, but a range key value was provided: {range_key}' - ) + raise ValueError(f"This model has no range key, but a range key value was provided: {range_key}") attributes[self._range_keyname] = range_key super(Model, self).__init__(_user_instantiated=_user_instantiated, **attributes) @classmethod def batch_get( - cls: Type[_T], - items: Iterable[Union[_KeyType, Iterable[_KeyType]]], - consistent_read: Optional[bool] = None, - attributes_to_get: Optional[Sequence[str]] = None, + cls: Type[_T], + items: Iterable[Union[_KeyType, Iterable[_KeyType]]], + consistent_read: Optional[bool] = None, + attributes_to_get: Optional[Sequence[str]] = None, ) -> Iterator[_T]: """ BatchGetItem for this model @@ -346,27 +349,23 @@ def batch_get( item = items.pop() if range_key_attribute: if isinstance(item, str): - raise ValueError( - f'Invalid key value {item!r}: ' - 'expected non-str iterable with exactly 2 elements (hash key, range key)' - ) + raise ValueError(f'Invalid key value {item!r}: ' + 'expected non-str iterable with exactly 2 elements (hash key, range key)') try: hash_key, range_key = cast(Tuple[_KeyType, _KeyType], item) except (TypeError, ValueError): - raise ValueError( - f'Invalid key value {item!r}: ' - 'expected iterable with exactly 2 elements (hash key, range key)' - ) + raise ValueError(f'Invalid key value {item!r}: ' + 'expected iterable with exactly 2 elements (hash key, range key)') hash_key_ser, range_key_ser = cls._serialize_keys(hash_key, range_key) - keys_to_get.append( - { - hash_key_attribute.attr_name: hash_key_ser, - range_key_attribute.attr_name: range_key_ser, - } - ) + keys_to_get.append({ + hash_key_attribute.attr_name: hash_key_ser, + range_key_attribute.attr_name: range_key_ser, + }) else: hash_key_ser, _ = cls._serialize_keys(item) - keys_to_get.append({hash_key_attribute.attr_name: hash_key_ser}) + keys_to_get.append({ + hash_key_attribute.attr_name: hash_key_ser + }) while keys_to_get: page, unprocessed_keys = cls._batch_get_page( @@ -394,12 +393,7 @@ def batch_write(cls: Type[_T], auto_commit: bool = True) -> BatchWrite[_T]: """ return BatchWrite(cls, auto_commit=auto_commit) - def delete( - self, - condition: Optional[Condition] = None, - *, - add_version_condition: bool = True, - ) -> Any: + def delete(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Any: """ Deletes this object from DynamoDB. @@ -414,17 +408,9 @@ def delete( if add_version_condition and version_condition is not None: condition &= version_condition - return self._get_connection().delete_item( - hk_value, range_key=rk_value, condition=condition - ) + return self._get_connection().delete_item(hk_value, range_key=rk_value, condition=condition) - def update( - self, - actions: List[Action], - condition: Optional[Condition] = None, - *, - add_version_condition: bool = True, - ) -> Any: + def update(self, actions: List[Action], condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Any: """ Updates an item using the UpdateItem operation. @@ -434,47 +420,30 @@ def update( :param add_version_condition: For models which have a :class:`~pynamodb.attributes.VersionAttribute`, specifies whether only to update if the version matches the model that is currently loaded. Set to `False` for a 'last write wins' strategy. - Regardless, the version will always be incremented to prevent 'rollbacks' by concurrent :meth:`save` calls. + Regardless, the version will always be incremented to prevent "rollbacks" by concurrent :meth:`save` calls. :raises pynamodb.exceptions.UpdateError: if the `condition` is not met """ if not isinstance(actions, list) or len(actions) == 0: - raise TypeError('the value of `actions` is expected to be a non-empty list') + raise TypeError("the value of `actions` is expected to be a non-empty list") hk_value, rk_value = self._get_hash_range_key_serialized_values() version_condition = self._handle_version_attribute(actions=actions) if add_version_condition and version_condition is not None: condition &= version_condition - data = self._get_connection().update_item( - hk_value, - range_key=rk_value, - return_values=ALL_NEW, - condition=condition, - actions=actions, - ) + data = self._get_connection().update_item(hk_value, range_key=rk_value, return_values=ALL_NEW, condition=condition, actions=actions) item_data = data[ATTRIBUTES] stored_cls = self._get_discriminator_class(item_data) if stored_cls and stored_cls != type(self): - raise ValueError( - 'Cannot update this item from the returned class: {}'.format( - stored_cls.__name__ - ) - ) + raise ValueError("Cannot update this item from the returned class: {}".format(stored_cls.__name__)) self.deserialize(item_data) return data - def save( - self, - condition: Optional[Condition] = None, - *, - add_version_condition: bool = True, - ) -> Dict[str, Any]: + def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Dict[str, Any]: """ Save this object to dynamodb """ - args, kwargs = self._get_save_args( - condition=condition, add_version_condition=add_version_condition - ) + args, kwargs = self._get_save_args(condition=condition, add_version_condition=add_version_condition) data = self._get_connection().put_item(*args, **kwargs) self.update_local_version_attribute() return data @@ -488,28 +457,22 @@ def refresh(self, consistent_read: bool = False) -> None: :raises ModelInstance.DoesNotExist: if the object to be updated does not exist """ hk_value, rk_value = self._get_hash_range_key_serialized_values() - attrs = self._get_connection().get_item( - hk_value, range_key=rk_value, consistent_read=consistent_read - ) + attrs = self._get_connection().get_item(hk_value, range_key=rk_value, consistent_read=consistent_read) item_data = attrs.get(ITEM, None) if item_data is None: - raise self.DoesNotExist('This item does not exist in the table.') + raise self.DoesNotExist("This item does not exist in the table.") stored_cls = self._get_discriminator_class(item_data) if stored_cls and stored_cls != type(self): - raise ValueError( - 'Cannot refresh this item from the returned class: {}'.format( - stored_cls.__name__ - ) - ) + raise ValueError("Cannot refresh this item from the returned class: {}".format(stored_cls.__name__)) self.deserialize(item_data) def get_update_kwargs_from_instance( - self, - actions: List[Action], - condition: Optional[Condition] = None, - return_values_on_condition_failure: Optional[str] = None, - *, - add_version_condition: bool = True, + self, + actions: List[Action], + condition: Optional[Condition] = None, + return_values_on_condition_failure: Optional[str] = None, + *, + add_version_condition: bool = True, ) -> Dict[str, Any]: hk_value, rk_value = self._get_hash_range_key_serialized_values() @@ -517,21 +480,14 @@ def get_update_kwargs_from_instance( if add_version_condition and version_condition is not None: condition &= version_condition - return self._get_connection().get_operation_kwargs( - hk_value, - range_key=rk_value, - key=KEY, - actions=actions, - condition=condition, - return_values_on_condition_failure=return_values_on_condition_failure, - ) + return self._get_connection().get_operation_kwargs(hk_value, range_key=rk_value, key=KEY, actions=actions, condition=condition, return_values_on_condition_failure=return_values_on_condition_failure) def get_delete_kwargs_from_instance( - self, - condition: Optional[Condition] = None, - return_values_on_condition_failure: Optional[str] = None, - *, - add_version_condition: bool = True, + self, + condition: Optional[Condition] = None, + return_values_on_condition_failure: Optional[str] = None, + *, + add_version_condition: bool = True, ) -> Dict[str, Any]: hk_value, rk_value = self._get_hash_range_key_serialized_values() @@ -539,45 +495,39 @@ def get_delete_kwargs_from_instance( if add_version_condition and version_condition is not None: condition &= version_condition - return self._get_connection().get_operation_kwargs( - hk_value, - range_key=rk_value, - key=KEY, - condition=condition, - return_values_on_condition_failure=return_values_on_condition_failure, - ) + return self._get_connection().get_operation_kwargs(hk_value, range_key=rk_value, key=KEY, condition=condition, return_values_on_condition_failure=return_values_on_condition_failure) def get_save_kwargs_from_instance( - self, - condition: Optional[Condition] = None, - return_values_on_condition_failure: Optional[str] = None, + self, + condition: Optional[Condition] = None, + return_values_on_condition_failure: Optional[str] = None, ) -> Dict[str, Any]: args, save_kwargs = self._get_save_args(condition=condition) save_kwargs['key'] = ITEM - save_kwargs['return_values_on_condition_failure'] = ( - return_values_on_condition_failure - ) + save_kwargs['return_values_on_condition_failure'] = return_values_on_condition_failure return self._get_connection().get_operation_kwargs(*args, **save_kwargs) @classmethod def get_operation_kwargs_from_class( - cls, - hash_key: _KeyType, - range_key: Optional[_KeyType] = None, - condition: Optional[Condition] = None, + cls, + hash_key: _KeyType, + range_key: Optional[_KeyType] = None, + condition: Optional[Condition] = None, ) -> Dict[str, Any]: hash_key, range_key = cls._serialize_keys(hash_key, range_key) return cls._get_connection().get_operation_kwargs( - hash_key=cast(str, hash_key), range_key=cast(Optional[str], range_key), condition=condition + hash_key=cast(str, hash_key), + range_key=cast(Optional[str], range_key), + condition=condition ) @classmethod def get( - cls: Type[_T], - hash_key: _KeyType, - range_key: Optional[_KeyType] = None, - consistent_read: bool = False, - attributes_to_get: Optional[Sequence[Text]] = None, + cls: Type[_T], + hash_key: _KeyType, + range_key: Optional[_KeyType] = None, + consistent_read: bool = False, + attributes_to_get: Optional[Sequence[Text]] = None, ) -> _T: """ Returns a single object using the provided keys @@ -611,21 +561,21 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T: :param data: A serialized DynamoDB object """ if data is None: - raise ValueError('Received no data to construct object') + raise ValueError("Received no data to construct object") return cls._instantiate(data) @classmethod def count( - cls: Type[_T], - hash_key: Optional[_KeyType] = None, - range_key_condition: Optional[Condition] = None, - filter_condition: Optional[Condition] = None, - consistent_read: bool = False, - index_name: Optional[str] = None, - limit: Optional[int] = None, - rate_limit: Optional[float] = None, - hash_keys: Optional[Mapping[str, _KeyType]] = None, + cls: Type[_T], + hash_key: Optional[_KeyType] = None, + range_key_condition: Optional[Condition] = None, + filter_condition: Optional[Condition] = None, + consistent_read: bool = False, + index_name: Optional[str] = None, + limit: Optional[int] = None, + rate_limit: Optional[float] = None, + hash_keys: Optional[Mapping[str, _KeyType]] = None, ) -> int: """ Provides a filtered count @@ -640,9 +590,7 @@ def count( """ if hash_key is None and hash_keys is None: if index_name: - raise ValueError( - 'A hash_key or hash_keys must be given to query an index' - ) + raise ValueError('A hash_key or hash_keys must be given to query an index') if filter_condition is not None: raise ValueError('A hash_key must be given to use filters') return cls.describe_table().get(ITEM_COUNT) @@ -651,9 +599,7 @@ def count( if index_name: index = cls._indexes[index_name] range_key_condition = index._normalize_range_key_condition(range_key_condition) - serialized_hash_key = index._serialize_hash_key_values( - hash_key, hash_keys=hash_keys - ) + serialized_hash_key = index._serialize_hash_key_values(hash_key, hash_keys=hash_keys) if isinstance(serialized_hash_key, dict): serialized_hash_keys = serialized_hash_key hash_key = None @@ -667,9 +613,7 @@ def count( # If this class has a discriminator attribute, filter the query to only return instances of this class. discriminator_attr = cls._get_discriminator_attribute() if discriminator_attr: - filter_condition &= discriminator_attr.is_in( - *discriminator_attr.get_registered_subclasses(cls) - ) + filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls)) query_args = (hash_key,) query_kwargs = dict( @@ -697,19 +641,19 @@ def count( @classmethod def query( - cls: Type[_T], - hash_key: Optional[_KeyType] = None, - range_key_condition: Optional[Condition] = None, - filter_condition: Optional[Condition] = None, - consistent_read: bool = False, - index_name: Optional[str] = None, - scan_index_forward: Optional[bool] = None, - limit: Optional[int] = None, - last_evaluated_key: Optional[Dict[str, Dict[str, Any]]] = None, - attributes_to_get: Optional[Iterable[str]] = None, - page_size: Optional[int] = None, - rate_limit: Optional[float] = None, - hash_keys: Optional[Mapping[str, _KeyType]] = None, + cls: Type[_T], + hash_key: Optional[_KeyType] = None, + range_key_condition: Optional[Condition] = None, + filter_condition: Optional[Condition] = None, + consistent_read: bool = False, + index_name: Optional[str] = None, + scan_index_forward: Optional[bool] = None, + limit: Optional[int] = None, + last_evaluated_key: Optional[Dict[str, Dict[str, Any]]] = None, + attributes_to_get: Optional[Iterable[str]] = None, + page_size: Optional[int] = None, + rate_limit: Optional[float] = None, + hash_keys: Optional[Mapping[str, _KeyType]] = None, ) -> ResultIterator[_T]: """ Provides a high level query API @@ -735,9 +679,7 @@ def query( if index_name: index = cls._indexes[index_name] range_key_condition = index._normalize_range_key_condition(range_key_condition) - serialized_hash_key = index._serialize_hash_key_values( - hash_key, hash_keys=hash_keys - ) + serialized_hash_key = index._serialize_hash_key_values(hash_key, hash_keys=hash_keys) if isinstance(serialized_hash_key, dict): serialized_hash_keys = serialized_hash_key hash_key = None @@ -751,9 +693,7 @@ def query( # If this class has a discriminator attribute, filter the query to only return instances of this class. discriminator_attr = cls._get_discriminator_attribute() if discriminator_attr: - filter_condition &= discriminator_attr.is_in( - *discriminator_attr.get_registered_subclasses(cls) - ) + filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls)) if page_size is None: page_size = limit @@ -782,17 +722,17 @@ def query( @classmethod def scan( - cls: Type[_T], - filter_condition: Optional[Condition] = None, - segment: Optional[int] = None, - total_segments: Optional[int] = None, - limit: Optional[int] = None, - last_evaluated_key: Optional[Dict[str, Dict[str, Any]]] = None, - page_size: Optional[int] = None, - consistent_read: Optional[bool] = None, - index_name: Optional[str] = None, - rate_limit: Optional[float] = None, - attributes_to_get: Optional[Sequence[str]] = None, + cls: Type[_T], + filter_condition: Optional[Condition] = None, + segment: Optional[int] = None, + total_segments: Optional[int] = None, + limit: Optional[int] = None, + last_evaluated_key: Optional[Dict[str, Dict[str, Any]]] = None, + page_size: Optional[int] = None, + consistent_read: Optional[bool] = None, + index_name: Optional[str] = None, + rate_limit: Optional[float] = None, + attributes_to_get: Optional[Sequence[str]] = None, ) -> ResultIterator[_T]: """ Iterates through all items in the table @@ -811,9 +751,7 @@ def scan( # If this class has a discriminator attribute, filter the scan to only return instances of this class. discriminator_attr = cls._get_discriminator_attribute() if discriminator_attr: - filter_condition &= discriminator_attr.is_in( - *discriminator_attr.get_registered_subclasses(cls) - ) + filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls)) if page_size is None: page_size = limit @@ -827,7 +765,7 @@ def scan( total_segments=total_segments, consistent_read=consistent_read, index_name=index_name, - attributes_to_get=attributes_to_get, + attributes_to_get=attributes_to_get ) return ResultIterator( @@ -873,12 +811,12 @@ def describe_table(cls) -> Any: @classmethod def create_table( - cls, - wait: bool = False, - read_capacity_units: Optional[int] = None, - write_capacity_units: Optional[int] = None, - billing_mode: Optional[str] = None, - ignore_update_ttl_errors: bool = False, + cls, + wait: bool = False, + read_capacity_units: Optional[int] = None, + write_capacity_units: Optional[int] = None, + billing_mode: Optional[str] = None, + ignore_update_ttl_errors: bool = False, ) -> Any: """ Create the table for this model @@ -903,7 +841,7 @@ def create_table( if hasattr(cls.Meta, 'stream_view_type'): operation_kwargs['stream_specification'] = { 'stream_enabled': True, - 'stream_view_type': cls.Meta.stream_view_type, + 'stream_view_type': cls.Meta.stream_view_type } if hasattr(cls.Meta, 'billing_mode'): operation_kwargs['billing_mode'] = cls.Meta.billing_mode @@ -915,7 +853,9 @@ def create_table( operation_kwargs['write_capacity_units'] = write_capacity_units if billing_mode is not None: operation_kwargs['billing_mode'] = billing_mode - cls._get_connection().create_table(**operation_kwargs) + cls._get_connection().create_table( + **operation_kwargs + ) if wait: while True: status = cls._get_connection().describe_table() @@ -926,7 +866,7 @@ def create_table( else: time.sleep(2) else: - raise TableError('No TableStatus returned for table') + raise TableError("No TableStatus returned for table") cls.update_ttl(ignore_update_ttl_errors) @@ -944,9 +884,7 @@ def update_ttl(cls, ignore_update_ttl_errors: bool) -> None: cls._get_connection().update_time_to_live(ttl_attribute.attr_name) except Exception: if ignore_update_ttl_errors: - log.info( - 'Unable to update the TTL for {}'.format(cls.Meta.table_name) - ) + log.info("Unable to update the TTL for {}".format(cls.Meta.table_name)) else: raise @@ -965,15 +903,14 @@ def _get_schema(cls) -> ModelSchema: } for attr_name, attr_cls in cls.get_attributes().items(): if attr_cls.is_hash_key or attr_cls.is_range_key: - schema['attribute_definitions'].append( - {ATTR_NAME: attr_cls.attr_name, ATTR_TYPE: attr_cls.attr_type} - ) - schema['key_schema'].append( - { - KEY_TYPE: HASH if attr_cls.is_hash_key else RANGE, - ATTR_NAME: attr_cls.attr_name, - } - ) + schema['attribute_definitions'].append({ + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type + }) + schema['key_schema'].append({ + KEY_TYPE: HASH if attr_cls.is_hash_key else RANGE, + ATTR_NAME: attr_cls.attr_name + }) indexes = cls._indexes.copy() # add indexes from derived classes that we might initialize @@ -987,12 +924,7 @@ def _get_schema(cls) -> ModelSchema: return schema - def _get_save_args( - self, - condition: Optional[Condition] = None, - *, - add_version_condition: bool = True, - ) -> Tuple[Iterable[Any], Dict[str, Any]]: + def _get_save_args(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True) -> Tuple[Iterable[Any], Dict[str, Any]]: """ Gets the proper *args, **kwargs for saving and retrieving this object @@ -1005,16 +937,12 @@ def _get_save_args( """ attribute_values = self.serialize(null_check=True) hash_key_attribute = self._hash_key_attribute() - hash_key = attribute_values.pop(hash_key_attribute.attr_name, {}).get( - hash_key_attribute.attr_type - ) + hash_key = attribute_values.pop(hash_key_attribute.attr_name, {}).get(hash_key_attribute.attr_type) range_key = None range_key_attribute = self._range_key_attribute() if range_key_attribute: - range_key = attribute_values.pop(range_key_attribute.attr_name, {}).get( - range_key_attribute.attr_type - ) - args = (hash_key,) + range_key = attribute_values.pop(range_key_attribute.attr_name, {}).get(range_key_attribute.attr_type) + args = (hash_key, ) kwargs = {} if range_key is not None: kwargs['range_key'] = range_key @@ -1027,7 +955,7 @@ def _get_save_args( def _get_hash_range_key_serialized_values(self) -> Tuple[Any, Optional[Any]]: if self._hash_keyname is None: - raise Exception('The model has no hash key') + raise Exception("The model has no hash key") attrs = self.get_attributes() @@ -1042,12 +970,7 @@ def _get_hash_range_key_serialized_values(self) -> Tuple[Any, Optional[Any]]: return hk_serialized_value, rk_serialized_value - def _handle_version_attribute( - self, - *, - attributes: Optional[Dict[str, Any]] = None, - actions: Optional[List[Action]] = None, - ) -> Optional[Condition]: + def _handle_version_attribute(self, *, attributes: Optional[Dict[str, Any]] = None, actions: Optional[List[Action]] = None) -> Optional[Condition]: """ Handles modifying the request to set or increment the version attribute. """ @@ -1060,17 +983,13 @@ def _handle_version_attribute( if value is not None: condition = version_attribute == value if attributes is not None: - attributes[version_attribute.attr_name] = self._serialize_value( - version_attribute, value + 1 - ) + attributes[version_attribute.attr_name] = self._serialize_value(version_attribute, value + 1) if actions is not None: actions.append(version_attribute.add(1)) else: condition = version_attribute.does_not_exist() if attributes is not None: - attributes[version_attribute.attr_name] = self._serialize_value( - version_attribute, 1 - ) + attributes[version_attribute.attr_name] = self._serialize_value(version_attribute, 1) if actions is not None: actions.append(version_attribute.set(1)) @@ -1135,11 +1054,9 @@ def _batch_get_page(cls, keys_to_get, consistent_read, attributes_to_get): :param consistent_read: Whether or not this needs to be consistent :param attributes_to_get: A list of attributes to return """ - log.debug('Fetching a BatchGetItem page') + log.debug("Fetching a BatchGetItem page") data = cls._get_connection().batch_get_item( - keys_to_get, - consistent_read=consistent_read, - attributes_to_get=attributes_to_get, + keys_to_get, consistent_read=consistent_read, attributes_to_get=attributes_to_get, ) responses = cast(Dict[str, Any], data.get(RESPONSES, {})) item_data = responses.get(cls.Meta.table_name) @@ -1152,63 +1069,57 @@ def _get_connection(cls) -> TableConnection: """ Returns a (cached) connection """ - if not hasattr(cls, 'Meta'): + if not hasattr(cls, "Meta"): raise AttributeError( 'As of v1.0 PynamoDB Models require a `Meta` class.\n' 'Model: {}.{}\n' 'See https://pynamodb.readthedocs.io/en/latest/release_notes.html\n'.format( - cls.__module__, - cls.__name__, + cls.__module__, cls.__name__, ), ) - elif not hasattr(cls.Meta, 'table_name') or cls.Meta.table_name is None: + elif not hasattr(cls.Meta, "table_name") or cls.Meta.table_name is None: raise AttributeError( 'As of v1.0 PynamoDB Models must have a table_name\n' 'Model: {}.{}\n' 'See https://pynamodb.readthedocs.io/en/latest/release_notes.html'.format( - cls.__module__, - cls.__name__, + cls.__module__, cls.__name__, ), ) # For now we just check that the connection exists and (in the case of model inheritance) # points to the same table. In the future we should update the connection if any of the attributes differ. if cls._connection is None or cls._connection.table_name != cls.Meta.table_name: schema = cls._get_schema() - meta_table = MetaTable( - { - constants.TABLE_NAME: cls.Meta.table_name, - constants.KEY_SCHEMA: schema['key_schema'], - constants.ATTR_DEFINITIONS: schema['attribute_definitions'], - constants.GLOBAL_SECONDARY_INDEXES: [ - { - constants.INDEX_NAME: index_schema['index_name'], - constants.KEY_SCHEMA: index_schema['key_schema'], - } - for index_schema in schema['global_secondary_indexes'] - ], - constants.LOCAL_SECONDARY_INDEXES: [ - { - constants.INDEX_NAME: index_schema['index_name'], - constants.KEY_SCHEMA: index_schema['key_schema'], - } - for index_schema in schema['local_secondary_indexes'] - ], - } - ) - cls._connection = TableConnection( - cls.Meta.table_name, - meta_table=meta_table, - region=cls.Meta.region, - host=cls.Meta.host, - connect_timeout_seconds=cls.Meta.connect_timeout_seconds, - read_timeout_seconds=cls.Meta.read_timeout_seconds, - max_retry_attempts=cls.Meta.max_retry_attempts, - max_pool_connections=cls.Meta.max_pool_connections, - extra_headers=cls.Meta.extra_headers, - aws_access_key_id=cls.Meta.aws_access_key_id, - aws_secret_access_key=cls.Meta.aws_secret_access_key, - aws_session_token=cls.Meta.aws_session_token, - ) + meta_table = MetaTable({ + constants.TABLE_NAME: cls.Meta.table_name, + constants.KEY_SCHEMA: schema['key_schema'], + constants.ATTR_DEFINITIONS: schema['attribute_definitions'], + constants.GLOBAL_SECONDARY_INDEXES: [ + { + constants.INDEX_NAME: index_schema['index_name'], + constants.KEY_SCHEMA: index_schema['key_schema'], + } + for index_schema in schema['global_secondary_indexes'] + ], + constants.LOCAL_SECONDARY_INDEXES: [ + { + constants.INDEX_NAME: index_schema['index_name'], + constants.KEY_SCHEMA: index_schema['key_schema'], + } + for index_schema in schema['local_secondary_indexes'] + ], + }) + cls._connection = TableConnection(cls.Meta.table_name, + meta_table=meta_table, + region=cls.Meta.region, + host=cls.Meta.host, + connect_timeout_seconds=cls.Meta.connect_timeout_seconds, + read_timeout_seconds=cls.Meta.read_timeout_seconds, + max_retry_attempts=cls.Meta.max_retry_attempts, + max_pool_connections=cls.Meta.max_pool_connections, + extra_headers=cls.Meta.extra_headers, + aws_access_key_id=cls.Meta.aws_access_key_id, + aws_secret_access_key=cls.Meta.aws_secret_access_key, + aws_session_token=cls.Meta.aws_session_token) return cls._connection @classmethod @@ -1229,11 +1140,7 @@ def _serialize_value(cls, attr, value): return {attr.attr_type: serialized} @classmethod - def _serialize_keys( - cls, - hash_key: _KeyType, - range_key: Optional[_KeyType] = None, - ) -> Tuple[Any, Any]: + def _serialize_keys(cls, hash_key: _KeyType, range_key: Optional[_KeyType] = None) -> Tuple[Any, Any]: """ Serializes the hash and range keys @@ -1273,7 +1180,6 @@ class _ModelFuture(Generic[_T]): For example: when performing a TransactGet request, this is a stand-in for a model that will be returned when the operation is complete """ - def __init__(self, model_cls: Type[_T]) -> None: self._model_cls = model_cls self._model: Optional[_T] = None