Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/indexes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,47 @@ range key of the index. Here is an example that queries the index for values of
print("Item queried from index: {0}".format(item.view))


Composite Global Secondary Keys
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

DynamoDB global secondary indexes support multi-attribute partition and sort keys.
To define a composite key in PynamoDB, declare multiple ``hash_key=True`` and/or
``range_key=True`` attributes on a ``GlobalSecondaryIndex`` in the order you want
them used.

.. code-block:: python

class TournamentRegionIndex(GlobalSecondaryIndex):
class Meta:
projection = AllProjection()
read_capacity_units = 2
write_capacity_units = 1

tournament_id = UnicodeAttribute(hash_key=True)
region = UnicodeAttribute(hash_key=True)
round = UnicodeAttribute(range_key=True)
bracket = UnicodeAttribute(range_key=True)

When querying a composite GSI, pass all partition-key values with ``hash_keys``.
PynamoDB validates the supplied names and sends them in the index declaration
order:

.. code-block:: python

for item in MatchModel.tournament_region_index.query(
hash_keys={
'tournament_id': 'WINTER2024',
'region': 'NA-EAST',
}
):
print(item)

For sort-key conditions, DynamoDB enforces left-to-right semantics across declared
sort-key attributes. If a later sort-key attribute is used, all preceding sort-key
attributes must have equality conditions. See the DynamoDB documentation for details:
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/GSI.DesignPattern.MultiAttributeKeys.html


Pagination and last evaluated key
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
214 changes: 198 additions & 16 deletions pynamodb/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from botocore.session import get_session

