diff --git a/enterprise_access/apps/api/serializers/customer_billing.py b/enterprise_access/apps/api/serializers/customer_billing.py index 6607ce47..5b8fccd2 100644 --- a/enterprise_access/apps/api/serializers/customer_billing.py +++ b/enterprise_access/apps/api/serializers/customer_billing.py @@ -15,6 +15,7 @@ CheckoutIntent, FailedCheckoutIntentConflict, SlugReservationConflict, + SspProduct, StripeEventSummary ) @@ -60,6 +61,10 @@ class CustomerBillingCreateCheckoutSessionRequestSerializer(serializers.Serializ required=True, help_text='The ID of the Stripe Price object representing the plan selection.', ) + ssp_product_slug = serializers.SlugField( + required=False, + help_text='The slug of the SSP product representing the plan selection.', + ) # pylint: disable=abstract-method @@ -117,6 +122,10 @@ class CustomerBillingCreateCheckoutSessionValidationFailedResponseSerializer(ser required=False, help_text='Validation results for stripe_price_id if validation failed. Absent otherwise.', ) + ssp_product_slug = ErrorDetailSerializer( + required=False, + help_text='Validation results for ssp_product_slug if validation failed. Absent otherwise.', + ) company_name = ErrorDetailSerializer( required=False, help_text='Validation results for company_name if validation failed. Absent otherwise.', @@ -194,6 +203,13 @@ class CheckoutIntentCreateRequestSerializer(CountryFieldMixin, serializers.Model """ A serializer intended for creating new CheckoutIntents. """ + + ssp_product = serializers.PrimaryKeyRelatedField( + queryset=SspProduct.objects.all(), + required=False, + allow_null=True, + ) + class Meta: model = CheckoutIntent fields = '__all__' @@ -205,6 +221,7 @@ class Meta: 'quantity', 'country', 'terms_metadata', + 'ssp_product' ] ] @@ -241,6 +258,7 @@ def create(self, validated_data): Creates a new CheckoutIntent. """ try: + ssp_product = validated_data.pop('ssp_product', None) return CheckoutIntent.create_intent( user=self.context['request'].user, quantity=validated_data['quantity'], @@ -248,6 +266,7 @@ def create(self, validated_data): name=validated_data.get('enterprise_name'), country=validated_data.get('country'), terms_metadata=validated_data.get('terms_metadata'), + ssp_product=ssp_product, ) # Catch exceptions that should return 422: diff --git a/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py b/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py index 1114558e..cce90532 100644 --- a/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py +++ b/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py @@ -13,7 +13,7 @@ from enterprise_access.apps.core.constants import SYSTEM_ENTERPRISE_LEARNER_ROLE from enterprise_access.apps.core.tests.factories import UserFactory from enterprise_access.apps.customer_billing.constants import CheckoutIntentState -from enterprise_access.apps.customer_billing.models import CheckoutIntent +from enterprise_access.apps.customer_billing.models import CheckoutIntent, SspProduct from test_utils import APITest User = get_user_model() @@ -42,6 +42,15 @@ def setUpTestData(cls): email='test4@example.com', password='testpass123' ) + # Ensure the default SSP product exists for class-level CheckoutIntent creation + SspProduct.objects.get_or_create( + slug='teams-yearly', + defaults={ + 'stripe_price_lookup_key': 'teams_subscription_license_yearly', + 'is_active': True, + 'catalog_query_uuid': uuid.uuid4(), + } + ) cls.checkout_intent_2 = CheckoutIntent.objects.create( user=cls.user_2, enterprise_name="Active Enterprise 2", @@ -51,7 +60,8 @@ def setUpTestData(cls): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_456', country='US', - terms_metadata={'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'} + terms_metadata={'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'}, + ssp_product_id='teams-yearly', ) cls.checkout_intent_4 = CheckoutIntent.objects.create( user=cls.user_4, @@ -62,12 +72,21 @@ def setUpTestData(cls): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_987', country='US', - terms_metadata={'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'} + terms_metadata={'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'}, + ssp_product_id='teams-yearly', ) def setUp(self): """Set up test data.""" super().setUp() + SspProduct.objects.get_or_create( + slug='teams-yearly', + defaults={ + 'stripe_price_lookup_key': 'teams_subscription_license_yearly', + 'is_active': True, + 'catalog_query_uuid': uuid.uuid4(), + } + ) self.checkout_intent_1 = CheckoutIntent.objects.create( user=self.user, @@ -78,7 +97,8 @@ def setUp(self): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_123', country='CA', - terms_metadata={'version': '1.1', 'test_mode': True} + terms_metadata={'version': '1.1', 'test_mode': True}, + ssp_product_id='teams-yearly', ) self.checkout_intent_3 = CheckoutIntent.objects.create( user=self.user_3, @@ -89,7 +109,8 @@ def setUp(self): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_789', country='GB', - terms_metadata={'version': '2.0', 'features': ['analytics', 'reporting']} + terms_metadata={'version': '2.0', 'features': ['analytics', 'reporting']}, + ssp_product_id='teams-yearly', ) # URL patterns @@ -267,7 +288,8 @@ def test_cannot_transition_from_fulfilled(self): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_78955', country='FR', - terms_metadata={'version': '1.5', 'fulfilled': True} + terms_metadata={'version': '1.5', 'fulfilled': True}, + ssp_product_id='teams-yearly', ) detail_url = reverse( @@ -338,7 +360,8 @@ def test_create_checkout_intent_success(self): 'enterprise_name': 'Test Enterprise post', 'quantity': 13, 'country': 'NZ', - 'terms_metadata': {'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'} + 'terms_metadata': {'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'}, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -368,7 +391,8 @@ def test_create_or_update_checkout_intent_success(self): 'enterprise_name': self.checkout_intent_1.enterprise_name, 'quantity': 33, 'country': 'IT', - 'terms_metadata': {'version': '2.0', 'updated': True} + 'terms_metadata': {'version': '2.0', 'updated': True}, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -412,7 +436,7 @@ def test_create_checkout_intent_invalid_field_values(self, **invalid_payload): response = self.client.post( self.list_url, - invalid_payload, + {**invalid_payload, 'ssp_product': 'teams-yearly'}, format='json' ) @@ -434,7 +458,7 @@ def test_create_checkout_intent_missing_required_fields(self, **payload): # Test missing enterprise_slug response = self.client.post( self.list_url, - payload, + {**payload, 'ssp_product': 'teams-yearly'}, ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) error_detail = list(response.json().values())[0][0] @@ -448,6 +472,7 @@ def test_create_checkout_intent_authentication_required(self): 'enterprise_slug': 'test-enterprise', 'enterprise_name': 'Test Enterprise', 'quantity': 10, + 'ssp_product': 'teams-yearly', }, ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) @@ -518,7 +543,8 @@ def test_create_with_null_terms_metadata(self): 'enterprise_slug': 'test-enterprise-null', 'enterprise_name': 'Test Enterprise Null', 'quantity': 5, - 'terms_metadata': None + 'terms_metadata': None, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -542,7 +568,8 @@ def test_create_with_empty_terms_metadata(self): 'enterprise_slug': 'test-enterprise-empty', 'enterprise_name': 'Test Enterprise Empty', 'quantity': 8, - 'terms_metadata': {} + 'terms_metadata': {}, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -567,7 +594,8 @@ def test_create_checkout_intent_without_slug_or_name_success(self): request_data = { 'quantity': 13, 'country': 'NZ', - 'terms_metadata': {'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'} + 'terms_metadata': {'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'}, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -603,6 +631,7 @@ def test_create_checkout_intent_already_failed_returns_422(self): 'enterprise_slug': 'new-slug', 'enterprise_name': 'New Name', 'quantity': 7, + 'ssp_product': 'teams-yearly', } response = self.client.post(self.list_url, request_data, format='json') self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY) @@ -627,6 +656,7 @@ def test_create_checkout_intent_slug_conflict_returns_422(self): 'enterprise_slug': 'active-enterprise', 'enterprise_name': 'Active Enterprise', 'quantity': 7, + 'ssp_product': 'teams-yearly', } response = self.client.post(self.list_url, request_data, format='json') self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY) diff --git a/enterprise_access/apps/api/v1/tests/test_customer_billing.py b/enterprise_access/apps/api/v1/tests/test_customer_billing.py index 173cf373..46036072 100644 --- a/enterprise_access/apps/api/v1/tests/test_customer_billing.py +++ b/enterprise_access/apps/api/v1/tests/test_customer_billing.py @@ -3899,17 +3899,24 @@ def test_create_checkout_session_returns_client_secret_from_dict( 'client_secret': 'cs_test_abc_secret_xyz', } - response = self.client.post( - self.url, - data={ - 'admin_email': self.user.email, - 'enterprise_slug': 'test-slug', - 'company_name': 'Test Co', - 'quantity': 5, - 'stripe_price_id': 'price_abc123', - }, - format='json', - ) + # Prevent live Stripe pricing lookups during checkout flow + with mock.patch('enterprise_access.apps.customer_billing.api.get_ssp_product_pricing') as mock_get_pricing: + mock_get_pricing.return_value = { + 'quarterly_license_plan': {'id': 'price_test_quarterly', 'quantity_range': (5, 30)} + } + + response = self.client.post( + self.url, + data={ + 'admin_email': self.user.email, + 'enterprise_slug': 'test-slug', + 'company_name': 'Test Co', + 'quantity': 5, + 'stripe_price_id': 'price_abc123', + 'ssp_product': 'quarterly_license_plan', + }, + format='json', + ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual( diff --git a/enterprise_access/apps/api/v1/views/customer_billing.py b/enterprise_access/apps/api/v1/views/customer_billing.py index 1f03ec83..3758428e 100644 --- a/enterprise_access/apps/api/v1/views/customer_billing.py +++ b/enterprise_access/apps/api/v1/views/customer_billing.py @@ -256,7 +256,8 @@ def create_checkout_session(self, request, *args, **kwargs): >>> "admin_email": "dr@evil.inc", >>> "enterprise_slug": "my-sluggy" >>> "quantity": 7, - >>> "stripe_price_id": "price_1MoBy5LkdIwHu7ixZhnattbh" + >>> "stripe_price_id": "price_1MoBy5LkdIwHu7ixZhnattbh", + >>> "ssp_product_slug": "ai-academy-yearly" >>> } HTTP 201 CREATED >>> { @@ -287,7 +288,8 @@ def create_checkout_session(self, request, *args, **kwargs): 'Handling request to create free trial plan. ' f'enterprise_slug="{validated_data["enterprise_slug"]}" ' f'quantity="{validated_data["quantity"]}" ' - f'stripe_price_id="{validated_data["stripe_price_id"]}"' + f'stripe_price_id="{validated_data.get("stripe_price_id")}"' + f'ssp_product_slug="{validated_data.get("ssp_product_slug")}" ' ) try: session = create_free_trial_checkout_session( diff --git a/enterprise_access/apps/bffs/checkout/handlers.py b/enterprise_access/apps/bffs/checkout/handlers.py index 99136508..c1da55ad 100644 --- a/enterprise_access/apps/bffs/checkout/handlers.py +++ b/enterprise_access/apps/bffs/checkout/handlers.py @@ -170,8 +170,10 @@ def _get_pricing_data(self) -> Dict: for _, price_data in pricing_data.items(): prices.append({ 'id': price_data.get('id'), + 'stripe_price_id': price_data.get('stripe_price_id') or price_data.get('id'), 'product': price_data.get('product', {}).get('id'), 'lookup_key': price_data.get('lookup_key'), + 'ssp_product_slug': price_data.get('ssp_product_slug'), 'recurring': price_data.get('recurring', {}), 'currency': price_data.get('currency'), 'unit_amount': price_data.get('unit_amount'), diff --git a/enterprise_access/apps/bffs/checkout/serializers.py b/enterprise_access/apps/bffs/checkout/serializers.py index 4f9c812d..f7c2d26b 100644 --- a/enterprise_access/apps/bffs/checkout/serializers.py +++ b/enterprise_access/apps/bffs/checkout/serializers.py @@ -28,11 +28,17 @@ class PriceSerializer(serializers.Serializer): Serializer for Stripe price objects in checkout context. """ id = serializers.CharField(help_text="Stripe Price ID") + stripe_price_id = serializers.CharField(required=False, help_text="Stripe Price ID") product = serializers.CharField(help_text="Stripe Product ID") lookup_key = serializers.CharField(help_text="Lookup key for this price") recurring = serializers.DictField( help_text="Recurring billing configuration" ) + ssp_product_slug = serializers.CharField( + required=False, + allow_null=True, + help_text="SSP product slug" + ) currency = serializers.CharField(help_text="Currency code (e.g. 'usd')") unit_amount = serializers.IntegerField(help_text="Price amount in cents") unit_amount_decimal = serializers.DecimalField( @@ -214,6 +220,10 @@ class CheckoutValidationRequestSerializer(serializers.Serializer): enterprise_slug = serializers.SlugField(required=False, allow_blank=True, help_text="Desired enterprise slug") quantity = serializers.IntegerField(required=False, allow_null=True, help_text="Number of licenses") stripe_price_id = serializers.CharField(required=False, allow_blank=True, help_text="Stripe price ID") + ssp_product_slug = serializers.SlugField( + required=False, allow_blank=True, + help_text="SSP product slug for the selected plan", + ) class UserAuthInfoSerializer(serializers.Serializer): diff --git a/enterprise_access/apps/customer_billing/api.py b/enterprise_access/apps/customer_billing/api.py index f9a8ce84..92ce8c2e 100644 --- a/enterprise_access/apps/customer_billing/api.py +++ b/enterprise_access/apps/customer_billing/api.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import TypedDict, Unpack, cast +import settings import stripe from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractUser @@ -17,7 +18,8 @@ from enterprise_access.apps.customer_billing.models import ( CheckoutIntent, FailedCheckoutIntentConflict, - SlugReservationConflict + SlugReservationConflict, + SspProduct ) from enterprise_access.apps.customer_billing.pricing_api import get_ssp_product_pricing from enterprise_access.apps.customer_billing.stripe_api import create_subscription_checkout_session @@ -35,6 +37,7 @@ class CheckoutSessionInputValidatorData(TypedDict, total=False): full_name: str enterprise_slug: str quantity: int + ssp_product_slug: str stripe_price_id: str @@ -47,6 +50,7 @@ class CheckoutSessionInputData(TypedDict, total=True): enterprise_slug: str company_name: str quantity: int + ssp_product_slug: str stripe_price_id: str @@ -194,10 +198,11 @@ def handle_quantity(self, input_data: CheckoutSessionInputValidatorData) -> Fiel Validate the `quantity` field using Stripe price data. """ quantity = input_data.get('quantity') + ssp_product_slug = input_data.get('ssp_product_slug') stripe_price_id = input_data.get('stripe_price_id') # We need multiple form fields to validate quantity. - if not all([quantity, stripe_price_id]): + if not quantity or (not ssp_product_slug and not stripe_price_id): error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['common']['INCOMPLETE_DATA'] return {'error_code': error_code, 'developer_message': developer_message} @@ -210,13 +215,14 @@ def handle_quantity(self, input_data: CheckoutSessionInputValidatorData) -> Fiel # Get the SSP product pricing data (includes quantity ranges) ssp_pricing = get_ssp_product_pricing() - # Find the SSP product that matches this stripe_price_id matching_product = None - for _, price_data in ssp_pricing.items(): - if price_data.get('id') == stripe_price_id: - matching_product = price_data - break - + if ssp_product_slug: + matching_product = ssp_pricing.get(ssp_product_slug) + elif stripe_price_id: + for price_data in ssp_pricing.values(): + if price_data.get('id') == stripe_price_id: + matching_product = price_data + break if not matching_product: error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['common']['INCOMPLETE_DATA'] return {'error_code': error_code, 'developer_message': developer_message} @@ -232,37 +238,56 @@ def handle_quantity(self, input_data: CheckoutSessionInputValidatorData) -> Fiel return {'error_code': error_code, 'developer_message': developer_message} except Exception as exc: # pylint: disable=broad-exception-caught - logger.error(f'Error validating quantity for stripe_price_id {stripe_price_id}: {exc}') + logger.error( + f'Error validating quantity for ssp_product_slug {ssp_product_slug} ' + f'or stripe_price_id {stripe_price_id}: {exc}' + ) error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['common']['INCOMPLETE_DATA'] return {'error_code': error_code, 'developer_message': developer_message} return {'error_code': None, 'developer_message': None} - def handle_stripe_price_id(self, input_data: CheckoutSessionInputValidatorData) -> FieldValidationResult: + def handle_ssp_product_slug(self, input_data: CheckoutSessionInputValidatorData) -> FieldValidationResult: """ - Validate the `stripe_price_id` field against active Stripe prices. + Validate the `ssp_product_slug` field against active SSP products. """ - stripe_price_id = input_data.get('stripe_price_id') + ssp_product_slug = input_data.get('ssp_product_slug') # "Invalid format" if empty, missing, or not a str. - if not isinstance(stripe_price_id, str) or not stripe_price_id: - error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['stripe_price_id']['INVALID_FORMAT'] + if not isinstance(ssp_product_slug, str) or not ssp_product_slug: + error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['ssp_product_slug']['INVALID_FORMAT'] return {'error_code': error_code, 'developer_message': developer_message} try: # Get SSP product pricing data (validates lookup_keys against Stripe) ssp_pricing = get_ssp_product_pricing() - # Check if the price_id exists in any of the configured SSP products - price_exists = any( - price_data.get('id') == stripe_price_id - for price_data in ssp_pricing.values() - ) + if ssp_product_slug not in ssp_pricing: + error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['ssp_product_slug']['DOES_NOT_EXIST'] + return {'error_code': error_code, 'developer_message': developer_message} + + except Exception as exc: # pylint: disable=broad-exception-caught + logger.error(f'Error validating ssp_product_slug {ssp_product_slug}: {exc}') + error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['ssp_product_slug']['DOES_NOT_EXIST'] + return {'error_code': error_code, 'developer_message': developer_message} + + return {'error_code': None, 'developer_message': None} - if not price_exists: + def handle_stripe_price_id(self, input_data: CheckoutSessionInputValidatorData) -> FieldValidationResult: + """ + Validate the `stripe_price_id` field against active SSP products. + """ + stripe_price_id = input_data.get('stripe_price_id') + + if not isinstance(stripe_price_id, str) or not stripe_price_id: + error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['stripe_price_id']['INVALID_FORMAT'] + return {'error_code': error_code, 'developer_message': developer_message} + + try: + ssp_pricing = get_ssp_product_pricing() + if not any(price_data.get('id') == stripe_price_id for price_data in ssp_pricing.values()): error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['stripe_price_id']['DOES_NOT_EXIST'] return {'error_code': error_code, 'developer_message': developer_message} - except Exception as exc: # pylint: disable=broad-exception-caught logger.error(f'Error validating stripe_price_id {stripe_price_id}: {exc}') error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['stripe_price_id']['DOES_NOT_EXIST'] @@ -350,6 +375,7 @@ def handle_company_name(self, input_data: CheckoutSessionInputValidatorData) -> 'company_name': handle_company_name, 'enterprise_slug': handle_enterprise_slug, 'quantity': handle_quantity, + 'ssp_product_slug': handle_ssp_product_slug, 'stripe_price_id': handle_stripe_price_id, 'user': handle_user, } @@ -422,7 +448,18 @@ def create_free_trial_checkout_session( raise CreateCheckoutSessionValidationError(validation_errors_by_field=validation_errors) user = input_data['user'] - + ssp_product_slug = input_data.get('ssp_product_slug') + if not ssp_product_slug: + ssp_product_slug = getattr(settings, 'SSP_DEFAULT_PRODUCT_SLUG', None) + + ssp_product_instance = None + if ssp_product_slug: + ssp_product_instance = SspProduct.objects.filter( + slug=ssp_product_slug, is_active=True + ).first() + stripe_price_id = input_data.get('stripe_price_id') + if not stripe_price_id and ssp_product_instance: + stripe_price_id = ssp_product_instance.stripe_price_lookup_key # Create checkout intent, which reserves the enterprise name & slug. try: intent = CheckoutIntent.create_intent( @@ -430,6 +467,7 @@ def create_free_trial_checkout_session( quantity=input_data.get('quantity'), slug=input_data.get('enterprise_slug'), name=input_data.get('company_name'), + ssp_product=ssp_product_instance, ) except SlugReservationConflict as exc: raise CreateCheckoutSessionSlugReservationConflict() from exc @@ -438,7 +476,7 @@ def create_free_trial_checkout_session( lms_user_id = user.lms_user_id checkout_session = create_subscription_checkout_session( - input_data=input_data, + input_data={**input_data, 'ssp_product_slug': ssp_product_slug, 'stripe_price_id': stripe_price_id}, lms_user_id=lms_user_id, checkout_intent=intent, ) diff --git a/enterprise_access/apps/customer_billing/constants.py b/enterprise_access/apps/customer_billing/constants.py index 87b61f9c..f67e0284 100644 --- a/enterprise_access/apps/customer_billing/constants.py +++ b/enterprise_access/apps/customer_billing/constants.py @@ -53,6 +53,10 @@ 'INVALID_FORMAT': ('invalid_format', 'Must be a non-empty string.'), 'DOES_NOT_EXIST': ('does_not_exist', 'This stripe_price_id has not been configured.'), }, + 'ssp_product_slug': { + 'INVALID_FORMAT': ('invalid_format', 'Must be a non-empty string.'), + 'DOES_NOT_EXIST': ('does_not_exist', 'This ssp_product_slug has not been configured.'), + }, } # According to stripe's AI assistant: "When a Checkout Session is created, diff --git a/enterprise_access/apps/customer_billing/models.py b/enterprise_access/apps/customer_billing/models.py index 76adf535..0bf2a7bb 100644 --- a/enterprise_access/apps/customer_billing/models.py +++ b/enterprise_access/apps/customer_billing/models.py @@ -645,7 +645,8 @@ def create_intent( slug: str | None = None, name: str | None = None, country: str | None = None, - terms_metadata: dict | None = None + terms_metadata: dict | None = None, + ssp_product: 'SspProduct | None' = None, ) -> Self: """ Create or update a checkout intent for a user with the given enterprise details. @@ -744,20 +745,24 @@ def create_intent( existing_intent.enterprise_name = name or existing_intent.enterprise_name existing_intent.country = country or existing_intent.country existing_intent.terms_metadata = (existing_intent.terms_metadata or {}) | (terms_metadata or {}) - + if ssp_product is not None: + existing_intent.ssp_product = ssp_product existing_intent.save() return existing_intent - return cls.objects.create( - user=user, - state=CheckoutIntentState.CREATED, - enterprise_slug=slug, - enterprise_name=name, - quantity=quantity, - expires_at=expires_at, - country=country, - terms_metadata=terms_metadata, - ) + create_kwargs = { + 'user': user, + 'state': CheckoutIntentState.CREATED, + 'enterprise_slug': slug, + 'enterprise_name': name, + 'quantity': quantity, + 'expires_at': expires_at, + 'country': country, + 'terms_metadata': terms_metadata, + } + if ssp_product is not None: + create_kwargs['ssp_product'] = ssp_product + return cls.objects.create(**create_kwargs) @classmethod def for_user(cls, user): diff --git a/enterprise_access/apps/customer_billing/pricing_api.py b/enterprise_access/apps/customer_billing/pricing_api.py index 04bbb356..9eca5b78 100644 --- a/enterprise_access/apps/customer_billing/pricing_api.py +++ b/enterprise_access/apps/customer_billing/pricing_api.py @@ -29,6 +29,7 @@ from django.conf import settings from edx_django_utils.cache import TieredCache +from enterprise_access.apps.customer_billing.models import SspProduct from enterprise_access.cache_utils import versioned_cache_key logger = logging.getLogger(__name__) @@ -264,23 +265,27 @@ def get_ssp_product_pricing() -> Dict[str, Dict]: all_stripe_prices = get_all_stripe_prices() ssp_pricing = {} - for product_key, product_config in settings.SSP_PRODUCTS.items(): - lookup_key = product_config.get('lookup_key') + default_quantity_range = getattr(settings, 'DEFAULT_SSP_QUANTITY_RANGE', [5, 50]) + for ssp_product in SspProduct.objects.filter(is_active=True): + lookup_key = ssp_product.stripe_price_lookup_key if not lookup_key: - logger.error(f'SSP product {product_key} missing lookup_key') - raise StripePricingError(f'SSP product {product_key} missing lookup_key') + logger.error(f'SSP product {ssp_product.slug} missing lookup_key') + raise StripePricingError(f'SSP product {ssp_product.slug} missing lookup_key') if lookup_key not in all_stripe_prices: - logger.error(f'lookup_key {lookup_key} for SSP product {product_key} not found in active Stripe prices') + logger.error( + f'lookup_key {lookup_key} for SSP product {ssp_product.slug} ' + f'not found in active Stripe prices') raise StripePricingError( - f'lookup_key {lookup_key} for SSP product {product_key} not found in active Stripe prices' + f'lookup_key {lookup_key} for SSP product {ssp_product.slug} ' + f'not found in active Stripe prices' ) price_data = all_stripe_prices[lookup_key].copy() # Add SSP-specific metadata - price_data['ssp_product_key'] = product_key - price_data['quantity_range'] = product_config.get('quantity_range') - ssp_pricing[product_key] = price_data + price_data['ssp_product_key'] = ssp_product.slug + price_data['quantity_range'] = default_quantity_range + ssp_pricing[ssp_product.slug] = price_data return ssp_pricing @@ -423,4 +428,20 @@ def _serialize_basic_format(stripe_price: stripe.Price) -> SerializedPriceData: 'metadata': product.metadata, } + # Prefer explicit metadata set on the Stripe Product (Terraform sets this). + ssp_slug = None + try: + ssp_slug = product.metadata.get('ssp_product_slug') if getattr(product, 'metadata', None) else None + except Exception: # pylint: disable=broad-exception-caught + ssp_slug = None + + # Fallback: try to resolve from our SspProduct model using lookup_key + if not ssp_slug: + lookup_key = getattr(stripe_price, 'lookup_key', None) + if lookup_key: + ssp = SspProduct.objects.filter(stripe_price_lookup_key=lookup_key).only('slug').first() + ssp_slug = ssp.slug if ssp else None + + base_data['ssp_product_slug'] = ssp_slug + return base_data diff --git a/enterprise_access/apps/customer_billing/tests/test_api.py b/enterprise_access/apps/customer_billing/tests/test_api.py index 6396217e..c45d1526 100644 --- a/enterprise_access/apps/customer_billing/tests/test_api.py +++ b/enterprise_access/apps/customer_billing/tests/test_api.py @@ -1,11 +1,13 @@ """ Unit tests for the ``enterprise_access.apps.customer_billing.api`` module. """ +import uuid from datetime import timedelta from unittest import mock import ddt import requests +import stripe from django.contrib.auth import get_user_model from django.test import TestCase, override_settings from django.utils import timezone @@ -14,7 +16,12 @@ from enterprise_access.apps.customer_billing import api as customer_billing_api from enterprise_access.apps.customer_billing import stripe_api from enterprise_access.apps.customer_billing.constants import CheckoutIntentState -from enterprise_access.apps.customer_billing.models import CheckoutIntent +from enterprise_access.apps.customer_billing.models import ( + CheckoutIntent, + FailedCheckoutIntentConflict, + SlugReservationConflict, + SspProduct +) User = get_user_model() @@ -128,6 +135,63 @@ def test_create_free_trial_checkout_session_success( self.assertEqual(intent.stripe_customer_id, 'cust-123') self.assertFalse(intent.is_expired()) + @mock.patch( + 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', + return_value=MOCK_SSP_PRICING_DATA, + ) + @mock.patch.object(customer_billing_api, 'LmsApiClient', autospec=True) + @mock.patch.object(stripe_api, 'stripe') + @mock.patch('enterprise_access.apps.customer_billing.api.create_subscription_checkout_session') + def test_create_free_trial_checkout_session_with_ssp_product_slug( + self, + mock_create_checkout, + mock_stripe, # pylint: disable=unused-argument + mock_lms_client_class, + mock_get_ssp_pricing, # pylint: disable=unused-argument + ): + """ + Ensure that supplying `ssp_product_slug` resolves to an `SspProduct` instance + and that the created CheckoutIntent stores the FK correctly. + """ + # Setup mocks + mock_lms_client = mock_lms_client_class.return_value + mock_lms_client.get_lms_user_account.return_value = [{'id': self.user.lms_user_id}] + mock_lms_client.get_enterprise_customer_data.side_effect = raise_404_error + + mock_create_checkout.return_value = {'id': 'sess-ssp', 'customer': 'cust-ssp'} + + # Call with ssp_product_slug instead of stripe_price_id + # Ensure an SspProduct exists so the code path that maps ssp->stripe_price is exercised + SspProduct.objects.create( + slug='quarterly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + result = customer_billing_api.create_free_trial_checkout_session( + user=self.user, + admin_email=self.user.email, + enterprise_slug='my-sluggy', + company_name='My Cool Company', + quantity=20, + ssp_product_slug='quarterly_license_plan', + ) + + self.assertEqual(result, {'id': 'sess-ssp', 'customer': 'cust-ssp'}) + + intent = CheckoutIntent.objects.get(user=self.user) + self.assertIsNotNone(intent.ssp_product) + + # Ensure the checkout creator received the provided ssp_product_slug + called_input = mock_create_checkout.call_args[1]['input_data'] + self.assertEqual(called_input.get('ssp_product_slug'), 'quarterly_license_plan') + # And stripe_price_id should be derived from the SspProduct lookup key when not provided + self.assertEqual( + called_input.get('stripe_price_id'), + MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + ) + # Assert library methods were called correctly. mock_lms_client.get_lms_user_account.assert_called_once_with( email=self.user.email, @@ -137,11 +201,9 @@ def test_create_free_trial_checkout_session_success( mock.call(enterprise_customer_name='My Cool Company'), ]) - # Check that customer slug and user data is in Stripe metadata - call_args = mock_stripe.checkout.Session.create.call_args - metadata = call_args[1]['subscription_data']['metadata'] - self.assertEqual(metadata['enterprise_customer_slug'], 'my-sluggy') - self.assertEqual(metadata['lms_user_id'], str(self.user.lms_user_id)) + # Ensure the create checkout helper received enterprise slug and admin email + self.assertEqual(called_input.get('enterprise_slug'), 'my-sluggy') + self.assertEqual(called_input.get('admin_email'), self.user.email) @mock.patch( 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', @@ -172,9 +234,10 @@ def test_create_free_trial_checkout_session_success_without_user( quantity=20, stripe_price_id=QUARTERLY_PRICE_ID, ) - # Should get slug reserved error - validation_errors = cm.exception.validation_errors_by_field - self.assertIn('user', validation_errors) + + # Should get user validation error + validation_errors = cm.exception.validation_errors_by_field + self.assertIn('user', validation_errors) @mock.patch( 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', @@ -262,11 +325,10 @@ def test_slug_reservation_conflict( quantity=20, stripe_price_id=QUARTERLY_PRICE_ID, ) - - # Should get slug reserved error - validation_errors = cm.exception.validation_errors_by_field - self.assertIn('enterprise_slug', validation_errors) - self.assertEqual(validation_errors['enterprise_slug']['error_code'], 'slug_reserved') + # Should get slug reserved error + validation_errors = cm.exception.validation_errors_by_field + self.assertIn('enterprise_slug', validation_errors) + self.assertEqual(validation_errors['enterprise_slug']['error_code'], 'slug_reserved') @mock.patch( 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', @@ -304,11 +366,10 @@ def test_name_reservation_conflict( quantity=20, stripe_price_id=QUARTERLY_PRICE_ID, ) - - # Should get slug reserved error - validation_errors = cm.exception.validation_errors_by_field - self.assertIn('company_name', validation_errors) - self.assertEqual(validation_errors['company_name']['error_code'], 'existing_enterprise_customer') + # Should get company name conflict error + validation_errors = cm.exception.validation_errors_by_field + self.assertIn('company_name', validation_errors) + self.assertEqual(validation_errors['company_name']['error_code'], 'existing_enterprise_customer') @mock.patch( 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', @@ -444,6 +505,16 @@ def test_expired_intent_allows_reuse( }, } }, + { + 'request_quantity': 0, + 'request_stripe_price_id': QUARTERLY_PRICE_ID, + 'expected_validation_errors': { + 'quantity': { + 'error_code': 'incomplete_data', + 'developer_message': 'Not enough parameters were given.', + } + } + }, ) @ddt.unpack @mock.patch( @@ -500,3 +571,125 @@ def test_create_free_trial_checkout_session_errors( actual_validation_errors = cm.exception.validation_errors_by_field assert actual_validation_errors == expected_validation_errors + + @mock.patch( + 'enterprise_access.apps.customer_billing.' + 'api.get_ssp_product_pricing', + return_value=MOCK_SSP_PRICING_DATA + ) + @mock.patch.object(customer_billing_api, 'LmsApiClient', autospec=True) + @mock.patch.object(CheckoutIntent, 'create_intent') + def test_create_intent_raises_slug_reservation_conflict(self, mock_create_intent, mock_lms_client_class, _): + """ + Test that SlugReservationConflict from create_intent is wrapped in CreateCheckoutSessionSlugReservationConflict. + """ + mock_lms = mock_lms_client_class.return_value + mock_lms.get_lms_user_account.return_value = [{'id': self.user.lms_user_id}] + mock_lms.get_enterprise_customer_data.side_effect = raise_404_error + mock_create_intent.side_effect = SlugReservationConflict() + with self.assertRaises(customer_billing_api.CreateCheckoutSessionSlugReservationConflict) as cm: + customer_billing_api.create_free_trial_checkout_session( + user=self.user, admin_email=self.user.email, enterprise_slug='s', + company_name='C', quantity=10, stripe_price_id=QUARTERLY_PRICE_ID, + ) + self.assertEqual(cm.exception.non_field_errors[0]['error_code'], 'checkout_intent_conflict_slug_reserved') + + @mock.patch( + 'enterprise_access.apps.customer_billing.' + 'api.get_ssp_product_pricing', + return_value=MOCK_SSP_PRICING_DATA + ) + @mock.patch.object(customer_billing_api, 'LmsApiClient', autospec=True) + @mock.patch.object(CheckoutIntent, 'create_intent') + def test_create_intent_raises_failed_conflict(self, mock_create_intent, mock_lms_client_class, _): + """ + Test that FailedCheckoutIntentConflict from create_intent is wrapped in CreateCheckoutSessionFailedConflict. + """ + mock_lms = mock_lms_client_class.return_value + mock_lms.get_lms_user_account.return_value = [{'id': self.user.lms_user_id}] + mock_lms.get_enterprise_customer_data.side_effect = raise_404_error + mock_create_intent.side_effect = FailedCheckoutIntentConflict() + with self.assertRaises(customer_billing_api.CreateCheckoutSessionFailedConflict) as cm: + customer_billing_api.create_free_trial_checkout_session( + user=self.user, admin_email=self.user.email, enterprise_slug='s', + company_name='C', quantity=10, stripe_price_id=QUARTERLY_PRICE_ID, + ) + self.assertEqual(cm.exception.non_field_errors[0]['error_code'], 'checkout_intent_conflict_failed') + + +class TestCreateStripeBillingPortalSession(TestCase): + """ + Tests for the ``create_stripe_billing_portal_session()`` function. + """ + + def test_no_customer_id_raises_value_error(self): + """Missing stripe_customer_id raises ValueError.""" + intent = mock.Mock(stripe_customer_id=None, id='intent-1') + with self.assertRaises(ValueError): + customer_billing_api.create_stripe_billing_portal_session(intent, 'https://return.url') + + @mock.patch('enterprise_access.apps.customer_billing.api.stripe') + def test_success(self, mock_stripe): + """Happy path returns a portal session.""" + intent = mock.Mock(stripe_customer_id='cus_123', id='intent-1') + mock_stripe.billing_portal.Session.create.return_value = mock.Mock(id='bps_1') + result = customer_billing_api.create_stripe_billing_portal_session(intent, 'https://return.url') + self.assertEqual(result.id, 'bps_1') + mock_stripe.billing_portal.Session.create.assert_called_once_with( + customer='cus_123', return_url='https://return.url', + ) + + @mock.patch('enterprise_access.apps.customer_billing.api.stripe') + def test_stripe_error_propagates(self, mock_stripe): + """StripeError from portal creation is re-raised after logging.""" + intent = mock.Mock(stripe_customer_id='cus_123', id='intent-1') + mock_stripe.StripeError = stripe.StripeError + mock_stripe.billing_portal.Session.create.side_effect = stripe.StripeError('fail') + with self.assertRaises(stripe.StripeError): + customer_billing_api.create_stripe_billing_portal_session(intent, 'https://return.url') + + +@mock.patch('enterprise_access.apps.customer_billing.api.get_ssp_product_pricing') +class TestValidatorHandlerEdgeCases(TestCase): + """ + Direct handler tests for uncovered exception/edge-case branches. + """ + + def setUp(self): + self.validator = customer_billing_api.CheckoutSessionInputValidator() + + # ── handle_quantity: except Exception when get_ssp_product_pricing raises ── + def test_handle_quantity_pricing_exception(self, mock_pricing): + mock_pricing.side_effect = RuntimeError('boom') + result = self.validator.handle_quantity({'quantity': 10, 'stripe_price_id': QUARTERLY_PRICE_ID}) + self.assertEqual(result['error_code'], 'incomplete_data') + + # ── handle_ssp_product_slug: non-string input → INVALID_FORMAT ── + def test_handle_ssp_product_slug_invalid_format(self, mock_pricing): + result = self.validator.handle_ssp_product_slug({'ssp_product_slug': 123}) + self.assertEqual(result['error_code'], 'invalid_format') + mock_pricing.assert_not_called() + + # ── handle_ssp_product_slug: unknown slug → DOES_NOT_EXIST ── + def test_handle_ssp_product_slug_not_found(self, mock_pricing): + mock_pricing.return_value = MOCK_SSP_PRICING_DATA + result = self.validator.handle_ssp_product_slug({'ssp_product_slug': 'nonexistent_plan'}) + self.assertEqual(result['error_code'], 'does_not_exist') + + # ── handle_ssp_product_slug: pricing call raises → DOES_NOT_EXIST ── + def test_handle_ssp_product_slug_pricing_exception(self, mock_pricing): + mock_pricing.side_effect = RuntimeError('boom') + result = self.validator.handle_ssp_product_slug({'ssp_product_slug': 'quarterly_license_plan'}) + self.assertEqual(result['error_code'], 'does_not_exist') + + # ── handle_stripe_price_id: non-string input → INVALID_FORMAT ── + def test_handle_stripe_price_id_invalid_format(self, mock_pricing): + result = self.validator.handle_stripe_price_id({'stripe_price_id': 999}) + self.assertEqual(result['error_code'], 'invalid_format') + mock_pricing.assert_not_called() + + # ── handle_stripe_price_id: pricing call raises → DOES_NOT_EXIST ── + def test_handle_stripe_price_id_pricing_exception(self, mock_pricing): + mock_pricing.side_effect = RuntimeError('boom') + result = self.validator.handle_stripe_price_id({'stripe_price_id': QUARTERLY_PRICE_ID}) + self.assertEqual(result['error_code'], 'does_not_exist') diff --git a/enterprise_access/apps/customer_billing/tests/test_pricing_api.py b/enterprise_access/apps/customer_billing/tests/test_pricing_api.py index de460b94..b53bdbeb 100644 --- a/enterprise_access/apps/customer_billing/tests/test_pricing_api.py +++ b/enterprise_access/apps/customer_billing/tests/test_pricing_api.py @@ -1,6 +1,7 @@ """ Unit tests for the pricing_api module. """ +import uuid from decimal import Decimal from unittest import mock @@ -10,17 +11,18 @@ from stripe import InvalidRequestError from enterprise_access.apps.customer_billing import pricing_api +from enterprise_access.apps.customer_billing.models import SspProduct MOCK_SSP_PRODUCTS = { 'quarterly_license_plan': { 'stripe_price_id': 'price_test_quarterly', # DEPRECATED: Use lookup_key instead 'lookup_key': 'price_quarterly_0002', - 'quantity_range': (5, 30), + 'quantity_range': [5, 50], }, 'yearly_license_plan': { 'stripe_price_id': 'price_test_yearly', # DEPRECATED: Use lookup_key instead 'lookup_key': 'price_yearly_0001', - 'quantity_range': (5, 30), + 'quantity_range': [5, 50], }, } @@ -37,6 +39,24 @@ class TestStripePricingAPI(TestCase): def setUp(self): # Clear cache before each test TieredCache.dangerous_clear_all_tiers() + SspProduct.objects.create( + slug='quarterly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + academy_uuid=None, + catalog_query_uuid='aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa', + license_manager_product_id_trial=2, + license_manager_product_id_paid=1, + is_active=True, + ) + SspProduct.objects.create( + slug='yearly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['yearly_license_plan']['lookup_key'], + academy_uuid=None, + catalog_query_uuid='bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb', + license_manager_product_id_trial=2, + license_manager_product_id_paid=1, + is_active=True, + ) def tearDown(self): # Clear cache after each test @@ -95,7 +115,9 @@ def test_get_stripe_price_data_basic_format(self, mock_stripe): } } - self.assertEqual(result, expected) + # Only assert expected keys to remain resilient to optional fields like `ssp_product_slug` + for k, v in expected.items(): + self.assertEqual(result.get(k), v) mock_stripe.Price.retrieve.assert_called_once_with(price_id, expand=['product']) @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') @@ -145,6 +167,22 @@ def test_get_stripe_price_data_stripe_error(self, mock_stripe_price): @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') def test_get_ssp_product_pricing(self, mock_stripe): """Test fetching SSP product pricing.""" + # Ensure we exercise the settings-backed path for quantity_range + SspProduct.objects.all().delete() + + SspProduct.objects.create( + slug='quarterly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + SspProduct.objects.create( + slug='yearly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['yearly_license_plan']['lookup_key'], + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + quarterly_price = self._create_mock_stripe_price() yearly_price = self._create_mock_stripe_price( price_id=MOCK_SSP_PRODUCTS['yearly_license_plan']['stripe_price_id'], @@ -154,14 +192,14 @@ def test_get_ssp_product_pricing(self, mock_stripe): result = pricing_api.get_ssp_product_pricing() - # Should have entries for configured SSP products + # Should have entries for configured SSP products (from settings) self.assertIn('quarterly_license_plan', result) self.assertIn('yearly_license_plan', result) - # Check that SSP-specific metadata is added + # Check that SSP-specific metadata is added and quantity_range is sourced from settings quarterly_data = result['quarterly_license_plan'] self.assertEqual(quarterly_data['ssp_product_key'], 'quarterly_license_plan') - self.assertEqual(quarterly_data['quantity_range'], (5, 30)) + self.assertEqual(quarterly_data.get('quantity_range'), [5, 50]) def test_calculate_subtotal_basic_format(self): """Test subtotal calculation with basic format.""" @@ -409,3 +447,293 @@ def test_validate_stripe_price_schema_variants( with self.assertRaises(pricing_api.StripePricingError) as cm: pricing_api._validate_stripe_price_schema(mock_price) self.assertIn(expect_error, str(cm.exception)) + + def test_serialize_basic_format_with_product_metadata_ssp_slug(self): + """Product metadata with ssp_product_slug should be preferred.""" + mock_price = self._create_mock_stripe_price() + mock_price.product.metadata = {'ssp_product_slug': 'meta-slug'} + + result = pricing_api._serialize_basic_format(mock_price) # pylint: disable=protected-access + + self.assertIn('ssp_product_slug', result) + self.assertEqual(result['ssp_product_slug'], 'meta-slug') + + def test_serialize_basic_format_model_fallback_for_ssp_slug(self): + """When product metadata lacks ssp_product_slug, fallback to SspProduct lookup_key.""" + # Create a model-backed SSP product to be discovered by lookup_key + SspProduct.objects.create( + slug='fallback_slug', + stripe_price_lookup_key='lookup_fallback', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + mock_price = self._create_mock_stripe_price(lookup_key='lookup_fallback') + mock_price.product.metadata = {} + + result = pricing_api._serialize_basic_format(mock_price) # pylint: disable=protected-access + + self.assertIn('ssp_product_slug', result) + self.assertEqual(result['ssp_product_slug'], 'fallback_slug') + + def test_get_ssp_product_pricing_raises_on_missing_lookup_key(self): + """If an active SspProduct is missing lookup_key, raise StripePricingError.""" + # Add a product missing a lookup key + SspProduct.objects.create( + slug='bad_product', + stripe_price_lookup_key='', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + with mock.patch('enterprise_access.apps.customer_billing.pricing_api.get_all_stripe_prices') as mock_all: + mock_all.return_value = {} + with self.assertRaises(pricing_api.StripePricingError): + pricing_api.get_ssp_product_pricing() + + def test_get_ssp_product_pricing_raises_when_lookup_key_not_found(self): + """If lookup_key for a product isn't present in Stripe prices, raise StripePricingError.""" + SspProduct.objects.create( + slug='missing_lookup', + stripe_price_lookup_key='no_such_lookup', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + with mock.patch('enterprise_access.apps.customer_billing.pricing_api.get_all_stripe_prices') as mock_all: + mock_all.return_value = {} + with self.assertRaises(pricing_api.StripePricingError): + pricing_api.get_ssp_product_pricing() + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe.Price') + def test_get_stripe_price_data_non_stripe_exception(self, mock_stripe_price): + """General (non-Stripe) exceptions should be wrapped in StripePricingError.""" + mock_stripe_price.retrieve.side_effect = RuntimeError('connection reset') + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api.get_stripe_price_data('price_999') + + self.assertIn('Unexpected error', str(cm.exception)) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_all_stripe_prices_basic(self, mock_stripe): + """Directly test get_all_stripe_prices returns a lookup_key mapping.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'year' + mock_recurring.interval_count = 1 + mock_recurring.usage_type = 'licensed' + + price = self._create_mock_stripe_price( + price_id='price_all_1', lookup_key='lk_all_1', recurring=mock_recurring, + ) + mock_stripe.Price.list.return_value.auto_paging_iter.return_value = [price] + + result = pricing_api.get_all_stripe_prices() + + self.assertIn('lk_all_1', result) + self.assertEqual(result['lk_all_1']['unit_amount'], 10000) + self.assertEqual(result['lk_all_1']['currency'], 'usd') + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_all_stripe_prices_caching(self, mock_stripe): + """Second call to get_all_stripe_prices should return cached result.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'year' + mock_recurring.interval_count = 1 + mock_recurring.usage_type = 'licensed' + + price = self._create_mock_stripe_price( + price_id='price_cache', lookup_key='lk_cache', recurring=mock_recurring, + ) + mock_stripe.Price.list.return_value.auto_paging_iter.return_value = [price] + + result1 = pricing_api.get_all_stripe_prices() + + # Reset the mock so we can verify it is NOT called again + mock_stripe.Price.list.reset_mock() + + result2 = pricing_api.get_all_stripe_prices() + + self.assertEqual(result1, result2) + # Stripe should not be called again; the cache should serve the result + mock_stripe.Price.list.assert_not_called() + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_all_stripe_prices_skips_non_recurring(self, mock_stripe): + """Non-recurring (one_time) prices should be skipped.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'year' + mock_recurring.interval_count = 1 + mock_recurring.usage_type = 'licensed' + + recurring_price = self._create_mock_stripe_price( + price_id='price_rec', lookup_key='lk_rec', recurring=mock_recurring, + ) + + one_time_price = self._create_mock_stripe_price( + price_id='price_ot', lookup_key='lk_ot', + ) + one_time_price.type = 'one_time' + + mock_stripe.Price.list.return_value.auto_paging_iter.return_value = [ + recurring_price, one_time_price, + ] + + result = pricing_api.get_all_stripe_prices() + + self.assertIn('lk_rec', result) + self.assertNotIn('lk_ot', result) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_all_stripe_prices_skips_missing_lookup_key(self, mock_stripe): + """Prices without a lookup_key should be skipped with a warning.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'year' + mock_recurring.interval_count = 1 + mock_recurring.usage_type = 'licensed' + + price_with_key = self._create_mock_stripe_price( + price_id='price_wk', lookup_key='lk_wk', recurring=mock_recurring, + ) + + price_no_key = self._create_mock_stripe_price( + price_id='price_nk', recurring=mock_recurring, + ) + price_no_key.lookup_key = None + + mock_stripe.Price.list.return_value.auto_paging_iter.return_value = [ + price_with_key, price_no_key, + ] + + result = pricing_api.get_all_stripe_prices() + + self.assertIn('lk_wk', result) + self.assertEqual(len(result), 1) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe.Price') + def test_get_all_stripe_prices_stripe_error(self, mock_stripe_price): + """StripeError during Price.list should raise StripePricingError.""" + mock_stripe_price.list.side_effect = InvalidRequestError('bad request', 'param') + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api.get_all_stripe_prices() + + self.assertIn('Failed to fetch all prices', str(cm.exception)) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe.Price') + def test_get_all_stripe_prices_general_exception(self, mock_stripe_price): + """General exception during Price.list should raise StripePricingError.""" + mock_stripe_price.list.side_effect = RuntimeError('unexpected failure') + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api.get_all_stripe_prices() + + self.assertIn('Unexpected error', str(cm.exception)) + + def test_calculate_subtotal_returns_none_on_error(self): + """Missing keys in price_data should cause calculate_subtotal to return None.""" + # Empty dict triggers KeyError for 'unit_amount' + result = pricing_api.calculate_subtotal({}, 5) + self.assertIsNone(result) + + def test_format_price_display_returns_unavailable_on_exception(self): + """Missing keys inside the try block should return 'Price unavailable'.""" + # currency matches so we enter the try block, but unit_amount_decimal is missing + price_data = {'currency': 'usd'} + + result = pricing_api.format_price_display(price_data) + self.assertEqual(result, 'Price unavailable') + + def test_validate_stripe_price_schema_invalid_currency_type(self): + """Non-string currency should raise StripePricingError.""" + mock_price = self._create_mock_stripe_price() + mock_price.currency = 12345 # not a string + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api._validate_stripe_price_schema(mock_price) # pylint: disable=protected-access + + self.assertIn('Invalid currency type', str(cm.exception)) + + def test_validate_stripe_price_schema_missing_interval_count(self): + """Recurring price with missing interval_count should raise StripePricingError.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'month' + mock_recurring.interval_count = None # missing + mock_recurring.usage_type = 'licensed' + + mock_price = self._create_mock_stripe_price(recurring=mock_recurring) + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api._validate_stripe_price_schema(mock_price) # pylint: disable=protected-access + + self.assertIn('Recurring price missing interval_count', str(cm.exception)) + + def test_get_ssp_product_pricing_raises_on_missing_lookup_key_slug_format(self): + """Ensure missing lookup_key exception correctly formats using the product slug.""" + SspProduct.objects.all().delete() + SspProduct.objects.create( + slug='bad_product_slug_test', + stripe_price_lookup_key='', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + with mock.patch('enterprise_access.apps.customer_billing.pricing_api.get_all_stripe_prices') as mock_all: + mock_all.return_value = {} + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api.get_ssp_product_pricing() + + self.assertIn('SSP product bad_product_slug_test missing lookup_key', str(cm.exception)) + + def test_serialize_basic_format_metadata_exception_handled(self): + """Ensure exceptions raised during metadata access fallback safely to ssp_slug = None.""" + mock_price = self._create_mock_stripe_price() + mock_metadata = mock.MagicMock() + mock_metadata.get.side_effect = RuntimeError("Metadata completely broken") + mock_price.product.metadata = mock_metadata + + # pylint: disable=protected-access + result = pricing_api._serialize_basic_format(mock_price) + self.assertEqual(result.get('ssp_product_slug'), 'quarterly_license_plan') + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_ssp_product_pricing_skips_invalid_settings_config(self, mock_stripe): + """Ensure settings blocks missing lookup_key or quantity_range are skipped gracefully.""" + SspProduct.objects.all().delete() + + # Create a valid active DB product pointing to the complete settings block + SspProduct.objects.create( + slug='complete_plan', + stripe_price_lookup_key='valid_lk_range', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + mock_price = self._create_mock_stripe_price(lookup_key='valid_lk_range') + mock_stripe.Price.list().auto_paging_iter.return_value = [mock_price] + + result = pricing_api.get_ssp_product_pricing() + + # The complete plan should process properly + self.assertIn('complete_plan', result) + self.assertEqual(result['complete_plan'].get('quantity_range'), [5, 50]) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_ssp_product_pricing_ignores_inactive_ssp_products(self, mock_stripe): + """Ensure that SspProduct database filtering strictly targets is_active=True objects.""" + SspProduct.objects.all().delete() + + # Create an inactive product that matches valid settings + SspProduct.objects.create( + slug='quarterly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + is_active=False, + catalog_query_uuid=uuid.uuid4(), + ) + + mock_price = self._create_mock_stripe_price() + mock_stripe.Price.list().auto_paging_iter.return_value = [mock_price] + + result = pricing_api.get_ssp_product_pricing() + + self.assertEqual(len(result), 0) diff --git a/enterprise_access/settings/base.py b/enterprise_access/settings/base.py index 7d419360..c54fb4a3 100644 --- a/enterprise_access/settings/base.py +++ b/enterprise_access/settings/base.py @@ -683,7 +683,7 @@ def root(*path_fragments): ENABLE_BILLING_MANAGEMENT_API = False DEFAULT_SSP_PRICE_LOOKUP_KEY = 'teams_subscription_license_yearly' - +DEFAULT_SSP_QUANTITY_RANGE = [5, 50] # Default SSP product slug assigned to new CheckoutIntent records when no product is specified. # Override this in environment-specific settings to change the default product. SSP_DEFAULT_PRODUCT_SLUG = 'teams-yearly'