Skip to content
3 changes: 3 additions & 0 deletions changelog/1063.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
`RelatedNode`, `RelatedNodeSync`, `RelationshipManager` and `RelationshipManagerSync` are now generic over their peer type, and `infrahubctl protocols` parameterises generated relationships accordingly (e.g. `device: RelatedNode[NetworkDevice]`, `interfaces: RelationshipManager[NetworkInterface]`).

Traversing a relationship via `.peer`, `.peers` or indexing now preserves the peer's type instead of collapsing to the dynamic `InfrahubNode`, so chains such as `device.rack.peer.name.value` type-check under mypy/ty without casts. Existing un-parameterised `RelatedNode` / `RelationshipManager` usage is unaffected — the peer type defaults to `InfrahubNode` / `InfrahubNodeSync`, preserving current behaviour. ([#1063](https://github.com/opsmill/infrahub-sdk-python/issues/1063))
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ fetch(self, timeout: int | None = None) -> None
#### `peer`

```python
peer(self) -> InfrahubNode
peer(self) -> PeerT
```

Return the peer node, or raise ValueError if no identifier is available.

#### `get`

```python
get(self) -> InfrahubNode
get(self) -> PeerT
```

Return the peer node, performing a store lookup if not materialized.
Expand Down Expand Up @@ -135,15 +135,15 @@ fetch(self, timeout: int | None = None) -> None
#### `peer`

```python
peer(self) -> InfrahubNodeSync
peer(self) -> PeerTSync
```

Return the peer node, or raise ValueError if no identifier is available.

#### `get`

```python
get(self) -> InfrahubNodeSync
get(self) -> PeerTSync
```

Return the peer node, performing a store lookup if not materialized.
Expand Down
35 changes: 22 additions & 13 deletions infrahub_sdk/node/related_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Generic, cast

from typing_extensions import TypeVar

from ..exceptions import Error
from ..protocols_base import CoreNodeBase
Expand All @@ -13,6 +15,13 @@
from ..schema import RelationshipSchemaAPI
from .node import InfrahubNode, InfrahubNodeBase, InfrahubNodeSync

# Type of the related peer node. Defaults to ``InfrahubNode``/``InfrahubNodeSync`` so that
# existing un-parameterised ``RelatedNode`` / ``RelatedNodeSync`` usage keeps returning the
# dynamic node, while generated protocols can parameterise it (e.g. ``RelatedNode[CoreDevice]``)
# to preserve the peer type through ``.peer`` / ``.get()``.
PeerT = TypeVar("PeerT", default="InfrahubNode")
PeerTSync = TypeVar("PeerTSync", default="InfrahubNodeSync")


