Skip to content
Merged

Dev #697

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions backend/common/tests/test_notification_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from rest_framework.test import APIClient

from common.models import Notification, Org, Profile, User
from common.views import notification_views
from common.views.notification_views import (
_format_keepalive,
_format_sse,
Expand Down Expand Up @@ -119,6 +120,34 @@ async def runner():
assert payload["link"] == "/cases/abc"
assert payload["data"] == {"comment_excerpt": "hi"}

def test_stream_self_terminates_at_deadline(self):
"""The stream must end on its own once MAX_STREAM_SECONDS elapses, so a
long-open browser tab cannot pin a worker thread / DB connection
forever (the connection-exhaustion regression)."""
async def runner():
pubsub = _StubPubSub() # no messages -> keepalive-only loop
gen = _stream_events(
"notif:o:p", recipient_id=self.profile.id, pubsub=pubsub
)
first = await asyncio.wait_for(gen.__anext__(), timeout=1)
stopped = False
try:
# Deadline already passed -> next pull should end the stream.
await asyncio.wait_for(gen.__anext__(), timeout=2)
except StopAsyncIteration:
stopped = True
await gen.aclose()
return first, stopped

original = notification_views.MAX_STREAM_SECONDS
notification_views.MAX_STREAM_SECONDS = 0
try:
first, stopped = asyncio.run(runner())
finally:
notification_views.MAX_STREAM_SECONDS = original
assert first == b": keepalive\n\n"
assert stopped, "stream did not terminate at its deadline"

def test_drops_message_for_other_recipient(self):
n_other = Notification.objects.create(
org=self.org, recipient=self.other_profile, verb="other"
Expand Down
100 changes: 81 additions & 19 deletions backend/common/views/notification_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import asyncio
import json
import logging
import time

from asgiref.sync import sync_to_async
from django.http import StreamingHttpResponse
Expand All @@ -36,6 +37,17 @@

KEEPALIVE_SECONDS = 15

# Upper bound on a single SSE connection's lifetime. The browser's EventSource
# transparently reconnects when the server closes the stream, so capping this
# is invisible to the user but critical for the server: under ASGI the *entire*
# stream runs inside one request (one `ThreadSensitiveContext`), so an
# unbounded `while True` stream pins a worker thread — and any DB/redis
# connection bound to it — for as long as the tab stays open (hours). Left
# unbounded this is what slowly exhausts a shared Postgres cluster and leaves
# workers undrainable on deploy. 5 minutes bounds the leak window and lets
# workers recycle.
MAX_STREAM_SECONDS = 300

DEFAULT_LIMIT = 20
MAX_LIMIT = 100

Expand Down Expand Up @@ -141,38 +153,67 @@ async def _aget_serialized(notif_id, recipient_id):
"""Fetch a notification by id, scoped to recipient. Returns dict or None."""
@sync_to_async
def fetch():
notif = Notification.objects.filter(
pk=notif_id, recipient_id=recipient_id
).first()
if notif is None:
return None
return NotificationSerializer(notif).data
from django.db import connection

try:
notif = Notification.objects.filter(
pk=notif_id, recipient_id=recipient_id
).first()
if notif is None:
return None
return NotificationSerializer(notif).data
Comment on lines +156 to +164
finally:
# This runs in asgiref's shared thread-sensitive executor, outside
# Django's request/response cycle — so `close_old_connections`
# never fires for it. Without this explicit close the connection
# dangles open (the SSE tests need TransactionTestCase precisely
# because of it). Close it so each fetch reclaims its connection.
connection.close()

return await fetch()


async def _aclose_redis(obj):
"""Close a redis.asyncio client/pubsub across versions (aclose vs close)."""
if obj is None:
return
closer = getattr(obj, "aclose", None) or getattr(obj, "close", None)
if closer is None:
return
try:
result = closer()
if asyncio.iscoroutine(result):
await result
except Exception: # pragma: no cover - best effort
pass


async def _open_pubsub(channel: str):
"""Open a redis.asyncio pubsub subscribed to ``channel``.

Returns ``None`` if Redis is unreachable; the stream then runs in
keepalive-only mode (the frontend's polling-since path provides
backfill).
Returns ``(client, pubsub)``. Both are ``None`` if Redis is unreachable;
the stream then runs in keepalive-only mode (the frontend's polling-since
path provides backfill). The caller owns closing BOTH the pubsub and the
underlying client connection pool — closing only the pubsub leaks the
client's pooled connections, one per dropped SSE stream.
"""
try:
import redis.asyncio as aioredis # type: ignore
except ImportError: # pragma: no cover
return None
return None, None
from django.conf import settings

url = getattr(settings, "CELERY_BROKER_URL", None) or "redis://localhost:6379/0"
client = None
try:
client = aioredis.from_url(url)
pubsub = client.pubsub()
await pubsub.subscribe(channel)
return pubsub
return client, pubsub
except Exception as exc:
logger.warning("SSE redis subscribe failed (%s); keepalive-only mode", exc)
return None
await _aclose_redis(client)
return None, None


async def _stream_events(channel: str, recipient_id, *, pubsub=None):
Expand All @@ -182,12 +223,18 @@ async def _stream_events(channel: str, recipient_id, *, pubsub=None):
opens its own redis.asyncio pubsub via :func:`_open_pubsub`.
"""
owns_pubsub = pubsub is None
client = None
if owns_pubsub:
pubsub = await _open_pubsub(channel)
client, pubsub = await _open_pubsub(channel)
deadline = time.monotonic() + MAX_STREAM_SECONDS
try:
# Initial comment so the client confirms the stream opened.
yield _format_keepalive()
while True:
# Bounded lifetime: end the stream so the worker thread and its
# DB/redis connections are released. EventSource reconnects.
if time.monotonic() >= deadline:
return
if pubsub is None:
await asyncio.sleep(KEEPALIVE_SECONDS)
yield _format_keepalive()
Expand All @@ -211,12 +258,15 @@ async def _stream_events(channel: str, recipient_id, *, pubsub=None):
continue
yield _format_sse("notification", payload)
finally:
if owns_pubsub and pubsub is not None:
try:
await pubsub.unsubscribe(channel)
await pubsub.close()
except Exception: # pragma: no cover - best effort
pass
if owns_pubsub:
if pubsub is not None:
try:
await pubsub.unsubscribe(channel)
except Exception: # pragma: no cover - best effort
pass
await _aclose_redis(pubsub)
# Close the underlying client pool too — not just the pubsub.
await _aclose_redis(client)


def _drive_async_gen(agen):
Expand Down Expand Up @@ -271,6 +321,18 @@ def get(self, request, *args, **kwargs):
profile_id = request.profile.id
channel = notif_mod.channel_for(org_id, profile_id)

# Release THIS request thread's DB connection before streaming begins.
# Under ASGI the entire stream runs inside one request's
# ThreadSensitiveContext, so the connection opened by auth/middleware
# would otherwise stay checked out for the full (possibly hours-long)
# stream — one leaked Postgres connection per open browser tab, which
# is what exhausts the shared cluster. The stream's own reads go
# through a separate executor (see `_aget_serialized`) and close
# themselves, so dropping this one is safe.
from django.db import connection

connection.close()

response = StreamingHttpResponse(
_drive_async_gen(_stream_events(channel, profile_id)),
content_type="text/event-stream",
Expand Down
Loading