Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions backend/infrahub/git/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ async def add_git_repository(model: GitRepositoryAdd) -> None:
# Notify other workers they need to clone the repository and check out the SHA pinned
# by this initial sync, so the whole pool converges even if upstream advances meanwhile.
notification = messages.RefreshGitFetch(
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")),
meta=Meta(
initiator_id=WORKER_IDENTITY,
request_id=get_log_data().get("request_id", ""),
user_request_id=get_log_data().get("user_request_id", ""),
),
location=model.location,
repository_id=model.repository_id,
repository_name=model.repository_name,
Expand Down Expand Up @@ -150,7 +154,11 @@ async def add_git_repository_read_only(model: GitRepositoryAddReadOnly) -> None:

# Notify other workers they need to clone the repository and check out the resolved commit
notification = messages.RefreshGitFetch(
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")),
meta=Meta(
initiator_id=WORKER_IDENTITY,
request_id=get_log_data().get("request_id", ""),
user_request_id=get_log_data().get("user_request_id", ""),
),
location=model.location,
repository_id=model.repository_id,
repository_name=model.repository_name,
Expand Down Expand Up @@ -324,7 +332,11 @@ async def sync_remote_repositories() -> None:
# Tell workers to fetch and check out the SHA pinned by this sync, so the whole
# pool converges on the same commit even if upstream advances during fan-out.
message = messages.RefreshGitFetch(
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")),
meta=Meta(
initiator_id=WORKER_IDENTITY,
request_id=get_log_data().get("request_id", ""),
user_request_id=get_log_data().get("user_request_id", ""),
),
location=repository.location.value,
repository_id=repository.id,
repository_name=repository.name.value,
Expand Down Expand Up @@ -368,7 +380,11 @@ async def git_branch_create(
# New branch has been pushed remotely, tell workers to fetch it and check out the SHA it
# was created at so the pool converges even if upstream advances during fan-out.
message = messages.RefreshGitFetch(
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")),
meta=Meta(
initiator_id=WORKER_IDENTITY,
request_id=get_log_data().get("request_id", ""),
user_request_id=get_log_data().get("user_request_id", ""),
),
location=repo.get_location(),
repository_id=str(repo.id),
repository_name=repo.name,
Expand Down Expand Up @@ -410,7 +426,11 @@ async def git_branch_delete(

message_bus = await get_message_bus()
message = messages.RefreshGitRepositoryBranchDeleted(
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")),
meta=Meta(
initiator_id=WORKER_IDENTITY,
request_id=get_log_data().get("request_id", ""),
user_request_id=get_log_data().get("user_request_id", ""),
),
repository_id=str(repo.id),
repository_name=repo.name,
repository_kind=InfrahubKind.REPOSITORY,
Expand Down Expand Up @@ -605,7 +625,11 @@ async def pull_read_only(model: GitRepositoryPullReadOnly) -> None:

# Tell workers to fetch and check out the resolved commit to stay in sync
message = messages.RefreshGitFetch(
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")),
meta=Meta(
initiator_id=WORKER_IDENTITY,
request_id=get_log_data().get("request_id", ""),
user_request_id=get_log_data().get("user_request_id", ""),
),
location=model.location,
repository_id=model.repository_id,
repository_name=model.repository_name,
Expand Down Expand Up @@ -683,7 +707,11 @@ async def merge_git_repository(model: GitRepositoryMerge) -> None:
# Destination branch has changed and pushed remotely, tell workers to re-fetch and
# check out the merge commit so the pool converges even if upstream advances meanwhile.
message = messages.RefreshGitFetch(
meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")),
meta=Meta(
initiator_id=WORKER_IDENTITY,
request_id=get_log_data().get("request_id", ""),
user_request_id=get_log_data().get("user_request_id", ""),
),
location=repo.location,
repository_id=str(repo.id),
repository_name=repo.name,
Expand Down
4 changes: 4 additions & 0 deletions backend/infrahub/message_bus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

class Meta(BaseModel):
request_id: str = ""
user_request_id: str = Field(default="", description="Optional request-id submitted by the caller through the API")
correlation_id: str | None = Field(default=None)
reply_to: str | None = Field(default=None)
initiator_id: str | None = Field(
Expand Down Expand Up @@ -38,6 +39,7 @@ class InfrahubMessage(BaseModel):
def assign_meta(self, parent: InfrahubMessage) -> None:
"""Assign relevant meta properties from a parent message."""
self.meta.request_id = parent.meta.request_id
self.meta.user_request_id = parent.meta.user_request_id
self.meta.initiator_id = parent.meta.initiator_id

def assign_header(self, key: str, value: Any) -> None:
Expand All @@ -54,6 +56,8 @@ def set_log_data(self, routing_key: str) -> None:
set_log_data(key="routing_key", value=routing_key)
if self.meta.request_id:
set_log_data(key="request_id", value=self.meta.request_id)
if self.meta.user_request_id:
set_log_data(key="user_request_id", value=self.meta.user_request_id)

@property
def reply_requested(self) -> bool:
Expand Down
5 changes: 5 additions & 0 deletions backend/infrahub/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ async def logging_middleware(request: Request, call_next: Callable[[Request], Aw
request_id = correlation_id.get()

set_log_data(key="request_id", value=request_id)

user_request_id = request.headers.get("x-infrahub-request-id")
if user_request_id:
set_log_data(key="user_request_id", value=user_request_id)

set_log_data(key="app", value="infrahub.api")
set_log_data(key="worker", value=WORKER_IDENTITY)

Expand Down
6 changes: 5 additions & 1 deletion backend/infrahub/services/adapters/message_bus/nats.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
async def _add_request_id(message: InfrahubMessage) -> None:
log_data = get_log_data()
message.meta.request_id = log_data.get("request_id", "")
message.meta.user_request_id = log_data.get("user_request_id", "")


class NATSMessageBus(InfrahubMessageBus):
Expand Down Expand Up @@ -287,7 +288,10 @@ async def rpc(self, message: InfrahubMessage, response_class: type[ResponseClass
log_data = get_log_data()
request_id = log_data.get("request_id", "")
message.meta = Meta(
request_id=request_id, correlation_id=correlation_id, reply_to=self.callback_queue.config.name
request_id=request_id,
user_request_id=log_data.get("user_request_id", ""),
correlation_id=correlation_id,
reply_to=self.callback_queue.config.name,
)

await self.send(message=message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def set_channel(self: SpanBuilder, channel: AbstractChannel) -> None:
async def _add_request_id(message: InfrahubMessage) -> None:
log_data = get_log_data()
message.meta.request_id = log_data.get("request_id", "")
message.meta.user_request_id = log_data.get("user_request_id", "")


class RabbitMQMessageBus(InfrahubMessageBus):
Expand Down Expand Up @@ -239,7 +240,12 @@ async def rpc(self, message: InfrahubMessage, response_class: type[ResponseClass

log_data = get_log_data()
request_id = log_data.get("request_id", "")
message.meta = Meta(request_id=request_id, correlation_id=correlation_id, reply_to=self.callback_queue.name)
message.meta = Meta(
request_id=request_id,
user_request_id=log_data.get("user_request_id", ""),
correlation_id=correlation_id,
reply_to=self.callback_queue.name,
)

await self.send(message=message)

Expand Down
50 changes: 50 additions & 0 deletions backend/tests/unit/message_bus/test_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

from infrahub.log import clear_log_context, get_log_data, set_log_data
from infrahub.message_bus import InfrahubMessage, Meta
from infrahub.services.adapters.message_bus.rabbitmq import _add_request_id


def test_assign_meta_propagates_user_request_id() -> None:
parent = InfrahubMessage(meta=Meta(request_id="generated", user_request_id="caller-123"))
child = InfrahubMessage()

child.assign_meta(parent=parent)

assert child.meta.request_id == "generated"
assert child.meta.user_request_id == "caller-123"


def test_set_log_data_binds_user_request_id() -> None:
clear_log_context()
message = InfrahubMessage(meta=Meta(request_id="generated", user_request_id="caller-123"))

message.set_log_data(routing_key="dummy.routing.key")

log_data = get_log_data()
assert log_data["request_id"] == "generated"
assert log_data["user_request_id"] == "caller-123"
clear_log_context()


def test_set_log_data_skips_empty_user_request_id() -> None:
clear_log_context()
message = InfrahubMessage(meta=Meta(request_id="generated"))

message.set_log_data(routing_key="dummy.routing.key")

assert "user_request_id" not in get_log_data()
clear_log_context()


async def test_add_request_id_enricher_copies_user_request_id() -> None:
clear_log_context()
set_log_data(key="request_id", value="generated")
set_log_data(key="user_request_id", value="caller-123")
message = InfrahubMessage()

await _add_request_id(message=message)

assert message.meta.request_id == "generated"
assert message.meta.user_request_id == "caller-123"
clear_log_context()