From 92f9f51136ec3d026f8e009869b9af1cd531c4b4 Mon Sep 17 00:00:00 2001 From: tsunkara-sonata Date: Sat, 20 Jun 2026 15:38:01 +0000 Subject: [PATCH 1/6] feat: bff checkout update --- .../apps/api/serializers/customer_billing.py | 17 + .../api/v1/tests/test_customer_billing.py | 28 +- .../apps/api/v1/views/customer_billing.py | 6 +- .../apps/bffs/checkout/handlers.py | 2 + .../apps/bffs/checkout/serializers.py | 10 + .../apps/customer_billing/api.py | 88 ++-- .../apps/customer_billing/constants.py | 4 + ...ter_checkoutintent_ssp_product_and_more.py | 25 ++ .../apps/customer_billing/models.py | 45 +- .../apps/customer_billing/pricing_api.py | 49 ++- .../apps/customer_billing/tests/test_api.py | 216 +++++++++- .../tests/test_pricing_api.py | 385 +++++++++++++++++- 12 files changed, 791 insertions(+), 84 deletions(-) create mode 100644 enterprise_access/apps/customer_billing/migrations/0038_alter_checkoutintent_ssp_product_and_more.py diff --git a/enterprise_access/apps/api/serializers/customer_billing.py b/enterprise_access/apps/api/serializers/customer_billing.py index 6607ce47..4b39eda0 100644 --- a/enterprise_access/apps/api/serializers/customer_billing.py +++ b/enterprise_access/apps/api/serializers/customer_billing.py @@ -15,6 +15,7 @@ CheckoutIntent, FailedCheckoutIntentConflict, SlugReservationConflict, + SspProduct, StripeEventSummary ) @@ -60,6 +61,10 @@ class CustomerBillingCreateCheckoutSessionRequestSerializer(serializers.Serializ required=True, help_text='The ID of the Stripe Price object representing the plan selection.', ) + ssp_product_slug = serializers.SlugField( + required=False, + help_text='The slug of the SSP product representing the plan selection.', + ) # pylint: disable=abstract-method @@ -117,6 +122,10 @@ class CustomerBillingCreateCheckoutSessionValidationFailedResponseSerializer(ser required=False, help_text='Validation results for stripe_price_id if validation failed. Absent otherwise.', ) + ssp_product_slug = ErrorDetailSerializer( + required=False, + help_text='Validation results for ssp_product_slug if validation failed. Absent otherwise.', + ) company_name = ErrorDetailSerializer( required=False, help_text='Validation results for company_name if validation failed. Absent otherwise.', @@ -205,6 +214,7 @@ class Meta: 'quantity', 'country', 'terms_metadata', + 'ssp_product' ] ] @@ -241,6 +251,12 @@ def create(self, validated_data): Creates a new CheckoutIntent. """ try: + ssp_product_slug = validated_data.pop('ssp_product', None) + ssp_product = None + if ssp_product_slug: + ssp_product = SspProduct.objects.filter( + slug=ssp_product_slug, is_active=True + ).first() return CheckoutIntent.create_intent( user=self.context['request'].user, quantity=validated_data['quantity'], @@ -248,6 +264,7 @@ def create(self, validated_data): name=validated_data.get('enterprise_name'), country=validated_data.get('country'), terms_metadata=validated_data.get('terms_metadata'), + ssp_product=ssp_product, ) # Catch exceptions that should return 422: diff --git a/enterprise_access/apps/api/v1/tests/test_customer_billing.py b/enterprise_access/apps/api/v1/tests/test_customer_billing.py index 173cf373..c98ca1d9 100644 --- a/enterprise_access/apps/api/v1/tests/test_customer_billing.py +++ b/enterprise_access/apps/api/v1/tests/test_customer_billing.py @@ -3899,17 +3899,23 @@ def test_create_checkout_session_returns_client_secret_from_dict( 'client_secret': 'cs_test_abc_secret_xyz', } - response = self.client.post( - self.url, - data={ - 'admin_email': self.user.email, - 'enterprise_slug': 'test-slug', - 'company_name': 'Test Co', - 'quantity': 5, - 'stripe_price_id': 'price_abc123', - }, - format='json', - ) + # Prevent live Stripe pricing lookups during checkout flow + with mock.patch('enterprise_access.apps.customer_billing.api.get_ssp_product_pricing') as mock_get_pricing: + mock_get_pricing.return_value = { + 'quarterly_license_plan': {'id': 'price_test_quarterly', 'quantity_range': (5, 30)} + } + + response = self.client.post( + self.url, + data={ + 'admin_email': self.user.email, + 'enterprise_slug': 'test-slug', + 'company_name': 'Test Co', + 'quantity': 5, + 'stripe_price_id': 'price_abc123', + }, + format='json', + ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual( diff --git a/enterprise_access/apps/api/v1/views/customer_billing.py b/enterprise_access/apps/api/v1/views/customer_billing.py index 1f03ec83..3758428e 100644 --- a/enterprise_access/apps/api/v1/views/customer_billing.py +++ b/enterprise_access/apps/api/v1/views/customer_billing.py @@ -256,7 +256,8 @@ def create_checkout_session(self, request, *args, **kwargs): >>> "admin_email": "dr@evil.inc", >>> "enterprise_slug": "my-sluggy" >>> "quantity": 7, - >>> "stripe_price_id": "price_1MoBy5LkdIwHu7ixZhnattbh" + >>> "stripe_price_id": "price_1MoBy5LkdIwHu7ixZhnattbh", + >>> "ssp_product_slug": "ai-academy-yearly" >>> } HTTP 201 CREATED >>> { @@ -287,7 +288,8 @@ def create_checkout_session(self, request, *args, **kwargs): 'Handling request to create free trial plan. ' f'enterprise_slug="{validated_data["enterprise_slug"]}" ' f'quantity="{validated_data["quantity"]}" ' - f'stripe_price_id="{validated_data["stripe_price_id"]}"' + f'stripe_price_id="{validated_data.get("stripe_price_id")}"' + f'ssp_product_slug="{validated_data.get("ssp_product_slug")}" ' ) try: session = create_free_trial_checkout_session( diff --git a/enterprise_access/apps/bffs/checkout/handlers.py b/enterprise_access/apps/bffs/checkout/handlers.py index 99136508..c1da55ad 100644 --- a/enterprise_access/apps/bffs/checkout/handlers.py +++ b/enterprise_access/apps/bffs/checkout/handlers.py @@ -170,8 +170,10 @@ def _get_pricing_data(self) -> Dict: for _, price_data in pricing_data.items(): prices.append({ 'id': price_data.get('id'), + 'stripe_price_id': price_data.get('stripe_price_id') or price_data.get('id'), 'product': price_data.get('product', {}).get('id'), 'lookup_key': price_data.get('lookup_key'), + 'ssp_product_slug': price_data.get('ssp_product_slug'), 'recurring': price_data.get('recurring', {}), 'currency': price_data.get('currency'), 'unit_amount': price_data.get('unit_amount'), diff --git a/enterprise_access/apps/bffs/checkout/serializers.py b/enterprise_access/apps/bffs/checkout/serializers.py index 4f9c812d..f7c2d26b 100644 --- a/enterprise_access/apps/bffs/checkout/serializers.py +++ b/enterprise_access/apps/bffs/checkout/serializers.py @@ -28,11 +28,17 @@ class PriceSerializer(serializers.Serializer): Serializer for Stripe price objects in checkout context. """ id = serializers.CharField(help_text="Stripe Price ID") + stripe_price_id = serializers.CharField(required=False, help_text="Stripe Price ID") product = serializers.CharField(help_text="Stripe Product ID") lookup_key = serializers.CharField(help_text="Lookup key for this price") recurring = serializers.DictField( help_text="Recurring billing configuration" ) + ssp_product_slug = serializers.CharField( + required=False, + allow_null=True, + help_text="SSP product slug" + ) currency = serializers.CharField(help_text="Currency code (e.g. 'usd')") unit_amount = serializers.IntegerField(help_text="Price amount in cents") unit_amount_decimal = serializers.DecimalField( @@ -214,6 +220,10 @@ class CheckoutValidationRequestSerializer(serializers.Serializer): enterprise_slug = serializers.SlugField(required=False, allow_blank=True, help_text="Desired enterprise slug") quantity = serializers.IntegerField(required=False, allow_null=True, help_text="Number of licenses") stripe_price_id = serializers.CharField(required=False, allow_blank=True, help_text="Stripe price ID") + ssp_product_slug = serializers.SlugField( + required=False, allow_blank=True, + help_text="SSP product slug for the selected plan", + ) class UserAuthInfoSerializer(serializers.Serializer): diff --git a/enterprise_access/apps/customer_billing/api.py b/enterprise_access/apps/customer_billing/api.py index f9a8ce84..f40c1925 100644 --- a/enterprise_access/apps/customer_billing/api.py +++ b/enterprise_access/apps/customer_billing/api.py @@ -17,7 +17,8 @@ from enterprise_access.apps.customer_billing.models import ( CheckoutIntent, FailedCheckoutIntentConflict, - SlugReservationConflict + SlugReservationConflict, + SspProduct ) from enterprise_access.apps.customer_billing.pricing_api import get_ssp_product_pricing from enterprise_access.apps.customer_billing.stripe_api import create_subscription_checkout_session @@ -35,6 +36,7 @@ class CheckoutSessionInputValidatorData(TypedDict, total=False): full_name: str enterprise_slug: str quantity: int + ssp_product_slug: str stripe_price_id: str @@ -47,6 +49,7 @@ class CheckoutSessionInputData(TypedDict, total=True): enterprise_slug: str company_name: str quantity: int + ssp_product_slug: str stripe_price_id: str @@ -194,10 +197,11 @@ def handle_quantity(self, input_data: CheckoutSessionInputValidatorData) -> Fiel Validate the `quantity` field using Stripe price data. """ quantity = input_data.get('quantity') + ssp_product_slug = input_data.get('ssp_product_slug') stripe_price_id = input_data.get('stripe_price_id') # We need multiple form fields to validate quantity. - if not all([quantity, stripe_price_id]): + if not quantity or (not ssp_product_slug and not stripe_price_id): error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['common']['INCOMPLETE_DATA'] return {'error_code': error_code, 'developer_message': developer_message} @@ -210,13 +214,14 @@ def handle_quantity(self, input_data: CheckoutSessionInputValidatorData) -> Fiel # Get the SSP product pricing data (includes quantity ranges) ssp_pricing = get_ssp_product_pricing() - # Find the SSP product that matches this stripe_price_id matching_product = None - for _, price_data in ssp_pricing.items(): - if price_data.get('id') == stripe_price_id: - matching_product = price_data - break - + if ssp_product_slug: + matching_product = ssp_pricing.get(ssp_product_slug) + elif stripe_price_id: + for price_data in ssp_pricing.values(): + if price_data.get('id') == stripe_price_id: + matching_product = price_data + break if not matching_product: error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['common']['INCOMPLETE_DATA'] return {'error_code': error_code, 'developer_message': developer_message} @@ -232,37 +237,56 @@ def handle_quantity(self, input_data: CheckoutSessionInputValidatorData) -> Fiel return {'error_code': error_code, 'developer_message': developer_message} except Exception as exc: # pylint: disable=broad-exception-caught - logger.error(f'Error validating quantity for stripe_price_id {stripe_price_id}: {exc}') + logger.error( + f'Error validating quantity for ssp_product_slug {ssp_product_slug} ' + f'or stripe_price_id {stripe_price_id}: {exc}' + ) error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['common']['INCOMPLETE_DATA'] return {'error_code': error_code, 'developer_message': developer_message} return {'error_code': None, 'developer_message': None} - def handle_stripe_price_id(self, input_data: CheckoutSessionInputValidatorData) -> FieldValidationResult: + def handle_ssp_product_slug(self, input_data: CheckoutSessionInputValidatorData) -> FieldValidationResult: """ - Validate the `stripe_price_id` field against active Stripe prices. + Validate the `ssp_product_slug` field against active SSP products. """ - stripe_price_id = input_data.get('stripe_price_id') + ssp_product_slug = input_data.get('ssp_product_slug') # "Invalid format" if empty, missing, or not a str. - if not isinstance(stripe_price_id, str) or not stripe_price_id: - error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['stripe_price_id']['INVALID_FORMAT'] + if not isinstance(ssp_product_slug, str) or not ssp_product_slug: + error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['ssp_product_slug']['INVALID_FORMAT'] return {'error_code': error_code, 'developer_message': developer_message} try: # Get SSP product pricing data (validates lookup_keys against Stripe) ssp_pricing = get_ssp_product_pricing() - # Check if the price_id exists in any of the configured SSP products - price_exists = any( - price_data.get('id') == stripe_price_id - for price_data in ssp_pricing.values() - ) + if ssp_product_slug not in ssp_pricing: + error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['ssp_product_slug']['DOES_NOT_EXIST'] + return {'error_code': error_code, 'developer_message': developer_message} + + except Exception as exc: # pylint: disable=broad-exception-caught + logger.error(f'Error validating ssp_product_slug {ssp_product_slug}: {exc}') + error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['ssp_product_slug']['DOES_NOT_EXIST'] + return {'error_code': error_code, 'developer_message': developer_message} + + return {'error_code': None, 'developer_message': None} - if not price_exists: + def handle_stripe_price_id(self, input_data: CheckoutSessionInputValidatorData) -> FieldValidationResult: + """ + Validate the `stripe_price_id` field against active SSP products. + """ + stripe_price_id = input_data.get('stripe_price_id') + + if not isinstance(stripe_price_id, str) or not stripe_price_id: + error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['stripe_price_id']['INVALID_FORMAT'] + return {'error_code': error_code, 'developer_message': developer_message} + + try: + ssp_pricing = get_ssp_product_pricing() + if not any(price_data.get('id') == stripe_price_id for price_data in ssp_pricing.values()): error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['stripe_price_id']['DOES_NOT_EXIST'] return {'error_code': error_code, 'developer_message': developer_message} - except Exception as exc: # pylint: disable=broad-exception-caught logger.error(f'Error validating stripe_price_id {stripe_price_id}: {exc}') error_code, developer_message = CHECKOUT_SESSION_ERROR_CODES['stripe_price_id']['DOES_NOT_EXIST'] @@ -350,6 +374,7 @@ def handle_company_name(self, input_data: CheckoutSessionInputValidatorData) -> 'company_name': handle_company_name, 'enterprise_slug': handle_enterprise_slug, 'quantity': handle_quantity, + 'ssp_product_slug': handle_ssp_product_slug, 'stripe_price_id': handle_stripe_price_id, 'user': handle_user, } @@ -422,7 +447,23 @@ def create_free_trial_checkout_session( raise CreateCheckoutSessionValidationError(validation_errors_by_field=validation_errors) user = input_data['user'] - + # ── Resolve slug ↔ price_id BEFORE creating intent ── + ssp_pricing = get_ssp_product_pricing() + ssp_product_slug = input_data.get('ssp_product_slug') + stripe_price_id = input_data.get('stripe_price_id') + if not ssp_product_slug and stripe_price_id: + for candidate_slug, price_data in ssp_pricing.items(): + if price_data.get('id') == stripe_price_id: + ssp_product_slug = candidate_slug + break + if ssp_product_slug: + stripe_price_id = ssp_pricing[ssp_product_slug]['id'] + # ── Resolve SspProduct instance for FK ── + ssp_product_instance = None + if ssp_product_slug: + ssp_product_instance = SspProduct.objects.filter( + slug=ssp_product_slug, is_active=True + ).first() # Create checkout intent, which reserves the enterprise name & slug. try: intent = CheckoutIntent.create_intent( @@ -430,6 +471,7 @@ def create_free_trial_checkout_session( quantity=input_data.get('quantity'), slug=input_data.get('enterprise_slug'), name=input_data.get('company_name'), + ssp_product=ssp_product_instance, ) except SlugReservationConflict as exc: raise CreateCheckoutSessionSlugReservationConflict() from exc @@ -438,7 +480,7 @@ def create_free_trial_checkout_session( lms_user_id = user.lms_user_id checkout_session = create_subscription_checkout_session( - input_data=input_data, + input_data={**input_data, 'ssp_product_slug': ssp_product_slug, 'stripe_price_id': stripe_price_id}, lms_user_id=lms_user_id, checkout_intent=intent, ) diff --git a/enterprise_access/apps/customer_billing/constants.py b/enterprise_access/apps/customer_billing/constants.py index 87b61f9c..f67e0284 100644 --- a/enterprise_access/apps/customer_billing/constants.py +++ b/enterprise_access/apps/customer_billing/constants.py @@ -53,6 +53,10 @@ 'INVALID_FORMAT': ('invalid_format', 'Must be a non-empty string.'), 'DOES_NOT_EXIST': ('does_not_exist', 'This stripe_price_id has not been configured.'), }, + 'ssp_product_slug': { + 'INVALID_FORMAT': ('invalid_format', 'Must be a non-empty string.'), + 'DOES_NOT_EXIST': ('does_not_exist', 'This ssp_product_slug has not been configured.'), + }, } # According to stripe's AI assistant: "When a Checkout Session is created, diff --git a/enterprise_access/apps/customer_billing/migrations/0038_alter_checkoutintent_ssp_product_and_more.py b/enterprise_access/apps/customer_billing/migrations/0038_alter_checkoutintent_ssp_product_and_more.py new file mode 100644 index 00000000..f586d6f3 --- /dev/null +++ b/enterprise_access/apps/customer_billing/migrations/0038_alter_checkoutintent_ssp_product_and_more.py @@ -0,0 +1,25 @@ +# Generated by Django 5.2.13 on 2026-06-22 05:59 + +import django.db.models.deletion +import enterprise_access.apps.customer_billing.models +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('customer_billing', '0037_alter_checkoutintent_sspproduct_nonnull_and_unique'), + ] + + operations = [ + migrations.AlterField( + model_name='checkoutintent', + name='ssp_product', + field=models.ForeignKey(default=enterprise_access.apps.customer_billing.models.get_default_ssp_product_pk, help_text='The SSP product associated with this checkout intent.', on_delete=django.db.models.deletion.PROTECT, to='customer_billing.sspproduct'), + ), + migrations.AlterField( + model_name='historicalcheckoutintent', + name='ssp_product', + field=models.ForeignKey(blank=True, db_constraint=False, default=enterprise_access.apps.customer_billing.models.get_default_ssp_product_pk, help_text='The SSP product associated with this checkout intent.', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='+', to='customer_billing.sspproduct'), + ), + ] diff --git a/enterprise_access/apps/customer_billing/models.py b/enterprise_access/apps/customer_billing/models.py index 76adf535..b9178110 100644 --- a/enterprise_access/apps/customer_billing/models.py +++ b/enterprise_access/apps/customer_billing/models.py @@ -33,9 +33,17 @@ User = get_user_model() -def get_default_ssp_product_slug(): - """Return the default SSP product slug from Django settings.""" - return settings.SSP_DEFAULT_PRODUCT_SLUG +def get_default_ssp_product_pk(): + """Return the PK of the default SSP product.""" + product, _ = SspProduct.objects.get_or_create( + slug=settings.SSP_DEFAULT_PRODUCT_SLUG, + defaults={ + 'stripe_price_lookup_key': 'teams_subscription_license_yearly', + 'catalog_query_uuid': uuid4(), + 'is_active': True, + } + ) + return product.pk class FailedCheckoutIntentConflict(Exception): @@ -304,7 +312,7 @@ class StateChoices(models.TextChoices): on_delete=models.PROTECT, null=False, blank=False, - default=get_default_ssp_product_slug, + default=get_default_ssp_product_pk, help_text='The SSP product associated with this checkout intent.', ) terms_metadata = models.JSONField( @@ -645,7 +653,8 @@ def create_intent( slug: str | None = None, name: str | None = None, country: str | None = None, - terms_metadata: dict | None = None + terms_metadata: dict | None = None, + ssp_product: 'SspProduct | None' = None, ) -> Self: """ Create or update a checkout intent for a user with the given enterprise details. @@ -744,20 +753,24 @@ def create_intent( existing_intent.enterprise_name = name or existing_intent.enterprise_name existing_intent.country = country or existing_intent.country existing_intent.terms_metadata = (existing_intent.terms_metadata or {}) | (terms_metadata or {}) - + if ssp_product is not None: + existing_intent.ssp_product = ssp_product existing_intent.save() return existing_intent - return cls.objects.create( - user=user, - state=CheckoutIntentState.CREATED, - enterprise_slug=slug, - enterprise_name=name, - quantity=quantity, - expires_at=expires_at, - country=country, - terms_metadata=terms_metadata, - ) + create_kwargs = { + 'user': user, + 'state': CheckoutIntentState.CREATED, + 'enterprise_slug': slug, + 'enterprise_name': name, + 'quantity': quantity, + 'expires_at': expires_at, + 'country': country, + 'terms_metadata': terms_metadata, + } + if ssp_product is not None: + create_kwargs['ssp_product'] = ssp_product + return cls.objects.create(**create_kwargs) @classmethod def for_user(cls, user): diff --git a/enterprise_access/apps/customer_billing/pricing_api.py b/enterprise_access/apps/customer_billing/pricing_api.py index 04bbb356..caddccd9 100644 --- a/enterprise_access/apps/customer_billing/pricing_api.py +++ b/enterprise_access/apps/customer_billing/pricing_api.py @@ -29,6 +29,7 @@ from django.conf import settings from edx_django_utils.cache import TieredCache +from enterprise_access.apps.customer_billing.models import SspProduct from enterprise_access.cache_utils import versioned_cache_key logger = logging.getLogger(__name__) @@ -264,23 +265,32 @@ def get_ssp_product_pricing() -> Dict[str, Dict]: all_stripe_prices = get_all_stripe_prices() ssp_pricing = {} - for product_key, product_config in settings.SSP_PRODUCTS.items(): - lookup_key = product_config.get('lookup_key') + settings_quantity_ranges = {} + for product_config in getattr(settings, 'SSP_PRODUCTS', {}).values(): + lk = product_config.get('lookup_key') + qr = product_config.get('quantity_range') + if lk and qr: + settings_quantity_ranges[lk] = qr + for ssp_product in SspProduct.objects.filter(is_active=True): + lookup_key = ssp_product.stripe_price_lookup_key if not lookup_key: - logger.error(f'SSP product {product_key} missing lookup_key') - raise StripePricingError(f'SSP product {product_key} missing lookup_key') + logger.error(f'SSP product {ssp_product.slug} missing lookup_key') + raise StripePricingError(f'SSP product {ssp_product.slug} missing lookup_key') if lookup_key not in all_stripe_prices: - logger.error(f'lookup_key {lookup_key} for SSP product {product_key} not found in active Stripe prices') + logger.error( + f'lookup_key {lookup_key} for SSP product {ssp_product.slug} ' + f'not found in active Stripe prices') raise StripePricingError( - f'lookup_key {lookup_key} for SSP product {product_key} not found in active Stripe prices' + f'lookup_key {lookup_key} for SSP product {ssp_product.slug} ' + f'not found in active Stripe prices' ) price_data = all_stripe_prices[lookup_key].copy() # Add SSP-specific metadata - price_data['ssp_product_key'] = product_key - price_data['quantity_range'] = product_config.get('quantity_range') - ssp_pricing[product_key] = price_data + price_data['ssp_product_key'] = ssp_product.slug + price_data['quantity_range'] = settings_quantity_ranges.get(lookup_key) + ssp_pricing[ssp_product.slug] = price_data return ssp_pricing @@ -423,4 +433,25 @@ def _serialize_basic_format(stripe_price: stripe.Price) -> SerializedPriceData: 'metadata': product.metadata, } + # Prefer explicit metadata set on the Stripe Product (Terraform sets this). + ssp_slug = None + try: + ssp_slug = product.metadata.get('ssp_product_slug') if getattr(product, 'metadata', None) else None + except Exception: # pylint: disable=broad-exception-caught + ssp_slug = None + + # Fallback: try to resolve from our SspProduct model using lookup_key + if not ssp_slug: + lookup_key = getattr(stripe_price, 'lookup_key', None) + if lookup_key: + try: + ssp = SspProduct.objects.filter(stripe_price_lookup_key=lookup_key).only('slug').first() + if ssp: + ssp_slug = ssp.slug + except Exception: # pylint: disable=broad-exception-caught + logger.exception('Error looking up SspProduct for lookup_key %s', lookup_key) + + if ssp_slug: + base_data['ssp_product_slug'] = ssp_slug + return base_data diff --git a/enterprise_access/apps/customer_billing/tests/test_api.py b/enterprise_access/apps/customer_billing/tests/test_api.py index 6396217e..a8c8e501 100644 --- a/enterprise_access/apps/customer_billing/tests/test_api.py +++ b/enterprise_access/apps/customer_billing/tests/test_api.py @@ -6,6 +6,7 @@ import ddt import requests +import stripe from django.contrib.auth import get_user_model from django.test import TestCase, override_settings from django.utils import timezone @@ -14,7 +15,11 @@ from enterprise_access.apps.customer_billing import api as customer_billing_api from enterprise_access.apps.customer_billing import stripe_api from enterprise_access.apps.customer_billing.constants import CheckoutIntentState -from enterprise_access.apps.customer_billing.models import CheckoutIntent +from enterprise_access.apps.customer_billing.models import ( + CheckoutIntent, + FailedCheckoutIntentConflict, + SlugReservationConflict +) User = get_user_model() @@ -128,6 +133,50 @@ def test_create_free_trial_checkout_session_success( self.assertEqual(intent.stripe_customer_id, 'cust-123') self.assertFalse(intent.is_expired()) + @mock.patch( + 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', + return_value=MOCK_SSP_PRICING_DATA, + ) + @mock.patch.object(customer_billing_api, 'LmsApiClient', autospec=True) + @mock.patch.object(stripe_api, 'stripe') + @mock.patch('enterprise_access.apps.customer_billing.api.create_subscription_checkout_session') + def test_create_free_trial_checkout_session_with_ssp_product_slug( + self, + mock_create_checkout, + mock_stripe, # pylint: disable=unused-argument + mock_lms_client_class, + mock_get_ssp_pricing, # pylint: disable=unused-argument + ): + """ + Ensure that supplying `ssp_product_slug` resolves to an `SspProduct` instance + and that the created CheckoutIntent stores the FK correctly. + """ + # Setup mocks + mock_lms_client = mock_lms_client_class.return_value + mock_lms_client.get_lms_user_account.return_value = [{'id': self.user.lms_user_id}] + mock_lms_client.get_enterprise_customer_data.side_effect = raise_404_error + + mock_create_checkout.return_value = {'id': 'sess-ssp', 'customer': 'cust-ssp'} + + # Call with ssp_product_slug instead of stripe_price_id + result = customer_billing_api.create_free_trial_checkout_session( + user=self.user, + admin_email=self.user.email, + enterprise_slug='my-sluggy', + company_name='My Cool Company', + quantity=20, + ssp_product_slug='quarterly_license_plan', + ) + + self.assertEqual(result, {'id': 'sess-ssp', 'customer': 'cust-ssp'}) + + intent = CheckoutIntent.objects.get(user=self.user) + self.assertIsNotNone(intent.ssp_product) + + # Ensure the checkout creator received the provided ssp_product_slug + called_input = mock_create_checkout.call_args[1]['input_data'] + self.assertEqual(called_input.get('ssp_product_slug'), 'quarterly_license_plan') + # Assert library methods were called correctly. mock_lms_client.get_lms_user_account.assert_called_once_with( email=self.user.email, @@ -137,11 +186,9 @@ def test_create_free_trial_checkout_session_success( mock.call(enterprise_customer_name='My Cool Company'), ]) - # Check that customer slug and user data is in Stripe metadata - call_args = mock_stripe.checkout.Session.create.call_args - metadata = call_args[1]['subscription_data']['metadata'] - self.assertEqual(metadata['enterprise_customer_slug'], 'my-sluggy') - self.assertEqual(metadata['lms_user_id'], str(self.user.lms_user_id)) + # Ensure the create checkout helper received enterprise slug and admin email + self.assertEqual(called_input.get('enterprise_slug'), 'my-sluggy') + self.assertEqual(called_input.get('admin_email'), self.user.email) @mock.patch( 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', @@ -172,9 +219,10 @@ def test_create_free_trial_checkout_session_success_without_user( quantity=20, stripe_price_id=QUARTERLY_PRICE_ID, ) - # Should get slug reserved error - validation_errors = cm.exception.validation_errors_by_field - self.assertIn('user', validation_errors) + + # Should get user validation error + validation_errors = cm.exception.validation_errors_by_field + self.assertIn('user', validation_errors) @mock.patch( 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', @@ -262,11 +310,10 @@ def test_slug_reservation_conflict( quantity=20, stripe_price_id=QUARTERLY_PRICE_ID, ) - - # Should get slug reserved error - validation_errors = cm.exception.validation_errors_by_field - self.assertIn('enterprise_slug', validation_errors) - self.assertEqual(validation_errors['enterprise_slug']['error_code'], 'slug_reserved') + # Should get slug reserved error + validation_errors = cm.exception.validation_errors_by_field + self.assertIn('enterprise_slug', validation_errors) + self.assertEqual(validation_errors['enterprise_slug']['error_code'], 'slug_reserved') @mock.patch( 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', @@ -304,11 +351,10 @@ def test_name_reservation_conflict( quantity=20, stripe_price_id=QUARTERLY_PRICE_ID, ) - - # Should get slug reserved error - validation_errors = cm.exception.validation_errors_by_field - self.assertIn('company_name', validation_errors) - self.assertEqual(validation_errors['company_name']['error_code'], 'existing_enterprise_customer') + # Should get company name conflict error + validation_errors = cm.exception.validation_errors_by_field + self.assertIn('company_name', validation_errors) + self.assertEqual(validation_errors['company_name']['error_code'], 'existing_enterprise_customer') @mock.patch( 'enterprise_access.apps.customer_billing.api.get_ssp_product_pricing', @@ -444,6 +490,16 @@ def test_expired_intent_allows_reuse( }, } }, + { + 'request_quantity': 0, + 'request_stripe_price_id': QUARTERLY_PRICE_ID, + 'expected_validation_errors': { + 'quantity': { + 'error_code': 'incomplete_data', + 'developer_message': 'Not enough parameters were given.', + } + } + }, ) @ddt.unpack @mock.patch( @@ -500,3 +556,125 @@ def test_create_free_trial_checkout_session_errors( actual_validation_errors = cm.exception.validation_errors_by_field assert actual_validation_errors == expected_validation_errors + + @mock.patch( + 'enterprise_access.apps.customer_billing.' + 'api.get_ssp_product_pricing', + return_value=MOCK_SSP_PRICING_DATA + ) + @mock.patch.object(customer_billing_api, 'LmsApiClient', autospec=True) + @mock.patch.object(CheckoutIntent, 'create_intent') + def test_create_intent_raises_slug_reservation_conflict(self, mock_create_intent, mock_lms_client_class, _): + """ + Test that SlugReservationConflict from create_intent is wrapped in CreateCheckoutSessionSlugReservationConflict. + """ + mock_lms = mock_lms_client_class.return_value + mock_lms.get_lms_user_account.return_value = [{'id': self.user.lms_user_id}] + mock_lms.get_enterprise_customer_data.side_effect = raise_404_error + mock_create_intent.side_effect = SlugReservationConflict() + with self.assertRaises(customer_billing_api.CreateCheckoutSessionSlugReservationConflict) as cm: + customer_billing_api.create_free_trial_checkout_session( + user=self.user, admin_email=self.user.email, enterprise_slug='s', + company_name='C', quantity=10, stripe_price_id=QUARTERLY_PRICE_ID, + ) + self.assertEqual(cm.exception.non_field_errors[0]['error_code'], 'checkout_intent_conflict_slug_reserved') + + @mock.patch( + 'enterprise_access.apps.customer_billing.' + 'api.get_ssp_product_pricing', + return_value=MOCK_SSP_PRICING_DATA + ) + @mock.patch.object(customer_billing_api, 'LmsApiClient', autospec=True) + @mock.patch.object(CheckoutIntent, 'create_intent') + def test_create_intent_raises_failed_conflict(self, mock_create_intent, mock_lms_client_class, _): + """ + Test that FailedCheckoutIntentConflict from create_intent is wrapped in CreateCheckoutSessionFailedConflict. + """ + mock_lms = mock_lms_client_class.return_value + mock_lms.get_lms_user_account.return_value = [{'id': self.user.lms_user_id}] + mock_lms.get_enterprise_customer_data.side_effect = raise_404_error + mock_create_intent.side_effect = FailedCheckoutIntentConflict() + with self.assertRaises(customer_billing_api.CreateCheckoutSessionFailedConflict) as cm: + customer_billing_api.create_free_trial_checkout_session( + user=self.user, admin_email=self.user.email, enterprise_slug='s', + company_name='C', quantity=10, stripe_price_id=QUARTERLY_PRICE_ID, + ) + self.assertEqual(cm.exception.non_field_errors[0]['error_code'], 'checkout_intent_conflict_failed') + + +class TestCreateStripeBillingPortalSession(TestCase): + """ + Tests for the ``create_stripe_billing_portal_session()`` function. + """ + + def test_no_customer_id_raises_value_error(self): + """Missing stripe_customer_id raises ValueError.""" + intent = mock.Mock(stripe_customer_id=None, id='intent-1') + with self.assertRaises(ValueError): + customer_billing_api.create_stripe_billing_portal_session(intent, 'https://return.url') + + @mock.patch('enterprise_access.apps.customer_billing.api.stripe') + def test_success(self, mock_stripe): + """Happy path returns a portal session.""" + intent = mock.Mock(stripe_customer_id='cus_123', id='intent-1') + mock_stripe.billing_portal.Session.create.return_value = mock.Mock(id='bps_1') + result = customer_billing_api.create_stripe_billing_portal_session(intent, 'https://return.url') + self.assertEqual(result.id, 'bps_1') + mock_stripe.billing_portal.Session.create.assert_called_once_with( + customer='cus_123', return_url='https://return.url', + ) + + @mock.patch('enterprise_access.apps.customer_billing.api.stripe') + def test_stripe_error_propagates(self, mock_stripe): + """StripeError from portal creation is re-raised after logging.""" + intent = mock.Mock(stripe_customer_id='cus_123', id='intent-1') + mock_stripe.StripeError = stripe.StripeError + mock_stripe.billing_portal.Session.create.side_effect = stripe.StripeError('fail') + with self.assertRaises(stripe.StripeError): + customer_billing_api.create_stripe_billing_portal_session(intent, 'https://return.url') + + +@mock.patch('enterprise_access.apps.customer_billing.api.get_ssp_product_pricing') +class TestValidatorHandlerEdgeCases(TestCase): + """ + Direct handler tests for uncovered exception/edge-case branches. + """ + + def setUp(self): + self.validator = customer_billing_api.CheckoutSessionInputValidator() + + # ── handle_quantity: except Exception when get_ssp_product_pricing raises ── + def test_handle_quantity_pricing_exception(self, mock_pricing): + mock_pricing.side_effect = RuntimeError('boom') + result = self.validator.handle_quantity({'quantity': 10, 'stripe_price_id': QUARTERLY_PRICE_ID}) + self.assertEqual(result['error_code'], 'incomplete_data') + + # ── handle_ssp_product_slug: non-string input → INVALID_FORMAT ── + def test_handle_ssp_product_slug_invalid_format(self, mock_pricing): + result = self.validator.handle_ssp_product_slug({'ssp_product_slug': 123}) + self.assertEqual(result['error_code'], 'invalid_format') + mock_pricing.assert_not_called() + + # ── handle_ssp_product_slug: unknown slug → DOES_NOT_EXIST ── + def test_handle_ssp_product_slug_not_found(self, mock_pricing): + mock_pricing.return_value = MOCK_SSP_PRICING_DATA + result = self.validator.handle_ssp_product_slug({'ssp_product_slug': 'nonexistent_plan'}) + self.assertEqual(result['error_code'], 'does_not_exist') + + # ── handle_ssp_product_slug: pricing call raises → DOES_NOT_EXIST ── + def test_handle_ssp_product_slug_pricing_exception(self, mock_pricing): + mock_pricing.side_effect = RuntimeError('boom') + result = self.validator.handle_ssp_product_slug({'ssp_product_slug': 'quarterly_license_plan'}) + self.assertEqual(result['error_code'], 'does_not_exist') + + # ── handle_stripe_price_id: non-string input → INVALID_FORMAT ── + def test_handle_stripe_price_id_invalid_format(self, mock_pricing): + result = self.validator.handle_stripe_price_id({'stripe_price_id': 999}) + self.assertEqual(result['error_code'], 'invalid_format') + mock_pricing.assert_not_called() + + # ── handle_stripe_price_id: pricing call raises → DOES_NOT_EXIST ── + def test_handle_stripe_price_id_pricing_exception(self, mock_pricing): + mock_pricing.side_effect = RuntimeError('boom') + result = self.validator.handle_stripe_price_id({'stripe_price_id': QUARTERLY_PRICE_ID}) + self.assertEqual(result['error_code'], 'does_not_exist') diff --git a/enterprise_access/apps/customer_billing/tests/test_pricing_api.py b/enterprise_access/apps/customer_billing/tests/test_pricing_api.py index de460b94..8dd76b03 100644 --- a/enterprise_access/apps/customer_billing/tests/test_pricing_api.py +++ b/enterprise_access/apps/customer_billing/tests/test_pricing_api.py @@ -1,6 +1,7 @@ """ Unit tests for the pricing_api module. """ +import uuid from decimal import Decimal from unittest import mock @@ -10,6 +11,7 @@ from stripe import InvalidRequestError from enterprise_access.apps.customer_billing import pricing_api +from enterprise_access.apps.customer_billing.models import SspProduct MOCK_SSP_PRODUCTS = { 'quarterly_license_plan': { @@ -37,6 +39,24 @@ class TestStripePricingAPI(TestCase): def setUp(self): # Clear cache before each test TieredCache.dangerous_clear_all_tiers() + SspProduct.objects.create( + slug='quarterly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + academy_uuid=None, + catalog_query_uuid='aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa', + license_manager_product_id_trial=2, + license_manager_product_id_paid=1, + is_active=True, + ) + SspProduct.objects.create( + slug='yearly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['yearly_license_plan']['lookup_key'], + academy_uuid=None, + catalog_query_uuid='bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb', + license_manager_product_id_trial=2, + license_manager_product_id_paid=1, + is_active=True, + ) def tearDown(self): # Clear cache after each test @@ -95,7 +115,9 @@ def test_get_stripe_price_data_basic_format(self, mock_stripe): } } - self.assertEqual(result, expected) + # Only assert expected keys to remain resilient to optional fields like `ssp_product_slug` + for k, v in expected.items(): + self.assertEqual(result.get(k), v) mock_stripe.Price.retrieve.assert_called_once_with(price_id, expand=['product']) @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') @@ -145,6 +167,22 @@ def test_get_stripe_price_data_stripe_error(self, mock_stripe_price): @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') def test_get_ssp_product_pricing(self, mock_stripe): """Test fetching SSP product pricing.""" + # Ensure we exercise the settings-backed path for quantity_range + SspProduct.objects.all().delete() + + SspProduct.objects.create( + slug='quarterly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + SspProduct.objects.create( + slug='yearly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['yearly_license_plan']['lookup_key'], + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + quarterly_price = self._create_mock_stripe_price() yearly_price = self._create_mock_stripe_price( price_id=MOCK_SSP_PRODUCTS['yearly_license_plan']['stripe_price_id'], @@ -154,14 +192,14 @@ def test_get_ssp_product_pricing(self, mock_stripe): result = pricing_api.get_ssp_product_pricing() - # Should have entries for configured SSP products + # Should have entries for configured SSP products (from settings) self.assertIn('quarterly_license_plan', result) self.assertIn('yearly_license_plan', result) - # Check that SSP-specific metadata is added + # Check that SSP-specific metadata is added and quantity_range is sourced from settings quarterly_data = result['quarterly_license_plan'] self.assertEqual(quarterly_data['ssp_product_key'], 'quarterly_license_plan') - self.assertEqual(quarterly_data['quantity_range'], (5, 30)) + self.assertEqual(quarterly_data.get('quantity_range'), (5, 30)) def test_calculate_subtotal_basic_format(self): """Test subtotal calculation with basic format.""" @@ -409,3 +447,342 @@ def test_validate_stripe_price_schema_variants( with self.assertRaises(pricing_api.StripePricingError) as cm: pricing_api._validate_stripe_price_schema(mock_price) self.assertIn(expect_error, str(cm.exception)) + + def test_serialize_basic_format_with_product_metadata_ssp_slug(self): + """Product metadata with ssp_product_slug should be preferred.""" + mock_price = self._create_mock_stripe_price() + mock_price.product.metadata = {'ssp_product_slug': 'meta-slug'} + + result = pricing_api._serialize_basic_format(mock_price) # pylint: disable=protected-access + + self.assertIn('ssp_product_slug', result) + self.assertEqual(result['ssp_product_slug'], 'meta-slug') + + def test_serialize_basic_format_model_fallback_for_ssp_slug(self): + """When product metadata lacks ssp_product_slug, fallback to SspProduct lookup_key.""" + # Create a model-backed SSP product to be discovered by lookup_key + SspProduct.objects.create( + slug='fallback_slug', + stripe_price_lookup_key='lookup_fallback', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + mock_price = self._create_mock_stripe_price(lookup_key='lookup_fallback') + mock_price.product.metadata = {} + + result = pricing_api._serialize_basic_format(mock_price) # pylint: disable=protected-access + + self.assertIn('ssp_product_slug', result) + self.assertEqual(result['ssp_product_slug'], 'fallback_slug') + + def test_get_ssp_product_pricing_raises_on_missing_lookup_key(self): + """If an active SspProduct is missing lookup_key, raise StripePricingError.""" + # Add a product missing a lookup key + SspProduct.objects.create( + slug='bad_product', + stripe_price_lookup_key='', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + with mock.patch('enterprise_access.apps.customer_billing.pricing_api.get_all_stripe_prices') as mock_all: + mock_all.return_value = {} + with self.assertRaises(pricing_api.StripePricingError): + pricing_api.get_ssp_product_pricing() + + def test_get_ssp_product_pricing_raises_when_lookup_key_not_found(self): + """If lookup_key for a product isn't present in Stripe prices, raise StripePricingError.""" + SspProduct.objects.create( + slug='missing_lookup', + stripe_price_lookup_key='no_such_lookup', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + with mock.patch('enterprise_access.apps.customer_billing.pricing_api.get_all_stripe_prices') as mock_all: + mock_all.return_value = {} + with self.assertRaises(pricing_api.StripePricingError): + pricing_api.get_ssp_product_pricing() + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe.Price') + def test_get_stripe_price_data_non_stripe_exception(self, mock_stripe_price): + """General (non-Stripe) exceptions should be wrapped in StripePricingError.""" + mock_stripe_price.retrieve.side_effect = RuntimeError('connection reset') + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api.get_stripe_price_data('price_999') + + self.assertIn('Unexpected error', str(cm.exception)) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_all_stripe_prices_basic(self, mock_stripe): + """Directly test get_all_stripe_prices returns a lookup_key mapping.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'year' + mock_recurring.interval_count = 1 + mock_recurring.usage_type = 'licensed' + + price = self._create_mock_stripe_price( + price_id='price_all_1', lookup_key='lk_all_1', recurring=mock_recurring, + ) + mock_stripe.Price.list.return_value.auto_paging_iter.return_value = [price] + + result = pricing_api.get_all_stripe_prices() + + self.assertIn('lk_all_1', result) + self.assertEqual(result['lk_all_1']['unit_amount'], 10000) + self.assertEqual(result['lk_all_1']['currency'], 'usd') + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_all_stripe_prices_caching(self, mock_stripe): + """Second call to get_all_stripe_prices should return cached result.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'year' + mock_recurring.interval_count = 1 + mock_recurring.usage_type = 'licensed' + + price = self._create_mock_stripe_price( + price_id='price_cache', lookup_key='lk_cache', recurring=mock_recurring, + ) + mock_stripe.Price.list.return_value.auto_paging_iter.return_value = [price] + + result1 = pricing_api.get_all_stripe_prices() + + # Reset the mock so we can verify it is NOT called again + mock_stripe.Price.list.reset_mock() + + result2 = pricing_api.get_all_stripe_prices() + + self.assertEqual(result1, result2) + # Stripe should not be called again; the cache should serve the result + mock_stripe.Price.list.assert_not_called() + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_all_stripe_prices_skips_non_recurring(self, mock_stripe): + """Non-recurring (one_time) prices should be skipped.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'year' + mock_recurring.interval_count = 1 + mock_recurring.usage_type = 'licensed' + + recurring_price = self._create_mock_stripe_price( + price_id='price_rec', lookup_key='lk_rec', recurring=mock_recurring, + ) + + one_time_price = self._create_mock_stripe_price( + price_id='price_ot', lookup_key='lk_ot', + ) + one_time_price.type = 'one_time' + + mock_stripe.Price.list.return_value.auto_paging_iter.return_value = [ + recurring_price, one_time_price, + ] + + result = pricing_api.get_all_stripe_prices() + + self.assertIn('lk_rec', result) + self.assertNotIn('lk_ot', result) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_all_stripe_prices_skips_missing_lookup_key(self, mock_stripe): + """Prices without a lookup_key should be skipped with a warning.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'year' + mock_recurring.interval_count = 1 + mock_recurring.usage_type = 'licensed' + + price_with_key = self._create_mock_stripe_price( + price_id='price_wk', lookup_key='lk_wk', recurring=mock_recurring, + ) + + price_no_key = self._create_mock_stripe_price( + price_id='price_nk', recurring=mock_recurring, + ) + price_no_key.lookup_key = None + + mock_stripe.Price.list.return_value.auto_paging_iter.return_value = [ + price_with_key, price_no_key, + ] + + result = pricing_api.get_all_stripe_prices() + + self.assertIn('lk_wk', result) + self.assertEqual(len(result), 1) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe.Price') + def test_get_all_stripe_prices_stripe_error(self, mock_stripe_price): + """StripeError during Price.list should raise StripePricingError.""" + mock_stripe_price.list.side_effect = InvalidRequestError('bad request', 'param') + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api.get_all_stripe_prices() + + self.assertIn('Failed to fetch all prices', str(cm.exception)) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe.Price') + def test_get_all_stripe_prices_general_exception(self, mock_stripe_price): + """General exception during Price.list should raise StripePricingError.""" + mock_stripe_price.list.side_effect = RuntimeError('unexpected failure') + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api.get_all_stripe_prices() + + self.assertIn('Unexpected error', str(cm.exception)) + + def test_calculate_subtotal_returns_none_on_error(self): + """Missing keys in price_data should cause calculate_subtotal to return None.""" + # Empty dict triggers KeyError for 'unit_amount' + result = pricing_api.calculate_subtotal({}, 5) + self.assertIsNone(result) + + def test_format_price_display_returns_unavailable_on_exception(self): + """Missing keys inside the try block should return 'Price unavailable'.""" + # currency matches so we enter the try block, but unit_amount_decimal is missing + price_data = {'currency': 'usd'} + + result = pricing_api.format_price_display(price_data) + self.assertEqual(result, 'Price unavailable') + + def test_validate_stripe_price_schema_invalid_currency_type(self): + """Non-string currency should raise StripePricingError.""" + mock_price = self._create_mock_stripe_price() + mock_price.currency = 12345 # not a string + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api._validate_stripe_price_schema(mock_price) # pylint: disable=protected-access + + self.assertIn('Invalid currency type', str(cm.exception)) + + def test_validate_stripe_price_schema_missing_interval_count(self): + """Recurring price with missing interval_count should raise StripePricingError.""" + mock_recurring = mock.MagicMock() + mock_recurring.interval = 'month' + mock_recurring.interval_count = None # missing + mock_recurring.usage_type = 'licensed' + + mock_price = self._create_mock_stripe_price(recurring=mock_recurring) + + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api._validate_stripe_price_schema(mock_price) # pylint: disable=protected-access + + self.assertIn('Recurring price missing interval_count', str(cm.exception)) + + def test_get_ssp_product_pricing_raises_on_missing_lookup_key_slug_format(self): + """Ensure missing lookup_key exception correctly formats using the product slug.""" + SspProduct.objects.all().delete() + SspProduct.objects.create( + slug='bad_product_slug_test', + stripe_price_lookup_key='', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + with mock.patch('enterprise_access.apps.customer_billing.pricing_api.get_all_stripe_prices') as mock_all: + mock_all.return_value = {} + with self.assertRaises(pricing_api.StripePricingError) as cm: + pricing_api.get_ssp_product_pricing() + + self.assertIn('SSP product bad_product_slug_test missing lookup_key', str(cm.exception)) + + def test_serialize_basic_format_metadata_exception_handled(self): + """Ensure exceptions raised during metadata access fallback safely to ssp_slug = None.""" + mock_price = self._create_mock_stripe_price() + mock_metadata = mock.MagicMock() + mock_metadata.get.side_effect = RuntimeError("Metadata completely broken") + mock_price.product.metadata = mock_metadata + + # pylint: disable=protected-access + result = pricing_api._serialize_basic_format(mock_price) + self.assertEqual(result.get('ssp_product_slug'), 'quarterly_license_plan') + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.SspProduct.objects.get') + def test_serialize_basic_format_database_exception_handled(self, mock_get): + """Ensure an unexpected DB exception during lookup fallback is swallowed and logged.""" + mock_price = self._create_mock_stripe_price(lookup_key='corrupt_lookup_key') + mock_price.product.metadata = {} + + mock_get.side_effect = Exception("Database connection dropped completely") + + # pylint: disable=protected-access + result = pricing_api._serialize_basic_format(mock_price) + self.assertNotIn('ssp_product_slug', result) + + @override_settings( + SSP_PRODUCTS={ + 'incomplete_plan_1': { + 'lookup_key': 'only_key_no_range', + }, + 'incomplete_plan_2': { + 'quantity_range': (10, 50), + }, + 'complete_plan': { + 'lookup_key': 'valid_lk_range', + 'quantity_range': (5, 100), + } + } + ) + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_ssp_product_pricing_skips_invalid_settings_config(self, mock_stripe): + """Ensure settings blocks missing lookup_key or quantity_range are skipped gracefully.""" + SspProduct.objects.all().delete() + + # Create a valid active DB product pointing to the complete settings block + SspProduct.objects.create( + slug='complete_plan', + stripe_price_lookup_key='valid_lk_range', + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + + mock_price = self._create_mock_stripe_price(lookup_key='valid_lk_range') + mock_stripe.Price.list().auto_paging_iter.return_value = [mock_price] + + result = pricing_api.get_ssp_product_pricing() + + # The complete plan should process properly + self.assertIn('complete_plan', result) + self.assertEqual(result['complete_plan'].get('quantity_range'), (5, 100)) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') + def test_get_ssp_product_pricing_ignores_inactive_ssp_products(self, mock_stripe): + """Ensure that SspProduct database filtering strictly targets is_active=True objects.""" + SspProduct.objects.all().delete() + + # Create an inactive product that matches valid settings + SspProduct.objects.create( + slug='quarterly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + is_active=False, + catalog_query_uuid=uuid.uuid4(), + ) + + mock_price = self._create_mock_stripe_price() + mock_stripe.Price.list().auto_paging_iter.return_value = [mock_price] + + result = pricing_api.get_ssp_product_pricing() + + self.assertEqual(len(result), 0) + + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.logger') + @mock.patch('enterprise_access.apps.customer_billing.pricing_api.SspProduct.objects') + def test_serialize_basic_format_ssp_lookup_exception(self, mock_ssp_objects, mock_logger): + """ + Covers the except Exception branch when SspProduct lookup raises an error. + """ + mock_ssp_objects.filter.side_effect = RuntimeError('DB error') + mock_product = mock.MagicMock() + mock_product.metadata = {} + mock_price = mock.MagicMock() + mock_price.id = 'price_test' + mock_price.unit_amount = 1000 + mock_price.currency = 'usd' + mock_price.recurring = None + mock_price.lookup_key = 'test_lookup_key' + mock_price.product = mock_product + result = pricing_api._serialize_basic_format(mock_price) # pylint: disable=protected-access + self.assertIsNone(result.get('ssp_product_slug')) + mock_ssp_objects.filter.assert_called_once_with(stripe_price_lookup_key='test_lookup_key') + mock_logger.exception.assert_called_once_with( + 'Error looking up SspProduct for lookup_key %s', 'test_lookup_key' + ) From 40bf720aa9981bf16c79c07a2447fcf7b8267be2 Mon Sep 17 00:00:00 2001 From: tsunkara-sonata Date: Thu, 25 Jun 2026 05:48:26 +0000 Subject: [PATCH 2/6] feat: added addressed comments --- .../apps/api/serializers/customer_billing.py | 7 +-- .../v1/tests/test_checkout_intent_views.py | 9 ++- .../apps/customer_billing/api.py | 19 +++---- ...ter_checkoutintent_ssp_product_and_more.py | 25 --------- .../apps/customer_billing/models.py | 17 ++---- .../apps/customer_billing/pricing_api.py | 22 ++------ .../tests/test_pricing_api.py | 56 ++----------------- enterprise_access/settings/base.py | 2 +- 8 files changed, 32 insertions(+), 125 deletions(-) delete mode 100644 enterprise_access/apps/customer_billing/migrations/0038_alter_checkoutintent_ssp_product_and_more.py diff --git a/enterprise_access/apps/api/serializers/customer_billing.py b/enterprise_access/apps/api/serializers/customer_billing.py index 4b39eda0..fa188ba8 100644 --- a/enterprise_access/apps/api/serializers/customer_billing.py +++ b/enterprise_access/apps/api/serializers/customer_billing.py @@ -251,12 +251,7 @@ def create(self, validated_data): Creates a new CheckoutIntent. """ try: - ssp_product_slug = validated_data.pop('ssp_product', None) - ssp_product = None - if ssp_product_slug: - ssp_product = SspProduct.objects.filter( - slug=ssp_product_slug, is_active=True - ).first() + ssp_product = validated_data.pop('ssp_product', None) return CheckoutIntent.create_intent( user=self.context['request'].user, quantity=validated_data['quantity'], diff --git a/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py b/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py index 1114558e..6682d939 100644 --- a/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py +++ b/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py @@ -13,7 +13,7 @@ from enterprise_access.apps.core.constants import SYSTEM_ENTERPRISE_LEARNER_ROLE from enterprise_access.apps.core.tests.factories import UserFactory from enterprise_access.apps.customer_billing.constants import CheckoutIntentState -from enterprise_access.apps.customer_billing.models import CheckoutIntent +from enterprise_access.apps.customer_billing.models import CheckoutIntent, SspProduct from test_utils import APITest User = get_user_model() @@ -68,6 +68,13 @@ def setUpTestData(cls): def setUp(self): """Set up test data.""" super().setUp() + SspProduct.objects.get_or_create( + slug='teams-yearly', + defaults={ + 'stripe_price_lookup_key': 'teams_subscription_license_yearly', + 'is_active': True, + } + ) self.checkout_intent_1 = CheckoutIntent.objects.create( user=self.user, diff --git a/enterprise_access/apps/customer_billing/api.py b/enterprise_access/apps/customer_billing/api.py index f40c1925..464b17a6 100644 --- a/enterprise_access/apps/customer_billing/api.py +++ b/enterprise_access/apps/customer_billing/api.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import TypedDict, Unpack, cast +import settings import stripe from django.contrib.auth import get_user_model from django.contrib.auth.models import AbstractUser @@ -447,23 +448,19 @@ def create_free_trial_checkout_session( raise CreateCheckoutSessionValidationError(validation_errors_by_field=validation_errors) user = input_data['user'] - # ── Resolve slug ↔ price_id BEFORE creating intent ── - ssp_pricing = get_ssp_product_pricing() ssp_product_slug = input_data.get('ssp_product_slug') - stripe_price_id = input_data.get('stripe_price_id') - if not ssp_product_slug and stripe_price_id: - for candidate_slug, price_data in ssp_pricing.items(): - if price_data.get('id') == stripe_price_id: - ssp_product_slug = candidate_slug - break - if ssp_product_slug: - stripe_price_id = ssp_pricing[ssp_product_slug]['id'] - # ── Resolve SspProduct instance for FK ── + if not ssp_product_slug: + ssp_product_slug = getattr(settings, 'SSP_DEFAULT_PRODUCT_SLUG', None) + ssp_product = None + ssp_product_instance = None if ssp_product_slug: ssp_product_instance = SspProduct.objects.filter( slug=ssp_product_slug, is_active=True ).first() + stripe_price_id = input_data.get('stripe_price_id') + if not stripe_price_id and ssp_product_instance: + stripe_price_id = ssp_product_instance.stripe_price_lookup_key # Create checkout intent, which reserves the enterprise name & slug. try: intent = CheckoutIntent.create_intent( diff --git a/enterprise_access/apps/customer_billing/migrations/0038_alter_checkoutintent_ssp_product_and_more.py b/enterprise_access/apps/customer_billing/migrations/0038_alter_checkoutintent_ssp_product_and_more.py deleted file mode 100644 index f586d6f3..00000000 --- a/enterprise_access/apps/customer_billing/migrations/0038_alter_checkoutintent_ssp_product_and_more.py +++ /dev/null @@ -1,25 +0,0 @@ -# Generated by Django 5.2.13 on 2026-06-22 05:59 - -import django.db.models.deletion -import enterprise_access.apps.customer_billing.models -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('customer_billing', '0037_alter_checkoutintent_sspproduct_nonnull_and_unique'), - ] - - operations = [ - migrations.AlterField( - model_name='checkoutintent', - name='ssp_product', - field=models.ForeignKey(default=enterprise_access.apps.customer_billing.models.get_default_ssp_product_pk, help_text='The SSP product associated with this checkout intent.', on_delete=django.db.models.deletion.PROTECT, to='customer_billing.sspproduct'), - ), - migrations.AlterField( - model_name='historicalcheckoutintent', - name='ssp_product', - field=models.ForeignKey(blank=True, db_constraint=False, default=enterprise_access.apps.customer_billing.models.get_default_ssp_product_pk, help_text='The SSP product associated with this checkout intent.', null=True, on_delete=django.db.models.deletion.DO_NOTHING, related_name='+', to='customer_billing.sspproduct'), - ), - ] diff --git a/enterprise_access/apps/customer_billing/models.py b/enterprise_access/apps/customer_billing/models.py index b9178110..49bf435d 100644 --- a/enterprise_access/apps/customer_billing/models.py +++ b/enterprise_access/apps/customer_billing/models.py @@ -33,18 +33,9 @@ User = get_user_model() -def get_default_ssp_product_pk(): - """Return the PK of the default SSP product.""" - product, _ = SspProduct.objects.get_or_create( - slug=settings.SSP_DEFAULT_PRODUCT_SLUG, - defaults={ - 'stripe_price_lookup_key': 'teams_subscription_license_yearly', - 'catalog_query_uuid': uuid4(), - 'is_active': True, - } - ) - return product.pk - +def get_default_ssp_product_slug(): + """Return the default SSP product slug from Django settings.""" + return settings.SSP_DEFAULT_PRODUCT_SLUG class FailedCheckoutIntentConflict(Exception): pass @@ -312,7 +303,7 @@ class StateChoices(models.TextChoices): on_delete=models.PROTECT, null=False, blank=False, - default=get_default_ssp_product_pk, + default=get_default_ssp_product_slug, help_text='The SSP product associated with this checkout intent.', ) terms_metadata = models.JSONField( diff --git a/enterprise_access/apps/customer_billing/pricing_api.py b/enterprise_access/apps/customer_billing/pricing_api.py index caddccd9..9eca5b78 100644 --- a/enterprise_access/apps/customer_billing/pricing_api.py +++ b/enterprise_access/apps/customer_billing/pricing_api.py @@ -265,12 +265,7 @@ def get_ssp_product_pricing() -> Dict[str, Dict]: all_stripe_prices = get_all_stripe_prices() ssp_pricing = {} - settings_quantity_ranges = {} - for product_config in getattr(settings, 'SSP_PRODUCTS', {}).values(): - lk = product_config.get('lookup_key') - qr = product_config.get('quantity_range') - if lk and qr: - settings_quantity_ranges[lk] = qr + default_quantity_range = getattr(settings, 'DEFAULT_SSP_QUANTITY_RANGE', [5, 50]) for ssp_product in SspProduct.objects.filter(is_active=True): lookup_key = ssp_product.stripe_price_lookup_key if not lookup_key: @@ -289,7 +284,7 @@ def get_ssp_product_pricing() -> Dict[str, Dict]: price_data = all_stripe_prices[lookup_key].copy() # Add SSP-specific metadata price_data['ssp_product_key'] = ssp_product.slug - price_data['quantity_range'] = settings_quantity_ranges.get(lookup_key) + price_data['quantity_range'] = default_quantity_range ssp_pricing[ssp_product.slug] = price_data return ssp_pricing @@ -444,14 +439,9 @@ def _serialize_basic_format(stripe_price: stripe.Price) -> SerializedPriceData: if not ssp_slug: lookup_key = getattr(stripe_price, 'lookup_key', None) if lookup_key: - try: - ssp = SspProduct.objects.filter(stripe_price_lookup_key=lookup_key).only('slug').first() - if ssp: - ssp_slug = ssp.slug - except Exception: # pylint: disable=broad-exception-caught - logger.exception('Error looking up SspProduct for lookup_key %s', lookup_key) - - if ssp_slug: - base_data['ssp_product_slug'] = ssp_slug + ssp = SspProduct.objects.filter(stripe_price_lookup_key=lookup_key).only('slug').first() + ssp_slug = ssp.slug if ssp else None + + base_data['ssp_product_slug'] = ssp_slug return base_data diff --git a/enterprise_access/apps/customer_billing/tests/test_pricing_api.py b/enterprise_access/apps/customer_billing/tests/test_pricing_api.py index 8dd76b03..937c5340 100644 --- a/enterprise_access/apps/customer_billing/tests/test_pricing_api.py +++ b/enterprise_access/apps/customer_billing/tests/test_pricing_api.py @@ -17,12 +17,12 @@ 'quarterly_license_plan': { 'stripe_price_id': 'price_test_quarterly', # DEPRECATED: Use lookup_key instead 'lookup_key': 'price_quarterly_0002', - 'quantity_range': (5, 30), + 'quantity_range': [5, 50], }, 'yearly_license_plan': { 'stripe_price_id': 'price_test_yearly', # DEPRECATED: Use lookup_key instead 'lookup_key': 'price_yearly_0001', - 'quantity_range': (5, 30), + 'quantity_range': [5, 50], }, } @@ -199,7 +199,7 @@ def test_get_ssp_product_pricing(self, mock_stripe): # Check that SSP-specific metadata is added and quantity_range is sourced from settings quarterly_data = result['quarterly_license_plan'] self.assertEqual(quarterly_data['ssp_product_key'], 'quarterly_license_plan') - self.assertEqual(quarterly_data.get('quantity_range'), (5, 30)) + self.assertEqual(quarterly_data.get('quantity_range'), [5, 50]) def test_calculate_subtotal_basic_format(self): """Test subtotal calculation with basic format.""" @@ -696,32 +696,6 @@ def test_serialize_basic_format_metadata_exception_handled(self): result = pricing_api._serialize_basic_format(mock_price) self.assertEqual(result.get('ssp_product_slug'), 'quarterly_license_plan') - @mock.patch('enterprise_access.apps.customer_billing.pricing_api.SspProduct.objects.get') - def test_serialize_basic_format_database_exception_handled(self, mock_get): - """Ensure an unexpected DB exception during lookup fallback is swallowed and logged.""" - mock_price = self._create_mock_stripe_price(lookup_key='corrupt_lookup_key') - mock_price.product.metadata = {} - - mock_get.side_effect = Exception("Database connection dropped completely") - - # pylint: disable=protected-access - result = pricing_api._serialize_basic_format(mock_price) - self.assertNotIn('ssp_product_slug', result) - - @override_settings( - SSP_PRODUCTS={ - 'incomplete_plan_1': { - 'lookup_key': 'only_key_no_range', - }, - 'incomplete_plan_2': { - 'quantity_range': (10, 50), - }, - 'complete_plan': { - 'lookup_key': 'valid_lk_range', - 'quantity_range': (5, 100), - } - } - ) @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') def test_get_ssp_product_pricing_skips_invalid_settings_config(self, mock_stripe): """Ensure settings blocks missing lookup_key or quantity_range are skipped gracefully.""" @@ -742,7 +716,7 @@ def test_get_ssp_product_pricing_skips_invalid_settings_config(self, mock_stripe # The complete plan should process properly self.assertIn('complete_plan', result) - self.assertEqual(result['complete_plan'].get('quantity_range'), (5, 100)) + self.assertEqual(result['complete_plan'].get('quantity_range'), [5, 50]) @mock.patch('enterprise_access.apps.customer_billing.pricing_api.stripe') def test_get_ssp_product_pricing_ignores_inactive_ssp_products(self, mock_stripe): @@ -764,25 +738,3 @@ def test_get_ssp_product_pricing_ignores_inactive_ssp_products(self, mock_stripe self.assertEqual(len(result), 0) - @mock.patch('enterprise_access.apps.customer_billing.pricing_api.logger') - @mock.patch('enterprise_access.apps.customer_billing.pricing_api.SspProduct.objects') - def test_serialize_basic_format_ssp_lookup_exception(self, mock_ssp_objects, mock_logger): - """ - Covers the except Exception branch when SspProduct lookup raises an error. - """ - mock_ssp_objects.filter.side_effect = RuntimeError('DB error') - mock_product = mock.MagicMock() - mock_product.metadata = {} - mock_price = mock.MagicMock() - mock_price.id = 'price_test' - mock_price.unit_amount = 1000 - mock_price.currency = 'usd' - mock_price.recurring = None - mock_price.lookup_key = 'test_lookup_key' - mock_price.product = mock_product - result = pricing_api._serialize_basic_format(mock_price) # pylint: disable=protected-access - self.assertIsNone(result.get('ssp_product_slug')) - mock_ssp_objects.filter.assert_called_once_with(stripe_price_lookup_key='test_lookup_key') - mock_logger.exception.assert_called_once_with( - 'Error looking up SspProduct for lookup_key %s', 'test_lookup_key' - ) diff --git a/enterprise_access/settings/base.py b/enterprise_access/settings/base.py index 7d419360..c54fb4a3 100644 --- a/enterprise_access/settings/base.py +++ b/enterprise_access/settings/base.py @@ -683,7 +683,7 @@ def root(*path_fragments): ENABLE_BILLING_MANAGEMENT_API = False DEFAULT_SSP_PRICE_LOOKUP_KEY = 'teams_subscription_license_yearly' - +DEFAULT_SSP_QUANTITY_RANGE = [5, 50] # Default SSP product slug assigned to new CheckoutIntent records when no product is specified. # Override this in environment-specific settings to change the default product. SSP_DEFAULT_PRODUCT_SLUG = 'teams-yearly' From dd3ea68ad3ee11a9613c2df21af39f7caf12f820 Mon Sep 17 00:00:00 2001 From: tsunkara-sonata Date: Thu, 25 Jun 2026 07:18:39 +0000 Subject: [PATCH 3/6] feat: updated test and quality issues --- .../apps/api/serializers/customer_billing.py | 1 - .../v1/tests/test_checkout_intent_views.py | 47 ++++++++++++++----- .../api/v1/tests/test_customer_billing.py | 1 + .../apps/customer_billing/api.py | 3 +- .../apps/customer_billing/models.py | 5 +- .../tests/test_pricing_api.py | 1 - 6 files changed, 40 insertions(+), 18 deletions(-) diff --git a/enterprise_access/apps/api/serializers/customer_billing.py b/enterprise_access/apps/api/serializers/customer_billing.py index fa188ba8..7ed77b7b 100644 --- a/enterprise_access/apps/api/serializers/customer_billing.py +++ b/enterprise_access/apps/api/serializers/customer_billing.py @@ -15,7 +15,6 @@ CheckoutIntent, FailedCheckoutIntentConflict, SlugReservationConflict, - SspProduct, StripeEventSummary ) diff --git a/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py b/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py index 6682d939..cce90532 100644 --- a/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py +++ b/enterprise_access/apps/api/v1/tests/test_checkout_intent_views.py @@ -42,6 +42,15 @@ def setUpTestData(cls): email='test4@example.com', password='testpass123' ) + # Ensure the default SSP product exists for class-level CheckoutIntent creation + SspProduct.objects.get_or_create( + slug='teams-yearly', + defaults={ + 'stripe_price_lookup_key': 'teams_subscription_license_yearly', + 'is_active': True, + 'catalog_query_uuid': uuid.uuid4(), + } + ) cls.checkout_intent_2 = CheckoutIntent.objects.create( user=cls.user_2, enterprise_name="Active Enterprise 2", @@ -51,7 +60,8 @@ def setUpTestData(cls): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_456', country='US', - terms_metadata={'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'} + terms_metadata={'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'}, + ssp_product_id='teams-yearly', ) cls.checkout_intent_4 = CheckoutIntent.objects.create( user=cls.user_4, @@ -62,7 +72,8 @@ def setUpTestData(cls): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_987', country='US', - terms_metadata={'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'} + terms_metadata={'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'}, + ssp_product_id='teams-yearly', ) def setUp(self): @@ -73,6 +84,7 @@ def setUp(self): defaults={ 'stripe_price_lookup_key': 'teams_subscription_license_yearly', 'is_active': True, + 'catalog_query_uuid': uuid.uuid4(), } ) @@ -85,7 +97,8 @@ def setUp(self): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_123', country='CA', - terms_metadata={'version': '1.1', 'test_mode': True} + terms_metadata={'version': '1.1', 'test_mode': True}, + ssp_product_id='teams-yearly', ) self.checkout_intent_3 = CheckoutIntent.objects.create( user=self.user_3, @@ -96,7 +109,8 @@ def setUp(self): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_789', country='GB', - terms_metadata={'version': '2.0', 'features': ['analytics', 'reporting']} + terms_metadata={'version': '2.0', 'features': ['analytics', 'reporting']}, + ssp_product_id='teams-yearly', ) # URL patterns @@ -274,7 +288,8 @@ def test_cannot_transition_from_fulfilled(self): expires_at=timezone.now() + timedelta(minutes=30), stripe_checkout_session_id='cs_test_78955', country='FR', - terms_metadata={'version': '1.5', 'fulfilled': True} + terms_metadata={'version': '1.5', 'fulfilled': True}, + ssp_product_id='teams-yearly', ) detail_url = reverse( @@ -345,7 +360,8 @@ def test_create_checkout_intent_success(self): 'enterprise_name': 'Test Enterprise post', 'quantity': 13, 'country': 'NZ', - 'terms_metadata': {'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'} + 'terms_metadata': {'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'}, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -375,7 +391,8 @@ def test_create_or_update_checkout_intent_success(self): 'enterprise_name': self.checkout_intent_1.enterprise_name, 'quantity': 33, 'country': 'IT', - 'terms_metadata': {'version': '2.0', 'updated': True} + 'terms_metadata': {'version': '2.0', 'updated': True}, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -419,7 +436,7 @@ def test_create_checkout_intent_invalid_field_values(self, **invalid_payload): response = self.client.post( self.list_url, - invalid_payload, + {**invalid_payload, 'ssp_product': 'teams-yearly'}, format='json' ) @@ -441,7 +458,7 @@ def test_create_checkout_intent_missing_required_fields(self, **payload): # Test missing enterprise_slug response = self.client.post( self.list_url, - payload, + {**payload, 'ssp_product': 'teams-yearly'}, ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) error_detail = list(response.json().values())[0][0] @@ -455,6 +472,7 @@ def test_create_checkout_intent_authentication_required(self): 'enterprise_slug': 'test-enterprise', 'enterprise_name': 'Test Enterprise', 'quantity': 10, + 'ssp_product': 'teams-yearly', }, ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) @@ -525,7 +543,8 @@ def test_create_with_null_terms_metadata(self): 'enterprise_slug': 'test-enterprise-null', 'enterprise_name': 'Test Enterprise Null', 'quantity': 5, - 'terms_metadata': None + 'terms_metadata': None, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -549,7 +568,8 @@ def test_create_with_empty_terms_metadata(self): 'enterprise_slug': 'test-enterprise-empty', 'enterprise_name': 'Test Enterprise Empty', 'quantity': 8, - 'terms_metadata': {} + 'terms_metadata': {}, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -574,7 +594,8 @@ def test_create_checkout_intent_without_slug_or_name_success(self): request_data = { 'quantity': 13, 'country': 'NZ', - 'terms_metadata': {'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'} + 'terms_metadata': {'version': '1.0', 'accepted_at': '2024-01-15T10:30:00Z'}, + 'ssp_product': 'teams-yearly', } response = self.client.post( @@ -610,6 +631,7 @@ def test_create_checkout_intent_already_failed_returns_422(self): 'enterprise_slug': 'new-slug', 'enterprise_name': 'New Name', 'quantity': 7, + 'ssp_product': 'teams-yearly', } response = self.client.post(self.list_url, request_data, format='json') self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY) @@ -634,6 +656,7 @@ def test_create_checkout_intent_slug_conflict_returns_422(self): 'enterprise_slug': 'active-enterprise', 'enterprise_name': 'Active Enterprise', 'quantity': 7, + 'ssp_product': 'teams-yearly', } response = self.client.post(self.list_url, request_data, format='json') self.assertEqual(response.status_code, status.HTTP_422_UNPROCESSABLE_ENTITY) diff --git a/enterprise_access/apps/api/v1/tests/test_customer_billing.py b/enterprise_access/apps/api/v1/tests/test_customer_billing.py index c98ca1d9..46036072 100644 --- a/enterprise_access/apps/api/v1/tests/test_customer_billing.py +++ b/enterprise_access/apps/api/v1/tests/test_customer_billing.py @@ -3913,6 +3913,7 @@ def test_create_checkout_session_returns_client_secret_from_dict( 'company_name': 'Test Co', 'quantity': 5, 'stripe_price_id': 'price_abc123', + 'ssp_product': 'quarterly_license_plan', }, format='json', ) diff --git a/enterprise_access/apps/customer_billing/api.py b/enterprise_access/apps/customer_billing/api.py index 464b17a6..92ce8c2e 100644 --- a/enterprise_access/apps/customer_billing/api.py +++ b/enterprise_access/apps/customer_billing/api.py @@ -451,8 +451,7 @@ def create_free_trial_checkout_session( ssp_product_slug = input_data.get('ssp_product_slug') if not ssp_product_slug: ssp_product_slug = getattr(settings, 'SSP_DEFAULT_PRODUCT_SLUG', None) - ssp_product = None - + ssp_product_instance = None if ssp_product_slug: ssp_product_instance = SspProduct.objects.filter( diff --git a/enterprise_access/apps/customer_billing/models.py b/enterprise_access/apps/customer_billing/models.py index 49bf435d..0bf2a7bb 100644 --- a/enterprise_access/apps/customer_billing/models.py +++ b/enterprise_access/apps/customer_billing/models.py @@ -34,8 +34,9 @@ def get_default_ssp_product_slug(): - """Return the default SSP product slug from Django settings.""" - return settings.SSP_DEFAULT_PRODUCT_SLUG + """Return the default SSP product slug from Django settings.""" + return settings.SSP_DEFAULT_PRODUCT_SLUG + class FailedCheckoutIntentConflict(Exception): pass diff --git a/enterprise_access/apps/customer_billing/tests/test_pricing_api.py b/enterprise_access/apps/customer_billing/tests/test_pricing_api.py index 937c5340..b53bdbeb 100644 --- a/enterprise_access/apps/customer_billing/tests/test_pricing_api.py +++ b/enterprise_access/apps/customer_billing/tests/test_pricing_api.py @@ -737,4 +737,3 @@ def test_get_ssp_product_pricing_ignores_inactive_ssp_products(self, mock_stripe result = pricing_api.get_ssp_product_pricing() self.assertEqual(len(result), 0) - From 2b34fb8c2db79c9f64ba2cf3c401dca4a21062e1 Mon Sep 17 00:00:00 2001 From: tsunkara-sonata Date: Thu, 25 Jun 2026 08:38:30 +0000 Subject: [PATCH 4/6] feat: code coverage fix --- .../apps/customer_billing/tests/test_api.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/enterprise_access/apps/customer_billing/tests/test_api.py b/enterprise_access/apps/customer_billing/tests/test_api.py index a8c8e501..c45d1526 100644 --- a/enterprise_access/apps/customer_billing/tests/test_api.py +++ b/enterprise_access/apps/customer_billing/tests/test_api.py @@ -1,6 +1,7 @@ """ Unit tests for the ``enterprise_access.apps.customer_billing.api`` module. """ +import uuid from datetime import timedelta from unittest import mock @@ -18,7 +19,8 @@ from enterprise_access.apps.customer_billing.models import ( CheckoutIntent, FailedCheckoutIntentConflict, - SlugReservationConflict + SlugReservationConflict, + SspProduct ) User = get_user_model() @@ -159,6 +161,14 @@ def test_create_free_trial_checkout_session_with_ssp_product_slug( mock_create_checkout.return_value = {'id': 'sess-ssp', 'customer': 'cust-ssp'} # Call with ssp_product_slug instead of stripe_price_id + # Ensure an SspProduct exists so the code path that maps ssp->stripe_price is exercised + SspProduct.objects.create( + slug='quarterly_license_plan', + stripe_price_lookup_key=MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + is_active=True, + catalog_query_uuid=uuid.uuid4(), + ) + result = customer_billing_api.create_free_trial_checkout_session( user=self.user, admin_email=self.user.email, @@ -176,6 +186,11 @@ def test_create_free_trial_checkout_session_with_ssp_product_slug( # Ensure the checkout creator received the provided ssp_product_slug called_input = mock_create_checkout.call_args[1]['input_data'] self.assertEqual(called_input.get('ssp_product_slug'), 'quarterly_license_plan') + # And stripe_price_id should be derived from the SspProduct lookup key when not provided + self.assertEqual( + called_input.get('stripe_price_id'), + MOCK_SSP_PRODUCTS['quarterly_license_plan']['lookup_key'], + ) # Assert library methods were called correctly. mock_lms_client.get_lms_user_account.assert_called_once_with( From 907200a7136b574f27ef18f8b28e1b61d7758bbb Mon Sep 17 00:00:00 2001 From: tsunkara-sonata Date: Thu, 25 Jun 2026 13:28:33 +0000 Subject: [PATCH 5/6] feat: updated checkout intent --- .../apps/api/serializers/customer_billing.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/enterprise_access/apps/api/serializers/customer_billing.py b/enterprise_access/apps/api/serializers/customer_billing.py index 7ed77b7b..457d7e15 100644 --- a/enterprise_access/apps/api/serializers/customer_billing.py +++ b/enterprise_access/apps/api/serializers/customer_billing.py @@ -15,6 +15,7 @@ CheckoutIntent, FailedCheckoutIntentConflict, SlugReservationConflict, + SspProduct, StripeEventSummary ) @@ -202,6 +203,13 @@ class CheckoutIntentCreateRequestSerializer(CountryFieldMixin, serializers.Model """ A serializer intended for creating new CheckoutIntents. """ + + ssp_product = serializers.PrimaryKeyRelatedField( + queryset=SspProduct.objects.all(), + required=False, + allow_null=True, + ) + class Meta: model = CheckoutIntent fields = '__all__' @@ -251,6 +259,12 @@ def create(self, validated_data): """ try: ssp_product = validated_data.pop('ssp_product', None) + if ssp_product is None: + default_slug = getattr(settings, 'SSP_DEFAULT_PRODUCT_SLUG', None) + if default_slug: + ssp_product = SspProduct.objects.filter( + slug=default_slug, is_active=True + ).first() return CheckoutIntent.create_intent( user=self.context['request'].user, quantity=validated_data['quantity'], From 7145935e700832889bdbf83cf3f91d526904f075 Mon Sep 17 00:00:00 2001 From: tsunkara-sonata Date: Thu, 25 Jun 2026 14:58:33 +0000 Subject: [PATCH 6/6] feat: cleanup in customer serializer --- enterprise_access/apps/api/serializers/customer_billing.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/enterprise_access/apps/api/serializers/customer_billing.py b/enterprise_access/apps/api/serializers/customer_billing.py index 457d7e15..5b8fccd2 100644 --- a/enterprise_access/apps/api/serializers/customer_billing.py +++ b/enterprise_access/apps/api/serializers/customer_billing.py @@ -259,12 +259,6 @@ def create(self, validated_data): """ try: ssp_product = validated_data.pop('ssp_product', None) - if ssp_product is None: - default_slug = getattr(settings, 'SSP_DEFAULT_PRODUCT_SLUG', None) - if default_slug: - ssp_product = SspProduct.objects.filter( - slug=default_slug, is_active=True - ).first() return CheckoutIntent.create_intent( user=self.context['request'].user, quantity=validated_data['quantity'],