diff --git a/enterprise_access/apps/api/serializers/__init__.py b/enterprise_access/apps/api/serializers/__init__.py index 3c3436ef..c7eba80a 100644 --- a/enterprise_access/apps/api/serializers/__init__.py +++ b/enterprise_access/apps/api/serializers/__init__.py @@ -37,6 +37,7 @@ TransactionResponseSerializer, TransactionsListResponseSerializer ) +from .learner_pathways import LearningIntentRequestSerializer, LearningIntentResponseSerializer from .provisioning import ( ProvisioningRequestSerializer, ProvisioningResponseSerializer, diff --git a/enterprise_access/apps/api/serializers/learner_pathways.py b/enterprise_access/apps/api/serializers/learner_pathways.py new file mode 100644 index 00000000..cb24f729 --- /dev/null +++ b/enterprise_access/apps/api/serializers/learner_pathways.py @@ -0,0 +1,38 @@ +""" +Request and documentation-only response serializers for the Learner Pathways API. +""" +from rest_framework import serializers + +LEARNER_PATHWAYS_API_TAG = 'Learner Pathways' + + +class LearningIntentRequestSerializer(serializers.Serializer): + """ + Validates the request body for the learning-intent endpoint. + """ + selected_goals = serializers.CharField(allow_blank=False) + free_text = serializers.CharField(allow_blank=False) + known_context = serializers.CharField(allow_blank=False) + + def create(self, validated_data): + return validated_data + + def update(self, instance, validated_data): + return validated_data + + +class LearningIntentResponseSerializer(serializers.Serializer): + """ + Documents the expected HTTP 200 response shape for the learning-intent endpoint. + + For OpenAPI schema generation only — never instantiated at runtime. + """ + skills_required = serializers.ListField(child=serializers.CharField()) + skills_preferred = serializers.ListField(child=serializers.CharField()) + condensed_algolia_query = serializers.CharField() + + def create(self, validated_data): + return validated_data + + def update(self, instance, validated_data): + return validated_data diff --git a/enterprise_access/apps/api/v1/tests/test_prompt_views.py b/enterprise_access/apps/api/v1/tests/test_prompt_views.py index 03312082..2c813354 100644 --- a/enterprise_access/apps/api/v1/tests/test_prompt_views.py +++ b/enterprise_access/apps/api/v1/tests/test_prompt_views.py @@ -1,57 +1,49 @@ """ -Tests for BasePromptViewSet and PromptRequestException. +Tests for PromptRequestException, BasePromptViewSet, and LearnerPathwaysViewSet. + +Domain logic tests are in enterprise_access.apps.prompts.tests.test_api. +This module focuses on HTTP-layer behavior: validation, permission checks, +throttling, error mapping, and response serialization. """ -# pylint: disable=protected-access import json +import uuid from unittest import mock import ddt -from django.test import TestCase -from rest_framework import serializers, status -from rest_framework.exceptions import ValidationError - -from enterprise_access.apps.api.v1.views.prompt import BasePromptViewSet, PromptRequestException +from django.conf import settings as django_settings +from django.core.cache import cache as django_cache +from django.test import TestCase, override_settings +from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication +from rest_framework import permissions, status +from rest_framework.reverse import reverse +from rest_framework.test import APIClient +from rest_framework.throttling import ScopedRateThrottle + +from enterprise_access.apps.api import serializers as api_serializers +from enterprise_access.apps.api.v1.views.prompt import BasePromptViewSet, LearnerPathwaysViewSet, PromptRequestException +from enterprise_access.apps.core.constants import SYSTEM_ENTERPRISE_LEARNER_ROLE +from enterprise_access.apps.core.tests.factories import UserFactory +from enterprise_access.apps.prompts import api as prompts_api from enterprise_access.apps.prompts.api_client import ( XpertAPIConfigurationError, - XpertAPIError, XpertAPIRequestError, XpertAPIResponseError ) +from enterprise_access.apps.prompts.models import PromptType, XpertLearnerPathwaysSystemPrompt +from enterprise_access.apps.prompts.tests.factories import XpertLearnerPathwaysSystemPromptFactory +from test_utils import APITest -PATCH_XPERT_CLIENT = 'enterprise_access.apps.api.v1.views.prompt.XpertAPIClient' +PATCH_PROMPTS_API = 'enterprise_access.apps.api.v1.views.prompt.prompts_api' PATCH_GET_REQUEST_ID = 'enterprise_access.apps.api.v1.views.prompt.get_request_id' PATCH_UUID4 = 'enterprise_access.apps.api.v1.views.prompt.uuid_module.uuid4' +_LEARNING_INTENT_URL_NAME = 'api:v1:learner-pathways-learning-intent' -def _make_viewset(): - """Return a bare BasePromptViewSet instance for helper tests.""" - viewset = BasePromptViewSet() - viewset.request = mock.Mock() - viewset.kwargs = {} - viewset.format_kwarg = None - return viewset - - -def _make_request(data=None, headers=None): - """Return a mock DRF request.""" - request = mock.Mock() - request.data = data or {} - request.headers = headers or {} - return request - - -class _SampleSerializer(serializers.Serializer): - """Minimal serializer used in request-validation tests.""" - name = serializers.CharField() - count = serializers.IntegerField(required=False, default=0) - - def create(self, validated_data): - """Create is unused for validation-only tests.""" - return validated_data - - def update(self, instance, validated_data): - """Update is unused for validation-only tests.""" - return validated_data +_VALID_LEARNING_INTENT_PAYLOAD = { + 'selected_goals': 'data science', + 'free_text': 'I want to become a data scientist', + 'known_context': 'currently a software engineer', +} @ddt.ddt @@ -60,427 +52,548 @@ class TestPromptRequestException(TestCase): def test_status_code_is_500(self): exc = PromptRequestException('something went wrong') - self.assertEqual(exc.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + assert exc.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR def test_detail_message_is_preserved(self): exc = PromptRequestException('something went wrong') - self.assertIn('something went wrong', str(exc.detail)) + assert 'something went wrong' in str(exc.detail) def test_args_populated_with_message(self): exc = PromptRequestException('my error message') - self.assertEqual(exc.args[0], 'my error message') + assert exc.args[0] == 'my error message' def test_exception_chaining_preserved(self): - original = XpertAPIError('original error') + original = prompts_api.PromptError('original error') try: raise PromptRequestException('wrapped') from original except PromptRequestException as exc: - self.assertIs(exc.__cause__, original) - + assert exc.__cause__ is original -@ddt.ddt -class TestValidateRequest(TestCase): - """Tests for _validate_request.""" - def setUp(self): - self.viewset = _make_viewset() +# Domain logic tests are in enterprise_access.apps.prompts.tests.test_api. +# This test module focuses on HTTP-layer behavior in viewsets. - def test_valid_data_returns_validated_data(self): - request = _make_request({'name': 'Alice', 'count': 3}) - result = self.viewset._validate_request(request, _SampleSerializer) - self.assertEqual(result, {'name': 'Alice', 'count': 3}) - def test_valid_data_with_defaults(self): - request = _make_request({'name': 'Bob'}) - result = self.viewset._validate_request(request, _SampleSerializer) - self.assertEqual(result, {'name': 'Bob', 'count': 0}) +# --------------------------------------------------------------------------- +# Serializer tests +# --------------------------------------------------------------------------- - @ddt.ddt - class _Unused: - """Avoid nested TestCase discovery issues.""" +@ddt.ddt +class TestLearningIntentRequestSerializer(TestCase): + """Tests for LearningIntentRequestSerializer.""" + + def _valid(self): + return dict(_VALID_LEARNING_INTENT_PAYLOAD) + + def test_valid_payload_succeeds(self): + s = api_serializers.LearningIntentRequestSerializer(data=self._valid()) + assert s.is_valid() + + @ddt.data('selected_goals', 'free_text', 'known_context') + def test_missing_field_fails(self, field): + data = self._valid() + del data[field] + s = api_serializers.LearningIntentRequestSerializer(data=data) + assert not s.is_valid() + assert field in s.errors + + @ddt.data('selected_goals', 'free_text', 'known_context') + def test_blank_field_fails(self, field): + data = self._valid() + data[field] = '' + s = api_serializers.LearningIntentRequestSerializer(data=data) + assert not s.is_valid() + assert field in s.errors + + @ddt.data('selected_goals', 'free_text', 'known_context') + def test_whitespace_only_field_fails(self, field): + data = self._valid() + data[field] = ' ' + s = api_serializers.LearningIntentRequestSerializer(data=data) + assert not s.is_valid() + assert field in s.errors @ddt.data( - {}, - {'name': 'Alice', 'count': 'not-an-int'}, + ('selected_goals', 123), + ('free_text', []), + ('known_context', {'nested': True}), ) - def test_invalid_data_raises_validation_error(self, payload): - request = _make_request(payload) - with self.assertRaises(ValidationError) as ctx: - self.viewset._validate_request(request, _SampleSerializer) - self.assertEqual(ctx.exception.status_code, status.HTTP_400_BAD_REQUEST) - - def test_serializer_context_includes_request_format_and_view(self): - request = _make_request({'name': 'Test'}) - captured = {} - - class ContextCapturingSerializer(_SampleSerializer): - def is_valid(self, *, raise_exception=False): - captured['context'] = self.context - return super().is_valid(raise_exception=raise_exception) - - self.viewset._validate_request(request, ContextCapturingSerializer) - - self.assertIs(captured['context']['request'], request) - self.assertIs(captured['context']['view'], self.viewset) - self.assertIn('format', captured['context']) - + @ddt.unpack + def test_non_string_value_coerced_or_fails(self, field, value): + data = self._valid() + data[field] = value + s = api_serializers.LearningIntentRequestSerializer(data=data) + # DRF CharField coerces non-strings; result must still be non-blank. + # 123 → '123' (valid), [] → '' (invalid blank), {} → repr (valid) + if s.is_valid(): + assert isinstance(s.validated_data[field], str) + assert len(s.validated_data[field]) > 0 + else: + assert field in s.errors + + +# --------------------------------------------------------------------------- +# Routing tests +# --------------------------------------------------------------------------- + +class TestLearnerPathwaysRouting(TestCase): + """Tests for URL resolution of LearnerPathwaysViewSet.""" + + def test_learning_intent_url_reverses(self): + url = reverse(_LEARNING_INTENT_URL_NAME) + assert 'learner-pathways' in url + assert 'learning-intent' in url + + def test_learning_intent_post_accepted(self): + client = APIClient() + url = reverse(_LEARNING_INTENT_URL_NAME) + response = client.post(url, data={}, format='json') + # Unauthenticated — 401 or 403, but NOT 405. + assert response.status_code != status.HTTP_405_METHOD_NOT_ALLOWED + + def test_learning_intent_get_rejected(self): + url = reverse(_LEARNING_INTENT_URL_NAME) + client = APIClient() + user = UserFactory(is_active=True) + client.force_authenticate(user=user) + response = client.get(url) + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + + +# --------------------------------------------------------------------------- +# Route configuration tests +# --------------------------------------------------------------------------- + +class TestLearnerPathwaysRouteConfig(TestCase): + """Assert route-level configuration for each action.""" + # pylint: disable=no-member # DRF @action adds .kwargs at decoration time; pylint can't see it. + + def _get_action(self, name): + return getattr(LearnerPathwaysViewSet, name) + + def test_learning_intent_authentication_classes(self): + ac = self._get_action('learning_intent').kwargs.get('authentication_classes', ()) + assert JwtAuthentication in ac + + def test_learning_intent_is_authenticated_permission(self): + pc = self._get_action('learning_intent').kwargs.get('permission_classes', ()) + assert permissions.IsAuthenticated in pc + + def test_learning_intent_throttle_class(self): + tc = self._get_action('learning_intent').kwargs.get('throttle_classes', ()) + assert ScopedRateThrottle in tc + + def test_learning_intent_throttle_scope(self): + scope = self._get_action('learning_intent').kwargs.get('throttle_scope') + assert scope == 'learner_pathways_learning_intent' + + def test_no_throttle_on_base_prompt_viewset(self): + # throttle_classes must not be explicitly defined on BasePromptViewSet itself + assert 'throttle_classes' not in BasePromptViewSet.__dict__ + assert 'throttle_scope' not in BasePromptViewSet.__dict__ + + def test_no_class_level_throttle_classes_on_learner_pathways_viewset(self): + assert 'throttle_classes' not in LearnerPathwaysViewSet.__dict__ + + def test_throttle_scope_sentinel_is_none(self): + assert LearnerPathwaysViewSet.throttle_scope is None + + +# --------------------------------------------------------------------------- +# Authorization tests +# --------------------------------------------------------------------------- @ddt.ddt -class TestGetCurrentPrompt(TestCase): - """Tests for _get_current_prompt.""" +class TestLearnerPathwaysAuthorization(APITest): + """Authorization tests for the learning-intent endpoint.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.learning_intent_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, + ) def setUp(self): - self.viewset = _make_viewset() - - def test_returns_prompt_when_found(self): - prompt = mock.Mock() - prompt_model = mock.Mock() - prompt_model.get_current.return_value = prompt + super().setUp() + self.addCleanup(django_cache.clear) + + @ddt.data(_LEARNING_INTENT_URL_NAME) + def test_unauthenticated_caller_is_rejected(self, url_name): + self.client.logout() + self.client.cookies.clear() + url = reverse(url_name) + response = self.client.post(url, data={}, format='json') + assert response.status_code in [ + status.HTTP_401_UNAUTHORIZED, + status.HTTP_403_FORBIDDEN, + ] - result = self.viewset._get_current_prompt( - prompt_model=prompt_model, - prompt_type='learner_intent', + @ddt.data( + (_LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), + ) + @ddt.unpack + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_enterprise_learner_is_allowed( + self, url_name, payload, mock_client_class, + ): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '{"result":"ok"}', + } + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) + url = reverse(url_name) + response = self.client.post(url, data=payload, format='json') + assert response.status_code == status.HTTP_200_OK + + @ddt.data(_LEARNING_INTENT_URL_NAME) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_xpert_not_called_when_auth_fails(self, url_name, mock_client_class): + self.set_jwt_cookie([]) + url = reverse(url_name) + self.client.post(url, data={}, format='json') + mock_client_class.return_value.send_message.assert_not_called() + + +# --------------------------------------------------------------------------- +# Throttle tests +# --------------------------------------------------------------------------- + +@override_settings(REST_FRAMEWORK={ + 'DEFAULT_THROTTLE_RATES': { + 'learner_pathways_learning_intent': '2/minute', + }, +}) +@ddt.ddt +class TestLearnerPathwaysThrottle(APITest): + """Throttle tests for the learning-intent endpoint.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.learning_intent_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, ) - self.assertIs(result, prompt) - - def test_exact_prompt_type_passed_to_get_current(self): - prompt_model = mock.Mock() - prompt_model.get_current.return_value = mock.Mock() - - self.viewset._get_current_prompt( - prompt_model=prompt_model, - prompt_type='learner_intent', + def setUp(self): + super().setUp() + self.addCleanup(django_cache.clear) + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) + + def test_learning_intent_scope_in_default_throttle_rates(self): + rates = django_settings.REST_FRAMEWORK.get('DEFAULT_THROTTLE_RATES', {}) + assert 'learner_pathways_learning_intent' in rates + + @mock.patch.object(ScopedRateThrottle, 'THROTTLE_RATES', { + 'learner_pathways_learning_intent': '2/minute', + }) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_learning_intent_throttled_after_rate_exceeded(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + url = reverse(_LEARNING_INTENT_URL_NAME) + for _ in range(2): + resp = self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + assert resp.status_code == status.HTTP_200_OK + resp = self.client.post(url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + assert resp.status_code == status.HTTP_429_TOO_MANY_REQUESTS + + +# --------------------------------------------------------------------------- +# Happy path tests — learning intent +# --------------------------------------------------------------------------- + +class TestLearningIntentHappyPath(APITest): + """Full happy-path tests for the learning-intent action.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, ) - - prompt_model.get_current.assert_called_once_with(prompt_type='learner_intent') - - def test_missing_prompt_raises_500(self): - prompt_model = mock.Mock() - prompt_model.get_current.return_value = None - - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._get_current_prompt( - prompt_model=prompt_model, - prompt_type='learner_intent', - ) - - self.assertEqual(ctx.exception.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - - def test_prompt_for_another_type_cannot_satisfy_lookup(self): - prompt_model = mock.Mock() - prompt_model.get_current.return_value = None - - with self.assertRaises(PromptRequestException): - self.viewset._get_current_prompt( - prompt_model=prompt_model, - prompt_type='recommendations_feedback', - ) - - prompt_model.get_current.assert_called_once_with( - prompt_type='recommendations_feedback', + cls.other_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.RECOMMENDATIONS_FEEDBACK, ) - -@ddt.ddt -class TestBuildSystemPrompt(TestCase): - """Tests for _build_system_prompt.""" - - def setUp(self): - self.viewset = _make_viewset() - - def _make_prompt(self, system_prompt, output_schema=None): - prompt = mock.Mock() - prompt.system_prompt = system_prompt - prompt.output_schema = output_schema - return prompt - - def test_strips_surrounding_whitespace(self): - prompt = self._make_prompt(' Be helpful. ') - result = self.viewset._build_system_prompt(prompt) - self.assertEqual(result, 'Be helpful.') - - def test_non_empty_schema_appended(self): - schema = {'type': 'object', 'properties': {'answer': {'type': 'string'}}} - prompt = self._make_prompt('Be helpful.', output_schema=schema) - - result = self.viewset._build_system_prompt(prompt) - - self.assertIn('\n\nEXPECTED OUTPUT SCHEMA:\n', result) - self.assertIn(json.dumps(schema, indent=2, sort_keys=True), result) - - @ddt.data(None, {}) - def test_empty_schema_not_appended(self, output_schema): - prompt = self._make_prompt('Be helpful.', output_schema=output_schema) - - result = self.viewset._build_system_prompt(prompt) - - self.assertEqual(result, 'Be helpful.') - self.assertNotIn('EXPECTED OUTPUT SCHEMA:', result) - - def test_prompt_instance_not_mutated(self): - schema = {'key': 'value'} - prompt = self._make_prompt(' Original. ', output_schema=schema) - - self.viewset._build_system_prompt(prompt) - - self.assertEqual(prompt.system_prompt, ' Original. ') - self.assertIs(prompt.output_schema, schema) - - -@ddt.ddt -class TestBuildMessages(TestCase): - """Tests for _build_messages.""" - def setUp(self): - self.viewset = _make_viewset() - - def test_builds_single_user_message_with_string_content(self): - messages = self.viewset._build_messages({'name': 'Alice'}) - - self.assertEqual(messages, [ - {'role': 'user', 'content': '{"name":"Alice"}'}, - ]) - self.assertIsInstance(messages[0]['content'], str) - - def test_content_is_compact_json(self): - messages = self.viewset._build_messages({'name': 'Alice', 'count': 3}) - content = messages[0]['content'] - - self.assertNotIn(': ', content) - self.assertNotIn(', ', content) - self.assertEqual(json.loads(content), {'name': 'Alice', 'count': 3}) - - def test_nested_json_round_trips(self): - data = { - 'name': 'Alice', - 'items': [1, 2, 3], - 'metadata': {'active': True, 'notes': None}, + super().setUp() + self.addCleanup(django_cache.clear) + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) + self.url = reverse(_LEARNING_INTENT_URL_NAME) + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_http_200_with_valid_payload(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '{"skills_required":["python"]}', } + resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + assert resp.status_code == status.HTTP_200_OK + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_correct_prompt_type_used(self, mock_client_class): + with mock.patch.object( + XpertLearnerPathwaysSystemPrompt, 'get_current' + ) as mock_get_current: + mock_prompt = mock.Mock() + mock_prompt.system_prompt = 'Be helpful.' + mock_prompt.output_schema = None + mock_get_current.return_value = mock_prompt + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + mock_get_current.assert_called_once_with( + prompt_type=PromptType.LEARNER_INTENT, + ) - messages = self.viewset._build_messages(data) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_server_controlled_tags_passed(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + with override_settings(XPERT_LEARNER_PATHWAYS_RAG_TAGS=['tag-a', 'tag-b']): + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs + assert call_kwargs['tags'] == ['tag-a', 'tag-b'] + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_xpert_called_exactly_once(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + assert mock_client_class.return_value.send_message.call_count == 1 + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_full_parsed_json_returned(self, mock_client_class): + payload_json = '{"skills_required":["python","ml"],"condensed_algolia_query":"data"}' + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': payload_json, + } + resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + assert resp.status_code == status.HTTP_200_OK + assert resp.json() == json.loads(payload_json) + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_validated_data_encoded_as_user_message(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs + messages = call_kwargs['messages'] + assert len(messages) == 1 + assert messages[0]['role'] == 'user' + assert isinstance(messages[0]['content'], str) + parsed = json.loads(messages[0]['content']) + assert parsed['selected_goals'] == _VALID_LEARNING_INTENT_PAYLOAD['selected_goals'] + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_conversation_id_has_prefix(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': '{"r":1}', + } + self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + call_kwargs = mock_client_class.return_value.send_message.call_args.kwargs + assert call_kwargs['conversation_id'].startswith('enterprise-access:') + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_role_field_not_returned(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '{"answer":"yes"}', + } + resp = self.client.post(self.url, data=_VALID_LEARNING_INTENT_PAYLOAD, format='json') + assert 'role' not in resp.json() - self.assertEqual(json.loads(messages[0]['content']), data) +# --------------------------------------------------------------------------- +# Response passthrough tests +# --------------------------------------------------------------------------- @ddt.ddt -class TestGetConversationId(TestCase): - """Tests for _get_conversation_id.""" +class TestLearnerPathwaysResponsePassthrough(APITest): + """Assert Xpert response content is returned verbatim without filtering.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.learning_intent_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, + ) def setUp(self): - self.viewset = _make_viewset() - - @mock.patch(PATCH_GET_REQUEST_ID, return_value='from-crum') - def test_repo_request_id_helper_takes_precedence(self, mock_get_request_id): - request = _make_request(headers={'X-Request-ID': 'from-header'}) - - result = self.viewset._get_conversation_id(request) - - self.assertEqual(result, 'enterprise-access:from-crum') - mock_get_request_id.assert_called_once_with() - - @mock.patch(PATCH_GET_REQUEST_ID, return_value=None) - def test_header_used_when_repo_helper_returns_none(self, mock_get_request_id): - request = _make_request(headers={'X-Request-ID': 'from-header'}) - - result = self.viewset._get_conversation_id(request) - - self.assertEqual(result, 'enterprise-access:from-header') - mock_get_request_id.assert_called_once_with() - - @mock.patch(PATCH_UUID4, return_value='generated-uuid') - @mock.patch(PATCH_GET_REQUEST_ID, return_value=None) - def test_uuid_generated_when_no_request_id(self, mock_get_request_id, mock_uuid4): - request = _make_request(headers={}) - - result = self.viewset._get_conversation_id(request) - - self.assertEqual(result, 'enterprise-access:generated-uuid') - mock_get_request_id.assert_called_once_with() - mock_uuid4.assert_called_once_with() + super().setUp() + self.addCleanup(django_cache.clear) + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) @ddt.data( - ('from-crum', {'X-Request-ID': 'from-header'}), - (None, {'X-Request-ID': 'from-header'}), - (None, {}), + ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) @ddt.unpack - def test_result_always_has_prefix(self, helper_value, headers): - request = _make_request(headers=headers) - - with mock.patch(PATCH_GET_REQUEST_ID, return_value=helper_value): - result = self.viewset._get_conversation_id(request) - - self.assertTrue(result.startswith('enterprise-access:')) - - -@ddt.ddt -class TestSendXpertMessage(TestCase): - """Tests for _send_xpert_message.""" - - def setUp(self): - self.viewset = _make_viewset() - self.system_prompt = 'You are helpful.' - self.messages = [{'role': 'user', 'content': '{"q":1}'}] - self.conversation_id = 'enterprise-access:test-123' - - @mock.patch(PATCH_XPERT_CLIENT) - def test_client_called_once_with_correct_args(self, mock_client_class): - mock_response = {'role': 'assistant', 'content': '{"answer":"yes"}'} - mock_client_class.return_value.send_message.return_value = mock_response - - result = self.viewset._send_xpert_message( - system_prompt=self.system_prompt, - messages=self.messages, - conversation_id=self.conversation_id, - tags=('tag1', 'tag2'), - prompt_type='learner_intent', - ) - - self.assertEqual(result, mock_response) - mock_client_class.return_value.send_message.assert_called_once_with( - system_prompt=self.system_prompt, - messages=self.messages, - conversation_id=self.conversation_id, - tags=['tag1', 'tag2'], - ) - - @ddt.data(None, [], ()) - @mock.patch(PATCH_XPERT_CLIENT) - def test_empty_tags_passed_as_none(self, tags, mock_client_class): - mock_client_class.return_value.send_message.return_value = {} - - self.viewset._send_xpert_message( - system_prompt=self.system_prompt, - messages=self.messages, - conversation_id=self.conversation_id, - tags=tags, - ) - - self.assertIsNone( - mock_client_class.return_value.send_message.call_args.kwargs['tags'], - ) - - @mock.patch(PATCH_XPERT_CLIENT) - def test_no_second_call_made(self, mock_client_class): - mock_client_class.return_value.send_message.return_value = {} - - self.viewset._send_xpert_message( - system_prompt=self.system_prompt, - messages=self.messages, - conversation_id=self.conversation_id, - ) - - self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) - - -@ddt.ddt -class TestSendXpertMessageErrors(TestCase): - """Tests for XpertAPIError mapping.""" + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_extra_top_level_fields_preserved( + self, _action, url_name, payload, mock_client_class, + ): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '{"result":"ok","extra_field":"preserved"}', + } + resp = self.client.post(reverse(url_name), data=payload, format='json') + assert resp.status_code == status.HTTP_200_OK + assert 'extra_field' in resp.json() - def setUp(self): - self.viewset = _make_viewset() + @ddt.data( + ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), + ) + @ddt.unpack + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_list_response_returned_as_list( + self, _action, url_name, payload, mock_client_class, + ): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': '[1,2,3]', + } + resp = self.client.post(reverse(url_name), data=payload, format='json') + assert resp.status_code == status.HTTP_200_OK + assert isinstance(resp.json(), list) @ddt.data( - XpertAPIError, - XpertAPIConfigurationError, - XpertAPIRequestError, - XpertAPIResponseError, + ('learning_intent', _LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) - @mock.patch(PATCH_XPERT_CLIENT) - def test_xpert_errors_become_prompt_request_exception( - self, - error_class, - mock_client_class, + @ddt.unpack + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_nested_values_preserved( + self, _action, url_name, payload, mock_client_class, ): - original = error_class('original error text') - mock_client_class.return_value.send_message.side_effect = original - - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._send_xpert_message( - system_prompt='prompt', - messages=[], - conversation_id='enterprise-access:x', - prompt_type='learner_intent', - ) + nested = {'a': {'b': {'c': [1, 2, 3]}}} + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', + 'content': json.dumps(nested), + } + resp = self.client.post(reverse(url_name), data=payload, format='json') + assert resp.status_code == status.HTTP_200_OK + assert resp.json() == nested - self.assertEqual(ctx.exception.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - self.assertIs(ctx.exception.__cause__, original) - self.assertIn('original error text', ctx.exception.args[0]) - self.assertEqual(mock_client_class.return_value.send_message.call_count, 1) +# --------------------------------------------------------------------------- +# Failure tests +# --------------------------------------------------------------------------- @ddt.ddt -class TestExtractXpertContent(TestCase): - """Tests for _extract_xpert_content.""" +class TestLearnerPathwaysFailures(APITest): + """500-series failure paths for the learning-intent endpoint.""" + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.learning_intent_prompt = XpertLearnerPathwaysSystemPromptFactory( + prompt_type=PromptType.LEARNER_INTENT, + ) def setUp(self): - self.viewset = _make_viewset() - - def test_valid_response_returns_content_string(self): - response = {'role': 'assistant', 'content': '{"answer":"yes"}'} - - result = self.viewset._extract_xpert_content(response) - - self.assertEqual(result, '{"answer":"yes"}') + super().setUp() + self.addCleanup(django_cache.clear) + self.set_jwt_cookie([{ + 'system_wide_role': SYSTEM_ENTERPRISE_LEARNER_ROLE, + 'context': str(uuid.uuid4()), + }]) @ddt.data( - {'role': 'assistant'}, - {'role': 'assistant', 'content': None}, + (_LEARNING_INTENT_URL_NAME, _VALID_LEARNING_INTENT_PAYLOAD), ) - def test_invalid_content_raises_500(self, response): - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._extract_xpert_content(response) - - self.assertEqual(ctx.exception.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - - -@ddt.ddt -class TestParseJsonContent(TestCase): - """Tests for _parse_json_content.""" + @ddt.unpack + def test_missing_prompt_returns_500(self, url_name, payload): + with mock.patch.object(XpertLearnerPathwaysSystemPrompt, 'get_current', return_value=None): + resp = self.client.post(reverse(url_name), data=payload, format='json') + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - def setUp(self): - self.viewset = _make_viewset() + @ddt.data( + XpertAPIConfigurationError, + XpertAPIRequestError, + XpertAPIResponseError, + ) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_xpert_error_returns_500(self, error_class, mock_client_class): + mock_client_class.return_value.send_message.side_effect = error_class('xpert error') + resp = self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @ddt.data( - ('{"answer":42}', {'answer': 42}), - ('[1,2,3]', [1, 2, 3]), - ('"hello"', 'hello'), - ('99', 99), - ('false', False), - ('true', True), - ('null', None), - (' {"trimmed":true} ', {'trimmed': True}), + ('missing', {'role': 'assistant'}), + ('none', {'role': 'assistant', 'content': None}), ) @ddt.unpack - def test_valid_json_values_returned_unchanged(self, raw_content, expected): - result = self.viewset._parse_json_content(raw_content) - self.assertEqual(result, expected) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_bad_content_returns_500( + self, _case, xpert_response, mock_client_class, + ): + mock_client_class.return_value.send_message.return_value = xpert_response + resp = self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @ddt.data( 'not valid json', '```json\n{"key":"value"}\n```', - '```\n{"key":"value"}\n```', '{"unterminated": true', - '', ) - def test_invalid_or_fenced_json_raises_500(self, raw_content): - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._parse_json_content(raw_content) - - self.assertEqual(ctx.exception.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - - def test_invalid_json_exception_is_chained(self): - with self.assertRaises(PromptRequestException) as ctx: - self.viewset._parse_json_content('not valid json') - - self.assertIsNotNone(ctx.exception.__cause__) - self.assertIsInstance(ctx.exception.__cause__, json.JSONDecodeError) - - def test_no_repair_or_fallback_on_bad_json(self): - with self.assertRaises(PromptRequestException): - self.viewset._parse_json_content('garbage') - - def test_parse_json_content_requires_string_contract(self): - with self.assertRaises(AttributeError): - self.viewset._parse_json_content({'not': 'a string'}) + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_invalid_json_content_returns_500(self, bad_content, mock_client_class): + mock_client_class.return_value.send_message.return_value = { + 'role': 'assistant', 'content': bad_content, + } + resp = self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_no_second_xpert_call_on_failure(self, mock_client_class): + mock_client_class.return_value.send_message.side_effect = XpertAPIRequestError('fail') + self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + assert mock_client_class.return_value.send_message.call_count == 1 + + @mock.patch('enterprise_access.apps.prompts.api.XpertAPIClient') + def test_no_fallback_object_returned(self, mock_client_class): + mock_client_class.return_value.send_message.side_effect = XpertAPIRequestError('fail') + resp = self.client.post( + reverse(_LEARNING_INTENT_URL_NAME), + data=_VALID_LEARNING_INTENT_PAYLOAD, + format='json', + ) + assert resp.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + if resp.get('Content-Type', '').startswith('application/json'): + body = resp.json() + assert 'skills_required' not in body + assert 'reasons' not in body diff --git a/enterprise_access/apps/api/v1/urls.py b/enterprise_access/apps/api/v1/urls.py index ca52d9fa..fe1a2e9f 100644 --- a/enterprise_access/apps/api/v1/urls.py +++ b/enterprise_access/apps/api/v1/urls.py @@ -11,6 +11,7 @@ router = DefaultRouter() router.register("testimonials", views.TestimonialViewSet, "testimonials") +router.register('learner-pathways', views.LearnerPathwaysViewSet, 'learner-pathways') router.register("policy-redemption", views.SubsidyAccessPolicyRedeemViewset, 'policy-redemption') router.register("policy-allocation", views.SubsidyAccessPolicyAllocateViewset, 'policy-allocation') router.register("subsidy-access-policies", views.SubsidyAccessPolicyViewSet, 'subsidy-access-policies') diff --git a/enterprise_access/apps/api/v1/views/__init__.py b/enterprise_access/apps/api/v1/views/__init__.py index 18a5c090..db04c27f 100644 --- a/enterprise_access/apps/api/v1/views/__init__.py +++ b/enterprise_access/apps/api/v1/views/__init__.py @@ -23,6 +23,7 @@ SspProductViewSet, StripeEventSummaryViewSet ) +from .prompt import LearnerPathwaysViewSet from .provisioning import ProvisioningCreateView, SubscriptionPlanOLIUpdateView from .subsidy_access_policy import ( SubsidyAccessPolicyAllocateViewset, diff --git a/enterprise_access/apps/api/v1/views/prompt.py b/enterprise_access/apps/api/v1/views/prompt.py index ef1afd5b..9a22b9c1 100644 --- a/enterprise_access/apps/api/v1/views/prompt.py +++ b/enterprise_access/apps/api/v1/views/prompt.py @@ -1,41 +1,30 @@ """ -Reusable base viewset for prompt-backed Xpert requests. +REST API viewsets for prompt-backed Xpert requests. """ -import json import logging import uuid as uuid_module -from collections.abc import Sequence -from typing import TypeAlias, cast -from rest_framework import serializers, status +from django.conf import settings +from drf_spectacular.utils import extend_schema +from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication +from rest_framework import permissions, status +from rest_framework.decorators import action from rest_framework.exceptions import APIException from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.throttling import ScopedRateThrottle from rest_framework.viewsets import ViewSet +from enterprise_access.apps.api import serializers as api_serializers +from enterprise_access.apps.api.serializers.learner_pathways import LEARNER_PATHWAYS_API_TAG from enterprise_access.apps.api_client.base_user import get_request_id -from enterprise_access.apps.prompts.api_client import XpertAPIClient, XpertAPIError -from enterprise_access.apps.prompts.models import BaseSystemPrompt +from enterprise_access.apps.prompts import api as prompts_api +from enterprise_access.apps.prompts.models import PromptType, XpertLearnerPathwaysSystemPrompt logger = logging.getLogger(__name__) -JSONValue: TypeAlias = ( - str | - int | - float | - bool | - None | - list["JSONValue"] | - dict[str, "JSONValue"] -) - -ValidatedData: TypeAlias = dict[str, object] -XpertMessage: TypeAlias = dict[str, str] -XpertResponse: TypeAlias = dict[str, object] -SystemPromptModel: TypeAlias = type[BaseSystemPrompt] - _CONVERSATION_ID_PREFIX = 'enterprise-access' _X_REQUEST_ID_HEADER = 'X-Request-ID' -_SCHEMA_SEPARATOR = '\n\nEXPECTED OUTPUT SCHEMA:\n' class PromptRequestException(APIException): @@ -62,95 +51,14 @@ class BasePromptViewSet(ViewSet): """ Reusable helper methods for prompt-backed Xpert requests. + This base class provides HTTP-layer utilities for conversation ID generation. + Domain logic is delegated to enterprise_access.apps.prompts.api. + Concrete viewsets compose these helpers inside their individual actions. - This base class intentionally defines no actions, routes, authentication + This class intentionally defines no actions, routes, authentication classes, or permission policies. """ - def _validate_request( - self, - request: Request, - serializer_class: type[serializers.Serializer], - ) -> ValidatedData: - """ - Validate request data and return the serializer's validated payload. - - Invalid request data follows standard DRF validation behavior and - produces an HTTP 400 response. - """ - serializer = serializer_class( - data=request.data, - context={ - 'request': request, - 'format': None, - 'view': self, - }, - ) - serializer.is_valid(raise_exception=True) - - return serializer.validated_data - - def _get_current_prompt( - self, - *, - prompt_model: type[SystemPromptModel], - prompt_type: str, - ) -> BaseSystemPrompt: - """ - Resolve the current prompt for the exact supplied prompt type. - """ - prompt = prompt_model.get_current( - prompt_type=prompt_type, - ) - if prompt is None: - raise PromptRequestException( - f'No active prompt found for prompt_type={prompt_type!r}.' - ) - - return prompt - - def _build_system_prompt( - self, - prompt: BaseSystemPrompt, - ) -> str: - """ - Build the complete system prompt sent to Xpert. - - The configured prompt text is stripped of surrounding whitespace. - A non-empty output schema is appended as formatted JSON. - """ - system_prompt = prompt.system_prompt.strip() - output_schema = prompt.output_schema - - if output_schema: - system_prompt += _SCHEMA_SEPARATOR + json.dumps( - output_schema, - indent=2, - sort_keys=True, - ) - - return system_prompt - - def _build_messages( - self, - validated_data: ValidatedData, - ) -> list[XpertMessage]: - """ - Build the default Xpert message list. - - The complete validated request payload is encoded as compact JSON in - a single user message. - """ - return [ - { - 'role': 'user', - 'content': json.dumps( - validated_data, - separators=(',', ':'), - ), - }, - ] - def _get_conversation_id( self, request: Request, @@ -172,74 +80,81 @@ def _get_conversation_id( return f'{_CONVERSATION_ID_PREFIX}:{request_id}' - def _send_xpert_message( - self, - *, - system_prompt: str, - messages: list[XpertMessage], - conversation_id: str, - tags: Sequence[str] | None = None, - prompt_type: str | None = None, - ) -> XpertResponse: - """ - Send one prompt-backed request through the existing Xpert client. - Xpert client failures are logged with tracking metadata and converted - to HTTP 500 prompt request failures. Prompt text, request payloads, and - raw model responses are not logged. - """ - normalized_tags = list(tags) if tags else None +class LearnerPathwaysViewSet(BasePromptViewSet): + """ + Endpoints for the Learner Pathways Xpert-backed feature. - try: - response = XpertAPIClient().send_message( - system_prompt=system_prompt, - messages=messages, - conversation_id=conversation_id, - tags=normalized_tags, - ) - except XpertAPIError as exc: - logger.exception( - 'Xpert request failed for prompt_type=%r, conversation_id=%r.', - prompt_type, - conversation_id, - ) - raise PromptRequestException(str(exc)) from exc + Each action defines its own authentication, permissions, and throttle configuration + explicitly. No shared authentication, permissions, or throttle classes are defined + at the class level. + """ - return response + model_type = XpertLearnerPathwaysSystemPrompt + throttle_scope: str | None = None + + @extend_schema( + tags=[LEARNER_PATHWAYS_API_TAG], + summary='Derive learning intent from learner input.', + description=( + 'Calls Xpert with the learner\'s stated goals, free-text input, and known context ' + 'to derive skills and a search query. Returns the raw JSON produced by Xpert.' + ), + request=api_serializers.LearningIntentRequestSerializer, + responses={ + status.HTTP_200_OK: api_serializers.LearningIntentResponseSerializer, + status.HTTP_400_BAD_REQUEST: None, + status.HTTP_401_UNAUTHORIZED: None, + status.HTTP_403_FORBIDDEN: None, + status.HTTP_429_TOO_MANY_REQUESTS: None, + status.HTTP_500_INTERNAL_SERVER_ERROR: None, + }, + ) + @action( + detail=False, + methods=['post'], + url_path='learning-intent', + url_name='learning-intent', + authentication_classes=(JwtAuthentication,), + permission_classes=(permissions.IsAuthenticated,), + throttle_classes=(ScopedRateThrottle,), + throttle_scope='learner_pathways_learning_intent', + ) + def learning_intent(self, request: Request) -> Response: + """ + Derive learning intent from the learner's stated goals, free-text input, and known context. + + Returns HTTP 400 for invalid request input. + Returns HTTP 401/403 when the caller is unauthenticated or not an enterprise learner. + Returns HTTP 429 when the per-endpoint rate limit is exceeded. + Returns HTTP 500 when the prompt is missing, the Xpert call fails, or the response + cannot be parsed as JSON. + """ + serializer = api_serializers.LearningIntentRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + validated_data = serializer.validated_data - def _extract_xpert_content( - self, - xpert_response: XpertResponse, - ) -> str: - """ - Extract the raw content string from the normalized Xpert response. - """ - content = xpert_response.get('content') + conversation_id = self._get_conversation_id(request) - if content is None: - raise PromptRequestException( - 'Xpert response is missing the "content" field.' + try: + prompt = prompts_api.get_current_prompt( + prompt_model=self.model_type, + prompt_type=PromptType.LEARNER_INTENT, ) + system_prompt = prompts_api.build_system_prompt(prompt) + messages = prompts_api.build_messages(validated_data) - return content - - def _parse_json_content( - self, - content: str, - ) -> JSONValue: - """ - Parse and return the complete JSON value produced by Xpert. + xpert_response = prompts_api.send_xpert_message( + system_prompt=system_prompt, + messages=messages, + conversation_id=conversation_id, + tags=settings.XPERT_LEARNER_PATHWAYS_RAG_TAGS, + prompt_type=PromptType.LEARNER_INTENT, + ) - The content must be directly parseable as JSON after surrounding - whitespace is removed. Markdown fencing, repair prompts, retries, - fallback parsing, field mapping, and response normalization are - intentionally unsupported. - """ - try: - parsed_content = json.loads(content.strip()) - except json.JSONDecodeError as exc: - raise PromptRequestException( - f'Failed to parse Xpert response content as JSON: {exc}' - ) from exc + content = prompts_api.extract_xpert_content(xpert_response) + response_data = prompts_api.parse_json_content(content) + except prompts_api.PromptError as exc: + raise PromptRequestException(str(exc)) from exc - return cast(JSONValue, parsed_content) + return Response(response_data, status=status.HTTP_200_OK) diff --git a/enterprise_access/apps/prompts/api/__init__.py b/enterprise_access/apps/prompts/api/__init__.py index e69de29b..459c1821 100644 --- a/enterprise_access/apps/prompts/api/__init__.py +++ b/enterprise_access/apps/prompts/api/__init__.py @@ -0,0 +1,175 @@ +""" +Domain-layer API for Xpert prompt-backed workflows. + +This module handles the orchestration of prompt-based Xpert AI interactions, +including prompt retrieval, message construction, and response parsing. +No HTTP or DRF machinery — suitable for testing in isolation. +""" +import json +import logging +from collections.abc import Sequence +from typing import Any, TypeAlias + +from enterprise_access.apps.prompts.api_client import XpertAPIClient, XpertAPIError +from enterprise_access.apps.prompts.models import BaseSystemPrompt + +logger = logging.getLogger(__name__) + +ValidatedData: TypeAlias = dict[str, Any] +XpertMessage: TypeAlias = dict[str, str] +XpertResponse: TypeAlias = dict[str, object] +SystemPromptModel: TypeAlias = type[BaseSystemPrompt] + +_SCHEMA_SEPARATOR = '\n\nEXPECTED OUTPUT SCHEMA:\n' + + +class PromptError(Exception): + """ + Raised when a prompt-backed Xpert request fails at the domain layer. + + These errors are translated by the view layer into HTTP 500 responses. + """ + + +def get_current_prompt( + prompt_model: type[SystemPromptModel], + prompt_type: str, +) -> BaseSystemPrompt: + """ + Resolve the current prompt for the supplied prompt type. + + Raises: + PromptError: If no active prompt exists for the given prompt_type. + """ + prompt = prompt_model.get_current(prompt_type=prompt_type) + if prompt is None: + raise PromptError( + f'No active prompt found for prompt_type={prompt_type!r}.' + ) + return prompt + + +def build_system_prompt(prompt: BaseSystemPrompt) -> str: + """ + Build the complete system prompt sent to Xpert. + + The configured prompt text is stripped of surrounding whitespace. + A non-empty output schema is appended as formatted JSON. + """ + system_prompt = prompt.system_prompt.strip() + output_schema = prompt.output_schema + + if output_schema: + system_prompt += _SCHEMA_SEPARATOR + json.dumps( + output_schema, + indent=2, + sort_keys=True, + ) + + return system_prompt + + +def build_messages(validated_data: ValidatedData) -> list[XpertMessage]: + """ + Build the default Xpert message list. + + The complete validated request payload is encoded as compact JSON in + a single user message. + """ + return [ + { + 'role': 'user', + 'content': json.dumps( + validated_data, + separators=(',', ':'), + ), + }, + ] + + +def send_xpert_message( + *, + system_prompt: str, + messages: list[XpertMessage], + conversation_id: str, + tags: Sequence[str] | None = None, + prompt_type: str | None = None, +) -> XpertResponse: + """ + Send one prompt-backed request through the Xpert client. + + Xpert client failures are logged with tracking metadata and converted + to PromptError. Prompt text, request payloads, and raw model responses + are not logged. + + Args: + system_prompt: System prompt text for Xpert. + messages: List of messages for Xpert (user role + content). + conversation_id: Unique conversation ID for tracing. + tags: Optional list of tags for Xpert (e.g. RAG tags). + prompt_type: Optional prompt type for logging/tracking. + + Raises: + PromptError: If the Xpert API call fails. + + Returns: + The raw Xpert response dict. + """ + normalized_tags = list(tags) if tags else None + + try: + response = XpertAPIClient().send_message( + system_prompt=system_prompt, + messages=messages, + conversation_id=conversation_id, + tags=normalized_tags, + ) + except XpertAPIError as exc: + logger.exception( + 'Xpert request failed for prompt_type=%r, conversation_id=%r.', + prompt_type, + conversation_id, + ) + raise PromptError(str(exc)) from exc + + return response + + +def extract_xpert_content(xpert_response: XpertResponse) -> str: + """ + Extract the raw content string from the Xpert response. + + Raises: + PromptError: If the 'content' field is missing. + """ + content = xpert_response.get('content') + + if content is None: + raise PromptError('Xpert response is missing the "content" field.') + + if not isinstance(content, str): + raise PromptError('Xpert response "content" must be a string.') + + return content + + +def parse_json_content(content: str) -> dict[str, Any]: + """ + Parse and return the complete JSON value produced by Xpert. + + The content must be directly parseable as JSON after surrounding + whitespace is removed. Markdown fencing, repair prompts, retries, + fallback parsing, field mapping, and response normalization are + intentionally unsupported. + + Raises: + PromptError: If JSON parsing fails. + """ + try: + parsed_content = json.loads(content.strip()) + except json.JSONDecodeError as exc: + raise PromptError( + f'Failed to parse Xpert response content as JSON: {exc}' + ) from exc + + return parsed_content diff --git a/enterprise_access/apps/prompts/tests/test_api.py b/enterprise_access/apps/prompts/tests/test_api.py new file mode 100644 index 00000000..d1e41857 --- /dev/null +++ b/enterprise_access/apps/prompts/tests/test_api.py @@ -0,0 +1,292 @@ +""" +Tests for the prompts domain API module. +""" +import json +from unittest import mock + +import ddt +import pytest +from django.test import TestCase + +from enterprise_access.apps.prompts import api as prompts_api +from enterprise_access.apps.prompts.api_client import ( + XpertAPIConfigurationError, + XpertAPIError, + XpertAPIRequestError, + XpertAPIResponseError +) + +PATCH_XPERT_CLIENT = 'enterprise_access.apps.prompts.api.XpertAPIClient' + + +@ddt.ddt +class TestGetCurrentPrompt(TestCase): + """Tests for get_current_prompt.""" + + def test_returns_prompt_when_found(self): + prompt = mock.Mock() + prompt_model = mock.Mock() + prompt_model.get_current.return_value = prompt + + result = prompts_api.get_current_prompt( + prompt_model=prompt_model, + prompt_type='learner_intent', + ) + + assert result is prompt + + def test_exact_prompt_type_passed_to_get_current(self): + prompt_model = mock.Mock() + prompt_model.get_current.return_value = mock.Mock() + + prompts_api.get_current_prompt( + prompt_model=prompt_model, + prompt_type='learner_intent', + ) + + prompt_model.get_current.assert_called_once_with(prompt_type='learner_intent') + + def test_missing_prompt_raises_error(self): + prompt_model = mock.Mock() + prompt_model.get_current.return_value = None + + with pytest.raises(prompts_api.PromptError): + prompts_api.get_current_prompt( + prompt_model=prompt_model, + prompt_type='learner_intent', + ) + + def test_prompt_for_another_type_cannot_satisfy_lookup(self): + prompt_model = mock.Mock() + prompt_model.get_current.return_value = None + + with pytest.raises(prompts_api.PromptError): + prompts_api.get_current_prompt( + prompt_model=prompt_model, + prompt_type='recommendations_feedback', + ) + + prompt_model.get_current.assert_called_once_with( + prompt_type='recommendations_feedback', + ) + + +@ddt.ddt +class TestBuildSystemPrompt(TestCase): + """Tests for build_system_prompt.""" + + def _make_prompt(self, system_prompt, output_schema=None): + prompt = mock.Mock() + prompt.system_prompt = system_prompt + prompt.output_schema = output_schema + return prompt + + def test_strips_surrounding_whitespace(self): + prompt = self._make_prompt(' Be helpful. ') + result = prompts_api.build_system_prompt(prompt) + assert result == 'Be helpful.' + + def test_non_empty_schema_appended(self): + schema = {'type': 'object', 'properties': {'answer': {'type': 'string'}}} + prompt = self._make_prompt('Be helpful.', output_schema=schema) + + result = prompts_api.build_system_prompt(prompt) + + assert '\n\nEXPECTED OUTPUT SCHEMA:\n' in result + assert json.dumps(schema, indent=2, sort_keys=True) in result + + @ddt.data(None, {}) + def test_empty_schema_not_appended(self, output_schema): + prompt = self._make_prompt('Be helpful.', output_schema=output_schema) + + result = prompts_api.build_system_prompt(prompt) + + assert result == 'Be helpful.' + assert 'EXPECTED OUTPUT SCHEMA:' not in result + + def test_prompt_instance_not_mutated(self): + schema = {'key': 'value'} + prompt = self._make_prompt(' Original. ', output_schema=schema) + + prompts_api.build_system_prompt(prompt) + + assert prompt.system_prompt == ' Original. ' + assert prompt.output_schema is schema + + +@ddt.ddt +class TestBuildMessages(TestCase): + """Tests for build_messages.""" + + def test_builds_single_user_message_with_string_content(self): + messages = prompts_api.build_messages({'name': 'Alice'}) + + assert messages == [ + {'role': 'user', 'content': '{"name":"Alice"}'}, + ] + assert isinstance(messages[0]['content'], str) + + def test_content_is_compact_json(self): + messages = prompts_api.build_messages({'name': 'Alice', 'count': 3}) + content = messages[0]['content'] + + assert ': ' not in content + assert ', ' not in content + assert json.loads(content) == {'name': 'Alice', 'count': 3} + + def test_nested_json_round_trips(self): + data = { + 'name': 'Alice', + 'items': [1, 2, 3], + 'metadata': {'active': True, 'notes': None}, + } + + messages = prompts_api.build_messages(data) + + assert json.loads(messages[0]['content']) == data + + +@ddt.ddt +class TestSendXpertMessage(TestCase): + """Tests for send_xpert_message.""" + + def setUp(self): + self.system_prompt = 'You are helpful.' + self.messages = [{'role': 'user', 'content': '{"q":1}'}] + self.conversation_id = 'enterprise-access:test-123' + + @mock.patch(PATCH_XPERT_CLIENT) + def test_client_called_once_with_correct_args(self, mock_client_class): + mock_response = {'role': 'assistant', 'content': '{"answer":"yes"}'} + mock_client_class.return_value.send_message.return_value = mock_response + + result = prompts_api.send_xpert_message( + system_prompt=self.system_prompt, + messages=self.messages, + conversation_id=self.conversation_id, + tags=('tag1', 'tag2'), + prompt_type='learner_intent', + ) + + assert result == mock_response + mock_client_class.return_value.send_message.assert_called_once_with( + system_prompt=self.system_prompt, + messages=self.messages, + conversation_id=self.conversation_id, + tags=['tag1', 'tag2'], + ) + + @ddt.data(None, [], ()) + @mock.patch(PATCH_XPERT_CLIENT) + def test_empty_tags_passed_as_none(self, tags, mock_client_class): + mock_client_class.return_value.send_message.return_value = {} + + prompts_api.send_xpert_message( + system_prompt=self.system_prompt, + messages=self.messages, + conversation_id=self.conversation_id, + tags=tags, + ) + + assert mock_client_class.return_value.send_message.call_args.kwargs['tags'] is None + + @mock.patch(PATCH_XPERT_CLIENT) + def test_no_second_call_made(self, mock_client_class): + mock_client_class.return_value.send_message.return_value = {} + + prompts_api.send_xpert_message( + system_prompt=self.system_prompt, + messages=self.messages, + conversation_id=self.conversation_id, + ) + + assert mock_client_class.return_value.send_message.call_count == 1 + + +@ddt.ddt +class TestSendXpertMessageErrors(TestCase): + """Tests for XpertAPIError mapping to PromptError.""" + + @ddt.data( + XpertAPIError, + XpertAPIConfigurationError, + XpertAPIRequestError, + XpertAPIResponseError, + ) + @mock.patch(PATCH_XPERT_CLIENT) + def test_xpert_errors_become_prompt_error( + self, + error_class, + mock_client_class, + ): + original = error_class('original error text') + mock_client_class.return_value.send_message.side_effect = original + + with pytest.raises(prompts_api.PromptError) as exc_info: + prompts_api.send_xpert_message( + system_prompt='prompt', + messages=[], + conversation_id='enterprise-access:x', + prompt_type='learner_intent', + ) + + assert exc_info.value.__cause__ is original + assert 'original error text' in str(exc_info.value) + assert mock_client_class.return_value.send_message.call_count == 1 + + +@ddt.ddt +class TestExtractXpertContent(TestCase): + """Tests for extract_xpert_content.""" + + def test_valid_response_returns_content_string(self): + response = {'role': 'assistant', 'content': '{"answer":"yes"}'} + + result = prompts_api.extract_xpert_content(response) + + assert result == '{"answer":"yes"}' + + @ddt.data( + {'role': 'assistant'}, + {'role': 'assistant', 'content': None}, + ) + def test_invalid_content_raises_error(self, response): + with pytest.raises(prompts_api.PromptError): + prompts_api.extract_xpert_content(response) + + +@ddt.ddt +class TestParseJsonContent(TestCase): + """Tests for parse_json_content.""" + + @ddt.data( + ('{"answer":42}', {'answer': 42}), + ('[1,2,3]', [1, 2, 3]), + ('"hello"', 'hello'), + ('99', 99), + ('false', False), + ('true', True), + ('null', None), + (' {"trimmed":true} ', {'trimmed': True}), + ) + @ddt.unpack + def test_valid_json_values_returned_unchanged(self, raw_content, expected): + result = prompts_api.parse_json_content(raw_content) + assert result == expected + + @ddt.data( + 'not valid json', + '```json\n{"key":"value"}\n```', + '```\n{"key":"value"}\n```', + '{"unterminated": true', + '', + ) + def test_invalid_or_fenced_json_raises_error(self, raw_content): + with pytest.raises(prompts_api.PromptError): + prompts_api.parse_json_content(raw_content) + + def test_invalid_json_exception_is_chained(self): + with pytest.raises(prompts_api.PromptError) as exc_info: + prompts_api.parse_json_content('not valid json') + + assert exc_info.value.__cause__ is not None diff --git a/enterprise_access/settings/base.py b/enterprise_access/settings/base.py index 7d419360..c466c3eb 100644 --- a/enterprise_access/settings/base.py +++ b/enterprise_access/settings/base.py @@ -183,6 +183,7 @@ def root(*path_fragments): 'DEFAULT_THROTTLE_RATES': { 'bff_unauthenticated': '100/hour', 'ssp_product': '120/hour', + 'learner_pathways_learning_intent': '100/hour', }, }