diff --git a/enterprise_access/apps/api_client/enterprise_catalog_client.py b/enterprise_access/apps/api_client/enterprise_catalog_client.py index a2db35e3..6d6aca40 100644 --- a/enterprise_access/apps/api_client/enterprise_catalog_client.py +++ b/enterprise_access/apps/api_client/enterprise_catalog_client.py @@ -1,6 +1,7 @@ """ API client for enterprise-catalog service. """ +import logging from urllib.parse import urljoin import backoff @@ -9,17 +10,20 @@ from enterprise_access.apps.api_client.base_oauth import BaseOAuthClient from enterprise_access.apps.api_client.base_user import BaseUserApiClient from enterprise_access.apps.api_client.constants import autoretry_for_exceptions +from enterprise_access.apps.api_client.utils import fetch_all_results + +logger = logging.getLogger(__name__) class EnterpriseCatalogApiClient(BaseOAuthClient): """ - V2 API client for calls to the enterprise catalog service. + v2 API client for calls to the enterprise catalog service. """ api_version = 'v2' def __init__(self): self.api_base_url = urljoin(settings.ENTERPRISE_CATALOG_URL, f'api/{self.api_version}/') - # Academies are exposed on v1 of the enterprise-catalog API, not v2. + # Academies are exposed on v1 of the enterprise-catalog API, not v2 self.academies_endpoint = urljoin(settings.ENTERPRISE_CATALOG_URL, 'api/v1/academies/') self.enterprise_catalog_endpoint = urljoin(self.api_base_url, 'enterprise-catalogs/') super().__init__() @@ -40,6 +44,27 @@ def get_academy(self, academy_uuid): response.raise_for_status() return response.json() + @backoff.on_exception(wait_gen=backoff.expo, exception=autoretry_for_exceptions) + def get_academies(self, academy_uuid: str = None, is_active: bool | None = None) -> dict | list: + """ + Fetch a list of academies. + + Optionally filters results by academy UUID and active status. + + Returns: + dict: Paginated response containing academy results. + """ + # Defensive: if no endpoint configured, return empty paginated shape + if not self.academies_endpoint: + return {'count': 0, 'next': None, 'previous': None, 'results': []} + + params = {} + if academy_uuid: + params['academy_uuid'] = academy_uuid + if is_active is not None: + params['is_active'] = bool(is_active) + return fetch_all_results(self.client, self.academies_endpoint, params=params) + @backoff.on_exception(wait_gen=backoff.expo, exception=autoretry_for_exceptions) def associate_academy_with_catalog(self, academy_uuid, enterprise_catalog_uuid): """ @@ -133,6 +158,21 @@ def get_content_metadata_count(self, catalog_uuid): response.raise_for_status() return response.json()['count'] + @backoff.on_exception(wait_gen=backoff.expo, exception=autoretry_for_exceptions) + def get_catalogs(self, enterprise_customer_uuid: str = None) -> dict | list: + """ + Fetch a list of enterprise catalogs. + + Optionally filters results by enterprise_customer_uuid. + + Returns: + dict: Paginated response containing enterprise catalog results. + """ + params = {} + if enterprise_customer_uuid: + params['enterprise_customer'] = enterprise_customer_uuid + return fetch_all_results(self.client, self.enterprise_catalog_endpoint, params=params) + def content_metadata(self, content_id): raise NotImplementedError('There is currently no v2 API implementation for this endpoint.') diff --git a/enterprise_access/apps/api_client/tests/test_enterprise_catalog_client.py b/enterprise_access/apps/api_client/tests/test_enterprise_catalog_client.py index 04aafb7b..115c0f9b 100644 --- a/enterprise_access/apps/api_client/tests/test_enterprise_catalog_client.py +++ b/enterprise_access/apps/api_client/tests/test_enterprise_catalog_client.py @@ -19,6 +19,7 @@ EnterpriseCatalogUserV1ApiClient ) from enterprise_access.apps.api_client.tests.test_constants import DATE_FORMAT_ISO_8601 +from enterprise_access.apps.api_client.utils import fetch_all_results from enterprise_access.apps.core.tests.factories import UserFactory from enterprise_access.utils import _days_from_now @@ -151,6 +152,387 @@ def test_get_content_metadata_count(self, mock_oauth_client): f'http://enterprise-catalog.example.com/api/v2/enterprise-catalogs/{catalog_uuid}/get_content_metadata/', ) + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_academies(self, mock_oauth_client): + mock_response_json = {'count': 1, 'next': None, 'previous': None, 'results': [{'title': 'AI Academy'}]} + mock_oauth_client.return_value.get.return_value.json.return_value = mock_response_json + + client = EnterpriseCatalogApiClient() + fetched = client.get_academies() + + self.assertEqual(fetched, mock_response_json) + mock_oauth_client.return_value.get.assert_called_with( + 'http://enterprise-catalog.example.com/api/v1/academies/', + params={}, + ) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_academies_with_uuid(self, mock_oauth_client): + mock_response_json = {'count': 0, 'next': None, 'previous': None, 'results': []} + mock_oauth_client.return_value.get.return_value.json.return_value = mock_response_json + + academy_uuid = uuid4() + client = EnterpriseCatalogApiClient() + fetched = client.get_academies(academy_uuid=str(academy_uuid)) + + self.assertEqual(fetched, mock_response_json) + mock_oauth_client.return_value.get.assert_called_with( + 'http://enterprise-catalog.example.com/api/v1/academies/', + params={'academy_uuid': str(academy_uuid)}, + ) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_academy_fetches_single_record_from_v1(self, mock_oauth_client): + """Ensure `get_academy` calls the v1 academies endpoint and returns the academy JSON.""" + academy_uuid = uuid4() + mock_resp = mock.Mock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {'uuid': str(academy_uuid), 'title': 'Test Academy'} + mock_resp.raise_for_status = mock.Mock() + mock_oauth_client.return_value.get.return_value = mock_resp + + client = EnterpriseCatalogApiClient() + result = client.get_academy(academy_uuid) + + self.assertEqual(result, {'uuid': str(academy_uuid), 'title': 'Test Academy'}) + mock_oauth_client.return_value.get.assert_called_with( + f'http://enterprise-catalog.example.com/api/v1/academies/{academy_uuid}/' + ) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_catalogs(self, mock_oauth_client): + mock_response_json = {'count': 1, 'next': None, 'previous': None, 'results': [{'uuid': str(uuid4())}]} + mock_oauth_client.return_value.get.return_value.json.return_value = mock_response_json + + client = EnterpriseCatalogApiClient() + fetched = client.get_catalogs() + + self.assertEqual(fetched, mock_response_json) + mock_oauth_client.return_value.get.assert_called_with( + 'http://enterprise-catalog.example.com/api/v2/enterprise-catalogs/', + params={}, + ) + + @mock.patch("enterprise_access.apps.api_client.base_oauth.OAuthAPIClient") + def test_fetch_all_results_with_none_params(self, mock_oauth_client): + first_page = { + "count": 1, + "next": None, + "previous": None, + "results": [{"uuid": "123"}], + } + + mock_oauth_client.return_value.get.return_value.json.return_value = first_page + + result = fetch_all_results( + mock_oauth_client.return_value, + "http://enterprise-catalog.example.com/api/v2/enterprise-catalogs/", + params=None, + ) + + self.assertEqual(result, first_page) + + mock_oauth_client.return_value.get.assert_called_once_with( + "http://enterprise-catalog.example.com/api/v2/enterprise-catalogs/", + params={}, + ) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_catalogs_with_enterprise_customer(self, mock_oauth_client): + mock_response_json = {'count': 1, 'next': None, 'previous': None, 'results': [{'uuid': str(uuid4())}]} + mock_oauth_client.return_value.get.return_value.json.return_value = mock_response_json + + customer_uuid = str(uuid4()) + client = EnterpriseCatalogApiClient() + fetched = client.get_catalogs(enterprise_customer_uuid=customer_uuid) + + self.assertEqual(fetched, mock_response_json) + mock_oauth_client.return_value.get.assert_called_with( + 'http://enterprise-catalog.example.com/api/v2/enterprise-catalogs/', + params={"enterprise_customer": customer_uuid, } + ) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_academies_merges_paginated_results(self, mock_oauth_client): + page_1 = { + 'count': 2, + 'next': 'http://enterprise-catalog.example.com/api/v1/academies/?page=2', + 'previous': None, + 'results': [{'title': 'AI Academy'}], + } + page_2 = { + 'count': 2, + 'next': None, + 'previous': 'http://enterprise-catalog.example.com/api/v1/academies/?page=1', + 'results': [{'title': 'Data Academy'}], + } + mock_oauth_client.return_value.get.side_effect = [ + mock.Mock(json=mock.Mock(return_value=page_1), raise_for_status=mock.Mock()), + mock.Mock(json=mock.Mock(return_value=page_2), raise_for_status=mock.Mock()), + ] + + client = EnterpriseCatalogApiClient() + fetched = client.get_academies() + + self.assertEqual(fetched['count'], 2) + self.assertEqual(len(fetched['results']), 2) + self.assertIsNone(fetched['next']) + self.assertEqual(mock_oauth_client.return_value.get.call_count, 2) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_catalogs_merges_paginated_results(self, mock_oauth_client): + page_1 = { + 'count': 2, + 'next': 'http://enterprise-catalog.example.com/api/v2/enterprise-catalogs/?page=2', + 'previous': None, + 'results': [{'uuid': str(uuid4())}], + } + page_2 = { + 'count': 2, + 'next': None, + 'previous': 'http://enterprise-catalog.example.com/api/v2/enterprise-catalogs/?page=1', + 'results': [{'uuid': str(uuid4())}], + } + mock_oauth_client.return_value.get.side_effect = [ + mock.Mock(json=mock.Mock(return_value=page_1), raise_for_status=mock.Mock()), + mock.Mock(json=mock.Mock(return_value=page_2), raise_for_status=mock.Mock()), + ] + + client = EnterpriseCatalogApiClient() + fetched = client.get_catalogs() + + self.assertEqual(fetched['count'], 2) + self.assertEqual(len(fetched['results']), 2) + self.assertIsNone(fetched['next']) + self.assertEqual(mock_oauth_client.return_value.get.call_count, 2) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True) + def test_get_academies_with_empty_endpoint_returns_empty_payload(self, mock_oauth_client): + client = EnterpriseCatalogApiClient() + client.academies_endpoint = '' + + result = client.get_academies() + + self.assertEqual(result, {'count': 0, 'next': None, 'previous': None, 'results': []}) + mock_oauth_client.return_value.get.assert_not_called() + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True) + def test_get_academies_with_is_active_param(self, mock_oauth_client): + mock_oauth_client.return_value.get.return_value = mock.Mock( + json=mock.Mock(return_value={'count': 0, 'next': None, 'previous': None, 'results': []}), + raise_for_status=mock.Mock(), + ) + + client = EnterpriseCatalogApiClient() + result = client.get_academies(is_active=True) + + self.assertEqual(result['results'], []) + mock_oauth_client.return_value.get.assert_called_with( + 'http://enterprise-catalog.example.com/api/v1/academies/', + params={'is_active': True}, + ) + + def test_catalog_content_metadata_raises_for_empty_content_keys_with_traversal(self): + client = EnterpriseCatalogApiClient() + + with self.assertRaisesRegex(Exception, 'Cannot request all metadata for a catalog'): + client.catalog_content_metadata(uuid4(), content_keys=[], traverse_pagination=True) + + def test_content_metadata_not_implemented_for_v1_client(self): + client = EnterpriseCatalogApiClient() + + with self.assertRaises(NotImplementedError): + client.content_metadata('some-content-key') + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_catalogs_three_page_merge(self, mock_oauth_client): + page_1 = {'results': [{'uuid': '1'}], 'next': 'http://p2', 'previous': None, 'count': None} + page_2 = {'results': [{'uuid': '2'}], 'next': 'http://p3', 'previous': 'http://p1', 'count': None} + page_3 = {'results': [{'uuid': '3'}], 'next': None, 'previous': 'http://p2', 'count': None} + mock_oauth_client.return_value.get.side_effect = [ + mock.Mock(json=mock.Mock(return_value=page_1), raise_for_status=mock.Mock()), + mock.Mock(json=mock.Mock(return_value=page_2), raise_for_status=mock.Mock()), + mock.Mock(json=mock.Mock(return_value=page_3), raise_for_status=mock.Mock()), + ] + + client = EnterpriseCatalogApiClient() + fetched = client.get_catalogs() + + self.assertEqual(len(fetched['results']), 3) + self.assertIsNone(fetched['next']) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_catalog_content_metadata_traverse_false_allows_empty_keys(self, mock_oauth_client): + mock_oauth_client.return_value.get.return_value = mock.Mock( + json=mock.Mock(return_value={}), raise_for_status=mock.Mock() + ) + client = EnterpriseCatalogApiClient() + # Should not raise when traverse_pagination is False and content_keys empty + res = client.catalog_content_metadata('catalog', [], traverse_pagination=False) + self.assertEqual(res, {}) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_contains_content_items_true_value(self, mock_oauth_client): + mock_oauth_client.return_value.get.return_value = mock.Mock( + json=mock.Mock(return_value={'contains_content_items': True}), raise_for_status=mock.Mock() + ) + client = EnterpriseCatalogApiClient() + self.assertTrue(client.contains_content_items('catalog', ['x'])) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_content_metadata_count_raises_when_key_missing(self, mock_oauth_client): + mock_oauth_client.return_value.get.return_value = mock.Mock( + json=mock.Mock(return_value={}), raise_for_status=mock.Mock() + ) + client = EnterpriseCatalogApiClient() + with self.assertRaises(KeyError): + client.get_content_metadata_count('catalog') + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_catalogs_preserve_results_when_next_page_empty(self, mock_oauth_client): + page_1 = {'results': [{'uuid': '1'}], 'next': 'http://p2', 'previous': None, 'count': 2} + page_2 = {'results': [], 'next': None, 'previous': 'http://p1', 'count': 2} + mock_oauth_client.return_value.get.side_effect = [ + mock.Mock(json=mock.Mock(return_value=page_1), raise_for_status=mock.Mock()), + mock.Mock(json=mock.Mock(return_value=page_2), raise_for_status=mock.Mock()), + ] + + client = EnterpriseCatalogApiClient() + fetched = client.get_catalogs() + self.assertEqual(fetched['count'], 2) + self.assertEqual(len(fetched['results']), 1) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_catalog_content_metadata_returns_json_payload(self, mock_oauth_client): + payload = {'next': None, 'results': [{'key': 'k'}], 'count': 1} + mock_oauth_client.return_value.get.return_value = mock.Mock( + json=mock.Mock(return_value=payload), raise_for_status=mock.Mock() + ) + client = EnterpriseCatalogApiClient() + self.assertEqual(client.catalog_content_metadata('cat', ['k']), payload) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True) + def test_get_academies_merges_when_both_results_are_lists(self, mock_oauth_client): + """Ensure results are merged when both pages return list 'results'.""" + page_1 = { + 'count': None, + 'next': 'http://example.com/?page=2', + 'previous': None, + 'results': [{'id': 1}, {'id': 2}], + } + page_2 = { + 'count': None, + 'next': None, + 'previous': 'http://example.com/?page=1', + 'results': [{'id': 3}], + } + mock_oauth_client.return_value.get.side_effect = [ + mock.Mock(json=mock.Mock(return_value=page_1), raise_for_status=mock.Mock()), + mock.Mock(json=mock.Mock(return_value=page_2), raise_for_status=mock.Mock()), + ] + + client = EnterpriseCatalogApiClient() + res = client.get_academies() + + self.assertEqual(len(res['results']), 3) + self.assertIsNone(res['next']) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True) + def test_get_academies_returns_paginated_response(self, mock_oauth_client): + mock_response = { + 'count': 1, + 'next': None, + 'previous': None, + 'results': [{'uuid': 'academy-1'}], + } + + mock_oauth_client.return_value.get.return_value = mock.Mock( + json=mock.Mock(return_value=mock_response), + raise_for_status=mock.Mock(), + ) + + client = EnterpriseCatalogApiClient() + result = client.get_academies() + + self.assertEqual(result, mock_response) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient', autospec=True) + def test_get_academies_merges_paginated_results_(self, mock_oauth_client): + first_page = mock.Mock( + json=mock.Mock( + return_value={ + 'count': 2, + 'next': 'next-page', + 'previous': None, + 'results': [{'uuid': 'academy-1'}], + } + ), + raise_for_status=mock.Mock(), + ) + + second_page = mock.Mock( + json=mock.Mock( + return_value={ + 'count': 2, + 'next': None, + 'previous': None, + 'results': [{'uuid': 'academy-2'}], + } + ), + raise_for_status=mock.Mock(), + ) + + mock_oauth_client.return_value.get.side_effect = [ + first_page, + second_page, + ] + + client = EnterpriseCatalogApiClient() + result = client.get_academies() + + self.assertEqual(result['count'], 2) + self.assertEqual( + result['results'], + [ + {'uuid': 'academy-1'}, + {'uuid': 'academy-2'}, + ], + ) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_get_academies_preserves_explicit_count_across_pages(self, mock_oauth_client): + # First page reports an explicit larger count; second page has fewer results + page_1 = { + 'results': [{'id': 1}], + 'count': 4, + 'next': 'http://next', + 'previous': None, + } + page_2 = { + 'results': [{'id': 2}, {'id': 3}], + 'count': 2, + 'next': None, + 'previous': None, + } + mock_oauth_client.return_value.get.side_effect = [ + mock.Mock(json=mock.Mock(return_value=page_1), raise_for_status=mock.Mock()), + mock.Mock(json=mock.Mock(return_value=page_2), raise_for_status=mock.Mock()), + ] + + client = EnterpriseCatalogApiClient() + res = client.get_academies() + self.assertEqual(res['count'], 4) + self.assertEqual(len(res['results']), 3) + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_contains_content_items_missing_key_returns_false(self, mock_oauth_client): + mock_oauth_client.return_value.get.return_value = mock.Mock( + json=mock.Mock(return_value={}), + raise_for_status=mock.Mock(), + ) + client = EnterpriseCatalogApiClient() + self.assertFalse(client.contains_content_items('catalog', ['x'])) + @ddt.ddt class TestEnterpriseCatalogApiV1Client(TestCase): @@ -303,6 +685,31 @@ def test_secured_algolia_api_key(self, mock_crum_get_current_request, mock_send) # Assert the response is as expected self.assertEqual(result, expected_result) + @mock.patch('requests.Session.send') + @mock.patch('crum.get_current_request') + def test_secured_algolia_api_key_raises_on_http_error(self, mock_crum_get_current_request, mock_send): + """Ensure HTTP errors from the backend are propagated.""" + expected_url = ( + f'http://enterprise-catalog.example.com/api/v1' + f'/enterprise-customer/{self.mock_enterprise_customer_uuid}/secured-algolia-api-key/' + ) + request = self.factory.get(expected_url) + request.headers = { + "Authorization": 'test-auth', + self.request_id_key: 'test-request-id' + } + request.user = self.user + mock_crum_get_current_request.return_value = request + + mock_response = mock.Mock() + mock_response.status_code = 400 + mock_response.raise_for_status.side_effect = HTTPError('bad') + mock_send.return_value = mock_response + + client = EnterpriseCatalogUserV1ApiClient(request) + with self.assertRaises(HTTPError): + client.get_secured_algolia_api_key(enterprise_customer_uuid=self.mock_enterprise_customer_uuid) + class TestEnterpriseCatalogApiClientGetAcademy(TestCase): """Tests for EnterpriseCatalogApiClient.get_academy().""" diff --git a/enterprise_access/apps/api_client/utils.py b/enterprise_access/apps/api_client/utils.py new file mode 100644 index 00000000..2eaa8f06 --- /dev/null +++ b/enterprise_access/apps/api_client/utils.py @@ -0,0 +1,36 @@ +""" +Utility helpers for the API client implementation. +""" + + +def fetch_all_results(client, url, params=None): + """ + Fetch all paginated results. + + Args: + client: HTTP client. + url (str): Endpoint URL. + params (dict | None): Optional query parameters. + + Returns: + dict: Response payload with all pages merged. + """ + if params is None: + params = {} + + response = client.get(url, params=params) + response.raise_for_status() + + data = response.json() + + while data.get("next"): + response = client.get(data["next"]) + response.raise_for_status() + + page = response.json() + + data["results"].extend(page.get("results", [])) + data["next"] = page.get("next") + data["previous"] = data.get("previous") or page.get("previous") + + return data