Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infrahub_sdk/ctl/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ async def report(
git_files_changed = await check_git_files_changed(client, branch=branch_name)

proposed_changes = await client.filters(
kind=CoreProposedChange, # type: ignore[type-abstract]
kind=CoreProposedChange,
source_branch__value=branch_name,
prefetch_relationships=True,
property=True,
Expand Down
7 changes: 4 additions & 3 deletions infrahub_sdk/node/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, NamedTuple, get_args

from ..protocols_base import CoreNodeBase
from ..uuidt import UUIDT
from .constants import ATTRIBUTE_METADATA_OBJECT, IP_TYPES, PROPERTIES_FLAG, PROPERTIES_OBJECT, SAFE_VALUE
from .property import NodeProperty
Expand Down Expand Up @@ -115,7 +114,7 @@ def _initialize_graphql_payload(self) -> _GraphQLPayloadAttribute:
# Pool-based allocation (dict data or resource-pool node)
if self._from_pool is not None:
return _GraphQLPayloadAttribute(payload={"from_pool": self._from_pool}, variables={}, needs_metadata=True)
if isinstance(self.value, CoreNodeBase) and self.value.is_resource_pool():
if hasattr(self.value, "is_resource_pool") and self.value.is_resource_pool():
return _GraphQLPayloadAttribute(
payload={"from_pool": {"id": self.value.id}}, variables={}, needs_metadata=True
)
Expand Down Expand Up @@ -190,4 +189,6 @@ def is_from_pool_attribute(self) -> bool:
True if the attribute value is a resource pool node or was explicitly allocated from a pool.

"""
return (isinstance(self.value, CoreNodeBase) and self.value.is_resource_pool()) or self._from_pool is not None
return (
hasattr(self.value, "is_resource_pool") and self.value.is_resource_pool()
) or self._from_pool is not None
12 changes: 7 additions & 5 deletions infrahub_sdk/node/related_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from ..exceptions import Error
from ..protocols_base import CoreNodeBase
Expand All @@ -11,7 +11,7 @@
if TYPE_CHECKING:
from ..client import InfrahubClient, InfrahubClientSync
from ..schema import RelationshipSchemaAPI
from .node import InfrahubNode, InfrahubNodeSync
from .node import InfrahubNode, InfrahubNodeBase, InfrahubNodeSync


class RelatedNodeBase:
Expand All @@ -36,7 +36,7 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
self._properties_object = PROPERTIES_OBJECT
self._properties = self._properties_flag + self._properties_object

self._peer = None
self._peer: InfrahubNodeBase | CoreNodeBase | None = None
self._id: str | None = None
self._hfid: list[str] | None = None
self._display_label: str | None = None
Expand All @@ -45,8 +45,10 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
self._source_typename: str | None = None
self._relationship_metadata: RelationshipMetadata | None = None

if isinstance(data, (CoreNodeBase)):
self._peer = data
# Check for InfrahubNodeBase instances using duck-typing (_schema attribute)
# to avoid circular imports, or CoreNodeBase instances
if isinstance(data, CoreNodeBase) or hasattr(data, "_schema"):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: hasattr(data, "_schema") is an overly broad peer check and can misclassify non-node objects, causing runtime attribute errors.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At infrahub_sdk/node/related_node.py, line 50:

<comment>`hasattr(data, "_schema")` is an overly broad peer check and can misclassify non-node objects, causing runtime attribute errors.</comment>

<file context>
@@ -45,8 +45,10 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
-            self._peer = data
+        # Check for InfrahubNodeBase instances using duck-typing (_schema attribute)
+        # to avoid circular imports, or CoreNodeBase instances
+        if isinstance(data, CoreNodeBase) or hasattr(data, "_schema"):
+            self._peer = cast("InfrahubNodeBase | CoreNodeBase", data)
             for prop in self._properties:
</file context>

self._peer = cast("InfrahubNodeBase | CoreNodeBase", data)
for prop in self._properties:
setattr(self, prop, None)
self._relationship_metadata = None
Expand Down
27 changes: 15 additions & 12 deletions infrahub_sdk/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ class AnyAttributeOptional(Attribute):
value: float | None


@runtime_checkable
class CoreNodeBase(Protocol):
class CoreNodeBase:
_schema: MainSchemaTypes
_internal_id: str
id: str # NOTE this is incorrect, should be str | None
Expand All @@ -190,25 +189,30 @@ def get_human_friendly_id(self) -> list[str] | None: ...

def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | None: ...

def get_kind(self) -> str: ...
def get_kind(self) -> str:
raise NotImplementedError

def get_all_kinds(self) -> list[str]: ...
def get_all_kinds(self) -> list[str]:
raise NotImplementedError

def get_branch(self) -> str: ...
def get_branch(self) -> str:
raise NotImplementedError

def is_ip_prefix(self) -> bool: ...
def is_ip_prefix(self) -> bool:
raise NotImplementedError

def is_ip_address(self) -> bool: ...
def is_ip_address(self) -> bool:
raise NotImplementedError

def is_resource_pool(self) -> bool: ...
def is_resource_pool(self) -> bool:
raise NotImplementedError

def get_raw_graphql_data(self) -> dict | None: ...

def get_node_metadata(self) -> NodeMetadata | None: ...


@runtime_checkable
class CoreNode(CoreNodeBase, Protocol):
class CoreNode(CoreNodeBase):
async def save(
self,
allow_upsert: bool = False,
Expand All @@ -232,8 +236,7 @@ async def add_relationships(self, relation_to_update: str, related_nodes: list[s
async def remove_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ...


@runtime_checkable
class CoreNodeSync(CoreNodeBase, Protocol):
class CoreNodeSync(CoreNodeBase):
def save(
self,
allow_upsert: bool = False,
Expand Down
5 changes: 3 additions & 2 deletions infrahub_sdk/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ValidationError,
)
from ..graphql import Mutation
from ..protocols_base import CoreNodeBase
from ..queries import SCHEMA_HASH_SYNC_STATUS
from .export import RESTRICTED_NAMESPACES, NamespaceExport, SchemaExport, schema_to_export_dict
from .main import (
Expand Down Expand Up @@ -253,14 +254,14 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str:
if isinstance(schema, str):
return schema

if hasattr(schema, "_is_runtime_protocol") and getattr(schema, "_is_runtime_protocol", None):
if issubclass(schema, CoreNodeBase):
if inspect.iscoroutinefunction(schema.save):
return schema.__name__
if schema.__name__[-4:] == "Sync":
return schema.__name__[:-4]
return schema.__name__

raise ValueError("schema must be a protocol or a string")
raise ValueError("schema must be a CoreNode subclass or a string")

@staticmethod
def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapping[str, Any]:
Expand Down
14 changes: 12 additions & 2 deletions infrahub_sdk/store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import inspect
import warnings
from typing import TYPE_CHECKING, Literal, overload

from infrahub_sdk.protocols_base import CoreNodeBase

from .exceptions import NodeInvalidError, NodeNotFoundError
from .node.parsers import parse_human_friendly_id

Expand All @@ -16,8 +19,15 @@ def get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str | None = Non
if isinstance(schema, str):
return schema

if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr]
return schema.__name__ # type: ignore[union-attr]
if schema is None:
return None

if issubclass(schema, CoreNodeBase):
if inspect.iscoroutinefunction(schema.save):
return schema.__name__
if schema.__name__[-4:] == "Sync":
return schema.__name__[:-4]
return schema.__name__

return None

Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/testing/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def wait_for_sync_to_complete(
) -> bool:
for _ in range(retries):
repo = await client.get(
kind=CoreGenericRepository, # type: ignore[type-abstract]
kind=CoreGenericRepository,
name__value=self.name,
branch=branch or self.initial_branch,
)
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/sdk/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from infrahub_sdk.exceptions import NodeInvalidError, NodeNotFoundError
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
from infrahub_sdk.store import NodeStore, NodeStoreSync
from infrahub_sdk.protocols import BuiltinIPAddressSync, BuiltinIPPrefix
from infrahub_sdk.store import NodeStore, NodeStoreSync, get_schema_name

if TYPE_CHECKING:
from infrahub_sdk.schema import NodeSchemaAPI
Expand Down Expand Up @@ -157,3 +158,8 @@ def test_node_store_get_with_hfid(
store.get(kind="BuiltinLocation", key="anotherkey")
with pytest.raises(NodeNotFoundError):
store.get(key="anotherkey")


def test_store_get_schema_name() -> None:
assert get_schema_name(schema=BuiltinIPPrefix) == BuiltinIPPrefix.__name__
assert get_schema_name(schema=BuiltinIPAddressSync) == BuiltinIPAddressSync.__name__[:-4]