class RelatedNodeBase:
"""Base class for representing a related node in a relationship."""
Expand Down Expand Up @@ -221,7 +230,7 @@ def _generate_query_data(
return data


class RelatedNode(RelatedNodeBase):
class RelatedNode(RelatedNodeBase, Generic[PeerT]):
"""Represents a RelatedNodeBase in an asynchronous context."""

def __init__(
Expand Down Expand Up @@ -254,11 +263,11 @@ async def fetch(self, timeout: int | None = None) -> None:
)

@property
def peer(self) -> InfrahubNode:
def peer(self) -> PeerT:
"""Return the peer node, or raise ValueError if no identifier is available."""
return self.get()

def get(self) -> InfrahubNode:
def get(self) -> PeerT:
"""Return the peer node, performing a store lookup if not materialized.

When resolving via hfid_str the returned node has a non-None id even
Expand All @@ -271,18 +280,18 @@ def get(self) -> InfrahubNode:

"""
if self._peer:
return self._peer # type: ignore[return-value]
return cast("PeerT", self._peer)

if self.id and self.typename:
return self._client.store.get(key=self.id, kind=self.typename, branch=self._branch) # type: ignore[return-value]
return cast("PeerT", self._client.store.get(key=self.id, kind=self.typename, branch=self._branch))

if self.hfid_str:
return self._client.store.get(key=self.hfid_str, branch=self._branch) # type: ignore[return-value]
return cast("PeerT", self._client.store.get(key=self.hfid_str, branch=self._branch))

raise ValueError("Node must have at least one identifier (ID or HFID) to query it.")


class RelatedNodeSync(RelatedNodeBase):
class RelatedNodeSync(RelatedNodeBase, Generic[PeerTSync]):
"""Represents a related node in a synchronous context."""

def __init__(
Expand Down Expand Up @@ -315,11 +324,11 @@ def fetch(self, timeout: int | None = None) -> None:
)

@property
def peer(self) -> InfrahubNodeSync:
def peer(self) -> PeerTSync:
"""Return the peer node, or raise ValueError if no identifier is available."""
return self.get()

def get(self) -> InfrahubNodeSync:
def get(self) -> PeerTSync:
"""Return the peer node, performing a store lookup if not materialized.

When resolving via hfid_str the returned node has a non-None id even
Expand All @@ -332,12 +341,12 @@ def get(self) -> InfrahubNodeSync:

"""
if self._peer:
return self._peer # type: ignore[return-value]
return cast("PeerTSync", self._peer)

if self.id and self.typename:
return self._client.store.get(key=self.id, kind=self.typename, branch=self._branch) # type: ignore[return-value]
return cast("PeerTSync", self._client.store.get(key=self.id, kind=self.typename, branch=self._branch))

if self.hfid_str:
return self._client.store.get(key=self.hfid_str, branch=self._branch) # type: ignore[return-value]
return cast("PeerTSync", self._client.store.get(key=self.hfid_str, branch=self._branch))

raise ValueError("Node must have at least one identifier (ID or HFID) to query it.")
49 changes: 33 additions & 16 deletions infrahub_sdk/node/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Generic, cast

from ..exceptions import (
Error,
Expand All @@ -11,15 +11,15 @@
from ..types import Order
from .constants import PROPERTIES_FLAG, PROPERTIES_OBJECT
from .metadata import NodeMetadata, RelationshipMetadata
from .related_node import RelatedNode, RelatedNodeSync
from .related_node import PeerT, PeerTSync, RelatedNode, RelatedNodeSync

if TYPE_CHECKING:
from ..client import InfrahubClient, InfrahubClientSync
from ..schema import RelationshipSchemaAPI
from .node import InfrahubNode, InfrahubNodeSync


class RelationshipManagerBase:
class RelationshipManagerBase(Generic[PeerT]):
"""Base class for RelationshipManager and RelationshipManagerSync."""

def __init__(self, name: str, branch: str, schema: RelationshipSchemaAPI) -> None:
Expand All @@ -41,7 +41,7 @@ def __init__(self, name: str, branch: str, schema: RelationshipSchemaAPI) -> Non
self._properties_object = PROPERTIES_OBJECT
self._properties = self._properties_flag + self._properties_object

self.peers: list[RelatedNode | RelatedNodeSync] = []
self.peers: list[RelatedNode[PeerT] | RelatedNodeSync[PeerT]] = []

@property
def peer_ids(self) -> list[str]:
Expand Down Expand Up @@ -115,7 +115,7 @@ def _generate_query_data(
return data


class RelationshipManager(RelationshipManagerBase):
class RelationshipManager(RelationshipManagerBase[PeerT]):
"""Manages relationships of a node in an asynchronous context."""

def __init__(
Expand Down Expand Up @@ -155,12 +155,18 @@ def __init__(
if isinstance(data, list):
for item in data:
self.peers.append(
RelatedNode(name=name, client=self.client, branch=self.branch, schema=schema, data=item)
cast(
"RelatedNode[PeerT]",
RelatedNode(name=name, client=self.client, branch=self.branch, schema=schema, data=item),
)
)
elif isinstance(data, dict) and "edges" in data:
for item in data["edges"]:
self.peers.append(
RelatedNode(name=name, client=self.client, branch=self.branch, schema=schema, data=item)
cast(
"RelatedNode[PeerT]",
RelatedNode(name=name, client=self.client, branch=self.branch, schema=schema, data=item),
)
)
else:
raise ValueError(
Expand All @@ -169,8 +175,8 @@ def __init__(
f"Wrap the value in a list, e.g. {name}=[value]."
)

def __getitem__(self, item: int) -> RelatedNode:
return self.peers[item] # type: ignore[return-value]
def __getitem__(self, item: int) -> RelatedNode[PeerT]:
return cast("RelatedNode[PeerT]", self.peers[item])

async def fetch(self) -> None:
if not self.initialized:
Expand Down Expand Up @@ -217,7 +223,9 @@ def add(self, data: str | RelatedNode | dict) -> None:
"""
if not self.initialized:
raise UninitializedError("Must call fetch() on RelationshipManager before editing members")
new_node = RelatedNode(schema=self.schema, client=self.client, branch=self.branch, data=data)
new_node = cast(
"RelatedNode[PeerT]", RelatedNode(schema=self.schema, client=self.client, branch=self.branch, data=data)
)

if (new_node.id and new_node.id not in self.peer_ids) or (
new_node.hfid and new_node.hfid not in self.peer_hfids
Expand Down Expand Up @@ -252,7 +260,7 @@ def remove(self, data: str | RelatedNode | dict) -> None:
self._has_update = True


class RelationshipManagerSync(RelationshipManagerBase):
class RelationshipManagerSync(RelationshipManagerBase[PeerTSync]):
"""Manages relationships of a node in a synchronous context."""

def __init__(
Expand Down Expand Up @@ -292,12 +300,18 @@ def __init__(
if isinstance(data, list):
for item in data:
self.peers.append(
RelatedNodeSync(name=name, client=self.client, branch=self.branch, schema=schema, data=item)
cast(
"RelatedNodeSync[PeerTSync]",
RelatedNodeSync(name=name, client=self.client, branch=self.branch, schema=schema, data=item),
)
)
elif isinstance(data, dict) and "edges" in data:
for item in data["edges"]:
self.peers.append(
RelatedNodeSync(name=name, client=self.client, branch=self.branch, schema=schema, data=item)
cast(
"RelatedNodeSync[PeerTSync]",
RelatedNodeSync(name=name, client=self.client, branch=self.branch, schema=schema, data=item),
)
)
else:
raise ValueError(
Expand All @@ -306,8 +320,8 @@ def __init__(
f"Wrap the value in a list, e.g. {name}=[value]."
)

def __getitem__(self, item: int) -> RelatedNodeSync:
return self.peers[item] # type: ignore[return-value]
def __getitem__(self, item: int) -> RelatedNodeSync[PeerTSync]:
return cast("RelatedNodeSync[PeerTSync]", self.peers[item])

def fetch(self) -> None:
if not self.initialized:
Expand Down Expand Up @@ -354,7 +368,10 @@ def add(self, data: str | RelatedNodeSync | dict) -> None:
"""
if not self.initialized:
raise UninitializedError("Must call fetch() on RelationshipManager before editing members")
new_node = RelatedNodeSync(schema=self.schema, client=self.client, branch=self.branch, data=data)
new_node = cast(
"RelatedNodeSync[PeerTSync]",
RelatedNodeSync(schema=self.schema, client=self.client, branch=self.branch, data=data),
)

if (new_node.id and new_node.id not in self.peer_ids) or (
new_node.hfid and new_node.hfid not in self.peer_hfids
Expand Down
11 changes: 8 additions & 3 deletions infrahub_sdk/protocols_generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,24 @@ def _jinja2_filter_render_attribute(value: AttributeSchemaAPI) -> str:

return f"{value.name}: {attribute_kind}"

@staticmethod
def _jinja2_filter_render_relationship(value: RelationshipSchemaAPI, sync: bool = False) -> str:
def _jinja2_filter_render_relationship(self, value: RelationshipSchemaAPI, sync: bool = False) -> str:
name = value.name
cardinality = value.cardinality
peer = value.peer

type_ = "RelatedNode"
if cardinality == RelationshipCardinality.MANY:
type_ = "RelationshipManager"

if sync:
type_ += "Sync"
# Core peer protocols expose a dedicated ``*Sync`` variant; reference it in sync
# output. Locally generated node/generic classes keep their name (they already
# inherit the sync base class), so only swap when a ``*Sync`` peer actually exists.
if f"{peer}Sync" in self.base_protocols:
peer = f"{peer}Sync"

return f"{name}: {type_}"
return f"{name}: {type_}[{peer}]"

@staticmethod
def _sort_and_filter_models(
Expand Down
12 changes: 6 additions & 6 deletions infrahub_sdk/protocols_generator/template.j2
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class {{ generic.namespace + generic.name }}({{core_node_name}}):
{{ relationship | render_relationship(sync) }}
{% endfor %}
{% if generic.hierarchical | default(false) %}
parent: {{ "RelatedNode" | syncify(sync) }}
children: {{ "RelationshipManager" | syncify(sync) }}
parent: {{ "RelatedNode" | syncify(sync) }}[{{ generic.namespace + generic.name }}]
children: {{ "RelationshipManager" | syncify(sync) }}[{{ generic.namespace + generic.name }}]
{% endif %}
{% endfor %}

Expand All @@ -71,8 +71,8 @@ class {{ node.namespace + node.name }}({{ node.inherit_from | syncify(sync) | jo
{{ relationship | render_relationship(sync) }}
{% endfor %}
{% if node.hierarchical | default(false) %}
parent: {{ "RelatedNode" | syncify(sync) }}
children: {{ "RelationshipManager" | syncify(sync) }}
parent: {{ "RelatedNode" | syncify(sync) }}[{{ node.namespace + node.name }}]
children: {{ "RelationshipManager" | syncify(sync) }}[{{ node.namespace + node.name }}]
{% endif %}

{% endfor %}
Expand All @@ -91,8 +91,8 @@ class {{ node.namespace + node.name }}({{ node.inherit_from | syncify(sync) | jo
{{ relationship | render_relationship(sync) }}
{% endfor %}
{% if node.hierarchical | default(false) %}
parent: {{ "RelatedNode" | syncify(sync) }}
children: {{ "RelationshipManager" | syncify(sync) }}
parent: {{ "RelatedNode" | syncify(sync) }}[{{ node.namespace + node.name }}]
children: {{ "RelationshipManager" | syncify(sync) }}[{{ node.namespace + node.name }}]
{% endif %}

{% endfor %}
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/sdk/test_protocols_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ class LocationSite(LocationGeneric):
name: String
physical_address: StringOptional
shortname: String
children: RelationshipManagerSync
member_of_groups: RelationshipManagerSync
parent: RelatedNodeSync
profiles: RelationshipManagerSync
servers: RelationshipManagerSync
subscriber_of_groups: RelationshipManagerSync
tags: RelationshipManagerSync
children: RelationshipManagerSync[LocationRack]
member_of_groups: RelationshipManagerSync[CoreGroupSync]
parent: RelatedNodeSync[LocationCountry]
profiles: RelationshipManagerSync[CoreProfileSync]
servers: RelationshipManagerSync[NetworkManagementServer]
subscriber_of_groups: RelationshipManagerSync[CoreGroupSync]
tags: RelationshipManagerSync[BuiltinTagSync]
"""

assert location_site_sync in sync_protocols
Expand Down