From f505d058b85041d2db48f29da0d991ec6c114352 Mon Sep 17 00:00:00 2001 From: Hamzah Ullah Date: Mon, 22 Jun 2026 17:19:51 +0000 Subject: [PATCH 1/6] feat: add base prompt viewset helpers --- enterprise_access/apps/api/v1/views/prompt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enterprise_access/apps/api/v1/views/prompt.py b/enterprise_access/apps/api/v1/views/prompt.py index ef1afd5b..c2fd4b2a 100644 --- a/enterprise_access/apps/api/v1/views/prompt.py +++ b/enterprise_access/apps/api/v1/views/prompt.py @@ -27,8 +27,7 @@ list["JSONValue"] | dict[str, "JSONValue"] ) - -ValidatedData: TypeAlias = dict[str, object] +ValidatedData: TypeAlias = dict[str, JSONValue] XpertMessage: TypeAlias = dict[str, str] XpertResponse: TypeAlias = dict[str, object] SystemPromptModel: TypeAlias = type[BaseSystemPrompt] @@ -90,6 +89,7 @@ def _validate_request( return serializer.validated_data + def _get_current_prompt( self, *, From c77d5189d0799d967b854432c4eab9617309046d Mon Sep 17 00:00:00 2001 From: Hamzah Ullah Date: Tue, 23 Jun 2026 17:06:13 +0000 Subject: [PATCH 2/6] feat: add learner intent xpert api endpoint Co-Authored-By: Claude Sonnet 4.6 --- .../apps/api/serializers/__init__.py | 1 + .../apps/api/serializers/learner_pathways.py | 38 ++ .../apps/api/v1/tests/test_prompt_views.py | 623 +++++++++++++++++- enterprise_access/apps/api/v1/urls.py | 1 + .../apps/api/v1/views/__init__.py | 1 + enterprise_access/apps/api/v1/views/prompt.py | 135 +++- enterprise_access/settings/base.py | 3 + 7 files changed, 795 insertions(+), 7 deletions(-) create mode 100644 enterprise_access/apps/api/serializers/learner_pathways.py diff --git a/enterprise_access/apps/api/serializers/__init__.py b/enterprise_access/apps/api/serializers/__init__.py index 3c3436ef..c7eba80a 100644 --- a/enterprise_access/apps/api/serializers/__init__.py +++ b/enterprise_access/apps/api/serializers/__init__.py @@ -37,6 +37,7 @@ TransactionResponseSerializer, TransactionsListResponseSerializer ) +from .learner_pathways import LearningIntentRequestSerializer, LearningIntentResponseSerializer from .provisioning import ( ProvisioningRequestSerializer, ProvisioningResponseSerializer, diff --git a/enterprise_access/apps/api/serializers/learner_pathways.py b/enterprise_access/apps/api/serializers/learner_pathways.py new file mode 100644 index 00000000..cb24f729 --- /dev/null +++ b/enterprise_access/apps/api/serializers/learner_pathways.py @@ -0,0 +1,38 @@ +""" +Request and documentation-only response serializers for the Learner Pathways API. +""" +from rest_framework import serializers + +LEARNER_PATHWAYS_API_TAG = 'Learner Pathways' + + +class LearningIntentRequestSerializer(serializers.Serializer): + """ + Validates the request body for the learning-intent endpoint. + """ + selected_goals = serializers.CharField(allow_blank=False) + free_text = serializers.CharField(allow_blank=False) + known_context = serializers.CharField(allow_blank=False) + + def create(self, validated_data): + return validated_data + + def update(self, instance, validated_data): + return validated_data + + +class LearningIntentResponseSerializer(serializers.Serializer): + """ + Documents the expected HTTP 200 response shape for the learning-intent endpoint. + + For OpenAPI schema generation only — never instantiated at runtime. + """ + skills_required = serializers.ListField(child=serializers.CharField()) + skills_preferred = serializers.ListField(child=serializers.CharField()) + condensed_algolia_query = serializers.CharField() + + def create(self, validated_data): + return validated_data + + def update(self, instance, validated_data): + return validated_data diff --git a/enterprise_access/apps/api/v1/tests/test_prompt_views.py b/enterprise_access/apps/api/v1/tests/test_prompt_views.py index 03312082..0aeb0953 100644 --- a/enterprise_access/apps/api/v1/tests/test_prompt_views.py +++ b/enterprise_access/apps/api/v1/tests/test_prompt_views.py @@ -1,26 +1,53 @@ """ -Tests for BasePromptViewSet and PromptRequestException. +Tests for BasePromptViewSet, PromptRequestException, and LearnerPathwaysViewSet. """ # pylint: disable=protected-access import json +import uuid from unittest import mock import ddt -from django.test import TestCase -from rest_framework import serializers, status +from django.conf import settings as django_settings +from django.core.cache import cache as django_cache +from django.test import TestCase, override_settings +from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication +from rest_framework import permissions, serializers, status from rest_framework.exceptions import ValidationError - -from enterprise_access.apps.api.v1.views.prompt import BasePromptViewSet, PromptRequestException +from rest_framework.reverse import reverse +from rest_framework.test import APIClient +from rest_framework.throttling import ScopedRateThrottle + +from enterprise_access.apps.api import serializers as api_serializers +from enterprise_access.apps.api.v1.views.prompt import ( + BasePromptViewSet, + IsEnterpriseLearner, + LearnerPathwaysViewSet, + PromptRequestException +) +from enterprise_access.apps.core.constants import SYSTEM_ENTERPRISE_LEARNER_ROLE +from enterprise_access.apps.core.tests.factories import UserFactory from enterprise_access.apps.prompts.api_client import ( XpertAPIConfigurationError, XpertAPIError, XpertAPIRequestError, XpertAPIResponseError ) +from enterprise_access.apps.prompts.models import PromptType, XpertLearnerPathwaysSystemPrompt +from enterprise_access.apps.prompts.tests.factories import XpertLearnerPathwaysSystemPromptFactory +from test_utils import APITest PATCH_XPERT_CLIENT = 'enterprise_access.apps.api.v1.views.prompt.XpertAPIClient' PATCH_GET_REQUEST_ID = 'enterprise_access.apps.api.v1.views.prompt.get_request_id' PATCH_UUID4 = 'enterprise_access.apps.api.v1.views.prompt.uuid_module.uuid4' +PATCH_CONTEXTS_ACCESSIBLE = 'enterprise_access.apps.api.v1.views.prompt.contexts_accessible_from_request' + +_LEARNING_INTENT_URL_NAME = 'api:v1:learner-pathways-learning-intent' + +_VALID_LEARNING_INTENT_PAYLOAD = { + 'selected_goals': 'data science', + 'free_text': 'I want to become a data scientist', + 'known_context': 'currently a software engineer', +} def _make_viewset(): @@ -484,3 +511,589 @@ def test_no_repair_or_fallback_on_bad_json(self): def test_parse_json_content_requires_string_contract(self): with self.assertRaises(AttributeError): self.viewset._parse_json_content({'not': 'a string'}) + + +# --------------------------------------------------------------------------- +# Serializer tests +# --------------------------------------------------------------------------- + +@ddt.ddt +class TestLearningIntentRequestSerializer(TestCase): + """Tests for LearningIntentRequestSerializer.""" + + def _valid(self): + return dict(_VALID_LEARNING_INTENT_PAYLOAD) + + def test_valid_payload_succeeds(self): + s = api_serializers.LearningIntentRequestSerializer(data=self._valid()) + self.assertTrue(s.is_valid(), s.errors) + + @ddt.data('selected_goals', 'free_text', 'known_context') + def test_missing_field_fails(self, field): + data = self._valid() + del data[field] + s = api_serializers.LearningIntentRequestSerializer(data=data) + self.assertFalse(s.is_valid()) + self.assertIn(field, s.errors) + + @ddt.data('selected_goals', 'free_text', 'known_context') + def test_blank_field_fails(self, field): + data = self._valid() + data[field] = '' + s = api_serializers.LearningIntentRequestSerializer(data=data) + self.assertFalse(s.is_valid()) + self.assertIn(field, s.errors) + + @ddt.data('selected_goals', 'free_text', 'known_context') + def test_whitespace_only_field_fails(self, field): + data = self._valid() + data[field] = ' ' + s = api_serializers.LearningIntentRequestSerializer(data=data) + self.assertFalse(s.is_valid()) + self.assertIn(field, s.errors) + + @ddt.data( + ('selected_goals', 123), + ('free_text', []), + ('known_context', {'nested': True}), + ) + @ddt.unpack + def test_non_string_value_coerced_or_fails(self, field, value): + data = self._valid() + data[field] = value + s = api_serializers.LearningIntentRequestSerializer(data=data) + # DRF CharField coerces non-strings; result must still be non-blank. + # 123 → '123' (valid), [] → '' (invalid blank), {} → repr (valid) + if s.is_valid(): + self.assertIsInstance(s.validated_data[field], str) + self.assertGreater(len(s.validated_data[field]), 0) + else: + self.assertIn(field, s.errors) + + +# --------------------------------------------------------------------------- +# Routing tests +# --------------------------------------------------------------------------- + +class TestLearnerPathwaysRouting(TestCase): + """Tests for URL resolution of LearnerPathwaysViewSet.""" + + def test_learning_intent_url_reverses(self): + url = reverse(_LEARNING_INTENT_URL_NAME) + self.assertIn('learner-pathways', url) + self.assertIn('learning-intent', url) + + def test_learning_intent_post_accepted(self): + client = APIClient() + url = reverse(_LEARNING_INTENT_URL_NAME) + response = client.post(url, data={}, format='json') + # Unauthenticated — 401 or 403, but NOT 405. + self.assertNotEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_learning_intent_get_rejected(self, _mock_contexts): + _mock_contexts.return_value = {str(uuid.uuid4())} + url = reverse(_LEARNING_INTENT_URL_NAME) + client = APIClient() + user = UserFactory(is_active=True) + client.force_authenticate(user=user) + response = client.get(url) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + +# --------------------------------------------------------------------------- +# Route configuration tests +# --------------------------------------------------------------------------- + +class TestLearnerPathwaysRouteConfig(TestCase): + """Assert route-level configuration for each action.""" + # pylint: disable=no-member # DRF @action adds .kwargs at decoration time; pylint can't see it. + + def _get_action(self, name): + return getattr(LearnerPathwaysViewSet, name) + + def test_learning_intent_authentication_classes(self): + ac = self._get_action('learning_intent').kwargs.get('authentication_classes', ()) + self.assertIn(JwtAuthentication, ac) + + def test_learning_intent_is_authenticated_permission(self): + pc = self._get_action('learning_intent').kwargs.get('permission_classes', ()) + self.assertIn(permissions.IsAuthenticated, pc) + + def test_learning_intent_is_enterprise_learner_permission(self): + pc = self._get_action('learning_intent').kwargs.get('permission_classes', ()) + self.assertIn(IsEnterpriseLearner, pc) + + def test_learning_intent_throttle_class(self): + tc = self._get_action('learning_intent').kwargs.get('throttle_classes', ()) + self.assertIn(ScopedRateThrottle, tc) + + def test_learning_intent_throttle_scope(self): + scope = self._get_action('learning_intent').kwargs.get('throttle_scope') + self.assertEqual(scope, 'learner_pathways_learning_intent') + + def test_no_throttle_on_base_prompt_viewset(self): + # throttle_classes must not be explicitly defined on BasePromptViewSet itself + self.assertNotIn('throttle_classes', BasePromptViewSet.__dict__) + self.assertNotIn('throttle_scope', BasePromptViewSet.__dict__) + + def test_no_class_level_throttle_classes_on_learner_pathways_viewset(self): + self.assertNotIn('throttle_classes', LearnerPathwaysViewSet.__dict__) + + def test_throttle_scope_sentinel_is_none(self): + self.assertIsNone(LearnerPathwaysViewSet.throttle_scope) + + +# --------------------------------------------------------------------------- +# Authorization tests +# --------------------------------------------------------------------------- + +@ddt.ddt +class TestLearnerPathwaysAuthorization(APITest): + """Authorization tests for the learning-intent endpoint.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.learning_intent_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, + ) + + def setUp(self): + super().setUp() + self.addCleanup(django_cache.clear) + + @ddt.data(_LEARNING_INTENT_URL_NAME) + def test_unauthenticated_caller_is_rejected(self, url_name): + self.client.logout() + self.client.cookies.clear() + url = reverse(url_name) + response = self.client.post(url, data={}, format='json') + self.assertIn(response.status_code, [ + status.HTTP_401_UNAUTHORIZED, + status.HTTP_403_FORBIDDEN, + ]) + + @ddt.data(_LEARNING_INTENT_URL_NAME) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE, return_value=set()) + def test_authenticated_non_enterprise_user_rejected(self, url_name, _mock_contexts): + self.set_jwt_cookie([]) + url = reverse(url_name) + response = self.client.post(url, data={}, format='json') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @ddt.data( + (_LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), + ) + @ddt.unpack + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_enterprise_learner_is_allowed( + self, url_name, payload, mock_contexts, mock_client_class, + ): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '{"result":"ok"}', + } + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) + url = reverse(url_name) + response = self.client.post(url, data=payload, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + + @ddt.data(_LEARNING_INTENT_URL_NAME) + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE, return_value=set()) + def test_xpert_not_called_when_auth_fails(self, url_name, _mock_contexts, mock_client_class): + self.set_jwt_cookie([]) + url = reverse(url_name) + self.client.post(url, data={}, format='json') + mock_client_class.return_value.send_message.assert_not_called() + + +# --------------------------------------------------------------------------- +# Throttle tests +# --------------------------------------------------------------------------- + +@override_settings(REST_FRAMEWORK={ + 'DEFAULT_THROTTLE_RATES': { + 'learner_pathways_learning_intent': '2/minute', + }, +}) +@ddt.ddt +class TestLearnerPathwaysThrottle(APITest): + """Throttle tests for the learning-intent endpoint.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.learning_intent_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, + ) + + def setUp(self): + super().setUp() + self.addCleanup(django_cache.clear) + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) + + def test_learning_intent_scope_in_default_throttle_rates(self): + rates = django_settings.REST_FRAMEWORK.get('DEFAULT_THROTTLE_RATES', {}) + self.assertIn('learner_pathways_learning_intent', rates) + + @mock.patch.object(ScopedRateThrottle, 'THROTTLE_RATES', { + 'learner_pathways_learning_intent': '2/minute', + }) + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_learning_intent_throttled_after_rate_exceeded(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + url = reverse(_LEARNING_INTENT_URL_NAME) + for _ in range(2): + resp = self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + self.assertEqual(resp.status_code, status.HTTP_200_OK) + resp = self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + self.assertEqual(resp.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE, return_value=set()) + def test_auth_failure_does_not_call_xpert(self, _mock_contexts): + with mock.patch(PATCH_XPERT_CLIENT) as mock_client_class: + url = reverse(_LEARNING_INTENT_URL_NAME) + self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + mock_client_class.return_value.send_message.assert_not_called() + + +# --------------------------------------------------------------------------- +# Happy path tests — learning intent +# --------------------------------------------------------------------------- + +class TestLearningIntentHappyPath(APITest): + """Full happy-path tests for the learning-intent action.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, + ) + cls.other_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.RECOMMENDATIONS_FEEDBACK, + ) + + def setUp(self): + super().setUp() + self.addCleanup(django_cache.clear) + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) + self.url = reverse(_LEARNING_INTENT_URL_NAME) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_http_200_with_valid_payload(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '{"skills_required":["python"]}', + } + resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + self.assertEqual(resp.status_code, status.HTTP_200_OK) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_correct_prompt_type_used(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + with mock.patch.object( + XpertLearnerPathwaysSystemPrompt, 'get_current' + ) as mock_get_current: + mock_prompt = mock.Mock() + mock_prompt.system_prompt = 'Be helpful.' + mock_prompt.output_schema = None + mock_get_current.return_value = mock_prompt + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + mock_get_current.assert_called_once_with( + prompt_type=PromptType.LEARNER_INTENT, + ) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_server_controlled_tags_passed(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + with override_settings(XPERT_LEARNER_PATHWAYS_RAG_TAGS=['tag-a', 'tag-b']): + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs + self.assertEqual(call_kwargs['tags'], ['tag-a', 'tag-b']) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_xpert_called_exactly_once(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_full_parsed_json_returned(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + payload_json = '{"skills_required":["python","ml"],"condensed_algolia_query":"data"}' + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': payload_json, + } + resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(resp.json(), json.loads(payload_json)) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_validated_data_encoded_as_user_message(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs + messages = call_kwargs['messages'] + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]['role'], 'user') + self.assertIsInstance(messages[0]['content'], str) + parsed = json.loads(messages[0]['content']) + self.assertEqual(parsed['selected_goals'], _VALID_LEARNING_INTENT_PAYLOAD['selected_goals']) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_conversation_id_has_prefix(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs + self.assertTrue(call_kwargs['conversation_id'].startswith('enterprise-access:')) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_role_field_not_returned(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '{"answer":"yes"}', + } + resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + self.assertNotIn('role', resp.json()) + + +# --------------------------------------------------------------------------- +# Response passthrough tests +# --------------------------------------------------------------------------- + +@ddt.ddt +class TestLearnerPathwaysResponsePassthrough(APITest): + """Assert Xpert response content is returned verbatim without filtering.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.learning_intent_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, + ) + + def setUp(self): + super().setUp() + self.addCleanup(django_cache.clear) + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) + + @ddt.data( + ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), + ) + @ddt.unpack + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_extra_top_level_fields_preserved( + self, _action, url_name, payload, mock_contexts, mock_client_class, + ): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '{"result":"ok","extra_field":"preserved"}', + } + resp = self.client.post(reverse(url_name), data=payload, format='json') + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertIn('extra_field', resp.json()) + + @ddt.data( + ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), + ) + @ddt.unpack + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_list_response_returned_as_list( + self, _action, url_name, payload, mock_contexts, mock_client_class, + ): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '[1,2,3]', + } + resp = self.client.post(reverse(url_name), data=payload, format='json') + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertIsInstance(resp.json(), list) + + @ddt.data( + ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), + ) + @ddt.unpack + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_nested_values_preserved( + self, _action, url_name, payload, mock_contexts, mock_client_class, + ): + mock_contexts.return_value = {str(uuid.uuid4())} + nested = {'a': {'b': {'c': [1, 2, 3]}}} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': json.dumps(nested), + } + resp = self.client.post(reverse(url_name), data=payload, format='json') + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(resp.json(), nested) + + +# --------------------------------------------------------------------------- +# Failure tests +# --------------------------------------------------------------------------- + +@ddt.ddt +class TestLearnerPathwaysFailures(APITest): + """500-series failure paths for the learning-intent endpoint.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.learning_intent_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, + ) + + def setUp(self): + super().setUp() + self.addCleanup(django_cache.clear) + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) + + @ddt.data( + (_LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), + ) + @ddt.unpack + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_missing_prompt_returns_500(self, url_name, payload, mock_contexts): + mock_contexts.return_value = {str(uuid.uuid4())} + with mock.patch.object(XpertLearnerPathwaysSystemPrompt, 'get_current', return_value=None): + resp = self.client.post(reverse(url_name), data=payload, format='json') + self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + + @ddt.data( + XpertAPIConfigurationError, + XpertAPIRequestError, + XpertAPIResponseError, + ) + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_xpert_error_returns_500(self, error_class, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.side_effect = error_class('xpert error') + resp = self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + + @ddt.data( + ('missing', {'role': 'assistant'}), + ('none', {'role': 'assistant', 'content': None}), + ('non_string', {'role': 'assistant', 'content': 123}), + ) + @ddt.unpack + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_bad_content_returns_500( + self, _case, xpert_response, mock_contexts, mock_client_class, + ): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = xpert_response + resp = self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + + @ddt.data( + 'not valid json', + '```json\n{"key":"value"}\n```', + '{"unterminated": true', + ) + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_invalid_json_content_returns_500(self, bad_content, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': bad_content, + } + resp = self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_no_second_xpert_call_on_failure(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.side_effect = XpertAPIRequestError('fail') + self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) + + @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) + def test_no_fallback_object_returned(self, mock_contexts, mock_client_class): + mock_contexts.return_value = {str(uuid.uuid4())} + mock_client_class.return_value.send_message.side_effect = XpertAPIRequestError('fail') + resp = self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + if resp.get('Content-Type', '').startswith('application/json'): + body = resp.json() + self.assertNotIn('skills_required', body) + self.assertNotIn('reasons', body) diff --git a/enterprise_access/apps/api/v1/urls.py b/enterprise_access/apps/api/v1/urls.py index ca52d9fa..61ca59c2 100644 --- a/enterprise_access/apps/api/v1/urls.py +++ b/enterprise_access/apps/api/v1/urls.py @@ -11,6 +11,7 @@ router = DefaultRouter() router.register("testimonials", views.TestimonialViewSet, "testimonials") +router.register(r'learner-pathways', views.LearnerPathwaysViewSet, 'learner-pathways') router.register("policy-redemption", views.SubsidyAccessPolicyRedeemViewset, 'policy-redemption') router.register("policy-allocation", views.SubsidyAccessPolicyAllocateViewset, 'policy-allocation') router.register("subsidy-access-policies", views.SubsidyAccessPolicyViewSet, 'subsidy-access-policies') diff --git a/enterprise_access/apps/api/v1/views/__init__.py b/enterprise_access/apps/api/v1/views/__init__.py index 18a5c090..db04c27f 100644 --- a/enterprise_access/apps/api/v1/views/__init__.py +++ b/enterprise_access/apps/api/v1/views/__init__.py @@ -23,6 +23,7 @@ SspProductViewSet, StripeEventSummaryViewSet ) +from .prompt import LearnerPathwaysViewSet from .provisioning import ProvisioningCreateView, SubscriptionPlanOLIUpdateView from .subsidy_access_policy import ( SubsidyAccessPolicyAllocateViewset, diff --git a/enterprise_access/apps/api/v1/views/prompt.py b/enterprise_access/apps/api/v1/views/prompt.py index c2fd4b2a..0faa7aa5 100644 --- a/enterprise_access/apps/api/v1/views/prompt.py +++ b/enterprise_access/apps/api/v1/views/prompt.py @@ -7,14 +7,24 @@ from collections.abc import Sequence from typing import TypeAlias, cast -from rest_framework import serializers, status +from django.conf import settings +from drf_spectacular.utils import extend_schema +from edx_rbac.utils import contexts_accessible_from_request +from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication +from rest_framework import permissions, serializers, status +from rest_framework.decorators import action from rest_framework.exceptions import APIException from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.throttling import ScopedRateThrottle from rest_framework.viewsets import ViewSet +from enterprise_access.apps.api import serializers as api_serializers +from enterprise_access.apps.api.serializers.learner_pathways import LEARNER_PATHWAYS_API_TAG from enterprise_access.apps.api_client.base_user import get_request_id +from enterprise_access.apps.core.constants import BFF_LEARNER_ROLE from enterprise_access.apps.prompts.api_client import XpertAPIClient, XpertAPIError -from enterprise_access.apps.prompts.models import BaseSystemPrompt +from enterprise_access.apps.prompts.models import BaseSystemPrompt, PromptType, XpertLearnerPathwaysSystemPrompt logger = logging.getLogger(__name__) @@ -99,6 +109,16 @@ def _get_current_prompt( """ Resolve the current prompt for the exact supplied prompt type. """ + if prompt_model is None: + raise PromptRequestException( + 'prompt_model is a required configuration argument.' + ) + + if prompt_type is None: + raise PromptRequestException( + 'prompt_type is a required configuration argument.' + ) + prompt = prompt_model.get_current( prompt_type=prompt_type, ) @@ -221,6 +241,12 @@ def _extract_xpert_content( 'Xpert response is missing the "content" field.' ) + if not isinstance(content, str): + raise PromptRequestException( + 'Xpert response "content" is not a string: ' + f'got {type(content).__name__}.' + ) + return content def _parse_json_content( @@ -243,3 +269,108 @@ def _parse_json_content( ) from exc return cast(JSONValue, parsed_content) + + +class IsEnterpriseLearner(permissions.BasePermission): + """ + Permit requests from authenticated users associated with at least one enterprise as a learner. + + Uses the BFF_LEARNER_ROLE feature role, which is mapped from SYSTEM_ENTERPRISE_LEARNER_ROLE + in SYSTEM_TO_FEATURE_ROLE_MAPPING. Enterprise admins have BFF_ADMIN_ROLE and are not permitted + by this check. + + No existing "any enterprise learner" DRF permission class was found in the repository. + This is the minimal consistent implementation using the existing edx_rbac infrastructure. + """ + + def has_permission(self, request: Request, view: object) -> bool: + try: + contexts = contexts_accessible_from_request(request, [BFF_LEARNER_ROLE]) + return bool(contexts) + except Exception: # pylint: disable=broad-except + return False + + +class LearnerPathwaysViewSet(BasePromptViewSet): + """ + Endpoints for the Learner Pathways Xpert-backed feature. + + Each action defines its own authentication, permissions, and throttle configuration + explicitly. No shared authentication, permissions, or throttle classes are defined + at the class level. + """ + + model_type = XpertLearnerPathwaysSystemPrompt + + # DRF 3.17.1 ViewSetMixin.as_view() rejects any @action kwarg that is not already a + # class attribute (hasattr check). throttle_scope is not defined on APIView or ViewSet, + # so a class-level sentinel is required to allow per-action propagation. This sentinel + # does not configure a shared throttle; the actual scope values are set per action. + throttle_scope: str | None = None + + @extend_schema( + tags=[LEARNER_PATHWAYS_API_TAG], + summary='Derive learning intent from learner input.', + description=( + 'Calls Xpert with the learner\'s stated goals, free-text input, and known context ' + 'to derive skills and a search query. Returns the raw JSON produced by Xpert.' + ), + request=api_serializers.LearningIntentRequestSerializer, + responses={ + status.HTTP_200_OK: api_serializers.LearningIntentResponseSerializer, + status.HTTP_400_BAD_REQUEST: None, + status.HTTP_401_UNAUTHORIZED: None, + status.HTTP_403_FORBIDDEN: None, + status.HTTP_429_TOO_MANY_REQUESTS: None, + status.HTTP_500_INTERNAL_SERVER_ERROR: None, + }, + ) + @action( + detail=False, + methods=['post'], + url_path='learning-intent', + url_name='learning-intent', + authentication_classes=(JwtAuthentication,), + permission_classes=(permissions.IsAuthenticated, IsEnterpriseLearner), + throttle_classes=(ScopedRateThrottle,), + throttle_scope='learner_pathways_learning_intent', + ) + def learning_intent(self, request: Request) -> Response: + """ + Derive learning intent from the learner's stated goals, free-text input, and known context. + + Returns HTTP 400 for invalid request input. + Returns HTTP 401/403 when the caller is unauthenticated or not an enterprise learner. + Returns HTTP 429 when the per-endpoint rate limit is exceeded. + Returns HTTP 500 when the prompt is missing, the Xpert call fails, or the response + cannot be parsed as JSON. + """ + validated_data = self._validate_request( + request, + api_serializers.LearningIntentRequestSerializer, + ) + + prompt = self._get_current_prompt( + prompt_model=self.model_type, + prompt_type=PromptType.LEARNER_INTENT, + ) + + system_prompt = self._build_system_prompt(prompt) + + messages = self._build_messages(validated_data) + + conversation_id = self._get_conversation_id(request) + + xpert_response = self._send_xpert_message( + system_prompt=system_prompt, + messages=messages, + conversation_id=conversation_id, + tags=settings.XPERT_LEARNER_PATHWAYS_RAG_TAGS, + prompt_type=PromptType.LEARNER_INTENT, + ) + + content = self._extract_xpert_content(xpert_response) + + response_data = self._parse_json_content(content) + + return Response(response_data, status=status.HTTP_200_OK) diff --git a/enterprise_access/settings/base.py b/enterprise_access/settings/base.py index 7d419360..b02e390b 100644 --- a/enterprise_access/settings/base.py +++ b/enterprise_access/settings/base.py @@ -183,6 +183,9 @@ def root(*path_fragments): 'DEFAULT_THROTTLE_RATES': { 'bff_unauthenticated': '100/hour', 'ssp_product': '120/hour', + # BLOCKER: rate values for learner-pathways scopes are not yet defined. + # None disables throttling until a rate is agreed upon. + 'learner_pathways_learning_intent': None, }, } From 5476ef82a740be2e40d9e786067d97cec62b33bb Mon Sep 17 00:00:00 2001 From: Hamzah Ullah Date: Wed, 24 Jun 2026 15:17:32 +0000 Subject: [PATCH 3/6] chore: cleanup implementation --- .../apps/api/v1/tests/test_prompt_views.py | 111 ++++-------------- enterprise_access/apps/api/v1/urls.py | 2 +- enterprise_access/apps/api/v1/views/prompt.py | 49 +------- enterprise_access/settings/base.py | 4 +- 4 files changed, 25 insertions(+), 141 deletions(-) diff --git a/enterprise_access/apps/api/v1/tests/test_prompt_views.py b/enterprise_access/apps/api/v1/tests/test_prompt_views.py index 0aeb0953..74650600 100644 --- a/enterprise_access/apps/api/v1/tests/test_prompt_views.py +++ b/enterprise_access/apps/api/v1/tests/test_prompt_views.py @@ -18,12 +18,7 @@ from rest_framework.throttling import ScopedRateThrottle from enterprise_access.apps.api import serializers as api_serializers -from enterprise_access.apps.api.v1.views.prompt import ( - BasePromptViewSet, - IsEnterpriseLearner, - LearnerPathwaysViewSet, - PromptRequestException -) +from enterprise_access.apps.api.v1.views.prompt import BasePromptViewSet, LearnerPathwaysViewSet, PromptRequestException from enterprise_access.apps.core.constants import SYSTEM_ENTERPRISE_LEARNER_ROLE from enterprise_access.apps.core.tests.factories import UserFactory from enterprise_access.apps.prompts.api_client import ( @@ -39,7 +34,6 @@ PATCH_XPERT_CLIENT = 'enterprise_access.apps.api.v1.views.prompt.XpertAPIClient' PATCH_GET_REQUEST_ID = 'enterprise_access.apps.api.v1.views.prompt.get_request_id' PATCH_UUID4 = 'enterprise_access.apps.api.v1.views.prompt.uuid_module.uuid4' -PATCH_CONTEXTS_ACCESSIBLE = 'enterprise_access.apps.api.v1.views.prompt.contexts_accessible_from_request' _LEARNING_INTENT_URL_NAME = 'api:v1:learner-pathways-learning-intent' @@ -590,9 +584,7 @@ def test_learning_intent_post_accepted(self): # Unauthenticated — 401 or 403, but NOT 405. self.assertNotEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_learning_intent_get_rejected(self, _mock_contexts): - _mock_contexts.return_value = {str(uuid.uuid4())} + def test_learning_intent_get_rejected(self): url = reverse(_LEARNING_INTENT_URL_NAME) client = APIClient() user = UserFactory(is_active=True) @@ -620,10 +612,6 @@ def test_learning_intent_is_authenticated_permission(self): pc = self._get_action('learning_intent').kwargs.get('permission_classes', ()) self.assertIn(permissions.IsAuthenticated, pc) - def test_learning_intent_is_enterprise_learner_permission(self): - pc = self._get_action('learning_intent').kwargs.get('permission_classes', ()) - self.assertIn(IsEnterpriseLearner, pc) - def test_learning_intent_throttle_class(self): tc = self._get_action('learning_intent').kwargs.get('throttle_classes', ()) self.assertIn(ScopedRateThrottle, tc) @@ -674,24 +662,14 @@ def test_unauthenticated_caller_is_rejected(self, url_name): status.HTTP_403_FORBIDDEN, ]) - @ddt.data(_LEARNING_INTENT_URL_NAME) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE, return_value=set()) - def test_authenticated_non_enterprise_user_rejected(self, url_name, _mock_contexts): - self.set_jwt_cookie([]) - url = reverse(url_name) - response = self.client.post(url, data={}, format='json') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - @ddt.data( (_LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) @ddt.unpack @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) def test_enterprise_learner_is_allowed( - self, url_name, payload, mock_contexts, mock_client_class, + self, url_name, payload, mock_client_class, ): - mock_contexts.return_value = {str(uuid.uuid4())} mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"result":"ok"}', @@ -706,8 +684,7 @@ def test_enterprise_learner_is_allowed( @ddt.data(_LEARNING_INTENT_URL_NAME) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE, return_value=set()) - def test_xpert_not_called_when_auth_fails(self, url_name, _mock_contexts, mock_client_class): + def test_xpert_not_called_when_auth_fails(self, url_name, mock_client_class): self.set_jwt_cookie([]) url = reverse(url_name) self.client.post(url, data={}, format='json') @@ -750,9 +727,7 @@ def test_learning_intent_scope_in_default_throttle_rates(self): 'learner_pathways_learning_intent': '2/minute', }) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_learning_intent_throttled_after_rate_exceeded(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_learning_intent_throttled_after_rate_exceeded(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', } @@ -763,13 +738,6 @@ def test_learning_intent_throttled_after_rate_exceeded(self, mock_contexts, mock resp = self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') self.assertEqual(resp.status_code, status.HTTP_429_TOO_MANY_REQUESTS) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE, return_value=set()) - def test_auth_failure_does_not_call_xpert(self, _mock_contexts): - with mock.patch(PATCH_XPERT_CLIENT) as mock_client_class: - url = reverse(_LEARNING_INTENT_URL_NAME) - self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') - mock_client_class.return_value.send_message.assert_not_called() - # --------------------------------------------------------------------------- # Happy path tests — learning intent @@ -798,9 +766,7 @@ def setUp(self): self.url = reverse(_LEARNING_INTENT_URL_NAME) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_http_200_with_valid_payload(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_http_200_with_valid_payload(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"skills_required":["python"]}', @@ -809,9 +775,7 @@ def test_http_200_with_valid_payload(self, mock_contexts, mock_client_class): self.assertEqual(resp.status_code, status.HTTP_200_OK) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_correct_prompt_type_used(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_correct_prompt_type_used(self, mock_client_class): with mock.patch.object( XpertLearnerPathwaysSystemPrompt, 'get_current' ) as mock_get_current: @@ -828,9 +792,7 @@ def test_correct_prompt_type_used(self, mock_contexts, mock_client_class): ) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_server_controlled_tags_passed(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_server_controlled_tags_passed(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', } @@ -840,9 +802,7 @@ def test_server_controlled_tags_passed(self, mock_contexts, mock_client_class): self.assertEqual(call_kwargs['tags'], ['tag-a', 'tag-b']) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_xpert_called_exactly_once(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_xpert_called_exactly_once(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', } @@ -850,9 +810,7 @@ def test_xpert_called_exactly_once(self, mock_contexts, mock_client_class): self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_full_parsed_json_returned(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_full_parsed_json_returned(self, mock_client_class): payload_json = '{"skills_required":["python","ml"],"condensed_algolia_query":"data"}' mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', @@ -863,9 +821,7 @@ def test_full_parsed_json_returned(self, mock_contexts, mock_client_class): self.assertEqual(resp.json(), json.loads(payload_json)) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_validated_data_encoded_as_user_message(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_validated_data_encoded_as_user_message(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', } @@ -879,9 +835,7 @@ def test_validated_data_encoded_as_user_message(self, mock_contexts, mock_client self.assertEqual(parsed['selected_goals'], _VALID_LEARNING_INTENT_PAYLOAD['selected_goals']) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_conversation_id_has_prefix(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_conversation_id_has_prefix(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', } @@ -890,9 +844,7 @@ def test_conversation_id_has_prefix(self, mock_contexts, mock_client_class): self.assertTrue(call_kwargs['conversation_id'].startswith('enterprise-access:')) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_role_field_not_returned(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_role_field_not_returned(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"answer":"yes"}', @@ -929,11 +881,9 @@ def setUp(self): ) @ddt.unpack @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) def test_extra_top_level_fields_preserved( - self, _action, url_name, payload, mock_contexts, mock_client_class, + self, _action, url_name, payload, mock_client_class, ): - mock_contexts.return_value = {str(uuid.uuid4())} mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"result":"ok","extra_field":"preserved"}', @@ -947,11 +897,9 @@ def test_extra_top_level_fields_preserved( ) @ddt.unpack @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) def test_list_response_returned_as_list( - self, _action, url_name, payload, mock_contexts, mock_client_class, + self, _action, url_name, payload, mock_client_class, ): - mock_contexts.return_value = {str(uuid.uuid4())} mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '[1,2,3]', @@ -965,11 +913,9 @@ def test_list_response_returned_as_list( ) @ddt.unpack @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) def test_nested_values_preserved( - self, _action, url_name, payload, mock_contexts, mock_client_class, + self, _action, url_name, payload, mock_client_class, ): - mock_contexts.return_value = {str(uuid.uuid4())} nested = {'a': {'b': {'c': [1, 2, 3]}}} mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', @@ -1007,9 +953,7 @@ def setUp(self): (_LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) @ddt.unpack - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_missing_prompt_returns_500(self, url_name, payload, mock_contexts): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_missing_prompt_returns_500(self, url_name, payload): with mock.patch.object(XpertLearnerPathwaysSystemPrompt, 'get_current', return_value=None): resp = self.client.post(reverse(url_name), data=payload, format='json') self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -1020,9 +964,7 @@ def test_missing_prompt_returns_500(self, url_name, payload, mock_contexts): XpertAPIResponseError, ) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_xpert_error_returns_500(self, error_class, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_xpert_error_returns_500(self, error_class, mock_client_class): mock_client_class.return_value.send_message.side_effect = error_class('xpert error') resp = self.client.post( reverse(_LEARNING_INTENT_URL_NAME), @@ -1034,15 +976,12 @@ def test_xpert_error_returns_500(self, error_class, mock_contexts, mock_client_c @ddt.data( ('missing', {'role': 'assistant'}), ('none', {'role': 'assistant', 'content': None}), - ('non_string', {'role': 'assistant', 'content': 123}), ) @ddt.unpack @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) def test_bad_content_returns_500( - self, _case, xpert_response, mock_contexts, mock_client_class, + self, _case, xpert_response, mock_client_class, ): - mock_contexts.return_value = {str(uuid.uuid4())} mock_client_class.return_value.send_message.return_value = xpert_response resp = self.client.post( reverse(_LEARNING_INTENT_URL_NAME), @@ -1057,9 +996,7 @@ def test_bad_content_returns_500( '{"unterminated": true', ) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_invalid_json_content_returns_500(self, bad_content, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_invalid_json_content_returns_500(self, bad_content, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': bad_content, } @@ -1071,9 +1008,7 @@ def test_invalid_json_content_returns_500(self, bad_content, mock_contexts, mock self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_no_second_xpert_call_on_failure(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_no_second_xpert_call_on_failure(self, mock_client_class): mock_client_class.return_value.send_message.side_effect = XpertAPIRequestError('fail') self.client.post( reverse(_LEARNING_INTENT_URL_NAME), @@ -1083,9 +1018,7 @@ def test_no_second_xpert_call_on_failure(self, mock_contexts, mock_client_class) self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) @mock.patch(PATCH_XPERT_CLIENT) - @mock.patch(PATCH_CONTEXTS_ACCESSIBLE) - def test_no_fallback_object_returned(self, mock_contexts, mock_client_class): - mock_contexts.return_value = {str(uuid.uuid4())} + def test_no_fallback_object_returned(self, mock_client_class): mock_client_class.return_value.send_message.side_effect = XpertAPIRequestError('fail') resp = self.client.post( reverse(_LEARNING_INTENT_URL_NAME), diff --git a/enterprise_access/apps/api/v1/urls.py b/enterprise_access/apps/api/v1/urls.py index 61ca59c2..fe1a2e9f 100644 --- a/enterprise_access/apps/api/v1/urls.py +++ b/enterprise_access/apps/api/v1/urls.py @@ -11,7 +11,7 @@ router = DefaultRouter() router.register("testimonials", views.TestimonialViewSet, "testimonials") -router.register(r'learner-pathways', views.LearnerPathwaysViewSet, 'learner-pathways') +router.register('learner-pathways', views.LearnerPathwaysViewSet, 'learner-pathways') router.register("policy-redemption", views.SubsidyAccessPolicyRedeemViewset, 'policy-redemption') router.register("policy-allocation", views.SubsidyAccessPolicyAllocateViewset, 'policy-allocation') router.register("subsidy-access-policies", views.SubsidyAccessPolicyViewSet, 'subsidy-access-policies') diff --git a/enterprise_access/apps/api/v1/views/prompt.py b/enterprise_access/apps/api/v1/views/prompt.py index 0faa7aa5..37a7e3a1 100644 --- a/enterprise_access/apps/api/v1/views/prompt.py +++ b/enterprise_access/apps/api/v1/views/prompt.py @@ -9,7 +9,6 @@ from django.conf import settings from drf_spectacular.utils import extend_schema -from edx_rbac.utils import contexts_accessible_from_request from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from rest_framework import permissions, serializers, status from rest_framework.decorators import action @@ -22,7 +21,6 @@ from enterprise_access.apps.api import serializers as api_serializers from enterprise_access.apps.api.serializers.learner_pathways import LEARNER_PATHWAYS_API_TAG from enterprise_access.apps.api_client.base_user import get_request_id -from enterprise_access.apps.core.constants import BFF_LEARNER_ROLE from enterprise_access.apps.prompts.api_client import XpertAPIClient, XpertAPIError from enterprise_access.apps.prompts.models import BaseSystemPrompt, PromptType, XpertLearnerPathwaysSystemPrompt @@ -99,7 +97,6 @@ def _validate_request( return serializer.validated_data - def _get_current_prompt( self, *, @@ -109,16 +106,6 @@ def _get_current_prompt( """ Resolve the current prompt for the exact supplied prompt type. """ - if prompt_model is None: - raise PromptRequestException( - 'prompt_model is a required configuration argument.' - ) - - if prompt_type is None: - raise PromptRequestException( - 'prompt_type is a required configuration argument.' - ) - prompt = prompt_model.get_current( prompt_type=prompt_type, ) @@ -241,12 +228,6 @@ def _extract_xpert_content( 'Xpert response is missing the "content" field.' ) - if not isinstance(content, str): - raise PromptRequestException( - 'Xpert response "content" is not a string: ' - f'got {type(content).__name__}.' - ) - return content def _parse_json_content( @@ -271,26 +252,6 @@ def _parse_json_content( return cast(JSONValue, parsed_content) -class IsEnterpriseLearner(permissions.BasePermission): - """ - Permit requests from authenticated users associated with at least one enterprise as a learner. - - Uses the BFF_LEARNER_ROLE feature role, which is mapped from SYSTEM_ENTERPRISE_LEARNER_ROLE - in SYSTEM_TO_FEATURE_ROLE_MAPPING. Enterprise admins have BFF_ADMIN_ROLE and are not permitted - by this check. - - No existing "any enterprise learner" DRF permission class was found in the repository. - This is the minimal consistent implementation using the existing edx_rbac infrastructure. - """ - - def has_permission(self, request: Request, view: object) -> bool: - try: - contexts = contexts_accessible_from_request(request, [BFF_LEARNER_ROLE]) - return bool(contexts) - except Exception: # pylint: disable=broad-except - return False - - class LearnerPathwaysViewSet(BasePromptViewSet): """ Endpoints for the Learner Pathways Xpert-backed feature. @@ -301,11 +262,6 @@ class LearnerPathwaysViewSet(BasePromptViewSet): """ model_type = XpertLearnerPathwaysSystemPrompt - - # DRF 3.17.1 ViewSetMixin.as_view() rejects any @action kwarg that is not already a - # class attribute (hasattr check). throttle_scope is not defined on APIView or ViewSet, - # so a class-level sentinel is required to allow per-action propagation. This sentinel - # does not configure a shared throttle; the actual scope values are set per action. throttle_scope: str | None = None @extend_schema( @@ -331,7 +287,7 @@ class LearnerPathwaysViewSet(BasePromptViewSet): url_path='learning-intent', url_name='learning-intent', authentication_classes=(JwtAuthentication,), - permission_classes=(permissions.IsAuthenticated, IsEnterpriseLearner), + permission_classes=(permissions.IsAuthenticated,), throttle_classes=(ScopedRateThrottle,), throttle_scope='learner_pathways_learning_intent', ) @@ -354,11 +310,8 @@ def learning_intent(self, request: Request) -> Response: prompt_model=self.model_type, prompt_type=PromptType.LEARNER_INTENT, ) - system_prompt = self._build_system_prompt(prompt) - messages = self._build_messages(validated_data) - conversation_id = self._get_conversation_id(request) xpert_response = self._send_xpert_message( diff --git a/enterprise_access/settings/base.py b/enterprise_access/settings/base.py index b02e390b..c466c3eb 100644 --- a/enterprise_access/settings/base.py +++ b/enterprise_access/settings/base.py @@ -183,9 +183,7 @@ def root(*path_fragments): 'DEFAULT_THROTTLE_RATES': { 'bff_unauthenticated': '100/hour', 'ssp_product': '120/hour', - # BLOCKER: rate values for learner-pathways scopes are not yet defined. - # None disables throttling until a rate is agreed upon. - 'learner_pathways_learning_intent': None, + 'learner_pathways_learning_intent': '100/hour', }, } From b0099916b061dd5fe835737801e11a81f1e041cd Mon Sep 17 00:00:00 2001 From: Hamzah Ullah Date: Wed, 24 Jun 2026 19:17:58 +0000 Subject: [PATCH 4/6] chore: PR feedback 1 --- .../apps/api/v1/tests/test_prompt_views.py | 413 ++---------------- enterprise_access/apps/api/v1/views/prompt.py | 196 ++------- .../apps/prompts/api/__init__.py | 183 ++++++++ .../apps/prompts/tests/test_api.py | 293 +++++++++++++ 4 files changed, 532 insertions(+), 553 deletions(-) create mode 100644 enterprise_access/apps/prompts/tests/test_api.py diff --git a/enterprise_access/apps/api/v1/tests/test_prompt_views.py b/enterprise_access/apps/api/v1/tests/test_prompt_views.py index 74650600..e49519f8 100644 --- a/enterprise_access/apps/api/v1/tests/test_prompt_views.py +++ b/enterprise_access/apps/api/v1/tests/test_prompt_views.py @@ -1,5 +1,9 @@ """ -Tests for BasePromptViewSet, PromptRequestException, and LearnerPathwaysViewSet. +Tests for PromptRequestException, BasePromptViewSet, and LearnerPathwaysViewSet. + +Domain logic tests are in enterprise_access.apps.prompts.tests.test_api. +This module focuses on HTTP-layer behavior: validation, permission checks, +throttling, error mapping, and response serialization. """ # pylint: disable=protected-access import json @@ -21,9 +25,9 @@ from enterprise_access.apps.api.v1.views.prompt import BasePromptViewSet, LearnerPathwaysViewSet, PromptRequestException from enterprise_access.apps.core.constants import SYSTEM_ENTERPRISE_LEARNER_ROLE from enterprise_access.apps.core.tests.factories import UserFactory +from enterprise_access.apps.prompts import api as prompts_api from enterprise_access.apps.prompts.api_client import ( XpertAPIConfigurationError, - XpertAPIError, XpertAPIRequestError, XpertAPIResponseError ) @@ -31,7 +35,7 @@ from enterprise_access.apps.prompts.tests.factories import XpertLearnerPathwaysSystemPromptFactory from test_utils import APITest -PATCH_XPERT_CLIENT = 'enterprise_access.apps.api.v1.views.prompt.XpertAPIClient' +PATCH_PROMPTS_API = 'enterprise_access.apps.api.v1.views.prompt.prompts_api' PATCH_GET_REQUEST_ID = 'enterprise_access.apps.api.v1.views.prompt.get_request_id' PATCH_UUID4 = 'enterprise_access.apps.api.v1.views.prompt.uuid_module.uuid4' @@ -92,7 +96,7 @@ def test_args_populated_with_message(self): self.assertEqual(exc.args[0], 'my error message') def test_exception_chaining_preserved(self): - original = XpertAPIError('original error') + original = prompts_api.PromptError('original error') try: raise PromptRequestException('wrapped') from original except PromptRequestException as exc: @@ -144,367 +148,8 @@ def is_valid(self, *, raise_exception=False): self.assertIs(captured['context']['request'], request) self.assertIs(captured['context']['view'], self.viewset) self.assertIn('format', captured['context']) - - -@ddt.ddt -class TestGetCurrentPrompt(TestCase): - """Tests for _get_current_prompt.""" - - def setUp(self): - self.viewset = _make_viewset() - - def test_returns_prompt_when_found(self): - prompt = mock.Mock() - prompt_model = mock.Mock() - prompt_model.get_current.return_value = prompt - - result = self.viewset._get_current_prompt( - prompt_model=prompt_model, - prompt_type='learner_intent', - ) - - self.assertIs(result, prompt) - - def test_exact_prompt_type_passed_to_get_current(self): - prompt_model = mock.Mock() - prompt_model.get_current.return_value = mock.Mock() - - self.viewset._get_current_prompt( - prompt_model=prompt_model, - prompt_type='learner_intent', - ) - - prompt_model.get_current.assert_called_once_with(prompt_type='learner_intent') - - def test_missing_prompt_raises_500(self): - prompt_model = mock.Mock() - prompt_model.get_current.return_value = None - - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._get_current_prompt( - prompt_model=prompt_model, - prompt_type='learner_intent', - ) - - self.assertEqual(ctx.exception.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - - def test_prompt_for_another_type_cannot_satisfy_lookup(self): - prompt_model = mock.Mock() - prompt_model.get_current.return_value = None - - with self.assertRaises(PromptRequestException): - self.viewset._get_current_prompt( - prompt_model=prompt_model, - prompt_type='recommendations_feedback', - ) - - prompt_model.get_current.assert_called_once_with( - prompt_type='recommendations_feedback', - ) - - -@ddt.ddt -class TestBuildSystemPrompt(TestCase): - """Tests for _build_system_prompt.""" - - def setUp(self): - self.viewset = _make_viewset() - - def _make_prompt(self, system_prompt, output_schema=None): - prompt = mock.Mock() - prompt.system_prompt = system_prompt - prompt.output_schema = output_schema - return prompt - - def test_strips_surrounding_whitespace(self): - prompt = self._make_prompt(' Be helpful. ') - result = self.viewset._build_system_prompt(prompt) - self.assertEqual(result, 'Be helpful.') - - def test_non_empty_schema_appended(self): - schema = {'type': 'object', 'properties': {'answer': {'type': 'string'}}} - prompt = self._make_prompt('Be helpful.', output_schema=schema) - - result = self.viewset._build_system_prompt(prompt) - - self.assertIn('\n\nEXPECTED OUTPUT SCHEMA:\n', result) - self.assertIn(json.dumps(schema, indent=2, sort_keys=True), result) - - @ddt.data(None, {}) - def test_empty_schema_not_appended(self, output_schema): - prompt = self._make_prompt('Be helpful.', output_schema=output_schema) - - result = self.viewset._build_system_prompt(prompt) - - self.assertEqual(result, 'Be helpful.') - self.assertNotIn('EXPECTED OUTPUT SCHEMA:', result) - - def test_prompt_instance_not_mutated(self): - schema = {'key': 'value'} - prompt = self._make_prompt(' Original. ', output_schema=schema) - - self.viewset._build_system_prompt(prompt) - - self.assertEqual(prompt.system_prompt, ' Original. ') - self.assertIs(prompt.output_schema, schema) - - -@ddt.ddt -class TestBuildMessages(TestCase): - """Tests for _build_messages.""" - - def setUp(self): - self.viewset = _make_viewset() - - def test_builds_single_user_message_with_string_content(self): - messages = self.viewset._build_messages({'name': 'Alice'}) - - self.assertEqual(messages, [ - {'role': 'user', 'content': '{"name":"Alice"}'}, - ]) - self.assertIsInstance(messages[0]['content'], str) - - def test_content_is_compact_json(self): - messages = self.viewset._build_messages({'name': 'Alice', 'count': 3}) - content = messages[0]['content'] - - self.assertNotIn(': ', content) - self.assertNotIn(', ', content) - self.assertEqual(json.loads(content), {'name': 'Alice', 'count': 3}) - - def test_nested_json_round_trips(self): - data = { - 'name': 'Alice', - 'items': [1, 2, 3], - 'metadata': {'active': True, 'notes': None}, - } - - messages = self.viewset._build_messages(data) - - self.assertEqual(json.loads(messages[0]['content']), data) - - -@ddt.ddt -class TestGetConversationId(TestCase): - """Tests for _get_conversation_id.""" - - def setUp(self): - self.viewset = _make_viewset() - - @mock.patch(PATCH_GET_REQUEST_ID, return_value='from-crum') - def test_repo_request_id_helper_takes_precedence(self, mock_get_request_id): - request = _make_request(headers={'X-Request-ID': 'from-header'}) - - result = self.viewset._get_conversation_id(request) - - self.assertEqual(result, 'enterprise-access:from-crum') - mock_get_request_id.assert_called_once_with() - - @mock.patch(PATCH_GET_REQUEST_ID, return_value=None) - def test_header_used_when_repo_helper_returns_none(self, mock_get_request_id): - request = _make_request(headers={'X-Request-ID': 'from-header'}) - - result = self.viewset._get_conversation_id(request) - - self.assertEqual(result, 'enterprise-access:from-header') - mock_get_request_id.assert_called_once_with() - - @mock.patch(PATCH_UUID4, return_value='generated-uuid') - @mock.patch(PATCH_GET_REQUEST_ID, return_value=None) - def test_uuid_generated_when_no_request_id(self, mock_get_request_id, mock_uuid4): - request = _make_request(headers={}) - - result = self.viewset._get_conversation_id(request) - - self.assertEqual(result, 'enterprise-access:generated-uuid') - mock_get_request_id.assert_called_once_with() - mock_uuid4.assert_called_once_with() - - @ddt.data( - ('from-crum', {'X-Request-ID': 'from-header'}), - (None, {'X-Request-ID': 'from-header'}), - (None, {}), - ) - @ddt.unpack - def test_result_always_has_prefix(self, helper_value, headers): - request = _make_request(headers=headers) - - with mock.patch(PATCH_GET_REQUEST_ID, return_value=helper_value): - result = self.viewset._get_conversation_id(request) - - self.assertTrue(result.startswith('enterprise-access:')) - - -@ddt.ddt -class TestSendXpertMessage(TestCase): - """Tests for _send_xpert_message.""" - - def setUp(self): - self.viewset = _make_viewset() - self.system_prompt = 'You are helpful.' - self.messages = [{'role': 'user', 'content': '{"q":1}'}] - self.conversation_id = 'enterprise-access:test-123' - - @mock.patch(PATCH_XPERT_CLIENT) - def test_client_called_once_with_correct_args(self, mock_client_class): - mock_response = {'role': 'assistant', 'content': '{"answer":"yes"}'} - mock_client_class.return_value.send_message.return_value = mock_response - - result = self.viewset._send_xpert_message( - system_prompt=self.system_prompt, - messages=self.messages, - conversation_id=self.conversation_id, - tags=('tag1', 'tag2'), - prompt_type='learner_intent', - ) - - self.assertEqual(result, mock_response) - mock_client_class.return_value.send_message.assert_called_once_with( - system_prompt=self.system_prompt, - messages=self.messages, - conversation_id=self.conversation_id, - tags=['tag1', 'tag2'], - ) - - @ddt.data(None, [], ()) - @mock.patch(PATCH_XPERT_CLIENT) - def test_empty_tags_passed_as_none(self, tags, mock_client_class): - mock_client_class.return_value.send_message.return_value = {} - - self.viewset._send_xpert_message( - system_prompt=self.system_prompt, - messages=self.messages, - conversation_id=self.conversation_id, - tags=tags, - ) - - self.assertIsNone( - mock_client_class.return_value.send_message.call_args.kwargs['tags'], - ) - - @mock.patch(PATCH_XPERT_CLIENT) - def test_no_second_call_made(self, mock_client_class): - mock_client_class.return_value.send_message.return_value = {} - - self.viewset._send_xpert_message( - system_prompt=self.system_prompt, - messages=self.messages, - conversation_id=self.conversation_id, - ) - - self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) - - -@ddt.ddt -class TestSendXpertMessageErrors(TestCase): - """Tests for XpertAPIError mapping.""" - - def setUp(self): - self.viewset = _make_viewset() - - @ddt.data( - XpertAPIError, - XpertAPIConfigurationError, - XpertAPIRequestError, - XpertAPIResponseError, - ) - @mock.patch(PATCH_XPERT_CLIENT) - def test_xpert_errors_become_prompt_request_exception( - self, - error_class, - mock_client_class, - ): - original = error_class('original error text') - mock_client_class.return_value.send_message.side_effect = original - - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._send_xpert_message( - system_prompt='prompt', - messages=[], - conversation_id='enterprise-access:x', - prompt_type='learner_intent', - ) - - self.assertEqual(ctx.exception.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - self.assertIs(ctx.exception.__cause__, original) - self.assertIn('original error text', ctx.exception.args[0]) - self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) - - -@ddt.ddt -class TestExtractXpertContent(TestCase): - """Tests for _extract_xpert_content.""" - - def setUp(self): - self.viewset = _make_viewset() - - def test_valid_response_returns_content_string(self): - response = {'role': 'assistant', 'content': '{"answer":"yes"}'} - - result = self.viewset._extract_xpert_content(response) - - self.assertEqual(result, '{"answer":"yes"}') - - @ddt.data( - {'role': 'assistant'}, - {'role': 'assistant', 'content': None}, - ) - def test_invalid_content_raises_500(self, response): - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._extract_xpert_content(response) - - self.assertEqual(ctx.exception.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - - -@ddt.ddt -class TestParseJsonContent(TestCase): - """Tests for _parse_json_content.""" - - def setUp(self): - self.viewset = _make_viewset() - - @ddt.data( - ('{"answer":42}', {'answer': 42}), - ('[1,2,3]', [1, 2, 3]), - ('"hello"', 'hello'), - ('99', 99), - ('false', False), - ('true', True), - ('null', None), - (' {"trimmed":true} ', {'trimmed': True}), - ) - @ddt.unpack - def test_valid_json_values_returned_unchanged(self, raw_content, expected): - result = self.viewset._parse_json_content(raw_content) - self.assertEqual(result, expected) - - @ddt.data( - 'not valid json', - '```json\n{"key":"value"}\n```', - '```\n{"key":"value"}\n```', - '{"unterminated": true', - '', - ) - def test_invalid_or_fenced_json_raises_500(self, raw_content): - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._parse_json_content(raw_content) - - self.assertEqual(ctx.exception.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - - def test_invalid_json_exception_is_chained(self): - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._parse_json_content('not valid json') - - self.assertIsNotNone(ctx.exception.__cause__) - self.assertIsInstance(ctx.exception.__cause__, json.JSONDecodeError) - - def test_no_repair_or_fallback_on_bad_json(self): - with self.assertRaises(PromptRequestException): - self.viewset._parse_json_content('garbage') - - def test_parse_json_content_requires_string_contract(self): - with self.assertRaises(AttributeError): - self.viewset._parse_json_content({'not': 'a string'}) +# Domain logic tests are in enterprise_access.apps.prompts.tests.test_api. +# This test module focuses on HTTP-layer behavior in viewsets. # --------------------------------------------------------------------------- @@ -666,7 +311,7 @@ def test_unauthenticated_caller_is_rejected(self, url_name): (_LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) @ddt.unpack - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_enterprise_learner_is_allowed( self, url_name, payload, mock_client_class, ): @@ -683,7 +328,7 @@ def test_enterprise_learner_is_allowed( self.assertEqual(response.status_code, status.HTTP_200_OK) @ddt.data(_LEARNING_INTENT_URL_NAME) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_xpert_not_called_when_auth_fails(self, url_name, mock_client_class): self.set_jwt_cookie([]) url = reverse(url_name) @@ -726,7 +371,7 @@ def test_learning_intent_scope_in_default_throttle_rates(self): @mock.patch.object(ScopedRateThrottle, 'THROTTLE_RATES', { 'learner_pathways_learning_intent': '2/minute', }) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_learning_intent_throttled_after_rate_exceeded(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', @@ -765,7 +410,7 @@ def setUp(self): }]) self.url = reverse(_LEARNING_INTENT_URL_NAME) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_http_200_with_valid_payload(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', @@ -774,7 +419,7 @@ def test_http_200_with_valid_payload(self, mock_client_class): resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') self.assertEqual(resp.status_code, status.HTTP_200_OK) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_correct_prompt_type_used(self, mock_client_class): with mock.patch.object( XpertLearnerPathwaysSystemPrompt, 'get_current' @@ -791,7 +436,7 @@ def test_correct_prompt_type_used(self, mock_client_class): prompt_type=PromptType.LEARNER_INTENT, ) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_server_controlled_tags_passed(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', @@ -801,7 +446,7 @@ def test_server_controlled_tags_passed(self, mock_client_class): call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs self.assertEqual(call_kwargs['tags'], ['tag-a', 'tag-b']) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_xpert_called_exactly_once(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', @@ -809,7 +454,7 @@ def test_xpert_called_exactly_once(self, mock_client_class): self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_full_parsed_json_returned(self, mock_client_class): payload_json = '{"skills_required":["python","ml"],"condensed_algolia_query":"data"}' mock_client_class.return_value.send_message.return_value = { @@ -820,7 +465,7 @@ def test_full_parsed_json_returned(self, mock_client_class): self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.json(), json.loads(payload_json)) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_validated_data_encoded_as_user_message(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', @@ -834,7 +479,7 @@ def test_validated_data_encoded_as_user_message(self, mock_client_class): parsed = json.loads(messages[0]['content']) self.assertEqual(parsed['selected_goals'], _VALID_LEARNING_INTENT_PAYLOAD['selected_goals']) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_conversation_id_has_prefix(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': '{"r":1}', @@ -843,7 +488,7 @@ def test_conversation_id_has_prefix(self, mock_client_class): call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs self.assertTrue(call_kwargs['conversation_id'].startswith('enterprise-access:')) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_role_field_not_returned(self, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', @@ -880,7 +525,7 @@ def setUp(self): ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) @ddt.unpack - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_extra_top_level_fields_preserved( self, _action, url_name, payload, mock_client_class, ): @@ -896,7 +541,7 @@ def test_extra_top_level_fields_preserved( ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) @ddt.unpack - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_list_response_returned_as_list( self, _action, url_name, payload, mock_client_class, ): @@ -912,7 +557,7 @@ def test_list_response_returned_as_list( ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) @ddt.unpack - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_nested_values_preserved( self, _action, url_name, payload, mock_client_class, ): @@ -963,7 +608,7 @@ def test_missing_prompt_returns_500(self, url_name, payload): XpertAPIRequestError, XpertAPIResponseError, ) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_xpert_error_returns_500(self, error_class, mock_client_class): mock_client_class.return_value.send_message.side_effect = error_class('xpert error') resp = self.client.post( @@ -978,7 +623,7 @@ def test_xpert_error_returns_500(self, error_class, mock_client_class): ('none', {'role': 'assistant', 'content': None}), ) @ddt.unpack - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_bad_content_returns_500( self, _case, xpert_response, mock_client_class, ): @@ -995,7 +640,7 @@ def test_bad_content_returns_500( '```json\n{"key":"value"}\n```', '{"unterminated": true', ) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_invalid_json_content_returns_500(self, bad_content, mock_client_class): mock_client_class.return_value.send_message.return_value = { 'role': 'assistant', 'content': bad_content, @@ -1007,7 +652,7 @@ def test_invalid_json_content_returns_500(self, bad_content, mock_client_class): ) self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_no_second_xpert_call_on_failure(self, mock_client_class): mock_client_class.return_value.send_message.side_effect = XpertAPIRequestError('fail') self.client.post( @@ -1017,7 +662,7 @@ def test_no_second_xpert_call_on_failure(self, mock_client_class): ) self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) - @mock.patch(PATCH_XPERT_CLIENT) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_no_fallback_object_returned(self, mock_client_class): mock_client_class.return_value.send_message.side_effect = XpertAPIRequestError('fail') resp = self.client.post( diff --git a/enterprise_access/apps/api/v1/views/prompt.py b/enterprise_access/apps/api/v1/views/prompt.py index 37a7e3a1..0418068b 100644 --- a/enterprise_access/apps/api/v1/views/prompt.py +++ b/enterprise_access/apps/api/v1/views/prompt.py @@ -1,11 +1,8 @@ """ -Reusable base viewset for prompt-backed Xpert requests. +REST API viewsets for prompt-backed Xpert requests. """ -import json import logging import uuid as uuid_module -from collections.abc import Sequence -from typing import TypeAlias, cast from django.conf import settings from drf_spectacular.utils import extend_schema @@ -21,28 +18,15 @@ from enterprise_access.apps.api import serializers as api_serializers from enterprise_access.apps.api.serializers.learner_pathways import LEARNER_PATHWAYS_API_TAG from enterprise_access.apps.api_client.base_user import get_request_id -from enterprise_access.apps.prompts.api_client import XpertAPIClient, XpertAPIError -from enterprise_access.apps.prompts.models import BaseSystemPrompt, PromptType, XpertLearnerPathwaysSystemPrompt +from enterprise_access.apps.prompts import api as prompts_api +from enterprise_access.apps.prompts.models import PromptType, XpertLearnerPathwaysSystemPrompt logger = logging.getLogger(__name__) -JSONValue: TypeAlias = ( - str | - int | - float | - bool | - None | - list["JSONValue"] | - dict[str, "JSONValue"] -) -ValidatedData: TypeAlias = dict[str, JSONValue] -XpertMessage: TypeAlias = dict[str, str] -XpertResponse: TypeAlias = dict[str, object] -SystemPromptModel: TypeAlias = type[BaseSystemPrompt] +ValidatedData = dict[str, object] _CONVERSATION_ID_PREFIX = 'enterprise-access' _X_REQUEST_ID_HEADER = 'X-Request-ID' -_SCHEMA_SEPARATOR = '\n\nEXPECTED OUTPUT SCHEMA:\n' class PromptRequestException(APIException): @@ -69,8 +53,12 @@ class BasePromptViewSet(ViewSet): """ Reusable helper methods for prompt-backed Xpert requests. + This base class provides HTTP-layer utilities: request validation + and conversation ID generation. Domain logic is delegated to + enterprise_access.apps.prompts.api. + Concrete viewsets compose these helpers inside their individual actions. - This base class intentionally defines no actions, routes, authentication + This class intentionally defines no actions, routes, authentication classes, or permission policies. """ @@ -97,67 +85,6 @@ def _validate_request( return serializer.validated_data - def _get_current_prompt( - self, - *, - prompt_model: type[SystemPromptModel], - prompt_type: str, - ) -> BaseSystemPrompt: - """ - Resolve the current prompt for the exact supplied prompt type. - """ - prompt = prompt_model.get_current( - prompt_type=prompt_type, - ) - if prompt is None: - raise PromptRequestException( - f'No active prompt found for prompt_type={prompt_type!r}.' - ) - - return prompt - - def _build_system_prompt( - self, - prompt: BaseSystemPrompt, - ) -> str: - """ - Build the complete system prompt sent to Xpert. - - The configured prompt text is stripped of surrounding whitespace. - A non-empty output schema is appended as formatted JSON. - """ - system_prompt = prompt.system_prompt.strip() - output_schema = prompt.output_schema - - if output_schema: - system_prompt += _SCHEMA_SEPARATOR + json.dumps( - output_schema, - indent=2, - sort_keys=True, - ) - - return system_prompt - - def _build_messages( - self, - validated_data: ValidatedData, - ) -> list[XpertMessage]: - """ - Build the default Xpert message list. - - The complete validated request payload is encoded as compact JSON in - a single user message. - """ - return [ - { - 'role': 'user', - 'content': json.dumps( - validated_data, - separators=(',', ':'), - ), - }, - ] - def _get_conversation_id( self, request: Request, @@ -179,78 +106,6 @@ def _get_conversation_id( return f'{_CONVERSATION_ID_PREFIX}:{request_id}' - def _send_xpert_message( - self, - *, - system_prompt: str, - messages: list[XpertMessage], - conversation_id: str, - tags: Sequence[str] | None = None, - prompt_type: str | None = None, - ) -> XpertResponse: - """ - Send one prompt-backed request through the existing Xpert client. - - Xpert client failures are logged with tracking metadata and converted - to HTTP 500 prompt request failures. Prompt text, request payloads, and - raw model responses are not logged. - """ - normalized_tags = list(tags) if tags else None - - try: - response = XpertAPIClient().send_message( - system_prompt=system_prompt, - messages=messages, - conversation_id=conversation_id, - tags=normalized_tags, - ) - except XpertAPIError as exc: - logger.exception( - 'Xpert request failed for prompt_type=%r, conversation_id=%r.', - prompt_type, - conversation_id, - ) - raise PromptRequestException(str(exc)) from exc - - return response - - def _extract_xpert_content( - self, - xpert_response: XpertResponse, - ) -> str: - """ - Extract the raw content string from the normalized Xpert response. - """ - content = xpert_response.get('content') - - if content is None: - raise PromptRequestException( - 'Xpert response is missing the "content" field.' - ) - - return content - - def _parse_json_content( - self, - content: str, - ) -> JSONValue: - """ - Parse and return the complete JSON value produced by Xpert. - - The content must be directly parseable as JSON after surrounding - whitespace is removed. Markdown fencing, repair prompts, retries, - fallback parsing, field mapping, and response normalization are - intentionally unsupported. - """ - try: - parsed_content = json.loads(content.strip()) - except json.JSONDecodeError as exc: - raise PromptRequestException( - f'Failed to parse Xpert response content as JSON: {exc}' - ) from exc - - return cast(JSONValue, parsed_content) - class LearnerPathwaysViewSet(BasePromptViewSet): """ @@ -306,24 +161,27 @@ def learning_intent(self, request: Request) -> Response: api_serializers.LearningIntentRequestSerializer, ) - prompt = self._get_current_prompt( - prompt_model=self.model_type, - prompt_type=PromptType.LEARNER_INTENT, - ) - system_prompt = self._build_system_prompt(prompt) - messages = self._build_messages(validated_data) conversation_id = self._get_conversation_id(request) - xpert_response = self._send_xpert_message( - system_prompt=system_prompt, - messages=messages, - conversation_id=conversation_id, - tags=settings.XPERT_LEARNER_PATHWAYS_RAG_TAGS, - prompt_type=PromptType.LEARNER_INTENT, - ) + try: + prompt = prompts_api.get_current_prompt( + prompt_model=self.model_type, + prompt_type=PromptType.LEARNER_INTENT, + ) + system_prompt = prompts_api.build_system_prompt(prompt) + messages = prompts_api.build_messages(validated_data) - content = self._extract_xpert_content(xpert_response) + xpert_response = prompts_api.send_xpert_message( + system_prompt=system_prompt, + messages=messages, + conversation_id=conversation_id, + tags=settings.XPERT_LEARNER_PATHWAYS_RAG_TAGS, + prompt_type=PromptType.LEARNER_INTENT, + ) - response_data = self._parse_json_content(content) + content = prompts_api.extract_xpert_content(xpert_response) + response_data = prompts_api.parse_json_content(content) + except prompts_api.PromptError as exc: + raise PromptRequestException(str(exc)) from exc return Response(response_data, status=status.HTTP_200_OK) diff --git a/enterprise_access/apps/prompts/api/__init__.py b/enterprise_access/apps/prompts/api/__init__.py index e69de29b..0803686a 100644 --- a/enterprise_access/apps/prompts/api/__init__.py +++ b/enterprise_access/apps/prompts/api/__init__.py @@ -0,0 +1,183 @@ +""" +Domain-layer API for Xpert prompt-backed workflows. + +This module handles the orchestration of prompt-based Xpert AI interactions, +including prompt retrieval, message construction, and response parsing. +No HTTP or DRF machinery — suitable for testing in isolation. +""" +import json +import logging +from collections.abc import Sequence +from typing import TypeAlias, cast + +from enterprise_access.apps.prompts.api_client import XpertAPIClient, XpertAPIError +from enterprise_access.apps.prompts.models import BaseSystemPrompt + +logger = logging.getLogger(__name__) + +JSONValue: TypeAlias = ( + str | + int | + float | + bool | + None | + list["JSONValue"] | + dict[str, "JSONValue"] +) +ValidatedData: TypeAlias = dict[str, JSONValue] +XpertMessage: TypeAlias = dict[str, str] +XpertResponse: TypeAlias = dict[str, object] +SystemPromptModel: TypeAlias = type[BaseSystemPrompt] + +_SCHEMA_SEPARATOR = '\n\nEXPECTED OUTPUT SCHEMA:\n' + + +class PromptError(Exception): + """ + Raised when a prompt-backed Xpert request fails at the domain layer. + + These errors are translated by the view layer into HTTP 500 responses. + """ + + +def get_current_prompt( + prompt_model: SystemPromptModel, + prompt_type: str, +) -> BaseSystemPrompt: + """ + Resolve the current prompt for the supplied prompt type. + + Raises: + PromptError: If no active prompt exists for the given prompt_type. + """ + prompt = prompt_model.get_current(prompt_type=prompt_type) + if prompt is None: + raise PromptError( + f'No active prompt found for prompt_type={prompt_type!r}.' + ) + return prompt + + +def build_system_prompt(prompt: BaseSystemPrompt) -> str: + """ + Build the complete system prompt sent to Xpert. + + The configured prompt text is stripped of surrounding whitespace. + A non-empty output schema is appended as formatted JSON. + """ + system_prompt = prompt.system_prompt.strip() + output_schema = prompt.output_schema + + if output_schema: + system_prompt += _SCHEMA_SEPARATOR + json.dumps( + output_schema, + indent=2, + sort_keys=True, + ) + + return system_prompt + + +def build_messages(validated_data: ValidatedData) -> list[XpertMessage]: + """ + Build the default Xpert message list. + + The complete validated request payload is encoded as compact JSON in + a single user message. + """ + return [ + { + 'role': 'user', + 'content': json.dumps( + validated_data, + separators=(',', ':'), + ), + }, + ] + + +def send_xpert_message( + *, + system_prompt: str, + messages: list[XpertMessage], + conversation_id: str, + tags: Sequence[str] | None = None, + prompt_type: str | None = None, +) -> XpertResponse: + """ + Send one prompt-backed request through the Xpert client. + + Xpert client failures are logged with tracking metadata and converted + to PromptError. Prompt text, request payloads, and raw model responses + are not logged. + + Args: + system_prompt: System prompt text for Xpert. + messages: List of messages for Xpert (user role + content). + conversation_id: Unique conversation ID for tracing. + tags: Optional list of tags for Xpert (e.g. RAG tags). + prompt_type: Optional prompt type for logging/tracking. + + Raises: + PromptError: If the Xpert API call fails. + + Returns: + The raw Xpert response dict. + """ + normalized_tags = list(tags) if tags else None + + try: + response = XpertAPIClient().send_message( + system_prompt=system_prompt, + messages=messages, + conversation_id=conversation_id, + tags=normalized_tags, + ) + except XpertAPIError as exc: + logger.exception( + 'Xpert request failed for prompt_type=%r, conversation_id=%r.', + prompt_type, + conversation_id, + ) + raise PromptError(str(exc)) from exc + + return response + + +def extract_xpert_content(xpert_response: XpertResponse) -> str: + """ + Extract the raw content string from the Xpert response. + + Raises: + PromptError: If the 'content' field is missing. + """ + content = xpert_response.get('content') + + if content is None: + raise PromptError( + 'Xpert response is missing the "content" field.' + ) + + return content + + +def parse_json_content(content: str) -> JSONValue: + """ + Parse and return the complete JSON value produced by Xpert. + + The content must be directly parseable as JSON after surrounding + whitespace is removed. Markdown fencing, repair prompts, retries, + fallback parsing, field mapping, and response normalization are + intentionally unsupported. + + Raises: + PromptError: If JSON parsing fails. + """ + try: + parsed_content = json.loads(content.strip()) + except json.JSONDecodeError as exc: + raise PromptError( + f'Failed to parse Xpert response content as JSON: {exc}' + ) from exc + + return cast(JSONValue, parsed_content) diff --git a/enterprise_access/apps/prompts/tests/test_api.py b/enterprise_access/apps/prompts/tests/test_api.py new file mode 100644 index 00000000..ab7ec520 --- /dev/null +++ b/enterprise_access/apps/prompts/tests/test_api.py @@ -0,0 +1,293 @@ +""" +Tests for the prompts domain API module. +""" +import json +from unittest import mock + +import ddt +from django.test import TestCase + +from enterprise_access.apps.prompts import api as prompts_api +from enterprise_access.apps.prompts.api_client import ( + XpertAPIConfigurationError, + XpertAPIError, + XpertAPIRequestError, + XpertAPIResponseError +) + +PATCH_XPERT_CLIENT = 'enterprise_access.apps.prompts.api.XpertAPIClient' + + +@ddt.ddt +class TestGetCurrentPrompt(TestCase): + """Tests for get_current_prompt.""" + + def test_returns_prompt_when_found(self): + prompt = mock.Mock() + prompt_model = mock.Mock() + prompt_model.get_current.return_value = prompt + + result = prompts_api.get_current_prompt( + prompt_model=prompt_model, + prompt_type='learner_intent', + ) + + self.assertIs(result, prompt) + + def test_exact_prompt_type_passed_to_get_current(self): + prompt_model = mock.Mock() + prompt_model.get_current.return_value = mock.Mock() + + prompts_api.get_current_prompt( + prompt_model=prompt_model, + prompt_type='learner_intent', + ) + + prompt_model.get_current.assert_called_once_with(prompt_type='learner_intent') + + def test_missing_prompt_raises_error(self): + prompt_model = mock.Mock() + prompt_model.get_current.return_value = None + + with self.assertRaises(prompts_api.PromptError): + prompts_api.get_current_prompt( + prompt_model=prompt_model, + prompt_type='learner_intent', + ) + + def test_prompt_for_another_type_cannot_satisfy_lookup(self): + prompt_model = mock.Mock() + prompt_model.get_current.return_value = None + + with self.assertRaises(prompts_api.PromptError): + prompts_api.get_current_prompt( + prompt_model=prompt_model, + prompt_type='recommendations_feedback', + ) + + prompt_model.get_current.assert_called_once_with( + prompt_type='recommendations_feedback', + ) + + +@ddt.ddt +class TestBuildSystemPrompt(TestCase): + """Tests for build_system_prompt.""" + + def _make_prompt(self, system_prompt, output_schema=None): + prompt = mock.Mock() + prompt.system_prompt = system_prompt + prompt.output_schema = output_schema + return prompt + + def test_strips_surrounding_whitespace(self): + prompt = self._make_prompt(' Be helpful. ') + result = prompts_api.build_system_prompt(prompt) + self.assertEqual(result, 'Be helpful.') + + def test_non_empty_schema_appended(self): + schema = {'type': 'object', 'properties': {'answer': {'type': 'string'}}} + prompt = self._make_prompt('Be helpful.', output_schema=schema) + + result = prompts_api.build_system_prompt(prompt) + + self.assertIn('\n\nEXPECTED OUTPUT SCHEMA:\n', result) + self.assertIn(json.dumps(schema, indent=2, sort_keys=True), result) + + @ddt.data(None, {}) + def test_empty_schema_not_appended(self, output_schema): + prompt = self._make_prompt('Be helpful.', output_schema=output_schema) + + result = prompts_api.build_system_prompt(prompt) + + self.assertEqual(result, 'Be helpful.') + self.assertNotIn('EXPECTED OUTPUT SCHEMA:', result) + + def test_prompt_instance_not_mutated(self): + schema = {'key': 'value'} + prompt = self._make_prompt(' Original. ', output_schema=schema) + + prompts_api.build_system_prompt(prompt) + + self.assertEqual(prompt.system_prompt, ' Original. ') + self.assertIs(prompt.output_schema, schema) + + +@ddt.ddt +class TestBuildMessages(TestCase): + """Tests for build_messages.""" + + def test_builds_single_user_message_with_string_content(self): + messages = prompts_api.build_messages({'name': 'Alice'}) + + self.assertEqual(messages, [ + {'role': 'user', 'content': '{"name":"Alice"}'}, + ]) + self.assertIsInstance(messages[0]['content'], str) + + def test_content_is_compact_json(self): + messages = prompts_api.build_messages({'name': 'Alice', 'count': 3}) + content = messages[0]['content'] + + self.assertNotIn(': ', content) + self.assertNotIn(', ', content) + self.assertEqual(json.loads(content), {'name': 'Alice', 'count': 3}) + + def test_nested_json_round_trips(self): + data = { + 'name': 'Alice', + 'items': [1, 2, 3], + 'metadata': {'active': True, 'notes': None}, + } + + messages = prompts_api.build_messages(data) + + self.assertEqual(json.loads(messages[0]['content']), data) + + +@ddt.ddt +class TestSendXpertMessage(TestCase): + """Tests for send_xpert_message.""" + + def setUp(self): + self.system_prompt = 'You are helpful.' + self.messages = [{'role': 'user', 'content': '{"q":1}'}] + self.conversation_id = 'enterprise-access:test-123' + + @mock.patch(PATCH_XPERT_CLIENT) + def test_client_called_once_with_correct_args(self, mock_client_class): + mock_response = {'role': 'assistant', 'content': '{"answer":"yes"}'} + mock_client_class.return_value.send_message.return_value = mock_response + + result = prompts_api.send_xpert_message( + system_prompt=self.system_prompt, + messages=self.messages, + conversation_id=self.conversation_id, + tags=('tag1', 'tag2'), + prompt_type='learner_intent', + ) + + self.assertEqual(result, mock_response) + mock_client_class.return_value.send_message.assert_called_once_with( + system_prompt=self.system_prompt, + messages=self.messages, + conversation_id=self.conversation_id, + tags=['tag1', 'tag2'], + ) + + @ddt.data(None, [], ()) + @mock.patch(PATCH_XPERT_CLIENT) + def test_empty_tags_passed_as_none(self, tags, mock_client_class): + mock_client_class.return_value.send_message.return_value = {} + + prompts_api.send_xpert_message( + system_prompt=self.system_prompt, + messages=self.messages, + conversation_id=self.conversation_id, + tags=tags, + ) + + self.assertIsNone( + mock_client_class.return_value.send_message.call_args.kwargs['tags'], + ) + + @mock.patch(PATCH_XPERT_CLIENT) + def test_no_second_call_made(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = {} + + prompts_api.send_xpert_message( + system_prompt=self.system_prompt, + messages=self.messages, + conversation_id=self.conversation_id, + ) + + self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) + + +@ddt.ddt +class TestSendXpertMessageErrors(TestCase): + """Tests for XpertAPIError mapping to PromptError.""" + + @ddt.data( + XpertAPIError, + XpertAPIConfigurationError, + XpertAPIRequestError, + XpertAPIResponseError, + ) + @mock.patch(PATCH_XPERT_CLIENT) + def test_xpert_errors_become_prompt_error( + self, + error_class, + mock_client_class, + ): + original = error_class('original error text') + mock_client_class.return_value.send_message.side_effect = original + + with self.assertRaises(prompts_api.PromptError) as ctx: + prompts_api.send_xpert_message( + system_prompt='prompt', + messages=[], + conversation_id='enterprise-access:x', + prompt_type='learner_intent', + ) + + self.assertIs(ctx.exception.__cause__, original) + self.assertIn('original error text', str(ctx.exception)) + self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) + + +@ddt.ddt +class TestExtractXpertContent(TestCase): + """Tests for extract_xpert_content.""" + + def test_valid_response_returns_content_string(self): + response = {'role': 'assistant', 'content': '{"answer":"yes"}'} + + result = prompts_api.extract_xpert_content(response) + + self.assertEqual(result, '{"answer":"yes"}') + + @ddt.data( + {'role': 'assistant'}, + {'role': 'assistant', 'content': None}, + ) + def test_invalid_content_raises_error(self, response): + with self.assertRaises(prompts_api.PromptError): + prompts_api.extract_xpert_content(response) + + +@ddt.ddt +class TestParseJsonContent(TestCase): + """Tests for parse_json_content.""" + + @ddt.data( + ('{"answer":42}', {'answer': 42}), + ('[1,2,3]', [1, 2, 3]), + ('"hello"', 'hello'), + ('99', 99), + ('false', False), + ('true', True), + ('null', None), + (' {"trimmed":true} ', {'trimmed': True}), + ) + @ddt.unpack + def test_valid_json_values_returned_unchanged(self, raw_content, expected): + result = prompts_api.parse_json_content(raw_content) + self.assertEqual(result, expected) + + @ddt.data( + 'not valid json', + '```json\n{"key":"value"}\n```', + '```\n{"key":"value"}\n```', + '{"unterminated": true', + '', + ) + def test_invalid_or_fenced_json_raises_error(self, raw_content): + with self.assertRaises(prompts_api.PromptError): + prompts_api.parse_json_content(raw_content) + + def test_invalid_json_exception_is_chained(self): + with self.assertRaises(prompts_api.PromptError) as ctx: + prompts_api.parse_json_content('not valid json') + + self.assertIsNotNone(ctx.exception.__cause__) From d00141e58bbb65c3f6acd401cd363736c766e953 Mon Sep 17 00:00:00 2001 From: Hamzah Ullah Date: Wed, 24 Jun 2026 19:33:14 +0000 Subject: [PATCH 5/6] chore: refactor test with modern testing approach --- .../apps/api/v1/tests/test_prompt_views.py | 129 +++++++++--------- .../apps/prompts/tests/test_api.py | 63 +++++---- 2 files changed, 96 insertions(+), 96 deletions(-) diff --git a/enterprise_access/apps/api/v1/tests/test_prompt_views.py b/enterprise_access/apps/api/v1/tests/test_prompt_views.py index e49519f8..4de7a761 100644 --- a/enterprise_access/apps/api/v1/tests/test_prompt_views.py +++ b/enterprise_access/apps/api/v1/tests/test_prompt_views.py @@ -11,6 +11,7 @@ from unittest import mock import ddt +import pytest from django.conf import settings as django_settings from django.core.cache import cache as django_cache from django.test import TestCase, override_settings @@ -85,22 +86,22 @@ class TestPromptRequestException(TestCase): def test_status_code_is_500(self): exc = PromptRequestException('something went wrong') - self.assertEqual(exc.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert exc.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR def test_detail_message_is_preserved(self): exc = PromptRequestException('something went wrong') - self.assertIn('something went wrong', str(exc.detail)) + assert 'something went wrong' in str(exc.detail) def test_args_populated_with_message(self): exc = PromptRequestException('my error message') - self.assertEqual(exc.args[0], 'my error message') + assert exc.args[0] == 'my error message' def test_exception_chaining_preserved(self): original = prompts_api.PromptError('original error') try: raise PromptRequestException('wrapped') from original except PromptRequestException as exc: - self.assertIs(exc.__cause__, original) + assert exc.__cause__ is original @ddt.ddt @@ -113,12 +114,12 @@ def setUp(self): def test_valid_data_returns_validated_data(self): request = _make_request({'name': 'Alice', 'count': 3}) result = self.viewset._validate_request(request, _SampleSerializer) - self.assertEqual(result, {'name': 'Alice', 'count': 3}) + assert result == {'name': 'Alice', 'count': 3} def test_valid_data_with_defaults(self): request = _make_request({'name': 'Bob'}) result = self.viewset._validate_request(request, _SampleSerializer) - self.assertEqual(result, {'name': 'Bob', 'count': 0}) + assert result == {'name': 'Bob', 'count': 0} @ddt.ddt class _Unused: @@ -130,9 +131,9 @@ class _Unused: ) def test_invalid_data_raises_validation_error(self, payload): request = _make_request(payload) - with self.assertRaises(ValidationError) as ctx: + with pytest.raises(ValidationError) as exc_info: self.viewset._validate_request(request, _SampleSerializer) - self.assertEqual(ctx.exception.status_code, status.HTTP_400_BAD_REQUEST) + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST def test_serializer_context_includes_request_format_and_view(self): request = _make_request({'name': 'Test'}) @@ -145,9 +146,9 @@ def is_valid(self, *, raise_exception=False): self.viewset._validate_request(request, ContextCapturingSerializer) - self.assertIs(captured['context']['request'], request) - self.assertIs(captured['context']['view'], self.viewset) - self.assertIn('format', captured['context']) + assert captured['context']['request'] is request + assert captured['context']['view'] is self.viewset + assert 'format' in captured['context'] # Domain logic tests are in enterprise_access.apps.prompts.tests.test_api. # This test module focuses on HTTP-layer behavior in viewsets. @@ -165,31 +166,31 @@ def _valid(self): def test_valid_payload_succeeds(self): s = api_serializers.LearningIntentRequestSerializer(data=self._valid()) - self.assertTrue(s.is_valid(), s.errors) + assert s.is_valid() @ddt.data('selected_goals', 'free_text', 'known_context') def test_missing_field_fails(self, field): data = self._valid() del data[field] s = api_serializers.LearningIntentRequestSerializer(data=data) - self.assertFalse(s.is_valid()) - self.assertIn(field, s.errors) + assert not s.is_valid() + assert field in s.errors @ddt.data('selected_goals', 'free_text', 'known_context') def test_blank_field_fails(self, field): data = self._valid() data[field] = '' s = api_serializers.LearningIntentRequestSerializer(data=data) - self.assertFalse(s.is_valid()) - self.assertIn(field, s.errors) + assert not s.is_valid() + assert field in s.errors @ddt.data('selected_goals', 'free_text', 'known_context') def test_whitespace_only_field_fails(self, field): data = self._valid() data[field] = ' ' s = api_serializers.LearningIntentRequestSerializer(data=data) - self.assertFalse(s.is_valid()) - self.assertIn(field, s.errors) + assert not s.is_valid() + assert field in s.errors @ddt.data( ('selected_goals', 123), @@ -204,10 +205,10 @@ def test_non_string_value_coerced_or_fails(self, field, value): # DRF CharField coerces non-strings; result must still be non-blank. # 123 → '123' (valid), [] → '' (invalid blank), {} → repr (valid) if s.is_valid(): - self.assertIsInstance(s.validated_data[field], str) - self.assertGreater(len(s.validated_data[field]), 0) + assert isinstance(s.validated_data[field], str) + assert len(s.validated_data[field]) > 0 else: - self.assertIn(field, s.errors) + assert field in s.errors # --------------------------------------------------------------------------- @@ -219,15 +220,15 @@ class TestLearnerPathwaysRouting(TestCase): def test_learning_intent_url_reverses(self): url = reverse(_LEARNING_INTENT_URL_NAME) - self.assertIn('learner-pathways', url) - self.assertIn('learning-intent', url) + assert 'learner-pathways' in url + assert 'learning-intent' in url def test_learning_intent_post_accepted(self): client = APIClient() url = reverse(_LEARNING_INTENT_URL_NAME) response = client.post(url, data={}, format='json') # Unauthenticated — 401 or 403, but NOT 405. - self.assertNotEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + assert response.status_code != status.HTTP_405_METHOD_NOT_ALLOWED def test_learning_intent_get_rejected(self): url = reverse(_LEARNING_INTENT_URL_NAME) @@ -235,7 +236,7 @@ def test_learning_intent_get_rejected(self): user = UserFactory(is_active=True) client.force_authenticate(user=user) response = client.get(url) - self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED # --------------------------------------------------------------------------- @@ -251,30 +252,30 @@ def _get_action(self, name): def test_learning_intent_authentication_classes(self): ac = self._get_action('learning_intent').kwargs.get('authentication_classes', ()) - self.assertIn(JwtAuthentication, ac) + assert JwtAuthentication in ac def test_learning_intent_is_authenticated_permission(self): pc = self._get_action('learning_intent').kwargs.get('permission_classes', ()) - self.assertIn(permissions.IsAuthenticated, pc) + assert permissions.IsAuthenticated in pc def test_learning_intent_throttle_class(self): tc = self._get_action('learning_intent').kwargs.get('throttle_classes', ()) - self.assertIn(ScopedRateThrottle, tc) + assert ScopedRateThrottle in tc def test_learning_intent_throttle_scope(self): scope = self._get_action('learning_intent').kwargs.get('throttle_scope') - self.assertEqual(scope, 'learner_pathways_learning_intent') + assert scope == 'learner_pathways_learning_intent' def test_no_throttle_on_base_prompt_viewset(self): # throttle_classes must not be explicitly defined on BasePromptViewSet itself - self.assertNotIn('throttle_classes', BasePromptViewSet.__dict__) - self.assertNotIn('throttle_scope', BasePromptViewSet.__dict__) + assert 'throttle_classes' not in BasePromptViewSet.__dict__ + assert 'throttle_scope' not in BasePromptViewSet.__dict__ def test_no_class_level_throttle_classes_on_learner_pathways_viewset(self): - self.assertNotIn('throttle_classes', LearnerPathwaysViewSet.__dict__) + assert 'throttle_classes' not in LearnerPathwaysViewSet.__dict__ def test_throttle_scope_sentinel_is_none(self): - self.assertIsNone(LearnerPathwaysViewSet.throttle_scope) + assert LearnerPathwaysViewSet.throttle_scope is None # --------------------------------------------------------------------------- @@ -302,10 +303,10 @@ def test_unauthenticated_caller_is_rejected(self, url_name): self.client.cookies.clear() url = reverse(url_name) response = self.client.post(url, data={}, format='json') - self.assertIn(response.status_code, [ + assert response.status_code in [ status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN, - ]) + ] @ddt.data( (_LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), @@ -325,7 +326,7 @@ def test_enterprise_learner_is_allowed( }]) url = reverse(url_name) response = self.client.post(url, data=payload, format='json') - self.assertEqual(response.status_code, status.HTTP_200_OK) + assert response.status_code == status.HTTP_200_OK @ddt.data(_LEARNING_INTENT_URL_NAME) @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') @@ -366,7 +367,7 @@ def setUp(self): def test_learning_intent_scope_in_default_throttle_rates(self): rates = django_settings.REST_FRAMEWORK.get('DEFAULT_THROTTLE_RATES', {}) - self.assertIn('learner_pathways_learning_intent', rates) + assert 'learner_pathways_learning_intent' in rates @mock.patch.object(ScopedRateThrottle, 'THROTTLE_RATES', { 'learner_pathways_learning_intent': '2/minute', @@ -379,9 +380,9 @@ def test_learning_intent_throttled_after_rate_exceeded(self, mock_client_class): url = reverse(_LEARNING_INTENT_URL_NAME) for _ in range(2): resp = self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') - self.assertEqual(resp.status_code, status.HTTP_200_OK) + assert resp.status_code == status.HTTP_200_OK resp = self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') - self.assertEqual(resp.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + assert resp.status_code == status.HTTP_429_TOO_MANY_REQUESTS # --------------------------------------------------------------------------- @@ -417,7 +418,7 @@ def test_http_200_with_valid_payload(self, mock_client_class): 'content': '{"skills_required":["python"]}', } resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') - self.assertEqual(resp.status_code, status.HTTP_200_OK) + assert resp.status_code == status.HTTP_200_OK @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_correct_prompt_type_used(self, mock_client_class): @@ -444,7 +445,7 @@ def test_server_controlled_tags_passed(self, mock_client_class): with override_settings(XPERT_LEARNER_PATHWAYS_RAG_TAGS=['tag-a', 'tag-b']): self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs - self.assertEqual(call_kwargs['tags'], ['tag-a', 'tag-b']) + assert call_kwargs['tags'] == ['tag-a', 'tag-b'] @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_xpert_called_exactly_once(self, mock_client_class): @@ -452,7 +453,7 @@ def test_xpert_called_exactly_once(self, mock_client_class): 'role': 'assistant', 'content': '{"r":1}', } self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') - self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) + assert mock_client_class.return_value.send_message.call_count == 1 @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_full_parsed_json_returned(self, mock_client_class): @@ -462,8 +463,8 @@ def test_full_parsed_json_returned(self, mock_client_class): 'content': payload_json, } resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') - self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp.json(), json.loads(payload_json)) + assert resp.status_code == status.HTTP_200_OK + assert resp.json() == json.loads(payload_json) @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_validated_data_encoded_as_user_message(self, mock_client_class): @@ -473,11 +474,11 @@ def test_validated_data_encoded_as_user_message(self, mock_client_class): self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs messages = call_kwargs['messages'] - self.assertEqual(len(messages), 1) - self.assertEqual(messages[0]['role'], 'user') - self.assertIsInstance(messages[0]['content'], str) + assert len(messages) == 1 + assert messages[0]['role'] == 'user' + assert isinstance(messages[0]['content'], str) parsed = json.loads(messages[0]['content']) - self.assertEqual(parsed['selected_goals'], _VALID_LEARNING_INTENT_PAYLOAD['selected_goals']) + assert parsed['selected_goals'] == _VALID_LEARNING_INTENT_PAYLOAD['selected_goals'] @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_conversation_id_has_prefix(self, mock_client_class): @@ -486,7 +487,7 @@ def test_conversation_id_has_prefix(self, mock_client_class): } self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs - self.assertTrue(call_kwargs['conversation_id'].startswith('enterprise-access:')) + assert call_kwargs['conversation_id'].startswith('enterprise-access:') @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_role_field_not_returned(self, mock_client_class): @@ -495,7 +496,7 @@ def test_role_field_not_returned(self, mock_client_class): 'content': '{"answer":"yes"}', } resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') - self.assertNotIn('role', resp.json()) + assert 'role' not in resp.json() # --------------------------------------------------------------------------- @@ -534,8 +535,8 @@ def test_extra_top_level_fields_preserved( 'content': '{"result":"ok","extra_field":"preserved"}', } resp = self.client.post(reverse(url_name), data=payload, format='json') - self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertIn('extra_field', resp.json()) + assert resp.status_code == status.HTTP_200_OK + assert 'extra_field' in resp.json() @ddt.data( ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), @@ -550,8 +551,8 @@ def test_list_response_returned_as_list( 'content': '[1,2,3]', } resp = self.client.post(reverse(url_name), data=payload, format='json') - self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertIsInstance(resp.json(), list) + assert resp.status_code == status.HTTP_200_OK + assert isinstance(resp.json(), list) @ddt.data( ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), @@ -567,8 +568,8 @@ def test_nested_values_preserved( 'content': json.dumps(nested), } resp = self.client.post(reverse(url_name), data=payload, format='json') - self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp.json(), nested) + assert resp.status_code == status.HTTP_200_OK + assert resp.json() == nested # --------------------------------------------------------------------------- @@ -601,7 +602,7 @@ def setUp(self): def test_missing_prompt_returns_500(self, url_name, payload): with mock.patch.object(XpertLearnerPathwaysSystemPrompt, 'get_current', return_value=None): resp = self.client.post(reverse(url_name), data=payload, format='json') - self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @ddt.data( XpertAPIConfigurationError, @@ -616,7 +617,7 @@ def test_xpert_error_returns_500(self, error_class, mock_client_class): data=_VALID_LEARNING_INTENT_PAYLOAD, format='json', ) - self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @ddt.data( ('missing', {'role': 'assistant'}), @@ -633,7 +634,7 @@ def test_bad_content_returns_500( data=_VALID_LEARNING_INTENT_PAYLOAD, format='json', ) - self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @ddt.data( 'not valid json', @@ -650,7 +651,7 @@ def test_invalid_json_content_returns_500(self, bad_content, mock_client_class): data=_VALID_LEARNING_INTENT_PAYLOAD, format='json', ) - self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_no_second_xpert_call_on_failure(self, mock_client_class): @@ -660,7 +661,7 @@ def test_no_second_xpert_call_on_failure(self, mock_client_class): data=_VALID_LEARNING_INTENT_PAYLOAD, format='json', ) - self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) + assert mock_client_class.return_value.send_message.call_count == 1 @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') def test_no_fallback_object_returned(self, mock_client_class): @@ -670,8 +671,8 @@ def test_no_fallback_object_returned(self, mock_client_class): data=_VALID_LEARNING_INTENT_PAYLOAD, format='json', ) - self.assertEqual(resp.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR if resp.get('Content-Type', '').startswith('application/json'): body = resp.json() - self.assertNotIn('skills_required', body) - self.assertNotIn('reasons', body) + assert 'skills_required' not in body + assert 'reasons' not in body diff --git a/enterprise_access/apps/prompts/tests/test_api.py b/enterprise_access/apps/prompts/tests/test_api.py index ab7ec520..d1e41857 100644 --- a/enterprise_access/apps/prompts/tests/test_api.py +++ b/enterprise_access/apps/prompts/tests/test_api.py @@ -5,6 +5,7 @@ from unittest import mock import ddt +import pytest from django.test import TestCase from enterprise_access.apps.prompts import api as prompts_api @@ -32,7 +33,7 @@ def test_returns_prompt_when_found(self): prompt_type='learner_intent', ) - self.assertIs(result, prompt) + assert result is prompt def test_exact_prompt_type_passed_to_get_current(self): prompt_model = mock.Mock() @@ -49,7 +50,7 @@ def test_missing_prompt_raises_error(self): prompt_model = mock.Mock() prompt_model.get_current.return_value = None - with self.assertRaises(prompts_api.PromptError): + with pytest.raises(prompts_api.PromptError): prompts_api.get_current_prompt( prompt_model=prompt_model, prompt_type='learner_intent', @@ -59,7 +60,7 @@ def test_prompt_for_another_type_cannot_satisfy_lookup(self): prompt_model = mock.Mock() prompt_model.get_current.return_value = None - with self.assertRaises(prompts_api.PromptError): + with pytest.raises(prompts_api.PromptError): prompts_api.get_current_prompt( prompt_model=prompt_model, prompt_type='recommendations_feedback', @@ -83,7 +84,7 @@ def _make_prompt(self, system_prompt, output_schema=None): def test_strips_surrounding_whitespace(self): prompt = self._make_prompt(' Be helpful. ') result = prompts_api.build_system_prompt(prompt) - self.assertEqual(result, 'Be helpful.') + assert result == 'Be helpful.' def test_non_empty_schema_appended(self): schema = {'type': 'object', 'properties': {'answer': {'type': 'string'}}} @@ -91,8 +92,8 @@ def test_non_empty_schema_appended(self): result = prompts_api.build_system_prompt(prompt) - self.assertIn('\n\nEXPECTED OUTPUT SCHEMA:\n', result) - self.assertIn(json.dumps(schema, indent=2, sort_keys=True), result) + assert '\n\nEXPECTED OUTPUT SCHEMA:\n' in result + assert json.dumps(schema, indent=2, sort_keys=True) in result @ddt.data(None, {}) def test_empty_schema_not_appended(self, output_schema): @@ -100,8 +101,8 @@ def test_empty_schema_not_appended(self, output_schema): result = prompts_api.build_system_prompt(prompt) - self.assertEqual(result, 'Be helpful.') - self.assertNotIn('EXPECTED OUTPUT SCHEMA:', result) + assert result == 'Be helpful.' + assert 'EXPECTED OUTPUT SCHEMA:' not in result def test_prompt_instance_not_mutated(self): schema = {'key': 'value'} @@ -109,8 +110,8 @@ def test_prompt_instance_not_mutated(self): prompts_api.build_system_prompt(prompt) - self.assertEqual(prompt.system_prompt, ' Original. ') - self.assertIs(prompt.output_schema, schema) + assert prompt.system_prompt == ' Original. ' + assert prompt.output_schema is schema @ddt.ddt @@ -120,18 +121,18 @@ class TestBuildMessages(TestCase): def test_builds_single_user_message_with_string_content(self): messages = prompts_api.build_messages({'name': 'Alice'}) - self.assertEqual(messages, [ + assert messages == [ {'role': 'user', 'content': '{"name":"Alice"}'}, - ]) - self.assertIsInstance(messages[0]['content'], str) + ] + assert isinstance(messages[0]['content'], str) def test_content_is_compact_json(self): messages = prompts_api.build_messages({'name': 'Alice', 'count': 3}) content = messages[0]['content'] - self.assertNotIn(': ', content) - self.assertNotIn(', ', content) - self.assertEqual(json.loads(content), {'name': 'Alice', 'count': 3}) + assert ': ' not in content + assert ', ' not in content + assert json.loads(content) == {'name': 'Alice', 'count': 3} def test_nested_json_round_trips(self): data = { @@ -142,7 +143,7 @@ def test_nested_json_round_trips(self): messages = prompts_api.build_messages(data) - self.assertEqual(json.loads(messages[0]['content']), data) + assert json.loads(messages[0]['content']) == data @ddt.ddt @@ -167,7 +168,7 @@ def test_client_called_once_with_correct_args(self, mock_client_class): prompt_type='learner_intent', ) - self.assertEqual(result, mock_response) + assert result == mock_response mock_client_class.return_value.send_message.assert_called_once_with( system_prompt=self.system_prompt, messages=self.messages, @@ -187,9 +188,7 @@ def test_empty_tags_passed_as_none(self, tags, mock_client_class): tags=tags, ) - self.assertIsNone( - mock_client_class.return_value.send_message.call_args.kwargs['tags'], - ) + assert mock_client_class.return_value.send_message.call_args.kwargs['tags'] is None @mock.patch(PATCH_XPERT_CLIENT) def test_no_second_call_made(self, mock_client_class): @@ -201,7 +200,7 @@ def test_no_second_call_made(self, mock_client_class): conversation_id=self.conversation_id, ) - self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) + assert mock_client_class.return_value.send_message.call_count == 1 @ddt.ddt @@ -223,7 +222,7 @@ def test_xpert_errors_become_prompt_error( original = error_class('original error text') mock_client_class.return_value.send_message.side_effect = original - with self.assertRaises(prompts_api.PromptError) as ctx: + with pytest.raises(prompts_api.PromptError) as exc_info: prompts_api.send_xpert_message( system_prompt='prompt', messages=[], @@ -231,9 +230,9 @@ def test_xpert_errors_become_prompt_error( prompt_type='learner_intent', ) - self.assertIs(ctx.exception.__cause__, original) - self.assertIn('original error text', str(ctx.exception)) - self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) + assert exc_info.value.__cause__ is original + assert 'original error text' in str(exc_info.value) + assert mock_client_class.return_value.send_message.call_count == 1 @ddt.ddt @@ -245,14 +244,14 @@ def test_valid_response_returns_content_string(self): result = prompts_api.extract_xpert_content(response) - self.assertEqual(result, '{"answer":"yes"}') + assert result == '{"answer":"yes"}' @ddt.data( {'role': 'assistant'}, {'role': 'assistant', 'content': None}, ) def test_invalid_content_raises_error(self, response): - with self.assertRaises(prompts_api.PromptError): + with pytest.raises(prompts_api.PromptError): prompts_api.extract_xpert_content(response) @@ -273,7 +272,7 @@ class TestParseJsonContent(TestCase): @ddt.unpack def test_valid_json_values_returned_unchanged(self, raw_content, expected): result = prompts_api.parse_json_content(raw_content) - self.assertEqual(result, expected) + assert result == expected @ddt.data( 'not valid json', @@ -283,11 +282,11 @@ def test_valid_json_values_returned_unchanged(self, raw_content, expected): '', ) def test_invalid_or_fenced_json_raises_error(self, raw_content): - with self.assertRaises(prompts_api.PromptError): + with pytest.raises(prompts_api.PromptError): prompts_api.parse_json_content(raw_content) def test_invalid_json_exception_is_chained(self): - with self.assertRaises(prompts_api.PromptError) as ctx: + with pytest.raises(prompts_api.PromptError) as exc_info: prompts_api.parse_json_content('not valid json') - self.assertIsNotNone(ctx.exception.__cause__) + assert exc_info.value.__cause__ is not None From 998e89ea145aaa14cf5a6e3e412250c04022ce9a Mon Sep 17 00:00:00 2001 From: Hamzah Ullah Date: Wed, 24 Jun 2026 20:05:53 +0000 Subject: [PATCH 6/6] chore: PR feedback 2 --- .../apps/api/v1/tests/test_prompt_views.py | 81 +------------------ enterprise_access/apps/api/v1/views/prompt.py | 39 ++------- .../apps/prompts/api/__init__.py | 26 +++--- 3 files changed, 16 insertions(+), 130 deletions(-) diff --git a/enterprise_access/apps/api/v1/tests/test_prompt_views.py b/enterprise_access/apps/api/v1/tests/test_prompt_views.py index 4de7a761..2c813354 100644 --- a/enterprise_access/apps/api/v1/tests/test_prompt_views.py +++ b/enterprise_access/apps/api/v1/tests/test_prompt_views.py @@ -5,19 +5,16 @@ This module focuses on HTTP-layer behavior: validation, permission checks, throttling, error mapping, and response serialization. """ -# pylint: disable=protected-access import json import uuid from unittest import mock import ddt -import pytest from django.conf import settings as django_settings from django.core.cache import cache as django_cache from django.test import TestCase, override_settings from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication -from rest_framework import permissions, serializers, status -from rest_framework.exceptions import ValidationError +from rest_framework import permissions, status from rest_framework.reverse import reverse from rest_framework.test import APIClient from rest_framework.throttling import ScopedRateThrottle @@ -49,37 +46,6 @@ } -def _make_viewset(): - """Return a bare BasePromptViewSet instance for helper tests.""" - viewset = BasePromptViewSet() - viewset.request = mock.Mock() - viewset.kwargs = {} - viewset.format_kwarg = None - return viewset - - -def _make_request(data=None, headers=None): - """Return a mock DRF request.""" - request = mock.Mock() - request.data = data or {} - request.headers = headers or {} - return request - - -class _SampleSerializer(serializers.Serializer): - """Minimal serializer used in request-validation tests.""" - name = serializers.CharField() - count = serializers.IntegerField(required=False, default=0) - - def create(self, validated_data): - """Create is unused for validation-only tests.""" - return validated_data - - def update(self, instance, validated_data): - """Update is unused for validation-only tests.""" - return validated_data - - @ddt.ddt class TestPromptRequestException(TestCase): """Tests for PromptRequestException.""" @@ -104,51 +70,6 @@ def test_exception_chaining_preserved(self): assert exc.__cause__ is original -@ddt.ddt -class TestValidateRequest(TestCase): - """Tests for _validate_request.""" - - def setUp(self): - self.viewset = _make_viewset() - - def test_valid_data_returns_validated_data(self): - request = _make_request({'name': 'Alice', 'count': 3}) - result = self.viewset._validate_request(request, _SampleSerializer) - assert result == {'name': 'Alice', 'count': 3} - - def test_valid_data_with_defaults(self): - request = _make_request({'name': 'Bob'}) - result = self.viewset._validate_request(request, _SampleSerializer) - assert result == {'name': 'Bob', 'count': 0} - - @ddt.ddt - class _Unused: - """Avoid nested TestCase discovery issues.""" - - @ddt.data( - {}, - {'name': 'Alice', 'count': 'not-an-int'}, - ) - def test_invalid_data_raises_validation_error(self, payload): - request = _make_request(payload) - with pytest.raises(ValidationError) as exc_info: - self.viewset._validate_request(request, _SampleSerializer) - assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST - - def test_serializer_context_includes_request_format_and_view(self): - request = _make_request({'name': 'Test'}) - captured = {} - - class ContextCapturingSerializer(_SampleSerializer): - def is_valid(self, *, raise_exception=False): - captured['context'] = self.context - return super().is_valid(raise_exception=raise_exception) - - self.viewset._validate_request(request, ContextCapturingSerializer) - - assert captured['context']['request'] is request - assert captured['context']['view'] is self.viewset - assert 'format' in captured['context'] # Domain logic tests are in enterprise_access.apps.prompts.tests.test_api. # This test module focuses on HTTP-layer behavior in viewsets. diff --git a/enterprise_access/apps/api/v1/views/prompt.py b/enterprise_access/apps/api/v1/views/prompt.py index 0418068b..9a22b9c1 100644 --- a/enterprise_access/apps/api/v1/views/prompt.py +++ b/enterprise_access/apps/api/v1/views/prompt.py @@ -7,7 +7,7 @@ from django.conf import settings from drf_spectacular.utils import extend_schema from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication -from rest_framework import permissions, serializers, status +from rest_framework import permissions, status from rest_framework.decorators import action from rest_framework.exceptions import APIException from rest_framework.request import Request @@ -23,8 +23,6 @@ logger = logging.getLogger(__name__) -ValidatedData = dict[str, object] - _CONVERSATION_ID_PREFIX = 'enterprise-access' _X_REQUEST_ID_HEADER = 'X-Request-ID' @@ -53,38 +51,14 @@ class BasePromptViewSet(ViewSet): """ Reusable helper methods for prompt-backed Xpert requests. - This base class provides HTTP-layer utilities: request validation - and conversation ID generation. Domain logic is delegated to - enterprise_access.apps.prompts.api. + This base class provides HTTP-layer utilities for conversation ID generation. + Domain logic is delegated to enterprise_access.apps.prompts.api. Concrete viewsets compose these helpers inside their individual actions. This class intentionally defines no actions, routes, authentication classes, or permission policies. """ - def _validate_request( - self, - request: Request, - serializer_class: type[serializers.Serializer], - ) -> ValidatedData: - """ - Validate request data and return the serializer's validated payload. - - Invalid request data follows standard DRF validation behavior and - produces an HTTP 400 response. - """ - serializer = serializer_class( - data=request.data, - context={ - 'request': request, - 'format': None, - 'view': self, - }, - ) - serializer.is_valid(raise_exception=True) - - return serializer.validated_data - def _get_conversation_id( self, request: Request, @@ -156,10 +130,9 @@ def learning_intent(self, request: Request) -> Response: Returns HTTP 500 when the prompt is missing, the Xpert call fails, or the response cannot be parsed as JSON. """ - validated_data = self._validate_request( - request, - api_serializers.LearningIntentRequestSerializer, - ) + serializer = api_serializers.LearningIntentRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + validated_data = serializer.validated_data conversation_id = self._get_conversation_id(request) diff --git a/enterprise_access/apps/prompts/api/__init__.py b/enterprise_access/apps/prompts/api/__init__.py index 0803686a..459c1821 100644 --- a/enterprise_access/apps/prompts/api/__init__.py +++ b/enterprise_access/apps/prompts/api/__init__.py @@ -8,23 +8,14 @@ import json import logging from collections.abc import Sequence -from typing import TypeAlias, cast +from typing import Any, TypeAlias from enterprise_access.apps.prompts.api_client import XpertAPIClient, XpertAPIError from enterprise_access.apps.prompts.models import BaseSystemPrompt logger = logging.getLogger(__name__) -JSONValue: TypeAlias = ( - str | - int | - float | - bool | - None | - list["JSONValue"] | - dict[str, "JSONValue"] -) -ValidatedData: TypeAlias = dict[str, JSONValue] +ValidatedData: TypeAlias = dict[str, Any] XpertMessage: TypeAlias = dict[str, str] XpertResponse: TypeAlias = dict[str, object] SystemPromptModel: TypeAlias = type[BaseSystemPrompt] @@ -41,7 +32,7 @@ class PromptError(Exception): def get_current_prompt( - prompt_model: SystemPromptModel, + prompt_model: type[SystemPromptModel], prompt_type: str, ) -> BaseSystemPrompt: """ @@ -154,14 +145,15 @@ def extract_xpert_content(xpert_response: XpertResponse) -> str: content = xpert_response.get('content') if content is None: - raise PromptError( - 'Xpert response is missing the "content" field.' - ) + raise PromptError('Xpert response is missing the "content" field.') + + if not isinstance(content, str): + raise PromptError('Xpert response "content" must be a string.') return content -def parse_json_content(content: str) -> JSONValue: +def parse_json_content(content: str) -> dict[str, Any]: """ Parse and return the complete JSON value produced by Xpert. @@ -180,4 +172,4 @@ def parse_json_content(content: str) -> JSONValue: f'Failed to parse Xpert response content as JSON: {exc}' ) from exc - return cast(JSONValue, parsed_content) + return parsed_content