diff --git a/docs/indexes.rst b/docs/indexes.rst index c475a046..b539b3ff 100644 --- a/docs/indexes.rst +++ b/docs/indexes.rst @@ -124,6 +124,47 @@ range key of the index. Here is an example that queries the index for values of print("Item queried from index: {0}".format(item.view)) +Composite Global Secondary Keys +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +DynamoDB global secondary indexes support multi-attribute partition and sort keys. +To define a composite key in PynamoDB, declare multiple ``hash_key=True`` and/or +``range_key=True`` attributes on a ``GlobalSecondaryIndex`` in the order you want +them used. + +.. code-block:: python + + class TournamentRegionIndex(GlobalSecondaryIndex): + class Meta: + projection = AllProjection() + read_capacity_units = 2 + write_capacity_units = 1 + + tournament_id = UnicodeAttribute(hash_key=True) + region = UnicodeAttribute(hash_key=True) + round = UnicodeAttribute(range_key=True) + bracket = UnicodeAttribute(range_key=True) + +When querying a composite GSI, pass all partition-key values with ``hash_keys``. +PynamoDB validates the supplied names and sends them in the index declaration +order: + +.. code-block:: python + + for item in MatchModel.tournament_region_index.query( + hash_keys={ + 'tournament_id': 'WINTER2024', + 'region': 'NA-EAST', + } + ): + print(item) + +For sort-key conditions, DynamoDB enforces left-to-right semantics across declared +sort-key attributes. If a later sort-key attribute is used, all preceding sort-key +attributes must have equality conditions. See the DynamoDB documentation for details: +https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/GSI.DesignPattern.MultiAttributeKeys.html + + Pagination and last evaluated key ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 301ee171..6a641fdd 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -19,7 +19,6 @@ from botocore.session import get_session from pynamodb.connection._botocore_private import BotocoreBaseClientPrivate -from pynamodb._util import bin_decode_attr from pynamodb.constants import ( RETURN_CONSUMED_CAPACITY_VALUES, RETURN_ITEM_COLL_METRICS_VALUES, RETURN_ITEM_COLL_METRICS, RETURN_CONSUMED_CAPACITY, RETURN_VALUES_VALUES, @@ -118,12 +117,12 @@ def get_key_names(self, index_name=None): if self.range_keyname: key_names.append(self.range_keyname) if index_name is not None: - index_hash_keyname = self.get_index_hash_keyname(index_name) - if index_hash_keyname not in key_names: - key_names.append(index_hash_keyname) - index_range_keyname = self.get_index_range_keyname(index_name) - if index_range_keyname is not None and index_range_keyname not in key_names: - key_names.append(index_range_keyname) + for index_hash_keyname in self.get_index_hash_keynames(index_name): + if index_hash_keyname not in key_names: + key_names.append(index_hash_keyname) + for index_range_keyname in self.get_index_range_keynames(index_name): + if index_range_keyname not in key_names: + key_names.append(index_range_keyname) return key_names def has_index_name(self, index_name): @@ -139,6 +138,15 @@ def get_index_hash_keyname(self, index_name: str) -> str: """ Returns the name of the hash key for a given index """ + hash_keynames = self.get_index_hash_keynames(index_name) + if hash_keynames: + return hash_keynames[0] + raise ValueError("No hash key attribute for index: {}".format(index_name)) + + def get_index_hash_keynames(self, index_name: str) -> List[str]: + """ + Returns the names of the hash keys for a given index + """ global_indexes = self.data.get(GLOBAL_SECONDARY_INDEXES) local_indexes = self.data.get(LOCAL_SECONDARY_INDEXES) indexes = [] @@ -148,14 +156,26 @@ def get_index_hash_keyname(self, index_name: str) -> str: indexes += global_indexes for index in indexes: if index.get(INDEX_NAME) == index_name: + hash_keynames = [] for schema_key in index.get(KEY_SCHEMA): if schema_key.get(KEY_TYPE) == HASH: - return schema_key.get(ATTR_NAME) + hash_keynames.append(schema_key.get(ATTR_NAME)) + if hash_keynames: + return hash_keynames raise ValueError("No hash key attribute for index: {}".format(index_name)) def get_index_range_keyname(self, index_name): """ - Returns the name of the hash key for a given index + Returns the name of the range key for a given index + """ + range_keynames = self.get_index_range_keynames(index_name) + if range_keynames: + return range_keynames[0] + return None + + def get_index_range_keynames(self, index_name) -> List[str]: + """ + Returns the names of the range keys for a given index """ global_indexes = self.data.get(GLOBAL_SECONDARY_INDEXES) local_indexes = self.data.get(LOCAL_SECONDARY_INDEXES) @@ -166,10 +186,12 @@ def get_index_range_keyname(self, index_name): indexes += global_indexes for index in indexes: if index.get(INDEX_NAME) == index_name: + range_keynames = [] for schema_key in index.get(KEY_SCHEMA): if schema_key.get(KEY_TYPE) == RANGE: - return schema_key.get(ATTR_NAME) - return None + range_keynames.append(schema_key.get(ATTR_NAME)) + return range_keynames + return [] def get_item_attribute_map(self, attributes: Dict, item_key=ITEM, pythonic_key: bool = True): """ @@ -1173,7 +1195,7 @@ def scan( def query( self, table_name: str, - hash_key: str, + hash_key: Optional[Any] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Any] = None, attributes_to_get: Optional[Any] = None, @@ -1184,6 +1206,7 @@ def query( return_consumed_capacity: Optional[str] = None, scan_index_forward: Optional[bool] = None, select: Optional[str] = None, + hash_keys: Optional[Mapping[str, Any]] = None, ) -> Dict: """ Performs the Query operation and returns the result @@ -1201,15 +1224,32 @@ def query( if index_name: if not tbl.has_index_name(index_name): raise ValueError("Table {} has no index: {}".format(table_name, index_name)) - hash_keyname = tbl.get_index_hash_keyname(index_name) + hash_keynames = tbl.get_index_hash_keynames(index_name) + range_keynames = tbl.get_index_range_keynames(index_name) else: - hash_keyname = tbl.hash_keyname + hash_keynames = [tbl.hash_keyname] + range_keynames = [tbl.range_keyname] if tbl.range_keyname else [] - hash_condition_value = {self.get_attribute_type(table_name, hash_keyname, hash_key): self.parse_attribute(hash_key)} - key_condition = Path([hash_keyname]) == hash_condition_value + hash_key_values = self._get_query_hash_key_values( + hash_key, + hash_keys, + hash_keynames, + index_name=index_name, + ) + range_key_condition = self._normalize_multi_range_key_condition( + range_key_condition, + range_keynames, + index_name=index_name, + ) + key_condition = None + for hash_keyname, hash_keyvalue in zip(hash_keynames, hash_key_values): + hash_condition_value = {self.get_attribute_type(table_name, hash_keyname, hash_keyvalue): self.parse_attribute(hash_keyvalue)} + hash_condition = Path([hash_keyname]) == hash_condition_value + key_condition = hash_condition if key_condition is None else key_condition & hash_condition if range_key_condition is not None: key_condition &= range_key_condition + assert key_condition is not None operation_kwargs[KEY_CONDITION_EXPRESSION] = key_condition.serialize( name_placeholders, expression_attribute_values) if filter_condition is not None: @@ -1252,3 +1292,145 @@ def _check_condition(self, name, condition): @staticmethod def _reverse_dict(d): return {v: k for k, v in d.items()} + + @staticmethod + def _get_query_hash_key_values( + hash_key: Optional[Any], + hash_keys: Optional[Mapping[str, Any]], + hash_keynames: Sequence[str], + index_name: Optional[str] = None, + ) -> List[Any]: + if hash_key is not None and hash_keys is not None: + raise ValueError(f"Index {index_name} received both hash_key and hash_keys") + if len(hash_keynames) == 1: + if hash_keys is None: + if hash_key is None: + raise ValueError(f"Index {index_name} requires a hash_key") + if isinstance(hash_key, (tuple, list)) or ( + isinstance(hash_key, Mapping) + and not any(key in ATTRIBUTE_TYPES for key in hash_key) + ): + raise ValueError( + f"Index {index_name} expects a single hash_key value" + ) + return [hash_key] + return Connection._get_ordered_query_hash_key_values( + hash_keys, hash_keynames, index_name=index_name + ) + if hash_key is not None: + raise ValueError( + f"Index {index_name} has multiple hash key attributes; use hash_keys=..." + ) + if hash_keys is None: + raise ValueError(f"Index {index_name} requires hash_keys") + return Connection._get_ordered_query_hash_key_values( + hash_keys, hash_keynames, index_name=index_name + ) + + @staticmethod + def _get_ordered_query_hash_key_values( + hash_keys: Mapping[str, Any], + hash_keynames: Sequence[str], + index_name: Optional[str] = None, + ) -> List[Any]: + if not isinstance(hash_keys, Mapping): + raise ValueError(f"Index {index_name} expects hash_keys to be a mapping") + missing_keys = [ + keyname for keyname in hash_keynames if keyname not in hash_keys + ] + if missing_keys: + raise ValueError( + f"Index {index_name} requires values for hash keys: {', '.join(missing_keys)}" + ) + extra_keys = [keyname for keyname in hash_keys if keyname not in hash_keynames] + if extra_keys: + raise ValueError( + f"Index {index_name} received unknown hash keys: {', '.join(extra_keys)}" + ) + return [hash_keys[keyname] for keyname in hash_keynames] + + @staticmethod + def _flatten_and_conditions(condition: Condition) -> List[Condition]: + if condition.operator == "AND": + conditions = [] + for value in condition.values: + conditions.extend(Connection._flatten_and_conditions(value)) + return conditions + return [condition] + + @staticmethod + def _condition_key_name(condition: Condition) -> Optional[str]: + path = getattr(condition.values[0], "path", None) if condition.values else None + if not isinstance(path, list) or len(path) != 1: + return None + return path[0] + + @staticmethod + def _combine_conditions(conditions: List[Condition]) -> Condition: + combined_condition = conditions[0] + for condition in conditions[1:]: + combined_condition &= condition + return combined_condition + + @staticmethod + def _normalize_multi_range_key_condition( + range_key_condition: Optional[Condition], + range_keynames: Sequence[str], + index_name: Optional[str] = None, + ) -> Optional[Condition]: + if range_key_condition is None or len(range_keynames) <= 1: + return range_key_condition + + valid_operators = {"=", "<", "<=", ">", ">=", "BETWEEN", "begins_with"} + conditions_by_key: Dict[str, Condition] = {} + context = f"Index {index_name}" + for condition in Connection._flatten_and_conditions(range_key_condition): + if condition.operator not in valid_operators: + raise ValueError( + f"{context} range_key_condition uses unsupported range key operator: {condition.operator}" + ) + key_name = Connection._condition_key_name(condition) + if key_name is None or key_name not in range_keynames: + raise ValueError( + f"{context} range_key_condition must only use range keys: {', '.join(range_keynames)}" + ) + if key_name in conditions_by_key: + raise ValueError( + f"{context} range_key_condition has multiple conditions for range key: {key_name}" + ) + conditions_by_key[key_name] = condition + + if not conditions_by_key: + return range_key_condition + + highest_position = max( + range_keynames.index(key_name) for key_name in conditions_by_key + ) + missing_prefix_keys = [ + key_name + for key_name in range_keynames[:highest_position] + if key_name not in conditions_by_key + ] + if missing_prefix_keys: + raise ValueError( + f"{context} range_key_condition must include equality conditions for preceding range keys: " + f"{', '.join(missing_prefix_keys)}" + ) + + non_equal_prefix_keys = [ + key_name + for key_name in range_keynames[:highest_position] + if conditions_by_key[key_name].operator != "=" + ] + if non_equal_prefix_keys: + raise ValueError( + f"{context} range_key_condition must use equality for preceding range keys: " + f"{', '.join(non_equal_prefix_keys)}" + ) + + ordered_conditions = [ + conditions_by_key[key_name] + for key_name in range_keynames + if key_name in conditions_by_key + ] + return Connection._combine_conditions(ordered_conditions) diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 5e70ba5c..02a4cbfd 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -238,7 +238,7 @@ def scan( def query( self, - hash_key: str, + hash_key: Optional[Any] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Any] = None, attributes_to_get: Optional[Any] = None, @@ -249,6 +249,7 @@ def query( return_consumed_capacity: Optional[str] = None, scan_index_forward: Optional[bool] = None, select: Optional[str] = None, + hash_keys: Optional[Mapping[str, Any]] = None, ) -> Dict: """ Performs the Query operation and returns the result @@ -266,6 +267,7 @@ def query( return_consumed_capacity=return_consumed_capacity, scan_index_forward=scan_index_forward, select=select, + hash_keys=hash_keys, ) def describe_table(self) -> Dict: @@ -299,7 +301,8 @@ def update_table( self.table_name, read_capacity_units=read_capacity_units, write_capacity_units=write_capacity_units, - global_secondary_index_updates=global_secondary_index_updates) + global_secondary_index_updates=global_secondary_index_updates, + ) def create_table( self, @@ -326,5 +329,5 @@ def create_table( local_secondary_indexes=local_secondary_indexes, stream_specification=stream_specification, billing_mode=billing_mode, - tags=tags + tags=tags, ) diff --git a/pynamodb/indexes.py b/pynamodb/indexes.py index e282b837..04c664a0 100644 --- a/pynamodb/indexes.py +++ b/pynamodb/indexes.py @@ -2,7 +2,7 @@ PynamoDB Indexes """ from inspect import getmembers -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Mapping, Optional, Type, TypeVar, Union from typing import TYPE_CHECKING from pynamodb._schema import IndexSchema, GlobalSecondaryIndexSchema @@ -20,6 +20,8 @@ from pynamodb.models import Model _KeyType = Any +_HashKeysInputType = Mapping[str, _KeyType] +_SerializedHashKeyType = Union[_KeyType, Dict[str, _KeyType]] _M = TypeVar('_M', bound='Model') @@ -30,13 +32,28 @@ class Index(Generic[_M]): Meta: Any = None _model: _M + @staticmethod + def _get_attributes_in_declaration_order( + index_cls: Type['Index'], + ) -> Dict[str, Attribute]: + """ + Returns attributes in declaration order, respecting overrides. + """ + attributes: Dict[str, Attribute] = {} + for base in reversed(index_cls.__mro__): + for name, attribute in getattr(base, '__dict__', {}).items(): + if name in attributes: + del attributes[name] + if isinstance(attribute, Attribute): + # If a subclass overrides an attribute, preserve the subclass declaration order. + attributes[name] = attribute + return attributes + @classmethod def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls.Meta is not None: - cls.Meta.attributes = {} - for name, attribute in getmembers(cls, lambda o: isinstance(o, Attribute)): - cls.Meta.attributes[name] = attribute + cls.Meta.attributes = cls._get_attributes_in_declaration_order(cls) def __init__(self) -> None: if self.Meta is None: @@ -50,12 +67,13 @@ def __set_name__(self, owner: Type[_M], name: str): def count( self, - hash_key: _KeyType, + hash_key: Optional[_KeyType] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, limit: Optional[int] = None, rate_limit: Optional[float] = None, + hash_keys: Optional[_HashKeysInputType] = None, ) -> int: """ Count on an index @@ -68,11 +86,12 @@ def count( consistent_read=consistent_read, limit=limit, rate_limit=rate_limit, + hash_keys=hash_keys, ) def query( self, - hash_key: _KeyType, + hash_key: Optional[_KeyType] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -82,6 +101,7 @@ def query( attributes_to_get: Optional[List[str]] = None, page_size: Optional[int] = None, rate_limit: Optional[float] = None, + hash_keys: Optional[_HashKeysInputType] = None, ) -> ResultIterator[_M]: """ Queries an index @@ -98,6 +118,7 @@ def query( attributes_to_get=attributes_to_get, page_size=page_size, rate_limit=rate_limit, + hash_keys=hash_keys, ) def scan( @@ -133,9 +154,224 @@ def _hash_key_attribute(cls): """ Returns the attribute class for the hash key """ - for attr_cls in cls.Meta.attributes.values(): - if attr_cls.is_hash_key: - return attr_cls + hash_key_attributes = cls._hash_key_attributes() + if hash_key_attributes: + return hash_key_attributes[0] + + @classmethod + def _hash_key_attributes(cls) -> List[Attribute]: + return [attr for attr in cls.Meta.attributes.values() if attr.is_hash_key] + + @classmethod + def _range_key_attributes(cls) -> List[Attribute]: + return [attr for attr in cls.Meta.attributes.values() if attr.is_range_key] + + @classmethod + def _hash_key_aliases( + cls, hash_key_attributes: List[Attribute] + ) -> Dict[str, Attribute]: + aliases: Dict[str, Attribute] = {} + hash_key_attribute_ids = {id(attr) for attr in hash_key_attributes} + for attr_name, attr in cls.Meta.attributes.items(): + if id(attr) in hash_key_attribute_ids: + aliases[attr_name] = attr + aliases[attr.attr_name] = attr + return aliases + + @staticmethod + def _flatten_and_conditions(condition: Condition) -> List[Condition]: + if condition.operator == "AND": + conditions: List[Condition] = [] + for value in condition.values: + conditions.extend(Index._flatten_and_conditions(value)) + return conditions + return [condition] + + @staticmethod + def _condition_key_name(condition: Condition) -> Optional[str]: + path = getattr(condition.values[0], "path", None) if condition.values else None + if not isinstance(path, list) or len(path) != 1: + return None + return path[0] + + @staticmethod + def _combine_conditions(conditions: List[Condition]) -> Condition: + combined_condition = conditions[0] + for condition in conditions[1:]: + combined_condition &= condition + return combined_condition + + @staticmethod + def _normalize_multi_key_condition( + range_key_condition: Condition, + range_keynames: List[str], + context: str, + ) -> Condition: + valid_operators = {"=", "<", "<=", ">", ">=", "BETWEEN", "begins_with"} + conditions_by_key: Dict[str, Condition] = {} + for condition in Index._flatten_and_conditions(range_key_condition): + if condition.operator not in valid_operators: + raise ValueError( + f"{context} range_key_condition uses unsupported range key operator: {condition.operator}" + ) + key_name = Index._condition_key_name(condition) + if key_name is None or key_name not in range_keynames: + raise ValueError(f"{context} range_key_condition must only use range keys: " + ", ".join(range_keynames)) + if key_name in conditions_by_key: + raise ValueError(f"{context} range_key_condition has multiple conditions for range key: {key_name}") + conditions_by_key[key_name] = condition + + if not conditions_by_key: + return range_key_condition + + highest_position = max( + range_keynames.index(key_name) for key_name in conditions_by_key + ) + missing_prefix_keys = [ + key_name + for key_name in range_keynames[:highest_position] + if key_name not in conditions_by_key + ] + if missing_prefix_keys: + raise ValueError( + f"{context} range_key_condition must include equality conditions for preceding range keys: " + + ", ".join(missing_prefix_keys) + ) + + non_equal_prefix_keys = [ + key_name + for key_name in range_keynames[:highest_position] + if conditions_by_key[key_name].operator != "=" + ] + if non_equal_prefix_keys: + raise ValueError( + f"{context} range_key_condition must use equality for preceding range keys: " + + ", ".join(non_equal_prefix_keys) + ) + + ordered_conditions = [ + conditions_by_key[key_name] + for key_name in range_keynames + if key_name in conditions_by_key + ] + return Index._combine_conditions(ordered_conditions) + + @classmethod + def _serialize_hash_key_values( + cls, + hash_key: Optional[_KeyType] = None, + hash_keys: Optional[_HashKeysInputType] = None, + ) -> _SerializedHashKeyType: + hash_key_attributes = cls._hash_key_attributes() + if not hash_key_attributes: + raise ValueError(f"{cls.__name__} has no hash key attributes") + + if hash_key is not None and hash_keys is not None: + raise ValueError(f"{cls.__name__} received both hash_key and hash_keys") + + if len(hash_key_attributes) == 1: + if hash_keys is None: + if hash_key is None: + raise ValueError(f"{cls.__name__} requires a hash_key") + if isinstance(hash_key, (tuple, list)): + raise ValueError(f"{cls.__name__} expects a single hash_key value") + if isinstance(hash_key, Mapping): + raise ValueError(f"{cls.__name__} expects hash_keys=... for named hash key values") + return hash_key_attributes[0].serialize(hash_key) + + hash_key_values = cls._get_ordered_hash_key_values( + hash_keys, hash_key_attributes + ) + return hash_key_attributes[0].serialize(hash_key_values[0]) + + if hash_keys is None: + if hash_key is None: + raise ValueError(f"{cls.__name__} requires hash_keys") + raise ValueError(f"{cls.__name__} has multiple hash key attributes; use hash_keys=...") + + hash_key_values = cls._get_ordered_hash_key_values( + hash_keys, hash_key_attributes + ) + return { + attr.attr_name: attr.serialize(value) + for attr, value in zip(hash_key_attributes, hash_key_values) + } + + @classmethod + def serialize_hash_key_values( + cls, + hash_key: Optional[_KeyType] = None, + hash_keys: Optional[_HashKeysInputType] = None, + ) -> _SerializedHashKeyType: + return cls._serialize_hash_key_values(hash_key, hash_keys=hash_keys) + + @classmethod + def _get_ordered_hash_key_values( + cls, + hash_keys: _HashKeysInputType, + hash_key_attributes: List[Attribute], + ) -> List[_KeyType]: + if not isinstance(hash_keys, Mapping): + raise ValueError(f"{cls.__name__} expects hash_keys to be a mapping") + + expected_aliases = cls._hash_key_aliases(hash_key_attributes) + + values_by_attr_name: Dict[str, _KeyType] = {} + unknown_keys = [] + for key, value in hash_keys.items(): + key_name = key + attr = expected_aliases.get(key_name) + if attr is None: + unknown_keys.append(str(key_name)) + continue + if attr.attr_name in values_by_attr_name: + raise ValueError(f"{cls.__name__} received duplicate value for hash key: {attr.attr_name}") + values_by_attr_name[attr.attr_name] = value + + if unknown_keys: + raise ValueError(f"{cls.__name__} received unknown hash keys: " + ", ".join(unknown_keys)) + + missing_keys = [ + attr.attr_name + for attr in hash_key_attributes + if attr.attr_name not in values_by_attr_name + ] + if missing_keys: + raise ValueError(f"{cls.__name__} requires values for hash keys: " + ", ".join(missing_keys)) + + return [values_by_attr_name[attr.attr_name] for attr in hash_key_attributes] + + @classmethod + def _normalize_range_key_condition( + cls, range_key_condition: Optional[Condition] + ) -> Optional[Condition]: + range_key_attributes = cls._range_key_attributes() + if range_key_condition is None or len(range_key_attributes) <= 1: + return range_key_condition + return cls._normalize_multi_key_condition( + range_key_condition, + [attr.attr_name for attr in range_key_attributes], + cls.__name__, + ) + + @classmethod + def _validate_range_key_condition( + cls, range_key_condition: Optional[Condition] + ) -> None: + cls._normalize_range_key_condition(range_key_condition) + + @classmethod + def validate_range_key_condition( + cls, range_key_condition: Optional[Condition] + ) -> None: + cls._validate_range_key_condition(range_key_condition) + + @classmethod + def _validate_key_attributes(cls) -> None: + """ + Hook for subclasses to validate key constraints. + """ + return None def _update_model_schema(self, schema: ModelSchema) -> None: raise NotImplementedError @@ -154,16 +390,31 @@ def _get_schema(cls) -> IndexSchema: 'attribute_definitions': [], } - for attr_cls in cls.Meta.attributes.values(): - if attr_cls.is_hash_key or attr_cls.is_range_key: - schema['attribute_definitions'].append({ - ATTR_NAME: attr_cls.attr_name, - ATTR_TYPE: attr_cls.attr_type, - }) - schema['key_schema'].append({ - ATTR_NAME: attr_cls.attr_name, - KEY_TYPE: HASH if attr_cls.is_hash_key else RANGE, - }) + cls._validate_key_attributes() + + hash_key_attributes = cls._hash_key_attributes() + range_key_attributes = cls._range_key_attributes() + + for attr_cls in range_key_attributes: + schema['attribute_definitions'].append({ + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type, + }) + for attr_cls in hash_key_attributes: + schema['attribute_definitions'].append({ + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type, + }) + for attr_cls in hash_key_attributes: + schema['key_schema'].append({ + ATTR_NAME: attr_cls.attr_name, + KEY_TYPE: HASH, + }) + for attr_cls in range_key_attributes: + schema['key_schema'].append({ + ATTR_NAME: attr_cls.attr_name, + KEY_TYPE: RANGE, + }) if cls.Meta.projection.non_key_attributes: schema['projection'][NON_KEY_ATTRIBUTES] = cls.Meta.projection.non_key_attributes return schema @@ -173,6 +424,15 @@ class GlobalSecondaryIndex(Index[_M]): """ A global secondary index """ + @classmethod + def _validate_key_attributes(cls) -> None: + hash_keys = cls._hash_key_attributes() + range_keys = cls._range_key_attributes() + if len(hash_keys) > 4: + raise ValueError(f"{cls.__name__} supports at most 4 hash key attributes") + if len(range_keys) > 4: + raise ValueError(f"{cls.__name__} supports at most 4 range key attributes") + @classmethod def _update_model_schema(cls, schema: ModelSchema) -> None: index_schema: GlobalSecondaryIndexSchema = { @@ -198,6 +458,15 @@ class LocalSecondaryIndex(Index[_M]): A local secondary index """ + @classmethod + def _validate_key_attributes(cls) -> None: + hash_keys = cls._hash_key_attributes() + range_keys = cls._range_key_attributes() + if len(hash_keys) > 1: + raise ValueError(f"{cls.__name__} supports at most one hash key attribute") + if len(range_keys) > 1: + raise ValueError(f"{cls.__name__} supports at most one range key attribute") + @classmethod def _update_model_schema(cls, schema: ModelSchema) -> None: index_schema = cls._get_schema() diff --git a/pynamodb/models.py b/pynamodb/models.py index 8e14918e..47561649 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -1,7 +1,6 @@ """ DynamoDB Models for PynamoDB """ -import random import time import logging import warnings @@ -59,8 +58,8 @@ ) _T = TypeVar('_T', bound='Model') -_KeyType = Any - +_KeyType = object +_HashKeyQueryType = Union[_KeyType, Tuple[_KeyType, ...], List[_KeyType]] log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) @@ -353,7 +352,7 @@ def batch_get( raise ValueError(f'Invalid key value {item!r}: ' 'expected non-str iterable with exactly 2 elements (hash key, range key)') try: - hash_key, range_key = item + hash_key, range_key = cast(Tuple[_KeyType, _KeyType], item) except (TypeError, ValueError): raise ValueError(f'Invalid key value {item!r}: ' 'expected iterable with exactly 2 elements (hash key, range key)') @@ -511,14 +510,14 @@ def get_save_kwargs_from_instance( @classmethod def get_operation_kwargs_from_class( cls, - hash_key: str, + hash_key: _KeyType, range_key: Optional[_KeyType] = None, condition: Optional[Condition] = None, ) -> Dict[str, Any]: hash_key, range_key = cls._serialize_keys(hash_key, range_key) return cls._get_connection().get_operation_kwargs( - hash_key=hash_key, - range_key=range_key, + hash_key=cast(str, hash_key), + range_key=cast(Optional[str], range_key), condition=condition ) @@ -542,8 +541,8 @@ def get( hash_key, range_key = cls._serialize_keys(hash_key, range_key) data = cls._get_connection().get_item( - hash_key, - range_key=range_key, + cast(str, hash_key), + range_key=cast(Optional[str], range_key), consistent_read=consistent_read, attributes_to_get=attributes_to_get, ) @@ -576,25 +575,39 @@ def count( index_name: Optional[str] = None, limit: Optional[int] = None, rate_limit: Optional[float] = None, + hash_keys: Optional[Mapping[str, _KeyType]] = None, ) -> int: """ Provides a filtered count :param hash_key: The hash key to query. Can be None. + :param hash_keys: Named hash key values for indexes with multiple hash key attributes. :param range_key_condition: Condition for range key :param filter_condition: Condition used to restrict the query results :param consistent_read: If True, a consistent read is performed :param index_name: If set, then this index is used :param rate_limit: If set then consumed capacity will be limited to this amount per second """ - if hash_key is None: + if hash_key is None and hash_keys is None: + if index_name: + raise ValueError('A hash_key or hash_keys must be given to query an index') if filter_condition is not None: raise ValueError('A hash_key must be given to use filters') return cls.describe_table().get(ITEM_COUNT) + serialized_hash_keys = None if index_name: - hash_key = cls._indexes[index_name]._hash_key_attribute().serialize(hash_key) + index = cls._indexes[index_name] + range_key_condition = index._normalize_range_key_condition(range_key_condition) + serialized_hash_key = index._serialize_hash_key_values(hash_key, hash_keys=hash_keys) + if isinstance(serialized_hash_key, dict): + serialized_hash_keys = serialized_hash_key + hash_key = None + else: + hash_key = serialized_hash_key else: + if hash_keys is not None: + raise ValueError('hash_keys can only be used with an index') hash_key = cls._serialize_keys(hash_key)[0] # If this class has a discriminator attribute, filter the query to only return instances of this class. @@ -609,7 +622,8 @@ def count( index_name=index_name, consistent_read=consistent_read, limit=limit, - select=COUNT + select=COUNT, + hash_keys=serialized_hash_keys, ) result_iterator: ResultIterator[_T] = ResultIterator( @@ -628,7 +642,7 @@ def count( @classmethod def query( cls: Type[_T], - hash_key: _KeyType, + hash_key: Optional[_KeyType] = None, range_key_condition: Optional[Condition] = None, filter_condition: Optional[Condition] = None, consistent_read: bool = False, @@ -639,11 +653,13 @@ def query( attributes_to_get: Optional[Iterable[str]] = None, page_size: Optional[int] = None, rate_limit: Optional[float] = None, + hash_keys: Optional[Mapping[str, _KeyType]] = None, ) -> ResultIterator[_T]: """ Provides a high level query API - :param hash_key: The hash key to query + :param hash_key: The hash key to query. + :param hash_keys: Named hash key values for indexes with multiple hash key attributes. :param range_key_condition: Condition for range key :param filter_condition: Condition used to restrict the query results :param consistent_read: If True, a consistent read is performed @@ -656,9 +672,22 @@ def query( :param page_size: Page size of the query to DynamoDB :param rate_limit: If set then consumed capacity will be limited to this amount per second """ + if hash_key is None and hash_keys is None: + raise ValueError('A hash_key or hash_keys must be given to query') + + serialized_hash_keys = None if index_name: - hash_key = cls._indexes[index_name]._hash_key_attribute().serialize(hash_key) + index = cls._indexes[index_name] + range_key_condition = index._normalize_range_key_condition(range_key_condition) + serialized_hash_key = index._serialize_hash_key_values(hash_key, hash_keys=hash_keys) + if isinstance(serialized_hash_key, dict): + serialized_hash_keys = serialized_hash_key + hash_key = None + else: + hash_key = serialized_hash_key else: + if hash_keys is not None: + raise ValueError('hash_keys can only be used with an index') hash_key = cls._serialize_keys(hash_key)[0] # If this class has a discriminator attribute, filter the query to only return instances of this class. @@ -679,6 +708,7 @@ def query( scan_index_forward=scan_index_forward, limit=page_size, attributes_to_get=attributes_to_get, + hash_keys=serialized_hash_keys, ) return ResultIterator( @@ -1028,8 +1058,10 @@ def _batch_get_page(cls, keys_to_get, consistent_read, attributes_to_get): data = cls._get_connection().batch_get_item( keys_to_get, consistent_read=consistent_read, attributes_to_get=attributes_to_get, ) - item_data = data.get(RESPONSES).get(cls.Meta.table_name) # type: ignore - unprocessed_items = data.get(UNPROCESSED_KEYS).get(cls.Meta.table_name, {}).get(KEYS, None) # type: ignore + responses = cast(Dict[str, Any], data.get(RESPONSES, {})) + item_data = responses.get(cls.Meta.table_name) + unprocessed_keys = cast(Dict[str, Any], data.get(UNPROCESSED_KEYS, {})) + unprocessed_items = unprocessed_keys.get(cls.Meta.table_name, {}).get(KEYS, None) return item_data, unprocessed_items @classmethod @@ -1108,7 +1140,7 @@ def _serialize_value(cls, attr, value): return {attr.attr_type: serialized} @classmethod - def _serialize_keys(cls, hash_key, range_key=None) -> Tuple[_KeyType, _KeyType]: + def _serialize_keys(cls, hash_key: _KeyType, range_key: Optional[_KeyType] = None) -> Tuple[Any, Any]: """ Serializes the hash and range keys diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index c90acd27..87f244ce 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -49,6 +49,50 @@ def test_meta_table_get_key_names__index(meta_table): assert key_names == ["ForumName", "Subject", "LastPostDateTime"] +def test_meta_table_get_key_names__composite_index(): + composite_table_data = { + "TableName": "Thread", + "AttributeDefinitions": [ + {"AttributeName": "ForumName", "AttributeType": "S"}, + {"AttributeName": "Subject", "AttributeType": "S"}, + {"AttributeName": "a_partition", "AttributeType": "S"}, + {"AttributeName": "z_partition", "AttributeType": "S"}, + {"AttributeName": "b_sort", "AttributeType": "S"}, + {"AttributeName": "c_sort", "AttributeType": "S"}, + ], + "KeySchema": [ + {"AttributeName": "ForumName", "KeyType": "HASH"}, + {"AttributeName": "Subject", "KeyType": "RANGE"}, + ], + "GlobalSecondaryIndexes": [ + { + "IndexName": "CompositeIndex", + "KeySchema": [ + {"AttributeName": "z_partition", "KeyType": "HASH"}, + {"AttributeName": "a_partition", "KeyType": "HASH"}, + {"AttributeName": "c_sort", "KeyType": "RANGE"}, + {"AttributeName": "b_sort", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "KEYS_ONLY"}, + } + ], + } + meta_table = MetaTable(composite_table_data) + + assert meta_table.get_index_hash_keynames("CompositeIndex") == ["z_partition", "a_partition"] + assert meta_table.get_index_hash_keyname("CompositeIndex") == "z_partition" + assert meta_table.get_index_range_keynames("CompositeIndex") == ["c_sort", "b_sort"] + assert meta_table.get_index_range_keyname("CompositeIndex") == "c_sort" + assert meta_table.get_key_names("CompositeIndex") == [ + "ForumName", + "Subject", + "z_partition", + "a_partition", + "c_sort", + "b_sort", + ] + + def test_meta_table_get_attribute_type(meta_table): assert meta_table.get_attribute_type('ForumName') == 'S' with pytest.raises(ValueError): @@ -219,6 +263,35 @@ def test_connection_create_table(): # Ensure that the hash key is first when creating indexes assert req.call_args[0][1]['GlobalSecondaryIndexes'][0]['KeySchema'][0]['KeyType'] == 'HASH' assert req.call_args[0][1] == params + + kwargs['global_secondary_indexes'] = [ + { + 'index_name': 'composite-index', + 'key_schema': [ + {'AttributeName': 'z_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'a_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'c_sort', 'KeyType': 'RANGE'}, + {'AttributeName': 'b_sort', 'KeyType': 'RANGE'}, + ], + 'projection': { + 'ProjectionType': 'KEYS_ONLY' + }, + 'provisioned_throughput': { + 'ReadCapacityUnits': 1, + 'WriteCapacityUnits': 1, + }, + } + ] + with patch(PATCH_METHOD) as req: + req.return_value = None + conn.create_table(TEST_TABLE_NAME, **kwargs) + assert req.call_args[0][1]['GlobalSecondaryIndexes'][0]['KeySchema'] == [ + {'AttributeName': 'z_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'a_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'c_sort', 'KeyType': 'RANGE'}, + {'AttributeName': 'b_sort', 'KeyType': 'RANGE'}, + ] + del(kwargs['global_secondary_indexes']) del(params['GlobalSecondaryIndexes']) @@ -1142,6 +1215,156 @@ def test_connection_query(): } assert req.call_args[0][1] == params + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + table_name, + {'S': 'FooForum'}, + Path('Subject').startswith('thread'), + ) + params = { + 'ReturnConsumedCapacity': 'TOTAL', + 'KeyConditionExpression': '(#0 = :0 AND begins_with (#1, :1))', + 'ExpressionAttributeNames': { + '#0': 'ForumName', + '#1': 'Subject' + }, + 'ExpressionAttributeValues': { + ':0': { + 'S': 'FooForum' + }, + ':1': { + 'S': 'thread' + } + }, + 'TableName': 'Thread' + } + assert req.call_args[0][1] == params + + composite_table_name = "ThreadComposite" + composite_table_data = { + "TableName": composite_table_name, + "AttributeDefinitions": [ + {"AttributeName": "ForumName", "AttributeType": "S"}, + {"AttributeName": "Subject", "AttributeType": "S"}, + {"AttributeName": "z_partition", "AttributeType": "S"}, + {"AttributeName": "a_partition", "AttributeType": "S"}, + {"AttributeName": "c_sort", "AttributeType": "S"}, + {"AttributeName": "b_sort", "AttributeType": "S"}, + ], + "KeySchema": [ + {"AttributeName": "ForumName", "KeyType": "HASH"}, + {"AttributeName": "Subject", "KeyType": "RANGE"}, + ], + "GlobalSecondaryIndexes": [ + { + "IndexName": "CompositeIndex", + "KeySchema": [ + {"AttributeName": "z_partition", "KeyType": "HASH"}, + {"AttributeName": "a_partition", "KeyType": "HASH"}, + {"AttributeName": "c_sort", "KeyType": "RANGE"}, + {"AttributeName": "b_sort", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "KEYS_ONLY"}, + } + ], + } + conn.add_meta_table(MetaTable(composite_table_data)) + + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + composite_table_name, + index_name='CompositeIndex', + hash_keys={'z_partition': 'z1', 'a_partition': 'a1'}, + ) + params = { + 'ReturnConsumedCapacity': 'TOTAL', + 'IndexName': 'CompositeIndex', + 'KeyConditionExpression': '(#0 = :0 AND #1 = :1)', + 'ExpressionAttributeNames': { + '#0': 'z_partition', + '#1': 'a_partition' + }, + 'ExpressionAttributeValues': { + ':0': {'S': 'z1'}, + ':1': {'S': 'a1'} + }, + 'TableName': composite_table_name + } + assert req.call_args[0][1] == params + + with pytest.raises(ValueError, match="multiple hash key attributes; use hash_keys"): + conn.query(composite_table_name, "z1", index_name='CompositeIndex') + + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + composite_table_name, + index_name='CompositeIndex', + hash_keys={'a_partition': 'a1', 'z_partition': 'z1'}, + ) + params = { + 'ReturnConsumedCapacity': 'TOTAL', + 'IndexName': 'CompositeIndex', + 'KeyConditionExpression': '(#0 = :0 AND #1 = :1)', + 'ExpressionAttributeNames': { + '#0': 'z_partition', + '#1': 'a_partition' + }, + 'ExpressionAttributeValues': { + ':0': {'S': 'z1'}, + ':1': {'S': 'a1'} + }, + 'TableName': composite_table_name + } + assert req.call_args[0][1] == params + + with pytest.raises(ValueError, match="requires values for hash keys: a_partition"): + conn.query( + composite_table_name, + index_name='CompositeIndex', + hash_keys={'z_partition': 'z1'}, + ) + + with pytest.raises(ValueError, match="received unknown hash keys: unknown"): + conn.query( + composite_table_name, + index_name='CompositeIndex', + hash_keys={'z_partition': 'z1', 'a_partition': 'a1', 'unknown': 'u1'}, + ) + + with patch(PATCH_METHOD) as req: + req.return_value = {} + conn.query( + composite_table_name, + index_name='CompositeIndex', + hash_keys={'z_partition': 'z1', 'a_partition': 'a1'}, + range_key_condition=( + (Path('b_sort') == 'b1') & + (Path('c_sort') == 'c1') + ), + ) + params = { + 'ReturnConsumedCapacity': 'TOTAL', + 'IndexName': 'CompositeIndex', + 'KeyConditionExpression': '((#0 = :0 AND #1 = :1) AND (#2 = :2 AND #3 = :3))', + 'ExpressionAttributeNames': { + '#0': 'z_partition', + '#1': 'a_partition', + '#2': 'c_sort', + '#3': 'b_sort' + }, + 'ExpressionAttributeValues': { + ':0': {'S': 'z1'}, + ':1': {'S': 'a1'}, + ':2': {'S': 'c1'}, + ':3': {'S': 'b1'} + }, + 'TableName': composite_table_name + } + assert req.call_args[0][1] == params + with patch(PATCH_METHOD) as req: req.return_value = {} conn.query( diff --git a/tests/test_model.py b/tests/test_model.py index da54303b..25ecc724 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -169,6 +169,72 @@ class Meta: icons = BinarySetAttribute(legacy_encoding=False) +class CompositeOrderIndex(GlobalSecondaryIndex): + class Meta: + index_name = 'composite_order_idx' + read_capacity_units = 2 + write_capacity_units = 1 + projection = AllProjection() + + z_partition = UnicodeAttribute(hash_key=True) + a_partition = UnicodeAttribute(hash_key=True) + c_sort = UnicodeAttribute(range_key=True) + b_sort = UnicodeAttribute(range_key=True) + + +class CompositeIndexedModel(Model): + class Meta: + table_name = 'CompositeIndexedModel' + + item_id = UnicodeAttribute(hash_key=True) + z_partition = UnicodeAttribute() + a_partition = UnicodeAttribute() + c_sort = UnicodeAttribute() + b_sort = UnicodeAttribute() + composite_index = CompositeOrderIndex() + + +class KeylessScanIndex(GlobalSecondaryIndex): + class Meta: + index_name = 'keyless_scan_idx' + read_capacity_units = 2 + write_capacity_units = 1 + projection = AllProjection() + + +class KeylessScanIndexedModel(Model): + class Meta: + table_name = 'KeylessScanIndexedModel' + + item_id = UnicodeAttribute(hash_key=True) + keyless_scan_index = KeylessScanIndex() + + +class ThreeRangeKeyIndex(GlobalSecondaryIndex): + class Meta: + index_name = 'three_range_idx' + read_capacity_units = 2 + write_capacity_units = 1 + projection = AllProjection() + + h = UnicodeAttribute(hash_key=True) + r1 = UnicodeAttribute(range_key=True) + r2 = UnicodeAttribute(range_key=True) + r3 = UnicodeAttribute(range_key=True) + + +class ThreeRangeKeyModel(Model): + class Meta: + table_name = 'ThreeRangeKeyModel' + + item_id = UnicodeAttribute(hash_key=True) + h = UnicodeAttribute() + r1 = UnicodeAttribute() + r2 = UnicodeAttribute() + r3 = UnicodeAttribute() + three_range_index = ThreeRangeKeyIndex() + + class SimpleUserModel(Model): """ A hash key only model @@ -1083,6 +1149,24 @@ def test_index_count(self): } deep_eq(args, params, _assert=True) + def test_index_count_composite_hash_key(self): + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 1, 'ScannedCount': 1} + res = CompositeIndexedModel.composite_index.count( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'} + ) + self.assertEqual(res, 1) + params = req.call_args[0][1] + self.assertEqual(params['KeyConditionExpression'], '(#0 = :0 AND #1 = :1)') + self.assertEqual(params['ExpressionAttributeNames'], { + '#0': 'z_partition', + '#1': 'a_partition', + }) + self.assertEqual(params['ExpressionAttributeValues'], { + ':0': {'S': 'p1'}, + ':1': {'S': 'p2'}, + }) + def test_index_multipage_count(self): with patch(PATCH_METHOD) as req: last_evaluated_key = { @@ -2301,6 +2385,245 @@ def fake_dynamodb(*args, **kwargs): } ) + def test_global_index_composite_keys(self): + scope_args = {'count': 0} + + def fake_dynamodb(*args, **kwargs): + if scope_args['count'] == 0: + scope_args['count'] += 1 + raise ClientError({'Error': {'Code': 'ResourceNotFoundException', 'Message': 'Not Found'}}, + "DescribeTable") + return {} + + fake_db = MagicMock() + fake_db.side_effect = fake_dynamodb + + with patch(PATCH_METHOD, new=fake_db) as req: + CompositeIndexedModel.create_table(read_capacity_units=2, write_capacity_units=2) + args = req.call_args[0][1] + self.assertEqual( + args['GlobalSecondaryIndexes'][0]['KeySchema'], + [ + {'AttributeName': 'z_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'a_partition', 'KeyType': 'HASH'}, + {'AttributeName': 'c_sort', 'KeyType': 'RANGE'}, + {'AttributeName': 'b_sort', 'KeyType': 'RANGE'}, + ] + ) + + def test_global_index_composite_query(self): + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=CompositeIndexedModel.c_sort == 's1', + )) + + params = req.call_args[0][1] + self.assertEqual(params['IndexName'], 'composite_order_idx') + self.assertEqual(params['TableName'], 'CompositeIndexedModel') + self.assertEqual(params['ExpressionAttributeValues'], { + ':0': {'S': 'p1'}, + ':1': {'S': 'p2'}, + ':2': {'S': 's1'}, + }) + self.assertEqual(params['ExpressionAttributeNames'], { + '#0': 'z_partition', + '#1': 'a_partition', + '#2': 'c_sort', + }) + self.assertEqual(params['KeyConditionExpression'], '((#0 = :0 AND #1 = :1) AND #2 = :2)') + + def test_global_index_composite_query_unordered_hash_keys(self): + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(CompositeIndexedModel.composite_index.query( + hash_keys={'a_partition': 'p2', 'z_partition': 'p1'} + )) + + params = req.call_args[0][1] + self.assertEqual(params['ExpressionAttributeValues'], { + ':0': {'S': 'p1'}, + ':1': {'S': 'p2'}, + }) + self.assertEqual(params['KeyConditionExpression'], '(#0 = :0 AND #1 = :1)') + + def test_global_index_composite_query_hash_key_validation(self): + with pytest.raises(ValueError, match='multiple hash key attributes; use hash_keys'): + CompositeIndexedModel.composite_index.query('p1') + + with pytest.raises(ValueError, match='multiple hash key attributes; use hash_keys'): + CompositeIndexedModel.composite_index.query(('p1',)) + + with pytest.raises(ValueError, match='requires values for hash keys: a_partition'): + CompositeIndexedModel.composite_index.query(hash_keys={'z_partition': 'p1'}) + + with pytest.raises(ValueError, match='received unknown hash keys: unknown'): + CompositeIndexedModel.composite_index.query(hash_keys={ + 'z_partition': 'p1', + 'a_partition': 'p2', + 'unknown': 'p3', + }) + + def test_global_index_public_key_validation_helpers(self): + CompositeIndexedModel.composite_index.validate_range_key_condition( + CompositeIndexedModel.c_sort == 's1' + ) + with pytest.raises(ValueError, match='preceding range keys: c_sort'): + CompositeIndexedModel.composite_index.validate_range_key_condition( + CompositeIndexedModel.b_sort == 's2' + ) + + self.assertEqual( + CompositeIndexedModel.composite_index.serialize_hash_key_values( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'} + ), + {'z_partition': 'p1', 'a_partition': 'p2'}, + ) + + def test_global_index_scan_allows_keyless_index_definition(self): + """ + Index scans do not need key schema metadata, so scan-only index definitions + should remain backwards-compatible. + """ + with patch(PATCH_METHOD) as req: + req.return_value = { + 'Count': 1, + 'ScannedCount': 1, + 'Items': [{'item_id': {STRING: 'item-1'}}], + } + + items = list(KeylessScanIndexedModel.keyless_scan_index.scan()) + + self.assertEqual([item.item_id for item in items], ['item-1']) + self.assertEqual(req.call_args[0][1]['IndexName'], 'keyless_scan_idx') + + def test_global_index_query_rejects_keyless_index_definition(self): + with pytest.raises(ValueError, match='has no hash key attributes'): + KeylessScanIndexedModel.keyless_scan_index.query('item-1') + + def test_global_index_composite_query_range_key_validation(self): + with pytest.raises(ValueError, match='preceding range keys: c_sort'): + CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=CompositeIndexedModel.b_sort == 's1', + ) + + with pytest.raises(ValueError, match='use equality for preceding range keys: c_sort'): + CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=( + (CompositeIndexedModel.c_sort > 's1') & + (CompositeIndexedModel.b_sort == 's2') + ), + ) + + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=( + (CompositeIndexedModel.c_sort == 's1') & + CompositeIndexedModel.b_sort.startswith('s2') + ), + )) + self.assertEqual( + req.call_args[0][1]['KeyConditionExpression'], + '((#0 = :0 AND #1 = :1) AND (#2 = :2 AND begins_with (#3, :3)))' + ) + + def test_global_index_composite_query_normalizes_out_of_order_range_keys(self): + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'p1', 'a_partition': 'p2'}, + range_key_condition=( + (CompositeIndexedModel.b_sort == 's2') & + (CompositeIndexedModel.c_sort == 's1') + ), + )) + self.assertEqual( + req.call_args[0][1]['KeyConditionExpression'], + '((#0 = :0 AND #1 = :1) AND (#2 = :2 AND #3 = :3))' + ) + self.assertEqual(req.call_args[0][1]['ExpressionAttributeNames'], { + '#0': 'z_partition', + '#1': 'a_partition', + '#2': 'c_sort', + '#3': 'b_sort', + }) + self.assertEqual(req.call_args[0][1]['ExpressionAttributeValues'], { + ':0': {'S': 'p1'}, + ':1': {'S': 'p2'}, + ':2': {'S': 's1'}, + ':3': {'S': 's2'}, + }) + + def test_global_index_three_range_key_validation(self): + with pytest.raises(ValueError, match='preceding range keys: r1, r2'): + ThreeRangeKeyModel.three_range_index.query( + 'h1', + range_key_condition=ThreeRangeKeyModel.r3 == 'r3', + ) + + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(ThreeRangeKeyModel.three_range_index.query( + 'h1', + range_key_condition=( + (ThreeRangeKeyModel.r1 == 'r1') & + (ThreeRangeKeyModel.r2 == 'r2') & + (ThreeRangeKeyModel.r3 >= 'r3') + ), + )) + self.assertEqual( + req.call_args[0][1]['KeyConditionExpression'], + '(#0 = :0 AND ((#1 = :1 AND #2 = :2) AND #3 >= :3))' + ) + + def test_global_index_query_hash_only_does_not_error(self): + """ + Non-composite GSI queries should work with hash key only. + """ + with patch(PATCH_METHOD) as req: + req.return_value = {'Count': 0, 'ScannedCount': 0, 'Items': []} + list(IndexedModel.email_index.query('foo')) + + params = req.call_args[0][1] + self.assertEqual(params['IndexName'], 'custom_idx_name') + self.assertEqual(params['TableName'], 'IndexedModel') + self.assertEqual(params['KeyConditionExpression'], '#0 = :0') + self.assertEqual(params['ExpressionAttributeNames'], {'#0': 'email'}) + self.assertEqual(params['ExpressionAttributeValues'], {':0': {'S': 'foo'}}) + + def test_global_index_composite_query_last_evaluated_key(self): + with patch(PATCH_METHOD) as req: + items = [] + for idx in range(30): + items.append({ + 'item_id': {STRING: f'id-{idx}'}, + 'z_partition': {STRING: 'z1'}, + 'a_partition': {STRING: 'a1'}, + 'c_sort': {STRING: f'c-{idx}'}, + 'b_sort': {STRING: f'b-{idx}'}, + }) + + req.return_value = {'Count': len(items), 'ScannedCount': len(items), 'Items': items} + results_iter = CompositeIndexedModel.composite_index.query( + hash_keys={'z_partition': 'z1', 'a_partition': 'a1'}, + limit=25, + ) + results = list(results_iter) + + self.assertEqual(len(results), 25) + self.assertEqual(results_iter.last_evaluated_key, { + 'item_id': items[24]['item_id'], + 'z_partition': items[24]['z_partition'], + 'a_partition': items[24]['a_partition'], + 'c_sort': items[24]['c_sort'], + 'b_sort': items[24]['b_sort'], + }) + def test_local_index(self): """ Models.LocalSecondaryIndex @@ -2357,6 +2680,57 @@ def fake_dynamodb(*args, **kwargs): ) self.assertTrue('ProvisionedThroughput' not in args['LocalSecondaryIndexes'][0]) + def test_local_index_reject_multiple_hash_or_range_keys(self): + class BadLocalHashIndex(LocalSecondaryIndex): + class Meta: + index_name = 'bad_local_hash_idx' + projection = AllProjection() + + h1 = UnicodeAttribute(hash_key=True) + h2 = UnicodeAttribute(hash_key=True) + + class BadLocalRangeIndex(LocalSecondaryIndex): + class Meta: + index_name = 'bad_local_range_idx' + projection = AllProjection() + + h = UnicodeAttribute(hash_key=True) + r1 = UnicodeAttribute(range_key=True) + r2 = UnicodeAttribute(range_key=True) + + with pytest.raises(ValueError, match='at most one hash key'): + BadLocalHashIndex._get_schema() + with pytest.raises(ValueError, match='at most one range key'): + BadLocalRangeIndex._get_schema() + + def test_index_attribute_shadowing_removes_inherited_key(self): + class BaseShadowIndex(GlobalSecondaryIndex): + class Meta: + index_name = 'base_shadow_idx' + projection = AllProjection() + + h = UnicodeAttribute(hash_key=True) + r = UnicodeAttribute(range_key=True) + + class ChildShadowIndex(BaseShadowIndex): + class Meta(BaseShadowIndex.Meta): + index_name = 'child_shadow_idx' + + h = None + h2 = UnicodeAttribute(hash_key=True) + + self.assertEqual( + [attr.attr_name for attr in ChildShadowIndex._hash_key_attributes()], + ['h2'], + ) + self.assertEqual( + ChildShadowIndex._get_schema()['key_schema'], + [ + {'AttributeName': 'h2', 'KeyType': 'HASH'}, + {'AttributeName': 'r', 'KeyType': 'RANGE'}, + ], + ) + def test_projections(self): """ Models.Projection