from pynamodb.connection._botocore_private import BotocoreBaseClientPrivate
from pynamodb._util import bin_decode_attr
from pynamodb.constants import (
RETURN_CONSUMED_CAPACITY_VALUES, RETURN_ITEM_COLL_METRICS_VALUES,
RETURN_ITEM_COLL_METRICS, RETURN_CONSUMED_CAPACITY, RETURN_VALUES_VALUES,
Expand Down Expand Up @@ -118,12 +117,12 @@ def get_key_names(self, index_name=None):
if self.range_keyname:
key_names.append(self.range_keyname)
if index_name is not None:
index_hash_keyname = self.get_index_hash_keyname(index_name)
if index_hash_keyname not in key_names:
key_names.append(index_hash_keyname)
index_range_keyname = self.get_index_range_keyname(index_name)
if index_range_keyname is not None and index_range_keyname not in key_names:
key_names.append(index_range_keyname)
for index_hash_keyname in self.get_index_hash_keynames(index_name):
if index_hash_keyname not in key_names:
key_names.append(index_hash_keyname)
for index_range_keyname in self.get_index_range_keynames(index_name):
if index_range_keyname not in key_names:
key_names.append(index_range_keyname)
return key_names

def has_index_name(self, index_name):
Expand All @@ -139,6 +138,15 @@ def get_index_hash_keyname(self, index_name: str) -> str:
"""
Returns the name of the hash key for a given index
"""
hash_keynames = self.get_index_hash_keynames(index_name)
if hash_keynames:
return hash_keynames[0]
raise ValueError("No hash key attribute for index: {}".format(index_name))

def get_index_hash_keynames(self, index_name: str) -> List[str]:
"""
Returns the names of the hash keys for a given index
"""
global_indexes = self.data.get(GLOBAL_SECONDARY_INDEXES)
local_indexes = self.data.get(LOCAL_SECONDARY_INDEXES)
indexes = []
Expand All @@ -148,14 +156,26 @@ def get_index_hash_keyname(self, index_name: str) -> str:
indexes += global_indexes
for index in indexes:
if index.get(INDEX_NAME) == index_name:
hash_keynames = []
for schema_key in index.get(KEY_SCHEMA):
if schema_key.get(KEY_TYPE) == HASH:
return schema_key.get(ATTR_NAME)
hash_keynames.append(schema_key.get(ATTR_NAME))
if hash_keynames:
return hash_keynames
raise ValueError("No hash key attribute for index: {}".format(index_name))

def get_index_range_keyname(self, index_name):
"""
Returns the name of the hash key for a given index
Returns the name of the range key for a given index
"""
range_keynames = self.get_index_range_keynames(index_name)
if range_keynames:
return range_keynames[0]
return None

def get_index_range_keynames(self, index_name) -> List[str]:
"""
Returns the names of the range keys for a given index
"""
global_indexes = self.data.get(GLOBAL_SECONDARY_INDEXES)
local_indexes = self.data.get(LOCAL_SECONDARY_INDEXES)
Expand All @@ -166,10 +186,12 @@ def get_index_range_keyname(self, index_name):
indexes += global_indexes
for index in indexes:
if index.get(INDEX_NAME) == index_name:
range_keynames = []
for schema_key in index.get(KEY_SCHEMA):
if schema_key.get(KEY_TYPE) == RANGE:
return schema_key.get(ATTR_NAME)
return None
range_keynames.append(schema_key.get(ATTR_NAME))
return range_keynames
return []

def get_item_attribute_map(self, attributes: Dict, item_key=ITEM, pythonic_key: bool = True):
"""
Expand Down Expand Up @@ -1173,7 +1195,7 @@ def scan(
def query(
self,
table_name: str,
hash_key: str,
hash_key: Optional[Any] = None,
range_key_condition: Optional[Condition] = None,
filter_condition: Optional[Any] = None,
attributes_to_get: Optional[Any] = None,
Expand All @@ -1184,6 +1206,7 @@ def query(
return_consumed_capacity: Optional[str] = None,
scan_index_forward: Optional[bool] = None,
select: Optional[str] = None,
hash_keys: Optional[Mapping[str, Any]] = None,
) -> Dict:
"""
Performs the Query operation and returns the result
Expand All @@ -1201,15 +1224,32 @@ def query(
if index_name:
if not tbl.has_index_name(index_name):
raise ValueError("Table {} has no index: {}".format(table_name, index_name))
hash_keyname = tbl.get_index_hash_keyname(index_name)
hash_keynames = tbl.get_index_hash_keynames(index_name)
range_keynames = tbl.get_index_range_keynames(index_name)
else:
hash_keyname = tbl.hash_keyname
hash_keynames = [tbl.hash_keyname]
range_keynames = [tbl.range_keyname] if tbl.range_keyname else []

hash_condition_value = {self.get_attribute_type(table_name, hash_keyname, hash_key): self.parse_attribute(hash_key)}
key_condition = Path([hash_keyname]) == hash_condition_value
hash_key_values = self._get_query_hash_key_values(
hash_key,
hash_keys,
hash_keynames,
index_name=index_name,
)
range_key_condition = self._normalize_multi_range_key_condition(
range_key_condition,
range_keynames,
index_name=index_name,
)
key_condition = None
for hash_keyname, hash_keyvalue in zip(hash_keynames, hash_key_values):
hash_condition_value = {self.get_attribute_type(table_name, hash_keyname, hash_keyvalue): self.parse_attribute(hash_keyvalue)}
hash_condition = Path([hash_keyname]) == hash_condition_value
key_condition = hash_condition if key_condition is None else key_condition & hash_condition
if range_key_condition is not None:
key_condition &= range_key_condition

assert key_condition is not None
operation_kwargs[KEY_CONDITION_EXPRESSION] = key_condition.serialize(
name_placeholders, expression_attribute_values)
if filter_condition is not None:
Expand Down Expand Up @@ -1252,3 +1292,145 @@ def _check_condition(self, name, condition):
@staticmethod
def _reverse_dict(d):
return {v: k for k, v in d.items()}

@staticmethod
def _get_query_hash_key_values(
hash_key: Optional[Any],
hash_keys: Optional[Mapping[str, Any]],
hash_keynames: Sequence[str],
index_name: Optional[str] = None,
) -> List[Any]:
if hash_key is not None and hash_keys is not None:
raise ValueError(f"Index {index_name} received both hash_key and hash_keys")
if len(hash_keynames) == 1:
if hash_keys is None:
if hash_key is None:
raise ValueError(f"Index {index_name} requires a hash_key")
if isinstance(hash_key, (tuple, list)) or (
isinstance(hash_key, Mapping)
and not any(key in ATTRIBUTE_TYPES for key in hash_key)
):
raise ValueError(
f"Index {index_name} expects a single hash_key value"
)
return [hash_key]
return Connection._get_ordered_query_hash_key_values(
hash_keys, hash_keynames, index_name=index_name
)
if hash_key is not None:
raise ValueError(
f"Index {index_name} has multiple hash key attributes; use hash_keys=..."
)
if hash_keys is None:
raise ValueError(f"Index {index_name} requires hash_keys")
return Connection._get_ordered_query_hash_key_values(
hash_keys, hash_keynames, index_name=index_name
)

@staticmethod
def _get_ordered_query_hash_key_values(
hash_keys: Mapping[str, Any],
hash_keynames: Sequence[str],
index_name: Optional[str] = None,
) -> List[Any]:
if not isinstance(hash_keys, Mapping):
raise ValueError(f"Index {index_name} expects hash_keys to be a mapping")
missing_keys = [
keyname for keyname in hash_keynames if keyname not in hash_keys
]
if missing_keys:
raise ValueError(
f"Index {index_name} requires values for hash keys: {', '.join(missing_keys)}"
)
extra_keys = [keyname for keyname in hash_keys if keyname not in hash_keynames]
if extra_keys:
raise ValueError(
f"Index {index_name} received unknown hash keys: {', '.join(extra_keys)}"
)
return [hash_keys[keyname] for keyname in hash_keynames]

@staticmethod
def _flatten_and_conditions(condition: Condition) -> List[Condition]:
if condition.operator == "AND":
conditions = []
for value in condition.values:
conditions.extend(Connection._flatten_and_conditions(value))
return conditions
return [condition]

@staticmethod
def _condition_key_name(condition: Condition) -> Optional[str]:
path = getattr(condition.values[0], "path", None) if condition.values else None
if not isinstance(path, list) or len(path) != 1:
return None
return path[0]

@staticmethod
def _combine_conditions(conditions: List[Condition]) -> Condition:
combined_condition = conditions[0]
for condition in conditions[1:]:
combined_condition &= condition
return combined_condition

@staticmethod
def _normalize_multi_range_key_condition(
range_key_condition: Optional[Condition],
range_keynames: Sequence[str],
index_name: Optional[str] = None,
) -> Optional[Condition]:
if range_key_condition is None or len(range_keynames) <= 1:
return range_key_condition

valid_operators = {"=", "<", "<=", ">", ">=", "BETWEEN", "begins_with"}
conditions_by_key: Dict[str, Condition] = {}
context = f"Index {index_name}"
for condition in Connection._flatten_and_conditions(range_key_condition):
if condition.operator not in valid_operators:
raise ValueError(
f"{context} range_key_condition uses unsupported range key operator: {condition.operator}"
)
key_name = Connection._condition_key_name(condition)
if key_name is None or key_name not in range_keynames:
raise ValueError(
f"{context} range_key_condition must only use range keys: {', '.join(range_keynames)}"
)
if key_name in conditions_by_key:
raise ValueError(
f"{context} range_key_condition has multiple conditions for range key: {key_name}"
)
conditions_by_key[key_name] = condition

if not conditions_by_key:
return range_key_condition

highest_position = max(
range_keynames.index(key_name) for key_name in conditions_by_key
)
missing_prefix_keys = [
key_name
for key_name in range_keynames[:highest_position]
if key_name not in conditions_by_key
]
if missing_prefix_keys:
raise ValueError(
f"{context} range_key_condition must include equality conditions for preceding range keys: "
f"{', '.join(missing_prefix_keys)}"
)

non_equal_prefix_keys = [
key_name
for key_name in range_keynames[:highest_position]
if conditions_by_key[key_name].operator != "="
]
if non_equal_prefix_keys:
raise ValueError(
f"{context} range_key_condition must use equality for preceding range keys: "
f"{', '.join(non_equal_prefix_keys)}"
)

ordered_conditions = [
conditions_by_key[key_name]
for key_name in range_keynames
if key_name in conditions_by_key
]
return Connection._combine_conditions(ordered_conditions)
9 changes: 6 additions & 3 deletions pynamodb/connection/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def scan(

def query(
self,
hash_key: str,
hash_key: Optional[Any] = None,
range_key_condition: Optional[Condition] = None,
filter_condition: Optional[Any] = None,
attributes_to_get: Optional[Any] = None,
Expand All @@ -249,6 +249,7 @@ def query(
return_consumed_capacity: Optional[str] = None,
scan_index_forward: Optional[bool] = None,
select: Optional[str] = None,
hash_keys: Optional[Mapping[str, Any]] = None,
) -> Dict:
"""
Performs the Query operation and returns the result
Expand All @@ -266,6 +267,7 @@ def query(
return_consumed_capacity=return_consumed_capacity,
scan_index_forward=scan_index_forward,
select=select,
hash_keys=hash_keys,
)

def describe_table(self) -> Dict:
Expand Down Expand Up @@ -299,7 +301,8 @@ def update_table(
self.table_name,
read_capacity_units=read_capacity_units,
write_capacity_units=write_capacity_units,
global_secondary_index_updates=global_secondary_index_updates)
global_secondary_index_updates=global_secondary_index_updates,
)

def create_table(
self,
Expand All @@ -326,5 +329,5 @@ def create_table(
local_secondary_indexes=local_secondary_indexes,
stream_specification=stream_specification,
billing_mode=billing_mode,
tags=tags
tags=tags,
)
Loading