diff --git a/backend/infrahub/git/tasks.py b/backend/infrahub/git/tasks.py index 64358d3aea5..65a86617fc3 100644 --- a/backend/infrahub/git/tasks.py +++ b/backend/infrahub/git/tasks.py @@ -110,7 +110,11 @@ async def add_git_repository(model: GitRepositoryAdd) -> None: # Notify other workers they need to clone the repository and check out the SHA pinned # by this initial sync, so the whole pool converges even if upstream advances meanwhile. notification = messages.RefreshGitFetch( - meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")), + meta=Meta( + initiator_id=WORKER_IDENTITY, + request_id=get_log_data().get("request_id", ""), + user_request_id=get_log_data().get("user_request_id", ""), + ), location=model.location, repository_id=model.repository_id, repository_name=model.repository_name, @@ -150,7 +154,11 @@ async def add_git_repository_read_only(model: GitRepositoryAddReadOnly) -> None: # Notify other workers they need to clone the repository and check out the resolved commit notification = messages.RefreshGitFetch( - meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")), + meta=Meta( + initiator_id=WORKER_IDENTITY, + request_id=get_log_data().get("request_id", ""), + user_request_id=get_log_data().get("user_request_id", ""), + ), location=model.location, repository_id=model.repository_id, repository_name=model.repository_name, @@ -324,7 +332,11 @@ async def sync_remote_repositories() -> None: # Tell workers to fetch and check out the SHA pinned by this sync, so the whole # pool converges on the same commit even if upstream advances during fan-out. message = messages.RefreshGitFetch( - meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")), + meta=Meta( + initiator_id=WORKER_IDENTITY, + request_id=get_log_data().get("request_id", ""), + user_request_id=get_log_data().get("user_request_id", ""), + ), location=repository.location.value, repository_id=repository.id, repository_name=repository.name.value, @@ -368,7 +380,11 @@ async def git_branch_create( # New branch has been pushed remotely, tell workers to fetch it and check out the SHA it # was created at so the pool converges even if upstream advances during fan-out. message = messages.RefreshGitFetch( - meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")), + meta=Meta( + initiator_id=WORKER_IDENTITY, + request_id=get_log_data().get("request_id", ""), + user_request_id=get_log_data().get("user_request_id", ""), + ), location=repo.get_location(), repository_id=str(repo.id), repository_name=repo.name, @@ -410,7 +426,11 @@ async def git_branch_delete( message_bus = await get_message_bus() message = messages.RefreshGitRepositoryBranchDeleted( - meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")), + meta=Meta( + initiator_id=WORKER_IDENTITY, + request_id=get_log_data().get("request_id", ""), + user_request_id=get_log_data().get("user_request_id", ""), + ), repository_id=str(repo.id), repository_name=repo.name, repository_kind=InfrahubKind.REPOSITORY, @@ -605,7 +625,11 @@ async def pull_read_only(model: GitRepositoryPullReadOnly) -> None: # Tell workers to fetch and check out the resolved commit to stay in sync message = messages.RefreshGitFetch( - meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")), + meta=Meta( + initiator_id=WORKER_IDENTITY, + request_id=get_log_data().get("request_id", ""), + user_request_id=get_log_data().get("user_request_id", ""), + ), location=model.location, repository_id=model.repository_id, repository_name=model.repository_name, @@ -683,7 +707,11 @@ async def merge_git_repository(model: GitRepositoryMerge) -> None: # Destination branch has changed and pushed remotely, tell workers to re-fetch and # check out the merge commit so the pool converges even if upstream advances meanwhile. message = messages.RefreshGitFetch( - meta=Meta(initiator_id=WORKER_IDENTITY, request_id=get_log_data().get("request_id", "")), + meta=Meta( + initiator_id=WORKER_IDENTITY, + request_id=get_log_data().get("request_id", ""), + user_request_id=get_log_data().get("user_request_id", ""), + ), location=repo.location, repository_id=str(repo.id), repository_name=repo.name, diff --git a/backend/infrahub/message_bus/__init__.py b/backend/infrahub/message_bus/__init__.py index d8286fd340b..7d85032e5da 100644 --- a/backend/infrahub/message_bus/__init__.py +++ b/backend/infrahub/message_bus/__init__.py @@ -11,6 +11,7 @@ class Meta(BaseModel): request_id: str = "" + user_request_id: str = Field(default="", description="Optional request-id submitted by the caller through the API") correlation_id: str | None = Field(default=None) reply_to: str | None = Field(default=None) initiator_id: str | None = Field( @@ -38,6 +39,7 @@ class InfrahubMessage(BaseModel): def assign_meta(self, parent: InfrahubMessage) -> None: """Assign relevant meta properties from a parent message.""" self.meta.request_id = parent.meta.request_id + self.meta.user_request_id = parent.meta.user_request_id self.meta.initiator_id = parent.meta.initiator_id def assign_header(self, key: str, value: Any) -> None: @@ -54,6 +56,8 @@ def set_log_data(self, routing_key: str) -> None: set_log_data(key="routing_key", value=routing_key) if self.meta.request_id: set_log_data(key="request_id", value=self.meta.request_id) + if self.meta.user_request_id: + set_log_data(key="user_request_id", value=self.meta.user_request_id) @property def reply_requested(self) -> bool: diff --git a/backend/infrahub/server.py b/backend/infrahub/server.py index 5e7bc9334a5..c23628271e4 100644 --- a/backend/infrahub/server.py +++ b/backend/infrahub/server.py @@ -157,6 +157,11 @@ async def logging_middleware(request: Request, call_next: Callable[[Request], Aw request_id = correlation_id.get() set_log_data(key="request_id", value=request_id) + + user_request_id = request.headers.get("x-infrahub-request-id") + if user_request_id: + set_log_data(key="user_request_id", value=user_request_id) + set_log_data(key="app", value="infrahub.api") set_log_data(key="worker", value=WORKER_IDENTITY) diff --git a/backend/infrahub/services/adapters/message_bus/nats.py b/backend/infrahub/services/adapters/message_bus/nats.py index 859bfc74c01..05b4d6d3154 100644 --- a/backend/infrahub/services/adapters/message_bus/nats.py +++ b/backend/infrahub/services/adapters/message_bus/nats.py @@ -31,6 +31,7 @@ async def _add_request_id(message: InfrahubMessage) -> None: log_data = get_log_data() message.meta.request_id = log_data.get("request_id", "") + message.meta.user_request_id = log_data.get("user_request_id", "") class NATSMessageBus(InfrahubMessageBus): @@ -287,7 +288,10 @@ async def rpc(self, message: InfrahubMessage, response_class: type[ResponseClass log_data = get_log_data() request_id = log_data.get("request_id", "") message.meta = Meta( - request_id=request_id, correlation_id=correlation_id, reply_to=self.callback_queue.config.name + request_id=request_id, + user_request_id=log_data.get("user_request_id", ""), + correlation_id=correlation_id, + reply_to=self.callback_queue.config.name, ) await self.send(message=message) diff --git a/backend/infrahub/services/adapters/message_bus/rabbitmq.py b/backend/infrahub/services/adapters/message_bus/rabbitmq.py index 24d0836e3fe..f178eb874d0 100644 --- a/backend/infrahub/services/adapters/message_bus/rabbitmq.py +++ b/backend/infrahub/services/adapters/message_bus/rabbitmq.py @@ -58,6 +58,7 @@ def set_channel(self: SpanBuilder, channel: AbstractChannel) -> None: async def _add_request_id(message: InfrahubMessage) -> None: log_data = get_log_data() message.meta.request_id = log_data.get("request_id", "") + message.meta.user_request_id = log_data.get("user_request_id", "") class RabbitMQMessageBus(InfrahubMessageBus): @@ -239,7 +240,12 @@ async def rpc(self, message: InfrahubMessage, response_class: type[ResponseClass log_data = get_log_data() request_id = log_data.get("request_id", "") - message.meta = Meta(request_id=request_id, correlation_id=correlation_id, reply_to=self.callback_queue.name) + message.meta = Meta( + request_id=request_id, + user_request_id=log_data.get("user_request_id", ""), + correlation_id=correlation_id, + reply_to=self.callback_queue.name, + ) await self.send(message=message) diff --git a/backend/tests/unit/message_bus/test_meta.py b/backend/tests/unit/message_bus/test_meta.py new file mode 100644 index 00000000000..f837a2269e3 --- /dev/null +++ b/backend/tests/unit/message_bus/test_meta.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from infrahub.log import clear_log_context, get_log_data, set_log_data +from infrahub.message_bus import InfrahubMessage, Meta +from infrahub.services.adapters.message_bus.rabbitmq import _add_request_id + + +def test_assign_meta_propagates_user_request_id() -> None: + parent = InfrahubMessage(meta=Meta(request_id="generated", user_request_id="caller-123")) + child = InfrahubMessage() + + child.assign_meta(parent=parent) + + assert child.meta.request_id == "generated" + assert child.meta.user_request_id == "caller-123" + + +def test_set_log_data_binds_user_request_id() -> None: + clear_log_context() + message = InfrahubMessage(meta=Meta(request_id="generated", user_request_id="caller-123")) + + message.set_log_data(routing_key="dummy.routing.key") + + log_data = get_log_data() + assert log_data["request_id"] == "generated" + assert log_data["user_request_id"] == "caller-123" + clear_log_context() + + +def test_set_log_data_skips_empty_user_request_id() -> None: + clear_log_context() + message = InfrahubMessage(meta=Meta(request_id="generated")) + + message.set_log_data(routing_key="dummy.routing.key") + + assert "user_request_id" not in get_log_data() + clear_log_context() + + +async def test_add_request_id_enricher_copies_user_request_id() -> None: + clear_log_context() + set_log_data(key="request_id", value="generated") + set_log_data(key="user_request_id", value="caller-123") + message = InfrahubMessage() + + await _add_request_id(message=message) + + assert message.meta.request_id == "generated" + assert message.meta.user_request_id == "caller-123" + clear_log_context()