diff --git a/h/migrations/versions/b6be2385d907_add_the_user_group_lms_role_column.py b/h/migrations/versions/b6be2385d907_add_the_user_group_lms_role_column.py new file mode 100644 index 00000000000..d7695440fc2 --- /dev/null +++ b/h/migrations/versions/b6be2385d907_add_the_user_group_lms_role_column.py @@ -0,0 +1,28 @@ +"""Add the user_group.lms_role column.""" + +from alembic import op +from sqlalchemy import CheckConstraint, Column, UnicodeText + +revision = "b6be2385d907" +down_revision = "3794945d8e88" + + +def upgrade(): + op.add_column( + "user_group", + Column( + "lms_role", + UnicodeText, + CheckConstraint( + " OR ".join( + f"lms_role = '{role}'" for role in ["lms_instructor", "lms_student"] + ), + name="validate_lms_role", + ), + nullable=True, + ), + ) + + +def downgrade(): + op.drop_column("user_group", "lms_role") diff --git a/h/models/document/_document.py b/h/models/document/_document.py index 75302954c72..5694754a813 100644 --- a/h/models/document/_document.py +++ b/h/models/document/_document.py @@ -160,6 +160,7 @@ def merge_documents(session, documents, updated=None): from h.services.annotation_write import AnnotationWriteService # noqa: PLC0415 AnnotationWriteService.change_document(session, duplicate_ids, master) + _merge_checkpoints(session, duplicate_ids, master) session.query(Document).filter(Document.id.in_(duplicate_ids)).delete( synchronize_session="fetch" ) @@ -170,6 +171,57 @@ def merge_documents(session, documents, updated=None): return master +def _merge_checkpoints(session, duplicate_ids, master): + """ + Re-point Hide & Reveal checkpoints from the duplicate documents to master. + + This mirrors how annotations are re-pointed + by AnnotationWriteService.change_document. + + They are collapsed into a single checkpoint that + keeps the most restrictive reveal_date (an annotation stays hidden while + any of the merged checkpoints would hide it), so a merge can never reveal + annotations that should remain hidden. + """ + from h.models import Checkpoint # noqa: PLC0415 + + checkpoints = ( + session.query(Checkpoint) + .filter(Checkpoint.document_id.in_([master.id, *duplicate_ids])) + .all() + ) + + by_key: dict = {} + for checkpoint in checkpoints: + key = (checkpoint.group_id, checkpoint.previous_checkpoint_id) + by_key.setdefault(key, []).append(checkpoint) + + for colliding in by_key.values(): + # Prefer a checkpoint already on master as the survivor, so we don't + # momentarily violate the unique constraint by re-pointing onto it. + colliding.sort(key=lambda checkpoint: checkpoint.document_id != master.id) + survivor, *losers = colliding + + reveal_date = _most_restrictive_reveal_date(colliding) + + for loser in losers: + session.delete(loser) + session.flush() + + survivor.document_id = master.id + survivor.reveal_date = reveal_date + + session.flush() + + +def _most_restrictive_reveal_date(checkpoints): + """Return the reveal_date that keeps annotations hidden the longest.""" + reveal_dates = [checkpoint.reveal_date for checkpoint in checkpoints] + if any(reveal_date is None for reveal_date in reveal_dates): + return None + return max(reveal_dates) + + def update_document_metadata( # noqa: PLR0913 session, target_uri, diff --git a/h/models/group.py b/h/models/group.py index 7a7ea4860d0..b3648fcec00 100644 --- a/h/models/group.py +++ b/h/models/group.py @@ -44,6 +44,17 @@ class GroupMembershipRoles(enum.StrEnum): OWNER = "owner" +class LMSRole(enum.StrEnum): + """A member's LMS role within a group, synced from the LMS for hide/reveal authz. + + Deliberately separate from GroupMembershipRoles: these are non-hierarchical + and are not surfaced in h's role UI. + """ + + LMS_INSTRUCTOR = "lms_instructor" + LMS_STUDENT = "lms_student" + + class GroupMembership(Base): __tablename__ = "user_group" @@ -76,6 +87,15 @@ class GroupMembership(Base): nullable=False, ) + lms_role = sa.Column( + sa.UnicodeText, + sa.CheckConstraint( + " OR ".join(f"lms_role = '{role.value}'" for role in LMSRole), + name="validate_lms_role", + ), + nullable=True, + ) + created = sa.Column(sa.DateTime, default=datetime.datetime.utcnow, index=True) updated = sa.Column( diff --git a/h/services/__init__.py b/h/services/__init__.py index 22f16cbe1ff..6b5061117ba 100644 --- a/h/services/__init__.py +++ b/h/services/__init__.py @@ -11,6 +11,7 @@ BulkGroupService, BulkLMSStatsService, ) +from h.services.checkpoint import CheckpointService from h.services.email import EmailService from h.services.http import HTTPService from h.services.job_queue import JobQueueService @@ -51,6 +52,9 @@ def includeme(config): # pragma: no cover # noqa: PLR0915 config.register_service_factory( "h.services.annotation_write.service_factory", iface=AnnotationWriteService ) + config.register_service_factory( + "h.services.checkpoint.factory", iface=CheckpointService + ) config.register_service_factory("h.services.mention.factory", iface=MentionService) config.register_service_factory( "h.services.notification.factory", iface=NotificationService diff --git a/h/services/checkpoint.py b/h/services/checkpoint.py new file mode 100644 index 00000000000..74c5a79728b --- /dev/null +++ b/h/services/checkpoint.py @@ -0,0 +1,46 @@ +from datetime import datetime + +from sqlalchemy import or_, select + +from h.models import Checkpoint, Document + + +class CheckpointService: + """Resolve Hide & Reveal checkpoints for annotation-search authorization.""" + + def __init__(self, db): + self.db = db + + def active_checkpoint(self, group_id: int, uri: str) -> Checkpoint | None: + """ + Return an active (unrevealed) checkpoint for `(group_id, uri)`, or None. + + The `uri` is resolved to its Document(s) the same way the search layer + resolves the request's `uri` param, so the checkpoint lookup matches the + annotations the search will return even when the same document is + addressed by an equivalent URI (e.g. a PDF fingerprint). + + A checkpoint is "active" (still hiding annotations) when its reveal_date + has not yet passed: it is NULL (never revealed) or in the future. + """ + document_ids = [doc.id for doc in Document.find_by_uris(self.db, [uri])] + if not document_ids: + return None + + return self.db.scalar( + select(Checkpoint) + .where(Checkpoint.group_id == group_id) + .where(Checkpoint.document_id.in_(document_ids)) + .where( + or_( + Checkpoint.reveal_date.is_(None), + Checkpoint.reveal_date > datetime.utcnow(), # noqa: DTZ003 + ) + ) + .limit(1) + ) + + +def factory(_context, request) -> CheckpointService: + """Return a CheckpointService instance for the passed context and request.""" + return CheckpointService(db=request.db) diff --git a/tests/unit/h/models/document/_document_test.py b/tests/unit/h/models/document/_document_test.py index 4a56f787330..9e110e6e9bb 100644 --- a/tests/unit/h/models/document/_document_test.py +++ b/tests/unit/h/models/document/_document_test.py @@ -226,6 +226,82 @@ def test_it_moves_annotations_to_the_first(self, db_session, duplicate_docs): assert count == expected_count + def test_it_moves_checkpoints_to_the_first( + self, db_session, duplicate_docs, factories + ): + checkpoint = factories.Checkpoint(document=duplicate_docs[1]) + + merge_documents(db_session, duplicate_docs) + db_session.flush() + + assert checkpoint.document_id == duplicate_docs[0].id + + def test_it_keeps_checkpoints_in_different_groups( + self, db_session, duplicate_docs, factories + ): + checkpoint_1 = factories.Checkpoint(document=duplicate_docs[0]) + checkpoint_2 = factories.Checkpoint(document=duplicate_docs[1]) + + merge_documents(db_session, duplicate_docs) + db_session.flush() + + # Different groups don't collide, so both survive on the master. + assert checkpoint_1.document_id == duplicate_docs[0].id + assert checkpoint_2.document_id == duplicate_docs[0].id + + def test_it_collapses_colliding_checkpoints_to_the_most_restrictive( + self, db_session, duplicate_docs, factories + ): + group = factories.Group() + # Same group on two merging documents => the checkpoints collide. + factories.Checkpoint( + group=group, + document=duplicate_docs[0], + reveal_date=_datetime(2000, 1, 1), # noqa: DTZ001 # already revealed + ) + factories.Checkpoint( + group=group, + document=duplicate_docs[1], + reveal_date=None, # never revealed = most restrictive + ) + + merge_documents(db_session, duplicate_docs) + db_session.flush() + + survivors = ( + db_session.query(models.Checkpoint) + .filter_by(group_id=group.id, document_id=duplicate_docs[0].id) + .all() + ) + assert len(survivors) == 1 + assert survivors[0].reveal_date is None + + def test_it_collapses_colliding_checkpoints_to_the_latest_reveal_date( + self, db_session, duplicate_docs, factories + ): + group = factories.Group() + factories.Checkpoint( + group=group, + document=duplicate_docs[0], + reveal_date=_datetime(2000, 1, 1), # noqa: DTZ001 + ) + factories.Checkpoint( + group=group, + document=duplicate_docs[1], + reveal_date=_datetime(2030, 1, 1), # noqa: DTZ001 # hides for longest + ) + + merge_documents(db_session, duplicate_docs) + db_session.flush() + + survivors = ( + db_session.query(models.Checkpoint) + .filter_by(group_id=group.id, document_id=duplicate_docs[0].id) + .all() + ) + assert len(survivors) == 1 + assert survivors[0].reveal_date == _datetime(2030, 1, 1) # noqa: DTZ001 + def test_it_raises_retryable_error_when_flush_fails( self, db_session, duplicate_docs, monkeypatch ): diff --git a/tests/unit/h/services/checkpoint_test.py b/tests/unit/h/services/checkpoint_test.py new file mode 100644 index 00000000000..de5240c0eb5 --- /dev/null +++ b/tests/unit/h/services/checkpoint_test.py @@ -0,0 +1,94 @@ +from datetime import datetime, timedelta +from unittest import mock + +import pytest + +from h.services.checkpoint import CheckpointService, factory + + +class TestActiveCheckpoint: + def test_it_returns_an_unrevealed_checkpoint(self, svc, group, document): + checkpoint = self.checkpoint(group, document, reveal_date=None) + + assert svc.active_checkpoint(group.id, "http://example.com/page") == checkpoint + + def test_it_returns_a_checkpoint_with_a_future_reveal_date( + self, svc, group, document + ): + checkpoint = self.checkpoint( + group, + document, + reveal_date=datetime.utcnow() + timedelta(days=1), # noqa: DTZ003 + ) + + assert svc.active_checkpoint(group.id, "http://example.com/page") == checkpoint + + def test_it_returns_None_when_the_checkpoint_is_revealed( + self, svc, group, document + ): + self.checkpoint( + group, + document, + reveal_date=datetime.utcnow() - timedelta(days=1), # noqa: DTZ003 + ) + + assert svc.active_checkpoint(group.id, "http://example.com/page") is None + + @pytest.mark.usefixtures("document") + def test_it_returns_None_when_there_is_no_checkpoint(self, svc, group): + assert svc.active_checkpoint(group.id, "http://example.com/page") is None + + def test_it_returns_None_for_a_different_group( + self, svc, group, document, factories + ): + self.checkpoint(group, document, reveal_date=None) + other_group = factories.Group() + + assert svc.active_checkpoint(other_group.id, "http://example.com/page") is None + + def test_it_resolves_the_uri_to_the_document(self, svc, group, document, factories): + # A second URI on the same document (e.g. a PDF fingerprint) must + # resolve to the same checkpoint. + factories.DocumentURI(document=document, uri="urn:x-pdf:the-fingerprint") + checkpoint = self.checkpoint(group, document, reveal_date=None) + + assert ( + svc.active_checkpoint(group.id, "urn:x-pdf:the-fingerprint") == checkpoint + ) + + def test_it_returns_None_for_an_unknown_uri(self, svc, group, document): + self.checkpoint(group, document, reveal_date=None) + + assert svc.active_checkpoint(group.id, "http://example.com/other") is None + + def checkpoint(self, group, document, reveal_date): + return self.factories.Checkpoint( + group=group, document=document, reveal_date=reveal_date + ) + + @pytest.fixture(autouse=True) + def _factories(self, factories): + self.factories = factories + + @pytest.fixture + def group(self, factories): + return factories.Group() + + @pytest.fixture + def document(self, factories): + document = factories.Document() + factories.DocumentURI(document=document, uri="http://example.com/page") + return document + + +class TestFactory: + def test_it(self, pyramid_request): + svc = factory(mock.sentinel.context, pyramid_request) + + assert isinstance(svc, CheckpointService) + assert svc.db == pyramid_request.db + + +@pytest.fixture +def svc(db_session): + return CheckpointService(db=db_session)