diff --git a/src/ad_buyer/booking/pricing.py b/src/ad_buyer/booking/pricing.py index 9bb470f..3cfdfb1 100644 --- a/src/ad_buyer/booking/pricing.py +++ b/src/ad_buyer/booking/pricing.py @@ -94,14 +94,16 @@ class PricingCalculator: # Volume discount thresholds (only for agency/advertiser tiers) VOLUME_DISCOUNT_THRESHOLDS: list[tuple[int, float]] = [ (10_000_000, 10.0), # 10M+ impressions: 10% discount - (5_000_000, 5.0), # 5M+ impressions: 5% discount + (5_000_000, 5.0), # 5M+ impressions: 5% discount ] # Tiers eligible for volume discounts - VOLUME_ELIGIBLE_TIERS: frozenset[AccessTier] = frozenset({ - AccessTier.AGENCY, - AccessTier.ADVERTISER, - }) + VOLUME_ELIGIBLE_TIERS: frozenset[AccessTier] = frozenset( + { + AccessTier.AGENCY, + AccessTier.ADVERTISER, + } + ) def calculate( self, diff --git a/src/ad_buyer/clients/capability_client.py b/src/ad_buyer/clients/capability_client.py index 39b5d79..43750b2 100644 --- a/src/ad_buyer/clients/capability_client.py +++ b/src/ad_buyer/clients/capability_client.py @@ -219,9 +219,7 @@ def invalidate(self, seller_endpoint: str | None = None) -> None: return self._cache.pop(self._cache_key(seller_endpoint), None) - async def discover_capabilities( - self, seller_endpoint: str - ) -> CapabilityDiscoveryResult: + async def discover_capabilities(self, seller_endpoint: str) -> CapabilityDiscoveryResult: """Discover a seller's audience capabilities. Hits the cache first, returns immediately on a fresh hit. On a @@ -275,8 +273,7 @@ async def discover_capabilities( await client.aclose() except (httpx.HTTPError, ValueError) as exc: logger.warning( - "capability_client fetch failed endpoint=%s err=%s -- " - "treating as legacy", + "capability_client fetch failed endpoint=%s err=%s -- treating as legacy", seller_endpoint, exc, ) @@ -288,8 +285,7 @@ async def discover_capabilities( if response.status_code != 200: logger.warning( - "capability_client non-200 endpoint=%s status=%d -- " - "treating as legacy", + "capability_client non-200 endpoint=%s status=%d -- treating as legacy", seller_endpoint, response.status_code, ) @@ -304,8 +300,7 @@ async def discover_capabilities( payload = response.json() except ValueError as exc: logger.warning( - "capability_client invalid JSON endpoint=%s err=%s -- " - "treating as legacy", + "capability_client invalid JSON endpoint=%s err=%s -- treating as legacy", seller_endpoint, exc, ) @@ -322,8 +317,7 @@ async def discover_capabilities( # fallback -- standard segments only, no constraints, no # extensions, no exclusions, no agentic. logger.info( - "capability_client legacy seller (no audience_capabilities) " - "endpoint=%s", + "capability_client legacy seller (no audience_capabilities) endpoint=%s", seller_endpoint, ) caps = _legacy_default_capabilities() @@ -356,8 +350,7 @@ async def discover_capabilities( self._store(key, caps, fetched_at=now, max_age=max_age) logger.info( - "capability_client %s endpoint=%s schema=%s agentic=%s " - "supports=(c=%s,e=%s,x=%s)", + "capability_client %s endpoint=%s schema=%s agentic=%s supports=(c=%s,e=%s,x=%s)", cache_status, seller_endpoint, caps.schema_version, diff --git a/src/ad_buyer/clients/deals_client.py b/src/ad_buyer/clients/deals_client.py index 14bc661..3a14b17 100644 --- a/src/ad_buyer/clients/deals_client.py +++ b/src/ad_buyer/clients/deals_client.py @@ -62,9 +62,7 @@ # Code-internal naming continues to use `ucp_*` (no rename per §5.6 lock). _UCP_CONTENT_TYPE = "application/vnd.ucp.embedding+json; v=1" _AGENTIC_AUDIENCES_CONTENT_TYPE = "application/vnd.iab.agentic-audiences+json; v=1" -_AUDIENCE_PLAN_ACCEPT = ( - f"{_UCP_CONTENT_TYPE}, {_AGENTIC_AUDIENCES_CONTENT_TYPE}" -) +_AUDIENCE_PLAN_ACCEPT = f"{_UCP_CONTENT_TYPE}, {_AGENTIC_AUDIENCES_CONTENT_TYPE}" class DealsClientError(Exception): @@ -389,7 +387,10 @@ async def _request_with_retry( if attempt < self._max_retries: logger.warning( "Timeout on attempt %d/%d for %s %s", - attempt, self._max_retries, method, path, + attempt, + self._max_retries, + method, + path, ) continue raise last_error from exc @@ -416,7 +417,11 @@ async def _request_with_retry( if attempt < self._max_retries: logger.warning( "Retryable error %d on attempt %d/%d for %s %s", - response.status_code, attempt, self._max_retries, method, path, + response.status_code, + attempt, + self._max_retries, + method, + path, ) continue raise last_error @@ -460,16 +465,10 @@ def _build_error_from_response(response: httpx.Response) -> DealsClientError: if isinstance(inner, dict): error_code = str(inner.get("error", "") or "") # Surface the inner "message" / "detail" / repr for humans. - detail = str( - inner.get("message") - or inner.get("detail") - or "" - ) + detail = str(inner.get("message") or inner.get("detail") or "") raw_unsupported = inner.get("unsupported") if isinstance(raw_unsupported, list): - unsupported = [ - u for u in raw_unsupported if isinstance(u, dict) - ] + unsupported = [u for u in raw_unsupported if isinstance(u, dict)] else: # Flat shape: {"error": "...", "detail": "..."} error_code = str(data.get("error", "") or "") @@ -511,15 +510,19 @@ def _persist_quote(self, quote: QuoteResponse, request: QuoteRequest) -> None: deal_type=request.deal_type, status="quoted", price=quote.pricing.final_cpm if quote.pricing.final_cpm is not None else 0.0, - original_price=quote.pricing.base_cpm if quote.pricing.base_cpm is not None else 0.0, # noqa: E501 + original_price=quote.pricing.base_cpm + if quote.pricing.base_cpm is not None + else 0.0, # noqa: E501 impressions=quote.terms.impressions, flight_start=quote.terms.flight_start, flight_end=quote.terms.flight_end, - metadata=json.dumps({ - "quote_id": quote.quote_id, - "buyer_tier": quote.buyer_tier, - "expires_at": quote.expires_at, - }), + metadata=json.dumps( + { + "quote_id": quote.quote_id, + "buyer_tier": quote.buyer_tier, + "expires_at": quote.expires_at, + } + ), ) except (sqlite3.Error, OSError, ValueError, AttributeError): logger.exception("Failed to persist quote %s to DealStore", quote.quote_id) @@ -544,15 +547,17 @@ def _persist_deal(self, deal: DealResponse) -> None: impressions=deal.terms.impressions, flight_start=deal.terms.flight_start, flight_end=deal.terms.flight_end, - metadata=json.dumps({ - "quote_id": deal.quote_id, - "buyer_tier": deal.buyer_tier, - "expires_at": deal.expires_at, - "activation_instructions": deal.activation_instructions, - "openrtb_params": ( - deal.openrtb_params.model_dump() if deal.openrtb_params else None - ), - }), + metadata=json.dumps( + { + "quote_id": deal.quote_id, + "buyer_tier": deal.buyer_tier, + "expires_at": deal.expires_at, + "activation_instructions": deal.activation_instructions, + "openrtb_params": ( + deal.openrtb_params.model_dump() if deal.openrtb_params else None + ), + } + ), ) except (sqlite3.Error, OSError, ValueError, AttributeError): logger.exception("Failed to persist deal %s to DealStore", deal.deal_id) @@ -577,6 +582,4 @@ def _update_stored_deal_status(self, deal: DealResponse) -> None: ) break except (sqlite3.Error, OSError, ValueError, AttributeError): - logger.exception( - "Failed to update stored deal status for %s", deal.deal_id - ) + logger.exception("Failed to update stored deal status for %s", deal.deal_id) diff --git a/src/ad_buyer/clients/openrtb_builder.py b/src/ad_buyer/clients/openrtb_builder.py index 9c6982d..06031be 100644 --- a/src/ad_buyer/clients/openrtb_builder.py +++ b/src/ad_buyer/clients/openrtb_builder.py @@ -207,8 +207,7 @@ def build_openrtb_audience_targeting( if agentic_refs: if not enable_agentic_ext: logger.warning( - "openrtb_builder skipping agentic refs: " - "enable_agentic_openrtb_ext flag disabled", + "openrtb_builder skipping agentic refs: enable_agentic_openrtb_ext flag disabled", extra={ "openrtb_drop": { "reason": "agentic_ext_feature_flag_disabled", diff --git a/src/ad_buyer/clients/seller_order_client.py b/src/ad_buyer/clients/seller_order_client.py index 8aa9a81..daec7f1 100644 --- a/src/ad_buyer/clients/seller_order_client.py +++ b/src/ad_buyer/clients/seller_order_client.py @@ -113,7 +113,5 @@ async def get_order_history(self, order_id: str) -> dict[str, Any] | None: ) return None except (httpx.RequestError, OSError) as e: - logger.error( - "Failed to reach seller for order %s history: %s", order_id, e - ) + logger.error("Failed to reach seller for order %s history: %s", order_id, e) return None diff --git a/src/ad_buyer/clients/ucp_client.py b/src/ad_buyer/clients/ucp_client.py index 414aa9b..0c0d5ce 100644 --- a/src/ad_buyer/clients/ucp_client.py +++ b/src/ad_buyer/clients/ucp_client.py @@ -37,9 +37,7 @@ UCP_CONTENT_TYPE = "application/vnd.ucp.embedding+json; v=1" # Embedding provenance literal -- mirrors ComplianceContext.embedding_provenance. -EmbeddingProvenance = Literal[ - "mock", "local_buyer", "advertiser_supplied", "hosted_external" -] +EmbeddingProvenance = Literal["mock", "local_buyer", "advertiser_supplied", "hosted_external"] # Local model details for "local" / "hybrid" embedding modes. # Locked in docs/decisions/EMBEDDING_STRATEGY_2026-04-25.md (E2-1). @@ -54,10 +52,10 @@ # follow the same convention as the buyer's local model. Re-derive these # from `ad_buyer.eval.evaluate_embedding_modes()` whenever the model swaps. _SIMILARITY_THRESHOLDS: dict[str, dict[str, float]] = { - "mock": {"strong": 0.85, "moderate": 0.65, "weak": 0.40}, - "local": {"strong": 0.70, "moderate": 0.50, "weak": 0.30}, - "advertiser": {"strong": 0.70, "moderate": 0.50, "weak": 0.30}, - "hybrid": {"strong": 0.70, "moderate": 0.50, "weak": 0.30}, + "mock": {"strong": 0.85, "moderate": 0.65, "weak": 0.40}, + "local": {"strong": 0.70, "moderate": 0.50, "weak": 0.30}, + "advertiser": {"strong": 0.70, "moderate": 0.50, "weak": 0.30}, + "hybrid": {"strong": 0.70, "moderate": 0.50, "weak": 0.30}, } _DEFAULT_THRESHOLDS = _SIMILARITY_THRESHOLDS["mock"] @@ -69,6 +67,7 @@ def _similarity_thresholds_for_mode() -> dict[str, float]: return _SIMILARITY_THRESHOLDS.get(settings.embedding_mode, _DEFAULT_THRESHOLDS) + # Process-wide cached SentenceTransformer instance. Lazy-loaded on first # use to avoid paying ~80MB model download cost at import time. _LOCAL_MODEL: Any = None @@ -88,6 +87,7 @@ def _get_local_embedding_model() -> Any: return None try: from sentence_transformers import SentenceTransformer # type: ignore + _LOCAL_MODEL = SentenceTransformer(LOCAL_EMBEDDING_MODEL_NAME) return _LOCAL_MODEL except Exception as exc: # ImportError, network errors, etc. diff --git a/src/ad_buyer/clients/unified_client.py b/src/ad_buyer/clients/unified_client.py index 483eed4..3f6347b 100644 --- a/src/ad_buyer/clients/unified_client.py +++ b/src/ad_buyer/clients/unified_client.py @@ -217,8 +217,7 @@ async def call_tool( "list_creatives": "List all creatives", # Create tools (args required) "create_account": lambda a: ( - f"Create an account named '{a.get('name')}' " - f"of type {a.get('type', 'advertiser')}" + f"Create an account named '{a.get('name')}' of type {a.get('type', 'advertiser')}" ), "create_order": lambda a: ( f"Create an order named '{a.get('name')}' " @@ -544,7 +543,6 @@ async def get_pricing( if result.data and isinstance(result.data, dict): base_price = result.data.get("basePrice", result.data.get("price")) if isinstance(base_price, (int, float)) and self.buyer_identity: - tier_obj = self.buyer_identity.get_access_tier() discount = self.buyer_identity.get_discount_percentage() @@ -559,7 +557,9 @@ async def get_pricing( result.data["pricing"] = { "base_price": pricing.base_price, - "tiered_price": round(pricing.final_price, 2) if pricing.final_price is not None else None, # noqa: E501 + "tiered_price": round(pricing.final_price, 2) + if pricing.final_price is not None + else None, # noqa: E501 "tier": tier_obj.value if self.buyer_identity else "public", "tier_discount": discount if self.buyer_identity else 0, "volume_discount": pricing.volume_discount, @@ -572,8 +572,12 @@ async def get_pricing( result.data["pricing"] = { "base_price": None, "tiered_price": None, - "tier": self.buyer_identity.get_access_tier().value if self.buyer_identity else "public", # noqa: E501 - "tier_discount": self.buyer_identity.get_discount_percentage() if self.buyer_identity else 0, # noqa: E501 + "tier": self.buyer_identity.get_access_tier().value + if self.buyer_identity + else "public", # noqa: E501 + "tier_discount": self.buyer_identity.get_discount_percentage() + if self.buyer_identity + else 0, # noqa: E501 "volume_discount": 0.0, "requested_volume": volume, "deal_type": deal_type, diff --git a/src/ad_buyer/crews/channel_crews.py b/src/ad_buyer/crews/channel_crews.py index 95a3442..4d06f0a 100644 --- a/src/ad_buyer/crews/channel_crews.py +++ b/src/ad_buyer/crews/channel_crews.py @@ -547,9 +547,7 @@ def kickoff_channel_crew_with_audience( factory = _CHANNEL_FACTORIES.get(channel) if factory is None: valid = sorted(_CHANNEL_FACTORIES.keys()) - raise ValueError( - f"Unknown channel {channel!r}; expected one of {valid}" - ) + raise ValueError(f"Unknown channel {channel!r}; expected one of {valid}") # If the caller passed a CampaignBrief but no plan, run the planner # step inline. The import is local because the planner module pulls diff --git a/src/ad_buyer/data/taxonomy_loader.py b/src/ad_buyer/data/taxonomy_loader.py index b771a3b..76d16bc 100644 --- a/src/ad_buyer/data/taxonomy_loader.py +++ b/src/ad_buyer/data/taxonomy_loader.py @@ -262,10 +262,7 @@ def validate_ref(ref: AudienceRef) -> ValidationResult: if entry is None: return ValidationResult( valid=False, - reason=( - f"identifier {ref.identifier!r} not found in " - f"{ref.taxonomy} v{ref.version}" - ), + reason=(f"identifier {ref.identifier!r} not found in {ref.taxonomy} v{ref.version}"), ) if ref.type == "agentic": # Agentic loader returns a stub, not a real validation. diff --git a/src/ad_buyer/demo/campaign_demo.py b/src/ad_buyer/demo/campaign_demo.py index 6ab5144..0e0faf8 100644 --- a/src/ad_buyer/demo/campaign_demo.py +++ b/src/ad_buyer/demo/campaign_demo.py @@ -81,8 +81,16 @@ def _build_sample_briefs() -> list[dict[str, Any]]: "flight_start": "2026-07-01", "flight_end": "2026-09-30", "channels": [ - {"channel": "CTV", "budget_pct": 60, "format_prefs": ["video_30s", "video_15s"]}, # noqa: E501 - {"channel": "DISPLAY", "budget_pct": 40, "format_prefs": ["300x250", "728x90", "160x600"]}, # noqa: E501 + { + "channel": "CTV", + "budget_pct": 60, + "format_prefs": ["video_30s", "video_15s"], + }, # noqa: E501 + { + "channel": "DISPLAY", + "budget_pct": 40, + "format_prefs": ["300x250", "728x90", "160x600"], + }, # noqa: E501 ], "target_audience": ["IAB-AUD-1001", "IAB-AUD-1045"], "kpis": [ @@ -132,7 +140,11 @@ def _build_sample_briefs() -> list[dict[str, Any]]: "flight_start": "2026-10-01", "flight_end": "2026-12-31", "channels": [ - {"channel": "DISPLAY", "budget_pct": 100, "format_prefs": ["300x250", "320x50"]}, # noqa: E501 + { + "channel": "DISPLAY", + "budget_pct": 100, + "format_prefs": ["300x250", "320x50"], + }, # noqa: E501 ], "target_audience": ["IAB-AUD-3001"], "kpis": [ @@ -179,8 +191,9 @@ def __init__( self._booking_results: dict[str, dict] = {} self._creative_results: dict[str, list] = {} - def _emit_sync(self, event_type: EventType, campaign_id: str = "", - payload: dict | None = None) -> None: + def _emit_sync( + self, event_type: EventType, campaign_id: str = "", payload: dict | None = None + ) -> None: """Emit an event synchronously to the InMemoryEventBus.""" event = Event( event_type=event_type, @@ -205,9 +218,7 @@ def ingest_brief(self, brief_data: dict[str, Any]) -> str: if brief.target_audience is None: target_audience_json = json.dumps(None) else: - target_audience_json = json.dumps( - brief.target_audience.model_dump(mode="json") - ) + target_audience_json = json.dumps(brief.target_audience.model_dump(mode="json")) store_brief = { "advertiser_id": brief.advertiser_id, "campaign_name": brief.campaign_name, @@ -215,15 +226,11 @@ def ingest_brief(self, brief_data: dict[str, Any]) -> str: "currency": brief.currency, "flight_start": brief.flight_start.isoformat(), "flight_end": brief.flight_end.isoformat(), - "channels": json.dumps( - [ch.model_dump(mode="json") for ch in brief.channels] - ), + "channels": json.dumps([ch.model_dump(mode="json") for ch in brief.channels]), "target_audience": target_audience_json, } if brief.kpis: - store_brief["kpis"] = json.dumps( - [k.model_dump(mode="json") for k in brief.kpis] - ) + store_brief["kpis"] = json.dumps([k.model_dump(mode="json") for k in brief.kpis]) if brief.approval_config: store_brief["approval_config"] = json.dumps( brief.approval_config.model_dump(mode="json") @@ -265,13 +272,20 @@ def plan_campaign(self, campaign_id: str) -> dict[str, Any]: # Channel-to-media-type mapping media_type_map = { - "CTV": "ctv", "DISPLAY": "display", "AUDIO": "audio", - "NATIVE": "native", "DOOH": "dooh", "LINEAR_TV": "linear_tv", + "CTV": "ctv", + "DISPLAY": "display", + "AUDIO": "audio", + "NATIVE": "native", + "DOOH": "dooh", + "LINEAR_TV": "linear_tv", } deal_type_map = { - "CTV": ["PG", "PD"], "DISPLAY": ["PD", "PA"], - "AUDIO": ["PD", "PA"], "NATIVE": ["PD", "PA"], - "DOOH": ["PG", "PD"], "LINEAR_TV": ["PG"], + "CTV": ["PG", "PD"], + "DISPLAY": ["PD", "PA"], + "AUDIO": ["PD", "PA"], + "NATIVE": ["PD", "PA"], + "DOOH": ["PG", "PD"], + "LINEAR_TV": ["PG"], } channel_plans = [] @@ -280,14 +294,16 @@ def plan_campaign(self, campaign_id: str) -> dict[str, Any]: budget_pct = ch.get("budget_pct", 0) budget = round(total_budget * budget_pct / 100.0, 2) - channel_plans.append({ - "channel": channel, - "budget": budget, - "budget_pct": budget_pct, - "media_type": media_type_map.get(channel, channel.lower()), - "deal_types": deal_type_map.get(channel, ["PD"]), - "format_prefs": ch.get("format_prefs", []), - }) + channel_plans.append( + { + "channel": channel, + "budget": budget, + "budget_pct": budget_pct, + "media_type": media_type_map.get(channel, channel.lower()), + "deal_types": deal_type_map.get(channel, ["PD"]), + "format_prefs": ch.get("format_prefs", []), + } + ) plan = { "campaign_id": campaign_id, @@ -304,8 +320,7 @@ def plan_campaign(self, campaign_id: str) -> dict[str, Any]: campaign_id=campaign_id, payload={ "channels": [ - {"channel": cp["channel"], "budget": cp["budget"]} - for cp in channel_plans + {"channel": cp["channel"], "budget": cp["budget"]} for cp in channel_plans ], "total_budget": total_budget, }, @@ -385,15 +400,17 @@ def execute_booking(self, campaign_id: str) -> dict[str, Any]: impressions = int((deal_budget / cpm) * 1000) if cpm > 0 else 0 deal_id = f"DEAL-{str(uuid.uuid4())[:8].upper()}" - channel_deals.append({ - "deal_id": deal_id, - "seller": seller_name, - "seller_domain": seller_domain, - "deal_type": deal_type, - "cpm": cpm, - "impressions": impressions, - "spend": round(deal_budget, 2), - }) + channel_deals.append( + { + "deal_id": deal_id, + "seller": seller_name, + "seller_domain": seller_domain, + "deal_type": deal_type, + "cpm": cpm, + "impressions": impressions, + "spend": round(deal_budget, 2), + } + ) channel_spend += deal_budget remaining -= deal_budget @@ -432,8 +449,18 @@ def match_creatives(self, campaign_id: str) -> list[dict[str, Any]]: # Generate simulated creative assets matched to channels creative_specs = { "CTV": [ - ("CTV Hero Spot 30s", "video", {"duration": "30s", "resolution": "1920x1080"}, "valid"), # noqa: E501 - ("CTV Bumper 15s", "video", {"duration": "15s", "resolution": "1920x1080"}, "valid"), # noqa: E501 + ( + "CTV Hero Spot 30s", + "video", + {"duration": "30s", "resolution": "1920x1080"}, + "valid", + ), # noqa: E501 + ( + "CTV Bumper 15s", + "video", + {"duration": "15s", "resolution": "1920x1080"}, + "valid", + ), # noqa: E501 ], "DISPLAY": [ ("Leaderboard 728x90", "display", {"width": 728, "height": 90}, "valid"), @@ -444,7 +471,12 @@ def match_creatives(self, campaign_id: str) -> list[dict[str, Any]]: ("Audio Spot 30s", "audio", {"duration": "30s", "format": "mp3"}, "valid"), ], "NATIVE": [ - ("Native Article Card", "native", {"headline_max": 50, "image": "1200x627"}, "valid"), # noqa: E501 + ( + "Native Article Card", + "native", + {"headline_max": 50, "image": "1200x627"}, + "valid", + ), # noqa: E501 ], "DOOH": [ ("DOOH Full Screen", "display", {"width": 1920, "height": 1080}, "valid"), @@ -457,9 +489,12 @@ def match_creatives(self, campaign_id: str) -> list[dict[str, Any]]: creatives = [] for channel, deals in booking.items(): - specs = creative_specs.get(channel, [ - ("Generic Creative", "display", {"width": 300, "height": 250}, "valid"), - ]) + specs = creative_specs.get( + channel, + [ + ("Generic Creative", "display", {"width": 300, "height": 250}, "valid"), + ], + ) for spec_name, asset_type, format_spec, status in specs: asset_id = self._store.save_creative_asset( @@ -474,15 +509,17 @@ def match_creatives(self, campaign_id: str) -> list[dict[str, Any]]: # Match creative to deals in this channel matched_deals = [d["deal_id"] for d in deals] - creatives.append({ - "asset_id": asset_id, - "asset_name": spec_name, - "asset_type": asset_type, - "format_spec": format_spec, - "validation_status": status, - "channel": channel, - "matched_deals": matched_deals, - }) + creatives.append( + { + "asset_id": asset_id, + "asset_name": spec_name, + "asset_type": asset_type, + "format_spec": format_spec, + "validation_status": status, + "channel": channel, + "matched_deals": matched_deals, + } + ) self._emit_sync( EventType.CREATIVE_MATCHED, @@ -518,26 +555,30 @@ def finalize(self, campaign_id: str) -> None: for channel, deals in booking.items(): ch_budget = sum(d["spend"] for d in deals) - channel_snapshots.append(ChannelSnapshot( - channel=channel, - allocated_budget=ch_budget, - spend=0.0, - pacing_pct=0.0, - impressions=0, - effective_cpm=0.0, - fill_rate=0.0, - )) - - for d in deals: - deal_snapshots.append(DealSnapshot( - deal_id=d["deal_id"], - allocated_budget=d["spend"], + channel_snapshots.append( + ChannelSnapshot( + channel=channel, + allocated_budget=ch_budget, spend=0.0, + pacing_pct=0.0, impressions=0, effective_cpm=0.0, fill_rate=0.0, - win_rate=0.0, - )) + ) + ) + + for d in deals: + deal_snapshots.append( + DealSnapshot( + deal_id=d["deal_id"], + allocated_budget=d["spend"], + spend=0.0, + impressions=0, + effective_cpm=0.0, + fill_rate=0.0, + win_rate=0.0, + ) + ) snapshot = PacingSnapshot( campaign_id=campaign_id, @@ -592,12 +633,8 @@ def activate_campaign(self, campaign_id: str) -> PacingSnapshot: # Parse flight dates for pacing calculation flight_start_str = campaign["flight_start"] flight_end_str = campaign["flight_end"] - flight_start = datetime.fromisoformat(flight_start_str).replace( - tzinfo=UTC - ) - flight_end = datetime.fromisoformat(flight_end_str).replace( - tzinfo=UTC - ) + flight_start = datetime.fromisoformat(flight_start_str).replace(tzinfo=UTC) + flight_end = datetime.fromisoformat(flight_end_str).replace(tzinfo=UTC) # Simulate "current time" as 35% through the flight flight_duration = (flight_end - flight_start).total_seconds() @@ -607,11 +644,11 @@ def activate_campaign(self, campaign_id: str) -> PacingSnapshot: # Pre-defined pacing multipliers per channel to create varied scenarios. # Values <1.0 = underpacing, >1.0 = overpacing. pacing_multipliers = { - "CTV": 0.72, # Underpacing (critical, -28%) - "DISPLAY": 1.35, # Overpacing (critical, +35%) - "AUDIO": 0.88, # Slightly underpacing (warning, -12%) - "NATIVE": 1.15, # Slightly overpacing (warning, +15%) - "DOOH": 0.60, # Heavily underpacing + "CTV": 0.72, # Underpacing (critical, -28%) + "DISPLAY": 1.35, # Overpacing (critical, +35%) + "AUDIO": 0.88, # Slightly underpacing (warning, -12%) + "NATIVE": 1.15, # Slightly overpacing (warning, +15%) + "DOOH": 0.60, # Heavily underpacing "LINEAR_TV": 1.05, # On pace } @@ -660,15 +697,17 @@ def activate_campaign(self, campaign_id: str) -> PacingSnapshot: deal_win = round(rng.uniform(0.30, 0.85), 2) deal_ecpm = round(d["cpm"] * rng.uniform(0.90, 1.10), 2) - deal_data.append({ - "deal_id": d["deal_id"], - "allocated_budget": deal_budget, - "spend": deal_spend, - "impressions": deal_imps, - "effective_cpm": deal_ecpm, - "fill_rate": deal_fill, - "win_rate": deal_win, - }) + deal_data.append( + { + "deal_id": d["deal_id"], + "allocated_budget": deal_budget, + "spend": deal_spend, + "impressions": deal_imps, + "effective_cpm": deal_ecpm, + "fill_rate": deal_fill, + "win_rate": deal_win, + } + ) # Use BudgetPacingEngine to generate the official snapshot engine = BudgetPacingEngine( @@ -708,9 +747,7 @@ def _extract_alerts(snapshot: PacingSnapshot) -> list: alerts: list[PacingAlert] = [] # Campaign-level alert - campaign_alert = engine.detect_deviation( - snapshot.total_spend, snapshot.expected_spend - ) + campaign_alert = engine.detect_deviation(snapshot.total_spend, snapshot.expected_spend) if campaign_alert is not None: alerts.append(campaign_alert) @@ -718,9 +755,7 @@ def _extract_alerts(snapshot: PacingSnapshot) -> list: for ch in snapshot.channel_snapshots: if ch.allocated_budget <= 0 or snapshot.total_budget <= 0: continue - ch_expected = snapshot.expected_spend * ( - ch.allocated_budget / snapshot.total_budget - ) + ch_expected = snapshot.expected_spend * (ch.allocated_budget / snapshot.total_budget) ch_alert = engine.detect_deviation(ch.spend, ch_expected) if ch_alert is not None: # Add channel context to the message @@ -815,10 +850,12 @@ def index(): def api_sample_briefs(): """Return pre-built sample briefs for the dropdown.""" samples = _build_sample_briefs() - return jsonify({ - "briefs": [s["brief"] for s in samples], - "names": [s["name"] for s in samples], - }) + return jsonify( + { + "briefs": [s["brief"] for s in samples], + "names": [s["name"] for s in samples], + } + ) # -- API: Submit brief (Stage 1) --------------------------------------- @@ -836,12 +873,14 @@ def api_submit_brief(): campaign = campaign_store.get_campaign(campaign_id) - return jsonify({ - "success": True, - "campaign_id": campaign_id, - "status": campaign["status"].lower() if campaign else "draft", - "campaign_name": campaign["campaign_name"] if campaign else "", - }) + return jsonify( + { + "success": True, + "campaign_id": campaign_id, + "status": campaign["status"].lower() if campaign else "draft", + "campaign_name": campaign["campaign_name"] if campaign else "", + } + ) # -- API: Get campaign state ------------------------------------------- @@ -854,8 +893,14 @@ def api_campaign_state(campaign_id: str): # Parse JSON fields for the response result = dict(campaign) - for field in ("channels", "target_audience", "kpis", "approval_config", - "target_geo", "brand_safety"): + for field in ( + "channels", + "target_audience", + "kpis", + "approval_config", + "target_geo", + "brand_safety", + ): if result.get(field) and isinstance(result[field], str): try: result[field] = json.loads(result[field]) @@ -883,11 +928,13 @@ def api_approve_plan(): except (ValueError, TypeError, OSError) as exc: return jsonify({"success": False, "error": str(exc)}), 400 - return jsonify({ - "success": True, - "campaign_id": campaign_id, - "plan": plan, - }) + return jsonify( + { + "success": True, + "campaign_id": campaign_id, + "plan": plan, + } + ) # -- API: Approve booking (Stage 3) ------------------------------------ @@ -907,11 +954,13 @@ def api_approve_booking(): except (ValueError, TypeError, OSError) as exc: return jsonify({"success": False, "error": str(exc)}), 400 - return jsonify({ - "success": True, - "campaign_id": campaign_id, - "deals": deals, - }) + return jsonify( + { + "success": True, + "campaign_id": campaign_id, + "deals": deals, + } + ) # -- API: Approve creative (Stage 4) ----------------------------------- @@ -932,12 +981,14 @@ def api_approve_creative(): except (ValueError, TypeError, OSError) as exc: return jsonify({"success": False, "error": str(exc)}), 400 - return jsonify({ - "success": True, - "campaign_id": campaign_id, - "status": "ready", - "creatives": creatives, - }) + return jsonify( + { + "success": True, + "campaign_id": campaign_id, + "status": "ready", + "creatives": creatives, + } + ) # -- API: Activate campaign (Stage 6) ---------------------------------- @@ -964,12 +1015,8 @@ def api_activate_campaign(): "expected_spend": snapshot.expected_spend, "pacing_pct": snapshot.pacing_pct, "deviation_pct": snapshot.deviation_pct, - "channel_snapshots": [ - ch.model_dump() for ch in snapshot.channel_snapshots - ], - "deal_snapshots": [ - ds.model_dump() for ds in snapshot.deal_snapshots - ], + "channel_snapshots": [ch.model_dump() for ch in snapshot.channel_snapshots], + "deal_snapshots": [ds.model_dump() for ds in snapshot.deal_snapshots], "alerts": [ { "level": alert.level.value if hasattr(alert.level, "value") else alert.level, @@ -979,17 +1026,17 @@ def api_activate_campaign(): } for alert in _extract_alerts(snapshot) ], - "recommendations": [ - rec.model_dump() for rec in snapshot.recommendations - ], + "recommendations": [rec.model_dump() for rec in snapshot.recommendations], } - return jsonify({ - "success": True, - "campaign_id": campaign_id, - "status": "active", - "pacing": pacing_data, - }) + return jsonify( + { + "success": True, + "campaign_id": campaign_id, + "status": "active", + "pacing": pacing_data, + } + ) # -- API: Pause campaign (Stage 6 control) ----------------------------- @@ -1012,11 +1059,13 @@ def api_pause_campaign(): except (ValueError, TypeError, OSError) as exc: return jsonify({"success": False, "error": str(exc)}), 400 - return jsonify({ - "success": True, - "campaign_id": campaign_id, - "status": "paused", - }) + return jsonify( + { + "success": True, + "campaign_id": campaign_id, + "status": "paused", + } + ) # -- API: Resume campaign (Stage 6 control) ---------------------------- @@ -1039,11 +1088,13 @@ def api_resume_campaign(): except (ValueError, TypeError, OSError) as exc: return jsonify({"success": False, "error": str(exc)}), 400 - return jsonify({ - "success": True, - "campaign_id": campaign_id, - "status": "active", - }) + return jsonify( + { + "success": True, + "campaign_id": campaign_id, + "status": "active", + } + ) # -- API: Complete campaign (Stage 6 control) -------------------------- @@ -1066,11 +1117,13 @@ def api_complete_campaign(): except (ValueError, TypeError, OSError) as exc: return jsonify({"success": False, "error": str(exc)}), 400 - return jsonify({ - "success": True, - "campaign_id": campaign_id, - "status": "completed", - }) + return jsonify( + { + "success": True, + "campaign_id": campaign_id, + "status": "completed", + } + ) # -- API: Campaign report (Stage 5) ------------------------------------ @@ -1083,60 +1136,64 @@ def api_campaign_report(campaign_id: str): try: report = reporter.full_report(campaign_id) - return jsonify({ - "campaign_id": campaign_id, - "status_summary": report.status_summary._to_dict(), - "pacing_dashboard": report.pacing_dashboard._to_dict(), - "creative_performance": report.creative_performance._to_dict(), - "deal_report": report.deal_report._to_dict(), - }) + return jsonify( + { + "campaign_id": campaign_id, + "status_summary": report.status_summary._to_dict(), + "pacing_dashboard": report.pacing_dashboard._to_dict(), + "creative_performance": report.creative_performance._to_dict(), + "deal_report": report.deal_report._to_dict(), + } + ) except (ValueError, TypeError, KeyError, OSError) as exc: logger.warning("Report generation failed: %s", exc) # Fall back to basic campaign data - return jsonify({ - "campaign_id": campaign_id, - "status_summary": { - "campaign_id": campaign_id, - "campaign_name": campaign["campaign_name"], - "advertiser_id": campaign["advertiser_id"], - "status": campaign["status"].lower(), - "total_budget": campaign["total_budget"], - "currency": campaign.get("currency", "USD"), - "total_spend": 0.0, - "delivery_pct": 0.0, - "pacing_pct": 0.0, - "flight_start": campaign["flight_start"], - "flight_end": campaign["flight_end"], - "channels": [], - }, - "pacing_dashboard": { - "campaign_id": campaign_id, - "total_budget": campaign["total_budget"], - "total_spend": 0.0, - "expected_spend": 0.0, - "pacing_pct": 0.0, - "deviation_pct": 0.0, - "channel_pacing": [], - "alerts": [], - }, - "creative_performance": { - "campaign_id": campaign_id, - "creatives": [], - "total_assets": 0, - "valid_assets": 0, - "pending_assets": 0, - "invalid_assets": 0, - }, - "deal_report": { + return jsonify( + { "campaign_id": campaign_id, - "deals": [], - "total_deals": 0, - "total_spend": 0.0, - "total_impressions": 0, - "avg_fill_rate": 0.0, - "avg_win_rate": 0.0, - }, - }) + "status_summary": { + "campaign_id": campaign_id, + "campaign_name": campaign["campaign_name"], + "advertiser_id": campaign["advertiser_id"], + "status": campaign["status"].lower(), + "total_budget": campaign["total_budget"], + "currency": campaign.get("currency", "USD"), + "total_spend": 0.0, + "delivery_pct": 0.0, + "pacing_pct": 0.0, + "flight_start": campaign["flight_start"], + "flight_end": campaign["flight_end"], + "channels": [], + }, + "pacing_dashboard": { + "campaign_id": campaign_id, + "total_budget": campaign["total_budget"], + "total_spend": 0.0, + "expected_spend": 0.0, + "pacing_pct": 0.0, + "deviation_pct": 0.0, + "channel_pacing": [], + "alerts": [], + }, + "creative_performance": { + "campaign_id": campaign_id, + "creatives": [], + "total_assets": 0, + "valid_assets": 0, + "pending_assets": 0, + "invalid_assets": 0, + }, + "deal_report": { + "campaign_id": campaign_id, + "deals": [], + "total_deals": 0, + "total_spend": 0.0, + "total_impressions": 0, + "avg_fill_rate": 0.0, + "avg_win_rate": 0.0, + }, + } + ) # -- API: Events ------------------------------------------------------- @@ -1150,19 +1207,21 @@ def api_events(): if campaign_id: events = [e for e in events if e.campaign_id == campaign_id] - return jsonify({ - "events": [ - { - "event_id": e.event_id, - "event_type": e.event_type.value, - "campaign_id": e.campaign_id, - "timestamp": e.timestamp.isoformat() if e.timestamp else "", - "payload": e.payload, - } - for e in events - ], - "count": len(events), - }) + return jsonify( + { + "events": [ + { + "event_id": e.event_id, + "event_type": e.event_type.value, + "campaign_id": e.campaign_id, + "timestamp": e.timestamp.isoformat() if e.timestamp else "", + "payload": e.payload, + } + for e in events + ], + "count": len(events), + } + ) # -- API: List campaigns ----------------------------------------------- @@ -1170,18 +1229,20 @@ def api_events(): def api_list_campaigns(): """List all campaigns.""" campaigns = campaign_store.list_campaigns(limit=50) - return jsonify({ - "campaigns": [ - { - "campaign_id": c["campaign_id"], - "campaign_name": c["campaign_name"], - "status": c["status"].lower(), - "total_budget": c["total_budget"], - "advertiser_id": c["advertiser_id"], - } - for c in campaigns - ], - }) + return jsonify( + { + "campaigns": [ + { + "campaign_id": c["campaign_id"], + "campaign_name": c["campaign_name"], + "status": c["status"].lower(), + "total_budget": c["total_budget"], + "advertiser_id": c["advertiser_id"], + } + for c in campaigns + ], + } + ) # --------------------------------------------------------------------------- @@ -1214,9 +1275,13 @@ def _run_headless(database_url: str, sample_index: int = 0, output: str = "json" print(json.dumps({"error": "no sample briefs available"})) return 2 if not 0 <= sample_index < len(samples): - print(json.dumps({ - "error": f"sample_index {sample_index} out of range; have {len(samples)}", - })) + print( + json.dumps( + { + "error": f"sample_index {sample_index} out of range; have {len(samples)}", + } + ) + ) return 2 brief = samples[sample_index]["brief"] @@ -1225,8 +1290,9 @@ def _emit(stage: str, payload: dict[str, Any]) -> None: if output == "json": print(json.dumps({"stage": stage, **payload}, default=str)) else: - print(f"[{stage}] {payload.get('status', '?')} " - f"campaign={payload.get('campaign_id', '-')}") + print( + f"[{stage}] {payload.get('status', '?')} campaign={payload.get('campaign_id', '-')}" + ) client = app.test_client() @@ -1234,7 +1300,9 @@ def _emit(stage: str, payload: dict[str, Any]) -> None: r = client.post("/api/submit-brief", json=brief) if r.status_code != 200 or not (r.get_json() or {}).get("success"): body = r.get_json() or {} - _emit("1-brief", {"status": "failed", "http": r.status_code, "error": body.get("error", "?")}) # noqa: E501 + _emit( + "1-brief", {"status": "failed", "http": r.status_code, "error": body.get("error", "?")} + ) # noqa: E501 return 1 campaign_id = r.get_json().get("campaign_id") _emit("1-brief", {"status": "submitted", "campaign_id": campaign_id}) @@ -1244,12 +1312,15 @@ def _approve(stage: str, route: str) -> int: rr = client.post(route, json={"campaign_id": campaign_id}) body = rr.get_json() or {} ok = rr.status_code == 200 and body.get("success") is True - _emit(stage, { - "status": "approved" if ok else "failed", - "campaign_id": campaign_id, - "http": rr.status_code, - **({"error": body.get("error", "?")} if not ok else {}), - }) + _emit( + stage, + { + "status": "approved" if ok else "failed", + "campaign_id": campaign_id, + "http": rr.status_code, + **({"error": body.get("error", "?")} if not ok else {}), + }, + ) return 0 if ok else 1 # Stage 2: Approve plan @@ -1271,12 +1342,15 @@ def _approve(stage: str, route: str) -> int: # Stage 6: Final report r = client.get(f"/api/campaign/{campaign_id}/report") body = r.get_json() if r.status_code == 200 else {} - _emit("6-report", { - "status": "ok" if r.status_code == 200 else "failed", - "campaign_id": campaign_id, - "http": r.status_code, - "report_keys": list(body.keys()) if isinstance(body, dict) else None, - }) + _emit( + "6-report", + { + "status": "ok" if r.status_code == 200 else "failed", + "campaign_id": campaign_id, + "http": r.status_code, + "report_keys": list(body.keys()) if isinstance(body, dict) else None, + }, + ) return 0 if r.status_code == 200 else 1 @@ -1291,14 +1365,24 @@ def main(argv: list[str] | None = None) -> int: import argparse parser = argparse.ArgumentParser(prog="campaign_demo") - parser.add_argument("--headless", action="store_true", - help="Run all stages programmatically (no Flask server, no browser)") - parser.add_argument("--json", action="store_true", - help="With --headless: emit one JSON object per stage to stdout (default)") - parser.add_argument("--summary", action="store_true", - help="With --headless: emit one short human-readable line per stage") - parser.add_argument("--sample-index", type=int, default=0, - help="With --headless: which sample brief (0-based)") + parser.add_argument( + "--headless", + action="store_true", + help="Run all stages programmatically (no Flask server, no browser)", + ) + parser.add_argument( + "--json", + action="store_true", + help="With --headless: emit one JSON object per stage to stdout (default)", + ) + parser.add_argument( + "--summary", + action="store_true", + help="With --headless: emit one short human-readable line per stage", + ) + parser.add_argument( + "--sample-index", type=int, default=0, help="With --headless: which sample brief (0-based)" + ) args = parser.parse_args(argv) logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") diff --git a/src/ad_buyer/eval/embedding_eval.py b/src/ad_buyer/eval/embedding_eval.py index 1c6f7ae..997dbd5 100644 --- a/src/ad_buyer/eval/embedding_eval.py +++ b/src/ad_buyer/eval/embedding_eval.py @@ -28,10 +28,10 @@ # semantically related pairs (so distinctiveness has signal). EMBEDDING_EVAL_FIXTURES: list[dict[str, Any]] = [ {"name": "auto_intenders", "interest": "auto", "age": "25-54", "income": "high"}, - {"name": "auto_owners", "interest": "auto", "age": "35-65", "income": "high"}, - {"name": "sports_fans", "interest": "sports", "age": "18-44"}, - {"name": "news_readers", "interest": "news", "age": "35-65"}, - {"name": "young_gamers", "interest": "gaming", "age": "18-24"}, + {"name": "auto_owners", "interest": "auto", "age": "35-65", "income": "high"}, + {"name": "sports_fans", "interest": "sports", "age": "18-44"}, + {"name": "news_readers", "interest": "news", "age": "35-65"}, + {"name": "young_gamers", "interest": "gaming", "age": "18-24"}, ] @@ -41,10 +41,10 @@ class PerModeMetrics: mode: str n_fixtures: int - deterministic: bool # repeat-call returns same vector for each fixture - dimension: int # all fixtures produce the same dim - distinctiveness: float # mean pairwise cosine distance across fixtures - provenance: str # provenance reported by the client + deterministic: bool # repeat-call returns same vector for each fixture + dimension: int # all fixtures produce the same dim + distinctiveness: float # mean pairwise cosine distance across fixtures + provenance: str # provenance reported by the client def as_dict(self) -> dict[str, Any]: return { @@ -95,17 +95,11 @@ def _eval_single_mode( client = UCPClient() with patch.object(settings, "embedding_mode", mode): # First pass: gather vectors - first = [ - client.create_query_embedding_with_provenance(f) for f in fixtures - ] + first = [client.create_query_embedding_with_provenance(f) for f in fixtures] # Second pass: gather again to check determinism - second = [ - client.create_query_embedding_with_provenance(f) for f in fixtures - ] + second = [client.create_query_embedding_with_provenance(f) for f in fixtures] - deterministic = all( - f.embedding.vector == s.embedding.vector for f, s in zip(first, second) - ) + deterministic = all(f.embedding.vector == s.embedding.vector for f, s in zip(first, second)) dims = {len(r.embedding.vector) for r in first} dimension = dims.pop() if len(dims) == 1 else -1 @@ -114,9 +108,7 @@ def _eval_single_mode( distances: list[float] = [] for i in range(len(first)): for j in range(i + 1, len(first)): - distances.append( - _cosine_distance(first[i].embedding.vector, first[j].embedding.vector) - ) + distances.append(_cosine_distance(first[i].embedding.vector, first[j].embedding.vector)) distinctiveness = sum(distances) / len(distances) if distances else 0.0 # Provenance: should be consistent across fixtures within a mode. diff --git a/src/ad_buyer/flows/buyer_deal_flow.py b/src/ad_buyer/flows/buyer_deal_flow.py index 2e716b2..ca73d66 100644 --- a/src/ad_buyer/flows/buyer_deal_flow.py +++ b/src/ad_buyer/flows/buyer_deal_flow.py @@ -316,8 +316,7 @@ def receive_request(self) -> dict[str, Any]: self.state.audience_plan = planner_result.plan if planner_result.plan is not None: logger.info( - "buyer_deal_flow: audience plan resolved " - "(audience_plan_id=%s)", + "buyer_deal_flow: audience plan resolved (audience_plan_id=%s)", planner_result.plan.audience_plan_id, ) except Exception as e: # noqa: BLE001 - audience is additive; do not abort the deal flow @@ -325,8 +324,7 @@ def receive_request(self) -> dict[str, Any]: # not break the deal flow -- record the warning and keep # going audience-blind so legacy callers see no regression. logger.warning( - "buyer_deal_flow: audience planner failed (%s); " - "continuing audience-blind", + "buyer_deal_flow: audience planner failed (%s); continuing audience-blind", e, ) self.state.errors.append(f"Audience planner warning: {e}") @@ -439,12 +437,12 @@ def evaluate_and_select(self, discovery_result: dict[str, Any]) -> dict[str, Any for the following request: {self.state.request} Discovery results: -{discovery_result.get('discovery_result', 'No results')} +{discovery_result.get("discovery_result", "No results")} Criteria: - Deal type: {self.state.deal_type.value} -- Max CPM: {self.state.max_cpm or 'No limit'} -- Volume: {self.state.impressions or 'Flexible'} +- Max CPM: {self.state.max_cpm or "No limit"} +- Volume: {self.state.impressions or "Flexible"} Return the product_id of the best matching product and explain why.""", expected_output="Product ID and selection rationale", @@ -588,9 +586,7 @@ def get_status(self) -> dict[str, Any]: "request": self.state.request, "deal_type": self.state.deal_type.value, "access_tier": ( - self._buyer_context.get_access_tier().value - if self._buyer_context - else "unknown" + self._buyer_context.get_access_tier().value if self._buyer_context else "unknown" ), "selected_product_id": self.state.selected_product_id, "deal_response": self.state.deal_response, @@ -598,9 +594,7 @@ def get_status(self) -> dict[str, Any]: "updated_at": self.state.updated_at.isoformat(), # Surface the audience_plan_id when one was resolved so callers # can correlate logs / audit trails by hash (proposal §5.1). - "audience_plan_id": ( - plan.audience_plan_id if plan is not None else None - ), + "audience_plan_id": (plan.audience_plan_id if plan is not None else None), } def get_audience_planner_result(self) -> AudiencePlannerResult | None: @@ -654,6 +648,7 @@ async def run_buyer_deal_flow( # Resolve server URL from Settings if not provided if base_url is None: from ..config.settings import get_settings + base_url = get_settings().iab_server_url # Create buyer context diff --git a/src/ad_buyer/interfaces/api/main.py b/src/ad_buyer/interfaces/api/main.py index 5450195..6c45e04 100644 --- a/src/ad_buyer/interfaces/api/main.py +++ b/src/ad_buyer/interfaces/api/main.py @@ -35,6 +35,7 @@ def _current_settings(): """ return sys.modules[__name__].settings + app = FastAPI( title="Ad Buyer Agent API", description=( @@ -469,13 +470,15 @@ async def list_bookings( for job_id, job in jobs.items(): if status and job["status"] != status: continue - job_list.append({ - "job_id": job_id, - "status": job["status"], - "campaign_name": job["brief"].get("name"), - "budget": job["brief"].get("budget"), - "created_at": job["created_at"], - }) + job_list.append( + { + "job_id": job_id, + "status": job["status"], + "campaign_name": job["brief"].get("name"), + "budget": job["brief"].get("budget"), + "created_at": job["created_at"], + } + ) # Sort by created_at descending job_list.sort(key=lambda x: x["created_at"], reverse=True) @@ -574,9 +577,7 @@ async def _run_booking_flow(job_id: str, request: BookingRequest) -> None: job["budget_allocations"] = { k: v.model_dump() for k, v in flow.state.budget_allocations.items() } - job["recommendations"] = [ - r.model_dump() for r in flow.state.pending_approvals - ] + job["recommendations"] = [r.model_dump() for r in flow.state.pending_approvals] if request.auto_approve: flow.approve_all() diff --git a/src/ad_buyer/interfaces/chat/main.py b/src/ad_buyer/interfaces/chat/main.py index d72d27b..211c72a 100644 --- a/src/ad_buyer/interfaces/chat/main.py +++ b/src/ad_buyer/interfaces/chat/main.py @@ -43,6 +43,7 @@ class SellerConnection: def check_health(self) -> bool: """Synchronously check if seller is reachable and discover tools.""" import httpx + try: response = httpx.get(f"{self.url}/health", timeout=5.0) if response.status_code == 200: @@ -111,9 +112,7 @@ def __init__(self, sellers: list[SellerConnection], **kwargs): def _run(self, query: str = "", channel: str = "", max_cpm: float = 0) -> str: """Synchronous wrapper.""" - return asyncio.get_event_loop().run_until_complete( - self._arun(query, channel, max_cpm) - ) + return asyncio.get_event_loop().run_until_complete(self._arun(query, channel, max_cpm)) async def _arun(self, query: str = "", channel: str = "", max_cpm: float = 0) -> str: """Search all sellers asynchronously.""" @@ -140,21 +139,31 @@ async def _arun(self, query: str = "", channel: str = "", max_cpm: float = 0) -> # Apply filters if specified if channel and isinstance(products, list): - products = [p for p in products if p.get("channel", "").lower() == channel.lower()] # noqa: E501 + products = [ + p for p in products if p.get("channel", "").lower() == channel.lower() + ] # noqa: E501 if max_cpm > 0 and isinstance(products, list): - products = [p for p in products if p.get("base_cpm", p.get("floor_cpm", 0)) <= max_cpm] # noqa: E501 - - results.append({ + products = [ + p + for p in products + if p.get("base_cpm", p.get("floor_cpm", 0)) <= max_cpm + ] # noqa: E501 + + results.append( + { + "seller": seller.name, + "url": seller.url, + "products": products, + } + ) + except (OSError, ValueError, KeyError) as e: + results.append( + { "seller": seller.name, "url": seller.url, - "products": products, - }) - except (OSError, ValueError, KeyError) as e: - results.append({ - "seller": seller.name, - "url": seller.url, - "error": str(e), - }) + "error": str(e), + } + ) if not results: return "No sellers connected. Configure SELLER_ENDPOINTS in .env" @@ -171,12 +180,17 @@ async def _arun(self, query: str = "", channel: str = "", max_cpm: float = 0) -> for p in products[:5]: # Limit to 5 per seller name = p.get("name", "Unknown") # Try various price field names - price = p.get("base_cpm", p.get("floor_cpm", p.get("basePrice", p.get("price", "N/A")))) # noqa: E501 + price = p.get( + "base_cpm", + p.get("floor_cpm", p.get("basePrice", p.get("price", "N/A"))), + ) # noqa: E501 channel = p.get("channel", "") publisher = p.get("publisher", "") avail = p.get("available_impressions", 0) - avail_str = f"{avail/1_000_000:.0f}M" if avail else "" - output.append(f" - {name} | {publisher} | {channel} | ${price} CPM | {avail_str} avail") # noqa: E501 + avail_str = f"{avail / 1_000_000:.0f}M" if avail else "" + output.append( + f" - {name} | {publisher} | {channel} | ${price} CPM | {avail_str} avail" + ) # noqa: E501 else: output.append(f" {products}") @@ -186,8 +200,12 @@ async def _arun(self, query: str = "", channel: str = "", max_cpm: float = 0) -> class CallSellerToolInput(BaseModel): """Input for calling any tool on a seller.""" - seller_name: str = Field(..., description="Name of the seller agent (e.g., 'Publisher Seller Agent')") # noqa: E501 - tool_name: str = Field(..., description="Name of the tool to call (e.g., 'book_programmatic_guaranteed')") # noqa: E501 + seller_name: str = Field( + ..., description="Name of the seller agent (e.g., 'Publisher Seller Agent')" + ) # noqa: E501 + tool_name: str = Field( + ..., description="Name of the tool to call (e.g., 'book_programmatic_guaranteed')" + ) # noqa: E501 arguments: str = Field(default="{}", description="JSON string of arguments to pass to the tool") @@ -285,21 +303,46 @@ def __init__(self, sellers: list[SellerConnection], **kwargs): super().__init__(**kwargs) self._sellers = sellers - def _run(self, seller_name: str, product_id: str, impressions: int, cpm_price: float, - start_date: str = "", end_date: str = "", advertiser_name: str = "Demo Advertiser", - campaign_name: str = "Demo Campaign") -> str: + def _run( + self, + seller_name: str, + product_id: str, + impressions: int, + cpm_price: float, + start_date: str = "", + end_date: str = "", + advertiser_name: str = "Demo Advertiser", + campaign_name: str = "Demo Campaign", + ) -> str: try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete( - self._arun(seller_name, product_id, impressions, cpm_price, start_date, end_date, advertiser_name, campaign_name) # noqa: E501 + self._arun( + seller_name, + product_id, + impressions, + cpm_price, + start_date, + end_date, + advertiser_name, + campaign_name, + ) # noqa: E501 ) - async def _arun(self, seller_name: str, product_id: str, impressions: int, cpm_price: float, - start_date: str = "", end_date: str = "", advertiser_name: str = "Demo Advertiser", # noqa: E501 - campaign_name: str = "Demo Campaign") -> str: + async def _arun( + self, + seller_name: str, + product_id: str, + impressions: int, + cpm_price: float, + start_date: str = "", + end_date: str = "", + advertiser_name: str = "Demo Advertiser", # noqa: E501 + campaign_name: str = "Demo Campaign", + ) -> str: import json as json_module from datetime import datetime, timedelta @@ -335,7 +378,7 @@ async def _arun(self, seller_name: str, product_id: str, impressions: int, cpm_p result = await client.call_tool("book_programmatic_guaranteed", args) if result.success: - return f"✓ PG DEAL BOOKED SUCCESSFULLY!\n\nSeller: {seller.name}\nProduct: {product_id}\nImpressions: {impressions:,}\nCPM: ${cpm_price}\nTotal Cost: ${(impressions/1000)*cpm_price:,.2f}\n\nBooking Details:\n{json_module.dumps(result.data, indent=2)}" # noqa: E501 + return f"✓ PG DEAL BOOKED SUCCESSFULLY!\n\nSeller: {seller.name}\nProduct: {product_id}\nImpressions: {impressions:,}\nCPM: ${cpm_price}\nTotal Cost: ${(impressions / 1000) * cpm_price:,.2f}\n\nBooking Details:\n{json_module.dumps(result.data, indent=2)}" # noqa: E501 else: return f"✗ Failed to book PG deal: {result.error}" except (OSError, ValueError, RuntimeError) as e: @@ -364,8 +407,14 @@ def __init__(self, sellers: list[SellerConnection], **kwargs): super().__init__(**kwargs) self._sellers = sellers - def _run(self, seller_name: str, product_id: str, floor_price: float, - impressions: int = 0, buyer_seat_id: str = "buyer-seat-001") -> str: + def _run( + self, + seller_name: str, + product_id: str, + floor_price: float, + impressions: int = 0, + buyer_seat_id: str = "buyer-seat-001", + ) -> str: try: loop = asyncio.get_event_loop() except RuntimeError: @@ -375,8 +424,14 @@ def _run(self, seller_name: str, product_id: str, floor_price: float, self._arun(seller_name, product_id, floor_price, impressions, buyer_seat_id) ) - async def _arun(self, seller_name: str, product_id: str, floor_price: float, - impressions: int = 0, buyer_seat_id: str = "buyer-seat-001") -> str: + async def _arun( + self, + seller_name: str, + product_id: str, + floor_price: float, + impressions: int = 0, + buyer_seat_id: str = "buyer-seat-001", + ) -> str: import json as json_module # Find seller @@ -402,7 +457,11 @@ async def _arun(self, seller_name: str, product_id: str, floor_price: float, if result.success: deal_data = result.data - deal_id = deal_data.get("deal", {}).get("deal_id", "N/A") if isinstance(deal_data, dict) else "N/A" # noqa: E501 + deal_id = ( + deal_data.get("deal", {}).get("deal_id", "N/A") + if isinstance(deal_data, dict) + else "N/A" + ) # noqa: E501 return f"✓ PMP DEAL CREATED!\n\nDeal ID: {deal_id}\nSeller: {seller.name}\nProduct: {product_id}\nFloor: ${floor_price} CPM\n\nFull Details:\n{json_module.dumps(deal_data, indent=2)}" # noqa: E501 else: return f"✗ Failed to create PMP deal: {result.error}" @@ -500,7 +559,11 @@ def _get_seller_info(self) -> str: lines = [] for i, seller in enumerate(self._sellers, 1): status = "Connected" if seller.connected else f"Failed: {seller.error}" - caps = ", ".join(seller.capabilities.get("tools", [])[:5]) if seller.capabilities else "N/A" # noqa: E501 + caps = ( + ", ".join(seller.capabilities.get("tools", [])[:5]) + if seller.capabilities + else "N/A" + ) # noqa: E501 lines.append(f"{i}. {seller.url}") lines.append(f" Status: {status}") if seller.connected: @@ -517,9 +580,7 @@ def process_message(self, user_message: str) -> str: Returns: The agent's response """ - self.conversation_history.append( - ConversationMessage(role="user", content=user_message) - ) + self.conversation_history.append(ConversationMessage(role="user", content=user_message)) # Build context from conversation history history_text = self._format_history() @@ -565,9 +626,7 @@ def process_message(self, user_message: str) -> str: response = str(result) # Store response - self.conversation_history.append( - ConversationMessage(role="assistant", content=response) - ) + self.conversation_history.append(ConversationMessage(role="assistant", content=response)) return response diff --git a/src/ad_buyer/interfaces/mcp_server.py b/src/ad_buyer/interfaces/mcp_server.py index 2209c0b..b3b167f 100644 --- a/src/ad_buyer/interfaces/mcp_server.py +++ b/src/ad_buyer/interfaces/mcp_server.py @@ -230,7 +230,7 @@ def get_setup_status() -> str: db_url = settings.database_url # Strip sqlite:/// prefix for direct connection test if db_url.startswith("sqlite:///"): - db_path = db_url[len("sqlite:///"):] + db_path = db_url[len("sqlite:///") :] else: db_path = db_url @@ -251,10 +251,7 @@ def get_setup_status() -> str: # Overall setup completeness # Minimum required: seller endpoints + database - setup_complete = ( - checks["seller_endpoints_configured"] - and checks["database_accessible"] - ) + setup_complete = checks["seller_endpoints_configured"] and checks["database_accessible"] result = { "setup_complete": setup_complete, @@ -282,7 +279,7 @@ def health_check() -> str: try: db_url = settings.database_url if db_url.startswith("sqlite:///"): - db_path = db_url[len("sqlite:///"):] + db_path = db_url[len("sqlite:///") :] else: db_path = db_url @@ -310,9 +307,7 @@ def health_check() -> str: services["event_bus"] = {"status": "healthy"} # Determine overall status - unhealthy_count = sum( - 1 for s in services.values() if s.get("status") == "unhealthy" - ) + unhealthy_count = sum(1 for s in services.values() if s.get("status") == "unhealthy") if unhealthy_count == 0: overall_status = "healthy" elif unhealthy_count < len(services): @@ -540,16 +535,18 @@ def list_campaigns(status: str | None = None) -> str: campaign_summaries = [] for c in campaigns: - campaign_summaries.append({ - "campaign_id": c["campaign_id"], - "campaign_name": c["campaign_name"], - "advertiser_id": c["advertiser_id"], - "status": c["status"], - "total_budget": c["total_budget"], - "currency": c.get("currency", "USD"), - "flight_start": c["flight_start"], - "flight_end": c["flight_end"], - }) + campaign_summaries.append( + { + "campaign_id": c["campaign_id"], + "campaign_name": c["campaign_name"], + "advertiser_id": c["advertiser_id"], + "status": c["status"], + "total_budget": c["total_budget"], + "currency": c.get("currency", "USD"), + "flight_start": c["flight_start"], + "flight_end": c["flight_end"], + } + ) result = { "total": len(campaign_summaries), @@ -687,13 +684,15 @@ def check_pacing(campaign_id: str) -> str: # Build channel pacing breakdown channel_pacing = [] for ch in latest.channel_snapshots: - channel_pacing.append({ - "channel": ch.channel, - "allocated_budget": ch.allocated_budget, - "spend": ch.spend, - "pacing_pct": ch.pacing_pct, - "impressions": ch.impressions, - }) + channel_pacing.append( + { + "channel": ch.channel, + "allocated_budget": ch.allocated_budget, + "spend": ch.spend, + "pacing_pct": ch.pacing_pct, + "impressions": ch.impressions, + } + ) result = { "campaign_id": campaign_id, @@ -748,22 +747,23 @@ def review_budgets() -> str: # Calculate delivery percentage delivery_pct = (spend / budget * 100.0) if budget > 0 else 0.0 - campaign_budgets.append({ - "campaign_id": c["campaign_id"], - "campaign_name": c["campaign_name"], - "status": c["status"], - "total_budget": budget, - "total_spend": spend, - "delivery_pct": round(delivery_pct, 1), - "currency": c.get("currency", "USD"), - }) + campaign_budgets.append( + { + "campaign_id": c["campaign_id"], + "campaign_name": c["campaign_name"], + "status": c["status"], + "total_budget": budget, + "total_spend": spend, + "delivery_pct": round(delivery_pct, 1), + "currency": c.get("currency", "USD"), + } + ) result = { "total_budget": total_budget, "total_spend": total_spend, "overall_delivery_pct": ( - round(total_spend / total_budget * 100.0, 1) - if total_budget > 0 else 0.0 + round(total_spend / total_budget * 100.0, 1) if total_budget > 0 else 0.0 ), "campaign_count": len(campaign_budgets), "campaigns": campaign_budgets, @@ -819,19 +819,21 @@ def list_deals( deal_summaries = [] for d in deals: - deal_summaries.append({ - "deal_id": d["id"], - "display_name": d.get("display_name") or d.get("product_name") or "(unnamed)", - "status": d.get("status", "unknown"), - "deal_type": d.get("deal_type", "unknown"), - "media_type": d.get("media_type"), - "seller_org": d.get("seller_org"), - "seller_domain": d.get("seller_domain"), - "price": d.get("price"), - "impressions": d.get("impressions"), - "flight_start": d.get("flight_start"), - "flight_end": d.get("flight_end"), - }) + deal_summaries.append( + { + "deal_id": d["id"], + "display_name": d.get("display_name") or d.get("product_name") or "(unnamed)", + "status": d.get("status", "unknown"), + "deal_type": d.get("deal_type", "unknown"), + "media_type": d.get("media_type"), + "seller_org": d.get("seller_org"), + "seller_domain": d.get("seller_domain"), + "price": d.get("price"), + "impressions": d.get("impressions"), + "flight_start": d.get("flight_start"), + "flight_end": d.get("flight_end"), + } + ) result = { "total": len(deal_summaries), @@ -890,21 +892,21 @@ def search_deals(query: str) -> str: if value and query_lower in str(value).lower(): matched_fields.append(field_label) if matched_fields: - matches.append({ - "deal_id": deal["id"], - "display_name": ( - deal.get("display_name") - or deal.get("product_name") - or "(unnamed)" - ), - "status": deal.get("status", "unknown"), - "deal_type": deal.get("deal_type", "unknown"), - "media_type": deal.get("media_type"), - "seller_org": deal.get("seller_org"), - "seller_domain": deal.get("seller_domain"), - "price": deal.get("price"), - "matched_in": matched_fields, - }) + matches.append( + { + "deal_id": deal["id"], + "display_name": ( + deal.get("display_name") or deal.get("product_name") or "(unnamed)" + ), + "status": deal.get("status", "unknown"), + "deal_type": deal.get("deal_type", "unknown"), + "media_type": deal.get("media_type"), + "seller_org": deal.get("seller_org"), + "seller_domain": deal.get("seller_domain"), + "price": deal.get("price"), + "matched_in": matched_fields, + } + ) result = { "total": len(matches), @@ -948,17 +950,19 @@ async def discover_sellers(capability: str | None = None) -> str: seller_list = [] for seller in sellers: - seller_list.append({ - "agent_id": seller.agent_id, - "name": seller.name, - "url": seller.url, - "capabilities": [ - {"name": c.name, "description": c.description, "tags": c.tags} - for c in seller.capabilities - ], - "trust_level": seller.trust_level.value, - "protocols": seller.protocols, - }) + seller_list.append( + { + "agent_id": seller.agent_id, + "name": seller.name, + "url": seller.url, + "capabilities": [ + {"name": c.name, "description": c.description, "tags": c.tags} + for c in seller.capabilities + ], + "trust_level": seller.trust_level.value, + "protocols": seller.protocols, + } + ) result = { "total": len(seller_list), @@ -1001,18 +1005,20 @@ async def get_seller_media_kit(seller_url: str) -> str: packages = [] for pkg in kit.all_packages: - packages.append({ - "package_id": pkg.package_id, - "name": pkg.name, - "description": pkg.description, - "ad_formats": pkg.ad_formats, - "device_types": pkg.device_types, - "price_range": pkg.price_range, - "rate_type": pkg.rate_type, - "is_featured": pkg.is_featured, - "geo_targets": pkg.geo_targets, - "tags": pkg.tags, - }) + packages.append( + { + "package_id": pkg.package_id, + "name": pkg.name, + "description": pkg.description, + "ad_formats": pkg.ad_formats, + "device_types": pkg.device_types, + "price_range": pkg.price_range, + "rate_type": pkg.rate_type, + "is_featured": pkg.is_featured, + "geo_targets": pkg.geo_targets, + "tags": pkg.tags, + } + ) result = { "seller_name": kit.seller_name, @@ -1034,7 +1040,9 @@ async def get_seller_media_kit(seller_url: str) -> str: except Exception as exc: logger.warning( - "Unexpected error fetching media kit from %s: %s", seller_url, exc, + "Unexpected error fetching media kit from %s: %s", + seller_url, + exc, ) result = { "error": f"Unexpected error: {exc}", @@ -1074,34 +1082,40 @@ async def compare_sellers(seller_urls: list[str]) -> str: packages = [] for pkg in kit.all_packages: seller_formats.update(pkg.ad_formats) - packages.append({ - "package_id": pkg.package_id, - "name": pkg.name, - "price_range": pkg.price_range, - "ad_formats": pkg.ad_formats, - "rate_type": pkg.rate_type, - }) + packages.append( + { + "package_id": pkg.package_id, + "name": pkg.name, + "price_range": pkg.price_range, + "ad_formats": pkg.ad_formats, + "rate_type": pkg.rate_type, + } + ) all_ad_formats.update(seller_formats) total_packages += len(packages) - sellers_data.append({ - "seller_url": url, - "seller_name": kit.seller_name, - "total_packages": len(packages), - "ad_formats": sorted(seller_formats), - "packages": packages, - }) + sellers_data.append( + { + "seller_url": url, + "seller_name": kit.seller_name, + "total_packages": len(packages), + "ad_formats": sorted(seller_formats), + "packages": packages, + } + ) except (MediaKitError, Exception) as exc: logger.warning("Failed to fetch media kit from %s: %s", url, exc) - sellers_data.append({ - "seller_url": url, - "error": f"Failed to fetch media kit: {exc}", - "total_packages": 0, - "ad_formats": [], - "packages": [], - }) + sellers_data.append( + { + "seller_url": url, + "error": f"Failed to fetch media kit: {exc}", + "total_packages": 0, + "ad_formats": [], + "packages": [], + } + ) result = { "sellers_compared": len(seller_urls), @@ -1109,12 +1123,8 @@ async def compare_sellers(seller_urls: list[str]) -> str: "summary": { "total_packages_across_sellers": total_packages, "all_ad_formats": sorted(all_ad_formats), - "sellers_reachable": sum( - 1 for s in sellers_data if "error" not in s - ), - "sellers_unreachable": sum( - 1 for s in sellers_data if "error" in s - ), + "sellers_reachable": sum(1 for s in sellers_data if "error" not in s), + "sellers_unreachable": sum(1 for s in sellers_data if "error" in s), }, "timestamp": datetime.now(UTC).isoformat(), } @@ -1217,13 +1227,15 @@ def get_negotiation_status(deal_id: str) -> str: round_summaries = [] for r in rounds: - round_summaries.append({ - "round_number": r["round_number"], - "buyer_price": r["buyer_price"], - "seller_price": r["seller_price"], - "action": r["action"], - "rationale": r.get("rationale", ""), - }) + round_summaries.append( + { + "round_number": r["round_number"], + "buyer_price": r["buyer_price"], + "seller_price": r["seller_price"], + "action": r["action"], + "rationale": r.get("rationale", ""), + } + ) result = { "deal_id": deal_id, @@ -1567,21 +1579,27 @@ def create_deal_manual( tags=tags, ) except (ValueError, TypeError) as exc: - return json.dumps({ - "success": False, - "errors": [str(exc)], - "timestamp": datetime.now(UTC).isoformat(), - }, indent=2) + return json.dumps( + { + "success": False, + "errors": [str(exc)], + "timestamp": datetime.now(UTC).isoformat(), + }, + indent=2, + ) # Validate and prepare entry_result = create_manual_deal(entry) if not entry_result.success: - return json.dumps({ - "success": False, - "errors": entry_result.errors, - "timestamp": datetime.now(UTC).isoformat(), - }, indent=2) + return json.dumps( + { + "success": False, + "errors": entry_result.errors, + "timestamp": datetime.now(UTC).isoformat(), + }, + indent=2, + ) # Save the deal store = _get_deal_store() @@ -1590,9 +1608,7 @@ def create_deal_manual( # Save portfolio metadata tags_json = ( - json.dumps(entry_result.metadata["tags"]) - if entry_result.metadata.get("tags") - else None + json.dumps(entry_result.metadata["tags"]) if entry_result.metadata.get("tags") else None ) store.save_portfolio_metadata( deal_id=deal_id, @@ -1602,12 +1618,15 @@ def create_deal_manual( tags=tags_json, ) - return json.dumps({ - "success": True, - "deal_id": deal_id, - "display_name": display_name, - "timestamp": datetime.now(UTC).isoformat(), - }, indent=2) + return json.dumps( + { + "success": True, + "deal_id": deal_id, + "display_name": display_name, + "timestamp": datetime.now(UTC).isoformat(), + }, + indent=2, + ) finally: if _deal_store_override is None: store.disconnect() @@ -1644,16 +1663,19 @@ def get_portfolio_summary( total = len(deals) if total == 0: - return json.dumps({ - "total_deals": 0, - "total_value": 0.0, - "by_status": {}, - "by_deal_type": {}, - "by_media_type": {}, - "top_sellers": [], - "expiring_deals": [], - "timestamp": datetime.now(UTC).isoformat(), - }, indent=2) + return json.dumps( + { + "total_deals": 0, + "total_value": 0.0, + "by_status": {}, + "by_deal_type": {}, + "by_media_type": {}, + "top_sellers": [], + "expiring_deals": [], + "timestamp": datetime.now(UTC).isoformat(), + }, + indent=2, + ) # Count by status status_counts: dict[str, int] = {} @@ -1679,7 +1701,9 @@ def get_portfolio_summary( seller = deal.get("seller_org") or deal.get("seller_domain") or "Unknown" seller_counts[seller] = seller_counts.get(seller, 0) + 1 top_sellers = sorted( - seller_counts.items(), key=lambda x: x[1], reverse=True, + seller_counts.items(), + key=lambda x: x[1], + reverse=True, )[:top_sellers_count] # Total portfolio value: sum of (price * impressions / 1000) @@ -1702,15 +1726,15 @@ def get_portfolio_summary( continue flight_end = deal.get("flight_end") if flight_end and now_str <= flight_end <= cutoff_str: - expiring_deals.append({ - "deal_id": deal["id"], - "display_name": ( - deal.get("display_name") - or deal.get("product_name") - or "(unnamed)" - ), - "flight_end": flight_end, - }) + expiring_deals.append( + { + "deal_id": deal["id"], + "display_name": ( + deal.get("display_name") or deal.get("product_name") or "(unnamed)" + ), + "flight_end": flight_end, + } + ) result = { "total_deals": total, @@ -1718,10 +1742,7 @@ def get_portfolio_summary( "by_status": status_counts, "by_deal_type": type_counts, "by_media_type": media_counts, - "top_sellers": [ - {"seller": name, "deal_count": count} - for name, count in top_sellers - ], + "top_sellers": [{"seller": name, "deal_count": count} for name, count in top_sellers], "expiring_deals": expiring_deals, "timestamp": datetime.now(UTC).isoformat(), } @@ -1751,16 +1772,18 @@ def list_active_negotiations() -> str: deal_id = d["id"] rounds = store.get_negotiation_history(deal_id) - negotiations.append({ - "deal_id": deal_id, - "product_id": d.get("product_id", ""), - "product_name": d.get("product_name", ""), - "seller_url": d.get("seller_url", ""), - "price": d.get("price"), - "status": d.get("status", "negotiating"), - "rounds_count": len(rounds), - "created_at": d.get("created_at", ""), - }) + negotiations.append( + { + "deal_id": deal_id, + "product_id": d.get("product_id", ""), + "product_name": d.get("product_name", ""), + "seller_url": d.get("seller_url", ""), + "price": d.get("price"), + "status": d.get("status", "negotiating"), + "rounds_count": len(rounds), + "created_at": d.get("created_at", ""), + } + ) result = { "total": len(negotiations), @@ -1910,14 +1933,16 @@ def list_pending_approvals(campaign_id: str | None = None) -> str: pending = [] for row in rows: - pending.append({ - "approval_request_id": row["approval_request_id"], - "campaign_id": row["campaign_id"], - "stage": row["stage"], - "status": row["status"], - "requested_at": row["requested_at"], - "context": json.loads(row.get("context") or "{}"), - }) + pending.append( + { + "approval_request_id": row["approval_request_id"], + "campaign_id": row["campaign_id"], + "stage": row["stage"], + "status": row["status"], + "requested_at": row["requested_at"], + "context": json.loads(row.get("context") or "{}"), + } + ) result = { "total": len(pending), @@ -2029,10 +2054,12 @@ def list_api_keys() -> str: keys = [] for seller_url in sellers: raw_key = key_store.get_key(seller_url) - keys.append({ - "seller_url": seller_url, - "masked_key": _mask_key(raw_key) if raw_key else "****", - }) + keys.append( + { + "seller_url": seller_url, + "masked_key": _mask_key(raw_key) if raw_key else "****", + } + ) result = { "total": len(keys), @@ -2130,24 +2157,28 @@ def list_templates(template_type: str | None = None) -> str: if template_type is None or template_type == "deal": raw = store.list_deal_templates() for t in raw: - deal_templates.append({ - "template_id": t["id"], - "name": t["name"], - "deal_type_pref": t.get("deal_type_pref"), - "advertiser_id": t.get("advertiser_id"), - "max_cpm": t.get("max_cpm"), - "created_at": t.get("created_at"), - }) + deal_templates.append( + { + "template_id": t["id"], + "name": t["name"], + "deal_type_pref": t.get("deal_type_pref"), + "advertiser_id": t.get("advertiser_id"), + "max_cpm": t.get("max_cpm"), + "created_at": t.get("created_at"), + } + ) if template_type is None or template_type == "supply_path": raw = store.list_supply_path_templates() for t in raw: - spo_templates.append({ - "template_id": t["id"], - "name": t["name"], - "max_reseller_hops": t.get("max_reseller_hops"), - "created_at": t.get("created_at"), - }) + spo_templates.append( + { + "template_id": t["id"], + "name": t["name"], + "max_reseller_hops": t.get("max_reseller_hops"), + "created_at": t.get("created_at"), + } + ) result = { "deal_templates": deal_templates, @@ -2499,25 +2530,29 @@ def get_pacing_report(campaign_id: str) -> str: # Build channel pacing with full details channel_pacing = [] for ch in dashboard.channel_pacing: - channel_pacing.append({ - "channel": ch.channel, - "allocated_budget": ch.allocated_budget, - "spend": ch.spend, - "pacing_pct": ch.pacing_pct, - "impressions": ch.impressions, - "effective_cpm": ch.effective_cpm, - "fill_rate": ch.fill_rate, - }) + channel_pacing.append( + { + "channel": ch.channel, + "allocated_budget": ch.allocated_budget, + "spend": ch.spend, + "pacing_pct": ch.pacing_pct, + "impressions": ch.impressions, + "effective_cpm": ch.effective_cpm, + "fill_rate": ch.fill_rate, + } + ) # Build alerts alerts = [] for alert in dashboard.alerts: - alerts.append({ - "severity": alert.severity, - "message": alert.message, - "channel": alert.channel, - "deviation_pct": alert.deviation_pct, - }) + alerts.append( + { + "severity": alert.severity, + "message": alert.message, + "channel": alert.channel, + "deviation_pct": alert.deviation_pct, + } + ) result = { "campaign_id": campaign_id, @@ -2565,6 +2600,7 @@ def _get_ssp_connector_class(name: str) -> type | None: if class_name is None: return None import sys + module = sys.modules[__name__] return getattr(module, class_name, None) @@ -2591,12 +2627,14 @@ def list_ssp_connectors() -> str: continue instance = cls() required = instance.get_required_config() - connectors.append({ - "name": name, - "display_name": instance.ssp_name, - "configured": instance.is_configured(), - "required_env_vars": required, - }) + connectors.append( + { + "name": name, + "display_name": instance.ssp_name, + "configured": instance.is_configured(), + "required_env_vars": required, + } + ) result = { "total": len(connectors), @@ -2635,10 +2673,7 @@ def import_deals_ssp(ssp_name: str) -> str: known = ", ".join(sorted(_SSP_CLASS_NAMES.keys())) return json.dumps( { - "error": ( - f"Unknown SSP connector: '{ssp_name}'. " - f"Known connectors: {known}" - ), + "error": (f"Unknown SSP connector: '{ssp_name}'. Known connectors: {known}"), "timestamp": datetime.now(UTC).isoformat(), }, indent=2, @@ -2718,10 +2753,7 @@ def test_ssp_connection(ssp_name: str) -> str: known = ", ".join(sorted(_SSP_CLASS_NAMES.keys())) return json.dumps( { - "error": ( - f"Unknown SSP connector: '{ssp_name}'. " - f"Known connectors: {known}" - ), + "error": (f"Unknown SSP connector: '{ssp_name}'. Known connectors: {known}"), "timestamp": datetime.now(UTC).isoformat(), }, indent=2, @@ -2768,112 +2800,132 @@ def test_ssp_connection(ssp_name: str) -> str: @mcp.prompt(name="setup", description="First-time guided setup wizard") async def setup_prompt() -> list[Message]: - return [Message( - role="user", - content="Check my setup status and walk me through configuring everything " - "that's incomplete. Go step by step through all 8 wizard steps: " - "deployment, seller connections, credentials, buyer identity, deal " - "preferences, campaign defaults, approval gates, and review. " - "Ask me one question at a time.", - )] + return [ + Message( + role="user", + content="Check my setup status and walk me through configuring everything " + "that's incomplete. Go step by step through all 8 wizard steps: " + "deployment, seller connections, credentials, buyer identity, deal " + "preferences, campaign defaults, approval gates, and review. " + "Ask me one question at a time.", + ) + ] @mcp.prompt(name="status", description="Configuration and health overview") async def status_prompt() -> list[Message]: - return [Message( - role="user", - content="Show me a complete status overview: setup state, system health, " - "seller connections, database status, and any issues that need " - "attention.", - )] + return [ + Message( + role="user", + content="Show me a complete status overview: setup state, system health, " + "seller connections, database status, and any issues that need " + "attention.", + ) + ] @mcp.prompt(name="campaigns", description="Campaign portfolio with budget pacing") async def campaigns_prompt() -> list[Message]: - return [Message( - role="user", - content="Show me all my campaigns with their current status and budget " - "pacing. Highlight any campaigns that are behind or ahead on " - "pacing, and flag anything that needs attention. Include a budget " - "summary across all campaigns.", - )] + return [ + Message( + role="user", + content="Show me all my campaigns with their current status and budget " + "pacing. Highlight any campaigns that are behind or ahead on " + "pacing, and flag anything that needs attention. Include a budget " + "summary across all campaigns.", + ) + ] @mcp.prompt(name="deals", description="Deal portfolio dashboard") async def deals_prompt() -> list[Message]: - return [Message( - role="user", - content="Give me a full dashboard of my deal portfolio: total deals, " - "breakdown by status and deal type, top sellers, portfolio value, " - "and any deals expiring in the next 30 days. Include recent " - "activity.", - )] + return [ + Message( + role="user", + content="Give me a full dashboard of my deal portfolio: total deals, " + "breakdown by status and deal type, top sellers, portfolio value, " + "and any deals expiring in the next 30 days. Include recent " + "activity.", + ) + ] @mcp.prompt(name="discover", description="Find and compare seller agents") async def discover_prompt() -> list[Message]: - return [Message( - role="user", - content="Search the IAB registry for available seller agents. Show me " - "who's out there, what they offer, and their capabilities. If I'm " - "interested in specific sellers, help me compare their media kits " - "and pricing side by side.", - )] + return [ + Message( + role="user", + content="Search the IAB registry for available seller agents. Show me " + "who's out there, what they offer, and their capabilities. If I'm " + "interested in specific sellers, help me compare their media kits " + "and pricing side by side.", + ) + ] @mcp.prompt(name="negotiate", description="Negotiation status and actions") async def negotiate_prompt() -> list[Message]: - return [Message( - role="user", - content="Show me all active negotiations: where each one stands, how many " - "rounds we've been through, the current price positions, and what " - "action is needed next. If there are no active negotiations, help " - "me start one by discovering sellers and their inventory.", - )] + return [ + Message( + role="user", + content="Show me all active negotiations: where each one stands, how many " + "rounds we've been through, the current price positions, and what " + "action is needed next. If there are no active negotiations, help " + "me start one by discovering sellers and their inventory.", + ) + ] @mcp.prompt(name="orders", description="Active orders and execution status") async def orders_prompt() -> list[Message]: - return [Message( - role="user", - content="Show me all my orders: their current status, any pending " - "transitions, and orders that need my action. Group them by " - "status and highlight anything stuck or overdue.", - )] + return [ + Message( + role="user", + content="Show me all my orders: their current status, any pending " + "transitions, and orders that need my action. Group them by " + "status and highlight anything stuck or overdue.", + ) + ] @mcp.prompt(name="approvals", description="Pending approvals queue") async def approvals_prompt() -> list[Message]: - return [Message( - role="user", - content="Show me everything waiting for my approval: pending deal " - "approvals, campaign approvals, and any budget or order changes " - "that need my decision. Most urgent first. For each item, show " - "me the context I need to decide.", - )] + return [ + Message( + role="user", + content="Show me everything waiting for my approval: pending deal " + "approvals, campaign approvals, and any budget or order changes " + "that need my decision. Most urgent first. For each item, show " + "me the context I need to decide.", + ) + ] @mcp.prompt(name="configure", description="Settings, templates, and SSP connectors") async def configure_prompt() -> list[Message]: - return [Message( - role="user", - content="Show me my current configuration: deal and supply path templates, " - "SSP connector status, API keys (masked), and campaign defaults. " - "Help me create new templates, configure SSP connectors, or update " - "settings.", - )] + return [ + Message( + role="user", + content="Show me my current configuration: deal and supply path templates, " + "SSP connector status, API keys (masked), and campaign defaults. " + "Help me create new templates, configure SSP connectors, or update " + "settings.", + ) + ] @mcp.prompt(name="help", description="What can this agent do?") async def help_prompt() -> list[Message]: - return [Message( - role="user", - content="List everything I can do with this buyer agent, organized by " - "category. Include all slash commands with descriptions, and " - "summarize the tool categories: campaigns, deals, seller discovery, " - "negotiation, orders, approvals, templates, reporting, SSP " - "connectors, and API keys.", - )] + return [ + Message( + role="user", + content="List everything I can do with this buyer agent, organized by " + "category. Include all slash commands with descriptions, and " + "summarize the tool categories: campaigns, deals, seller discovery, " + "negotiation, orders, approvals, templates, reporting, SSP " + "connectors, and API keys.", + ) + ] # --------------------------------------------------------------------------- diff --git a/src/ad_buyer/models/audience_plan.py b/src/ad_buyer/models/audience_plan.py index d49e38e..1f4f7de 100644 --- a/src/ad_buyer/models/audience_plan.py +++ b/src/ad_buyer/models/audience_plan.py @@ -59,9 +59,9 @@ class ComplianceContext(BaseModel): default=None, description="Hash or signature carrying any required attestation", ) - embedding_provenance: Literal[ - "local_buyer", "advertiser_supplied", "hosted_external", "mock" - ] | None = Field( + embedding_provenance: ( + Literal["local_buyer", "advertiser_supplied", "hosted_external", "mock"] | None + ) = Field( default=None, description=( "Provenance of the embedding bytes (E2-7 Gap 6). Populated by " @@ -124,9 +124,7 @@ def _validate_compliance_for_agentic(self) -> AudienceRef: """ if self.type == "agentic" and self.compliance_context is None: - raise ValueError( - "AudienceRef.compliance_context is required when type='agentic'" - ) + raise ValueError("AudienceRef.compliance_context is required when type='agentic'") return self @model_validator(mode="after") @@ -137,9 +135,7 @@ def _validate_confidence_provenance(self) -> AudienceRef: """ if self.source == "explicit" and self.confidence is not None: - raise ValueError( - "AudienceRef.confidence must be None when source='explicit'" - ) + raise ValueError("AudienceRef.confidence must be None when source='explicit'") return self @@ -601,10 +597,7 @@ def __init__(self, issues: list[dict[str, Any]]) -> None: if not issues: msg = "Global agentic refs are unsupported (no specific issues)" else: - heads = [ - f"{i['role']}[{i['index']}] id={i['identifier']!r}" - for i in issues - ] + heads = [f"{i['role']}[{i['index']}] id={i['identifier']!r}" for i in issues] msg = ( "Brief carries agentic refs with jurisdiction='GLOBAL', " "which is unsupported until per-jurisdiction consent fan-out " @@ -626,8 +619,7 @@ def __init__(self, issues: list[dict[str, Any]]) -> None: msg = "Content Taxonomy migration required (no specific issues)" else: heads = [ - f"{i['role']}[{i['index']}] id={i['identifier']!r} " - f"version={i['version']!r}" + f"{i['role']}[{i['index']}] id={i['identifier']!r} version={i['version']!r}" for i in issues ] msg = ( diff --git a/src/ad_buyer/orchestration/audience_degradation.py b/src/ad_buyer/orchestration/audience_degradation.py index 5d64202..184c0fa 100644 --- a/src/ad_buyer/orchestration/audience_degradation.py +++ b/src/ad_buyer/orchestration/audience_degradation.py @@ -188,9 +188,7 @@ def legacy_default(cls) -> SellerAudienceCapabilities: supports_constraints=False, supports_extensions=False, supports_exclusions=False, - max_refs_per_role=_MaxRefsPerRole( - primary=1, constraints=0, extensions=0, exclusions=0 - ), + max_refs_per_role=_MaxRefsPerRole(primary=1, constraints=0, extensions=0, exclusions=0), ) @@ -346,8 +344,7 @@ def _trim_to_max( DegradationLogEntry( path=f"{role}[{idx}]", reason=( - f"max_refs_per_role.{role}={max_for_role} exceeded " - f"(plan had {len(refs)} refs)" + f"max_refs_per_role.{role}={max_for_role} exceeded (plan had {len(refs)} refs)" ), original_ref=_ref_dump(ref), action="dropped", @@ -401,9 +398,7 @@ def degrade_plan_for_seller( # Agentic primary is a special case: dropping it leaves the plan with no # primary at all, which is fatal. primary = plan.primary - primary_kept = _filter_refs( - [primary], role="primary", capabilities=capabilities, log=log - ) + primary_kept = _filter_refs([primary], role="primary", capabilities=capabilities, log=log) if not primary_kept: # The most recent log entry describes why the primary was dropped. last_reason = log[-1].reason if log else "primary ref unsupported" @@ -418,9 +413,7 @@ def degrade_plan_for_seller( # ---- constraints ---- if plan.constraints: if not capabilities.supports_constraints: - constraints = _drop_role_unsupported( - plan.constraints, role="constraints", log=log - ) + constraints = _drop_role_unsupported(plan.constraints, role="constraints", log=log) else: constraints = _filter_refs( plan.constraints, @@ -440,9 +433,7 @@ def degrade_plan_for_seller( # ---- extensions ---- if plan.extensions: if not capabilities.supports_extensions: - extensions = _drop_role_unsupported( - plan.extensions, role="extensions", log=log - ) + extensions = _drop_role_unsupported(plan.extensions, role="extensions", log=log) else: extensions = _filter_refs( plan.extensions, @@ -462,9 +453,7 @@ def degrade_plan_for_seller( # ---- exclusions ---- if plan.exclusions: if not capabilities.supports_exclusions: - exclusions = _drop_role_unsupported( - plan.exclusions, role="exclusions", log=log - ) + exclusions = _drop_role_unsupported(plan.exclusions, role="exclusions", log=log) else: exclusions = _filter_refs( plan.exclusions, @@ -534,10 +523,7 @@ def synthesize_capabilities_from_unsupported( A `SellerAudienceCapabilities` with the relevant flags toggled off. """ - caps = ( - base.model_copy(deep=True) if base is not None - else SellerAudienceCapabilities() - ) + caps = base.model_copy(deep=True) if base is not None else SellerAudienceCapabilities() for entry in unsupported: path = (entry.get("path") or "").strip() diff --git a/src/ad_buyer/orchestration/multi_seller.py b/src/ad_buyer/orchestration/multi_seller.py index cdf99aa..c579503 100644 --- a/src/ad_buyer/orchestration/multi_seller.py +++ b/src/ad_buyer/orchestration/multi_seller.py @@ -124,9 +124,7 @@ def _entry_role(entry: DegradationLogEntry) -> str: path = entry.path or "" for prefix, role in _ROLE_PREFIXES: - if path == prefix or path.startswith(f"{prefix}.") or path.startswith( - f"{prefix}[" - ): + if path == prefix or path.startswith(f"{prefix}.") or path.startswith(f"{prefix}["): return role return "primary" @@ -410,17 +408,13 @@ async def _emit( ) await self._event_bus.publish(event) except Exception as exc: # noqa: BLE001 - event emission is fail-open by design - logger.warning( - "Failed to emit event %s: %s", event_type, exc - ) + logger.warning("Failed to emit event %s: %s", event_type, exc) # ------------------------------------------------------------------ # Stage 1: Discover sellers # ------------------------------------------------------------------ - async def discover_sellers( - self, requirements: InventoryRequirements - ) -> list[AgentCard]: + async def discover_sellers(self, requirements: InventoryRequirements) -> list[AgentCard]: """Discover qualifying sellers from the agent registry. Queries the registry for sellers matching the media type and @@ -443,16 +437,10 @@ async def discover_sellers( # Filter out excluded sellers excluded_set = set(requirements.excluded_sellers) - sellers = [ - s for s in sellers - if s.agent_id not in excluded_set - ] + sellers = [s for s in sellers if s.agent_id not in excluded_set] # Filter out blocked sellers - sellers = [ - s for s in sellers - if s.trust_level != TrustLevel.BLOCKED - ] + sellers = [s for s in sellers if s.trust_level != TrustLevel.BLOCKED] # Emit discovery event await self._emit( @@ -519,7 +507,11 @@ async def _request_one(seller: AgentCard) -> SellerQuoteResult: timeout=self._quote_timeout, ) - cpm_display = f"{quote.pricing.final_cpm:.2f}" if quote.pricing.final_cpm is not None else "unavailable" # noqa: E501 + cpm_display = ( + f"{quote.pricing.final_cpm:.2f}" + if quote.pricing.final_cpm is not None + else "unavailable" + ) # noqa: E501 logger.info( "Received quote %s from seller %s (CPM: %s)", quote.quote_id, @@ -537,9 +529,7 @@ async def _request_one(seller: AgentCard) -> SellerQuoteResult: except TimeoutError: msg = f"Quote request timed out after {self._quote_timeout}s" - logger.warning( - "Seller %s timed out on quote request", seller.agent_id - ) + logger.warning("Seller %s timed out on quote request", seller.agent_id) return SellerQuoteResult( seller_id=seller.agent_id, seller_url=seller_url, @@ -612,9 +602,7 @@ async def evaluate_and_rank( (best quote first). """ # Filter to successful quotes only - valid_results = [ - r for r in quote_results if r.quote is not None - ] + valid_results = [r for r in quote_results if r.quote is not None] if not valid_results: return [] @@ -630,8 +618,7 @@ async def evaluate_and_rank( # Apply max CPM filter (skip unpriced quotes — they have effective_cpm=None) if max_cpm is not None: ranked = [ - nq for nq in ranked - if nq.effective_cpm is not None and nq.effective_cpm <= max_cpm + nq for nq in ranked if nq.effective_cpm is not None and nq.effective_cpm <= max_cpm ] logger.info( @@ -704,8 +691,7 @@ async def select_and_book( # Skip if minimum spend exceeds remaining budget if nq.minimum_spend > 0 and nq.minimum_spend > remaining_budget: logger.info( - "Skipping quote %s: minimum spend %.2f exceeds " - "remaining budget %.2f", + "Skipping quote %s: minimum spend %.2f exceeds remaining budget %.2f", nq.quote_id, nq.minimum_spend, remaining_budget, @@ -714,21 +700,18 @@ async def select_and_book( seller_url = quote_seller_map.get(nq.quote_id) if seller_url is None: - logger.warning( - "No seller URL for quote %s, skipping", nq.quote_id + logger.warning("No seller URL for quote %s, skipping", nq.quote_id) + failed_bookings.append( + { + "quote_id": nq.quote_id, + "error": "No seller URL mapping found", + } ) - failed_bookings.append({ - "quote_id": nq.quote_id, - "error": "No seller URL mapping found", - }) continue try: client = self._deals_client_factory(seller_url) - if ( - self._capability_client is not None - and audience_plan is not None - ): + if self._capability_client is not None and audience_plan is not None: deal, deg_log = await self._book_with_preflight_then_retry( client=client, quote_id=nq.quote_id, @@ -765,7 +748,11 @@ async def select_and_book( }, ) - deal_cpm_display = f"{deal.pricing.final_cpm:.2f}" if deal.pricing.final_cpm is not None else "unavailable" # noqa: E501 + deal_cpm_display = ( + f"{deal.pricing.final_cpm:.2f}" + if deal.pricing.final_cpm is not None + else "unavailable" + ) # noqa: E501 logger.info( "Booked deal %s from seller %s (CPM: %s)", deal.deal_id, @@ -786,12 +773,14 @@ async def select_and_book( ) if nq.seller_id not in incompatible_sellers: incompatible_sellers.append(nq.seller_id) - failed_bookings.append({ - "quote_id": nq.quote_id, - "error": str(exc), - "error_code": "audience_plan_unsupported", - "seller_id": nq.seller_id, - }) + failed_bookings.append( + { + "quote_id": nq.quote_id, + "error": str(exc), + "error_code": "audience_plan_unsupported", + "seller_id": nq.seller_id, + } + ) except Exception as exc: # noqa: BLE001 - per-deal isolation; continue booking remaining deals logger.warning( @@ -799,10 +788,12 @@ async def select_and_book( nq.quote_id, exc, ) - failed_bookings.append({ - "quote_id": nq.quote_id, - "error": str(exc), - }) + failed_bookings.append( + { + "quote_id": nq.quote_id, + "error": str(exc), + } + ) return DealSelection( booked_deals=booked_deals, @@ -893,9 +884,7 @@ async def _book_with_audience_retry( # Synthesize what the seller doesn't support, run degradation, retry. try: caps = synthesize_capabilities_from_unsupported(unsupported) - degraded_plan, degradation_log = degrade_plan_for_seller( - audience_plan, caps - ) + degraded_plan, degradation_log = degrade_plan_for_seller(audience_plan, caps) except CannotFulfillPlan as cfp: # Degradation stripped the primary -- no usable plan to retry. # Record the would-be degradation log so the audit trail still @@ -915,8 +904,7 @@ async def _book_with_audience_retry( }, ) raise _SellerIncompatibleForCampaign( - f"Cannot reconcile audience_plan with seller {seller_id}: " - f"{cfp.reason}" + f"Cannot reconcile audience_plan with seller {seller_id}: {cfp.reason}" ) from cfp retry_request = DealBookingRequest( @@ -930,13 +918,11 @@ async def _book_with_audience_retry( # this campaign so the higher-level error path can route around # it. We do NOT auto-route here. raise _SellerIncompatibleForCampaign( - f"Seller {seller_id} rejected even the degraded plan: " - f"{retry_exc}" + f"Seller {seller_id} rejected even the degraded plan: {retry_exc}" ) from retry_exc logger.info( - "Booked deal from seller %s on retry after degrading " - "audience_plan (%d log entries)", + "Booked deal from seller %s on retry after degrading audience_plan (%d log entries)", seller_id, len(degradation_log), ) @@ -1024,8 +1010,8 @@ async def _book_with_preflight_then_retry( assert self._capability_client is not None # narrowed by select_and_book # ---- 1. capability discovery ---- - discovery: CapabilityDiscoveryResult = ( - await self._capability_client.discover_capabilities(seller_url) + discovery: CapabilityDiscoveryResult = await self._capability_client.discover_capabilities( + seller_url ) # Audit: every pre-flight call lands in the trail keyed by plan id. @@ -1053,8 +1039,7 @@ async def _book_with_preflight_then_retry( # Pre-flight stripped the primary entirely. No retry would help; # the seller advertises caps that can't carry this campaign. raise _SellerIncompatibleForCampaign( - f"Pre-flight: seller {seller_id} cannot fulfill plan: " - f"{cfp.reason}" + f"Pre-flight: seller {seller_id} cannot fulfill plan: {cfp.reason}" ) from cfp # ---- 3. apply strictness gate ---- @@ -1066,9 +1051,7 @@ async def _book_with_preflight_then_retry( quote_id, reason, ) - raise _SellerIncompatibleForCampaign( - f"Pre-flight strictness gate: {reason}" - ) + raise _SellerIncompatibleForCampaign(f"Pre-flight strictness gate: {reason}") # Audit: when the pre-flight produced any drops we record them. # The retry path emits its own degradation event keyed by the same @@ -1087,8 +1070,7 @@ async def _book_with_preflight_then_retry( }, ) logger.info( - "Pre-flight degraded plan for seller=%s quote=%s " - "(%d log entries)", + "Pre-flight degraded plan for seller=%s quote=%s (%d log entries)", seller_id, quote_id, len(preflight_log), diff --git a/src/ad_buyer/pipelines/audience_planner_reasoning.py b/src/ad_buyer/pipelines/audience_planner_reasoning.py index 49514df..cbd8d58 100644 --- a/src/ad_buyer/pipelines/audience_planner_reasoning.py +++ b/src/ad_buyer/pipelines/audience_planner_reasoning.py @@ -76,30 +76,68 @@ # Demographic / intent-driven tokens => prefer Standard primary. _DEMOGRAPHIC_TOKENS = { - "men", "women", "male", "female", - "kids", "children", "parent", "parents", - "millennials", "gen z", "gen x", "boomers", "seniors", - "household", "households", - "intender", "intenders", "in-market", "in market", - "demographic", "age", "income", + "men", + "women", + "male", + "female", + "kids", + "children", + "parent", + "parents", + "millennials", + "gen z", + "gen x", + "boomers", + "seniors", + "household", + "households", + "intender", + "intenders", + "in-market", + "in market", + "demographic", + "age", + "income", } # Content-adjacent tokens => prefer Contextual primary. _CONTEXTUAL_TOKENS = { - "content", "adjacent", "alongside", "next to", - "premium", "automotive content", "automotive blog", - "news", "sports", "lifestyle", "category", - "context", "contextual", + "content", + "adjacent", + "alongside", + "next to", + "premium", + "automotive content", + "automotive blog", + "news", + "sports", + "lifestyle", + "category", + "context", + "contextual", } # First-party / lookalike tokens => prefer Agentic primary. _AGENTIC_TOKENS = { - "our converters", "our customers", "our buyers", - "lookalike", "look-alike", "look alike", - "first-party", "first party", "1p data", "1p audience", - "previous campaign", "last campaign", "past campaign", - "crm", "advertiser data", "advertiser-supplied", - "high-ltv", "high ltv", "ltv lookalike", + "our converters", + "our customers", + "our buyers", + "lookalike", + "look-alike", + "look alike", + "first-party", + "first party", + "1p data", + "1p audience", + "previous campaign", + "last campaign", + "past campaign", + "crm", + "advertiser data", + "advertiser-supplied", + "high-ltv", + "high ltv", + "ltv lookalike", } @@ -162,12 +200,7 @@ class ClassificationResult: unmatched_tokens: list[str] = field(default_factory=list) def is_empty(self) -> bool: - return not ( - self.standard - or self.contextual - or self.agentic_seeds - or self.unmatched_tokens - ) + return not (self.standard or self.contextual or self.agentic_seeds or self.unmatched_tokens) @dataclass @@ -420,40 +453,55 @@ def pick_primary( if chosen == "standard" and classification.standard: cand = classification.standard[0] - return cand.to_ref(source="resolved", confidence=cand.score), "standard", ( - f"primary=Standard (id={cand.identifier} {cand.name!r}); " - "demographic / intent-driven brief" + return ( + cand.to_ref(source="resolved", confidence=cand.score), + "standard", + ( + f"primary=Standard (id={cand.identifier} {cand.name!r}); " + "demographic / intent-driven brief" + ), ) if chosen == "contextual" and classification.contextual: cand = classification.contextual[0] - return cand.to_ref(source="resolved", confidence=cand.score), "contextual", ( - f"primary=Contextual (id={cand.identifier} {cand.name!r}); " - "content-adjacent brief" + return ( + cand.to_ref(source="resolved", confidence=cand.score), + "contextual", + (f"primary=Contextual (id={cand.identifier} {cand.name!r}); content-adjacent brief"), ) if chosen == "agentic" and classification.agentic_seeds: # The caller will mint via EmbeddingMintTool; we return None for # the ref but signal the choice via the type. - return None, "agentic", ( - f"primary=Agentic (seed={classification.agentic_seeds[0]!r}); " - "first-party / lookalike-driven brief" + return ( + None, + "agentic", + ( + f"primary=Agentic (seed={classification.agentic_seeds[0]!r}); " + "first-party / lookalike-driven brief" + ), ) # Fallbacks: pick whatever we have. if classification.standard: cand = classification.standard[0] - return cand.to_ref(source="resolved", confidence=cand.score), "standard", ( - f"primary=Standard (id={cand.identifier}, fallback)" + return ( + cand.to_ref(source="resolved", confidence=cand.score), + "standard", + (f"primary=Standard (id={cand.identifier}, fallback)"), ) if classification.contextual: cand = classification.contextual[0] - return cand.to_ref(source="resolved", confidence=cand.score), "contextual", ( - f"primary=Contextual (id={cand.identifier}, fallback)" + return ( + cand.to_ref(source="resolved", confidence=cand.score), + "contextual", + (f"primary=Contextual (id={cand.identifier}, fallback)"), ) if classification.agentic_seeds: - return None, "agentic", ( - f"primary=Agentic (seed={classification.agentic_seeds[0]!r}, fallback)" + return ( + None, + "agentic", + (f"primary=Agentic (seed={classification.agentic_seeds[0]!r}, fallback)"), ) return None, "none", "no usable audience signals found" @@ -584,7 +632,7 @@ def add_extensions( description=seed, ) # We need source=inferred (mint tool emits source=inferred - # already, but be defensive in case that ever changes). + # already, but be defensive in case that ever changes). if ref.source != "inferred": ref = ref.model_copy(update={"source": "inferred"}) key = (ref.type, ref.identifier) @@ -616,9 +664,7 @@ def add_extensions( break if not refs: - rationale.append( - "no extensions added -- no broader candidates available" - ) + rationale.append("no extensions added -- no broader candidates available") return refs, rationale @@ -685,18 +731,13 @@ def validate_plan( "primary_type": plan_refs.get("primary_type"), } coverage_tool._run(targeting=targeting) - rationale.append( - "validation: discovery + coverage estimates ran successfully" - ) + rationale.append("validation: discovery + coverage estimates ran successfully") except Exception as exc: # noqa: BLE001 - tolerate tool flakiness rationale.append( - f"validation: coverage tool raised {type(exc).__name__}; " - "reach estimate skipped" + f"validation: coverage tool raised {type(exc).__name__}; reach estimate skipped" ) elif discovery_available: - rationale.append( - "validation: discovery ran; coverage tool not provided" - ) + rationale.append("validation: discovery ran; coverage tool not provided") return discovery_available, rationale @@ -761,8 +802,7 @@ def run_audience_reasoning( brief_notes = getattr(brief, "notes", None) if brief_audience is None and not brief_description and not brief_notes: rationale_lines.append( - "no target_audience and no advertiser context on brief; " - "needs human review" + "no target_audience and no advertiser context on brief; needs human review" ) return ReasoningResult( plan=None, @@ -785,8 +825,7 @@ def run_audience_reasoning( primary_type = primary_ref.type if primary_ref.source == "explicit": rationale_lines.append( - f"primary=preserved (explicit {primary_type} " - f"{primary_ref.identifier})" + f"primary=preserved (explicit {primary_type} {primary_ref.identifier})" ) else: rationale_lines.append( @@ -797,8 +836,7 @@ def run_audience_reasoning( # No brief plan at all -- compose from classification. if classification.is_empty(): rationale_lines.append( - "no audience signals classified from advertiser context; " - "needs human review" + "no audience signals classified from advertiser context; needs human review" ) return ReasoningResult( plan=None, @@ -821,8 +859,7 @@ def run_audience_reasoning( ) primary_ref = minted rationale_lines.append( - f"primary minted from seed {seed!r} -> " - f"{minted.identifier[:32]}..." + f"primary minted from seed {seed!r} -> {minted.identifier[:32]}..." ) except Exception as exc: # noqa: BLE001 logger.warning( @@ -838,22 +875,18 @@ def run_audience_reasoning( primary_ref = cand.to_ref(source="resolved", confidence=cand.score) primary_type = "standard" rationale_lines.append( - f"primary=Standard fallback {cand.identifier} " - "(agentic mint unavailable)" + f"primary=Standard fallback {cand.identifier} (agentic mint unavailable)" ) elif classification.contextual: cand = classification.contextual[0] primary_ref = cand.to_ref(source="resolved", confidence=cand.score) primary_type = "contextual" rationale_lines.append( - f"primary=Contextual fallback {cand.identifier} " - "(agentic mint unavailable)" + f"primary=Contextual fallback {cand.identifier} (agentic mint unavailable)" ) if primary_ref is None: - rationale_lines.append( - "could not compose primary ref; needs human review" - ) + rationale_lines.append("could not compose primary ref; needs human review") return ReasoningResult( plan=None, rationale_lines=rationale_lines, @@ -879,9 +912,7 @@ def run_audience_reasoning( # Phases 3 and 4: orient by KPI, then enrich. orientation = _kpi_orientation(brief) - rationale_lines.append( - f"KPI orientation: {orientation} (objective={brief.objective.value})" - ) + rationale_lines.append(f"KPI orientation: {orientation} (objective={brief.objective.value})") inferred_constraints: list[AudienceRef] = [] inferred_extensions: list[AudienceRef] = [] diff --git a/src/ad_buyer/pipelines/audience_planner_step.py b/src/ad_buyer/pipelines/audience_planner_step.py index cc484a3..d6ea187 100644 --- a/src/ad_buyer/pipelines/audience_planner_step.py +++ b/src/ad_buyer/pipelines/audience_planner_step.py @@ -177,8 +177,7 @@ def run_audience_planner_step( if reasoning.plan is None: logger.warning( - "audience_planner_step: reasoning produced no plan; " - "rationale=%s", + "audience_planner_step: reasoning produced no plan; rationale=%s", " | ".join(reasoning.rationale_lines), ) else: diff --git a/src/ad_buyer/pipelines/campaign_pipeline.py b/src/ad_buyer/pipelines/campaign_pipeline.py index 682fbc7..97691b4 100644 --- a/src/ad_buyer/pipelines/campaign_pipeline.py +++ b/src/ad_buyer/pipelines/campaign_pipeline.py @@ -189,17 +189,13 @@ async def _emit( ) await self._event_bus.publish(event) except Exception as exc: # noqa: BLE001 - event emission is fail-open by design - logger.warning( - "Failed to emit event %s: %s", event_type, exc - ) + logger.warning("Failed to emit event %s: %s", event_type, exc) # ------------------------------------------------------------------ # Stage 1: Ingest brief # ------------------------------------------------------------------ - async def ingest_brief( - self, brief_input: str | dict[str, Any] - ) -> str: + async def ingest_brief(self, brief_input: str | dict[str, Any]) -> str: """Parse and validate a campaign brief, create campaign in DRAFT. Args: @@ -229,9 +225,7 @@ async def ingest_brief( if brief.target_audience is None: target_audience_json = json.dumps(None) else: - target_audience_json = json.dumps( - brief.target_audience.model_dump(mode="json") - ) + target_audience_json = json.dumps(brief.target_audience.model_dump(mode="json")) store_brief = { "advertiser_id": brief.advertiser_id, "campaign_name": brief.campaign_name, @@ -239,9 +233,7 @@ async def ingest_brief( "currency": brief.currency, "flight_start": brief.flight_start.isoformat(), "flight_end": brief.flight_end.isoformat(), - "channels": json.dumps( - [ch.model_dump(mode="json") for ch in brief.channels] - ), + "channels": json.dumps([ch.model_dump(mode="json") for ch in brief.channels]), "target_audience": target_audience_json, } @@ -251,13 +243,9 @@ async def ingest_brief( [g.model_dump(mode="json") for g in brief.target_geo] ) if brief.kpis: - store_brief["kpis"] = json.dumps( - [k.model_dump(mode="json") for k in brief.kpis] - ) + store_brief["kpis"] = json.dumps([k.model_dump(mode="json") for k in brief.kpis]) if brief.brand_safety: - store_brief["brand_safety"] = json.dumps( - brief.brand_safety.model_dump(mode="json") - ) + store_brief["brand_safety"] = json.dumps(brief.brand_safety.model_dump(mode="json")) if brief.approval_config: store_brief["approval_config"] = json.dumps( brief.approval_config.model_dump(mode="json") @@ -324,14 +312,16 @@ async def plan_campaign(self, campaign_id: str) -> CampaignPlan: deal_types = _CHANNEL_DEAL_TYPES.get(ch.channel, ["PD"]) budget = round(brief.total_budget * ch.budget_pct / 100.0, 2) - channel_plans.append(ChannelPlan( - channel=ch.channel, - budget=budget, - budget_pct=ch.budget_pct, - media_type=media_type, - deal_types=deal_types, - format_prefs=ch.format_prefs, - )) + channel_plans.append( + ChannelPlan( + channel=ch.channel, + budget=budget, + budget_pct=ch.budget_pct, + media_type=media_type, + deal_types=deal_types, + format_prefs=ch.format_prefs, + ) + ) # Run the Audience Planner step BEFORE building the CampaignPlan # so the resolved plan rides on `target_audience` from this point @@ -385,9 +375,7 @@ async def plan_campaign(self, campaign_id: str) -> CampaignPlan: # Stage 3: Execute booking # ------------------------------------------------------------------ - async def execute_booking( - self, campaign_id: str - ) -> dict[str, OrchestrationResult]: + async def execute_booking(self, campaign_id: str) -> dict[str, OrchestrationResult]: """Transition to BOOKING and orchestrate deals for each channel. For each channel in the plan, invokes MultiSellerOrchestrator @@ -418,10 +406,7 @@ async def execute_booking( # Get the cached plan plan = self._plans.get(campaign_id) if plan is None: - raise KeyError( - f"No plan found for campaign {campaign_id}. " - "Call plan_campaign() first." - ) + raise KeyError(f"No plan found for campaign {campaign_id}. Call plan_campaign() first.") # Get the brief for excluded_sellers and other params brief = self._briefs.get(campaign_id) @@ -442,9 +427,7 @@ async def execute_booking( deal_types=cp.deal_types, excluded_sellers=excluded_sellers, max_cpm=( - brief.deal_preferences.max_cpm - if brief and brief.deal_preferences - else None + brief.deal_preferences.max_cpm if brief and brief.deal_preferences else None ), audience_plan=plan.target_audience, ) @@ -477,9 +460,7 @@ async def execute_booking( ) except Exception as exc: # noqa: BLE001 - per-channel isolation; one failure must not abort pipeline - logger.warning( - "Channel %s booking failed: %s", channel_key, exc - ) + logger.warning("Channel %s booking failed: %s", channel_key, exc) # Record empty result for failed channels rather than # aborting the entire pipeline results[channel_key] = OrchestrationResult( @@ -498,12 +479,8 @@ async def execute_booking( self._booking_results[campaign_id] = results # Emit booking completed event - total_deals = sum( - len(r.selection.booked_deals) for r in results.values() - ) - total_spend = sum( - r.selection.total_spend for r in results.values() - ) + total_deals = sum(len(r.selection.booked_deals) for r in results.values()) + total_spend = sum(r.selection.total_spend for r in results.values()) await self._emit( EventType.CAMPAIGN_BOOKING_COMPLETED, @@ -569,9 +546,7 @@ async def finalize(self, campaign_id: str) -> None: # End-to-end: run # ------------------------------------------------------------------ - async def run( - self, brief_input: str | dict[str, Any] - ) -> dict[str, Any]: + async def run(self, brief_input: str | dict[str, Any]) -> dict[str, Any]: """Run the complete pipeline: ingest -> plan -> book -> finalize. Args: @@ -601,9 +576,7 @@ async def run( for ch_key, result in booking_results.items(): channels_summary[ch_key] = { "deals_booked": len(result.selection.booked_deals), - "deal_ids": [ - d.deal_id for d in result.selection.booked_deals - ], + "deal_ids": [d.deal_id for d in result.selection.booked_deals], "total_spend": result.selection.total_spend, "remaining_budget": result.selection.remaining_budget, "sellers_discovered": len(result.discovered_sellers), @@ -620,8 +593,7 @@ async def run( } logger.info( - "Pipeline complete: campaign %s is READY " - "(%d channels, %d total deals)", + "Pipeline complete: campaign %s is READY (%d channels, %d total deals)", campaign_id, len(channels_summary), sum(ch["deals_booked"] for ch in channels_summary.values()), @@ -632,9 +604,7 @@ async def run( # Public accessors (Audience Planner introspection) # ------------------------------------------------------------------ - def get_audience_planner_result( - self, campaign_id: str - ) -> AudiencePlannerResult | None: + def get_audience_planner_result(self, campaign_id: str) -> AudiencePlannerResult | None: """Return the Audience Planner output for `campaign_id`, if any. Populated by `plan_campaign`. Returns None when planning has not @@ -681,22 +651,22 @@ def _reconstruct_brief(self, campaign: dict[str, Any]) -> CampaignBrief: source_context="campaign_pipeline._reconstruct_brief", ) - return parse_campaign_brief({ - "advertiser_id": campaign["advertiser_id"], - "campaign_name": campaign["campaign_name"], - "objective": "AWARENESS", # default when not stored - "total_budget": campaign["total_budget"], - "currency": campaign.get("currency", "USD"), - "flight_start": campaign["flight_start"], - "flight_end": campaign["flight_end"], - "channels": channels_raw or [], - "target_audience": audience_raw, - }) + return parse_campaign_brief( + { + "advertiser_id": campaign["advertiser_id"], + "campaign_name": campaign["campaign_name"], + "objective": "AWARENESS", # default when not stored + "total_budget": campaign["total_budget"], + "currency": campaign.get("currency", "USD"), + "flight_start": campaign["flight_start"], + "flight_end": campaign["flight_end"], + "channels": channels_raw or [], + "target_audience": audience_raw, + } + ) @staticmethod - def _estimate_impressions( - budget: float, assumed_cpm: float | None = None - ) -> int: + def _estimate_impressions(budget: float, assumed_cpm: float | None = None) -> int: """Estimate impression count from budget and CPM. When no CPM is available (assumed_cpm is None), returns 0 diff --git a/src/ad_buyer/services/setup_wizard.py b/src/ad_buyer/services/setup_wizard.py index c73412a..47bdd66 100644 --- a/src/ad_buyer/services/setup_wizard.py +++ b/src/ad_buyer/services/setup_wizard.py @@ -163,9 +163,7 @@ def progress_pct(self) -> float: @property def current_phase(self) -> WizardPhase: """Current phase based on developer step completion.""" - developer_steps = [ - s for s in self.steps if s.phase == WizardPhase.DEVELOPER - ] + developer_steps = [s for s in self.steps if s.phase == WizardPhase.DEVELOPER] all_dev_done = all( s.status in ( @@ -378,9 +376,7 @@ def get_step(self, step_number: int) -> WizardStep: ValueError: If step_number is not 1-8. """ if step_number < 1 or step_number > 8: - raise ValueError( - f"Invalid step number: {step_number}. Must be 1-8." - ) + raise ValueError(f"Invalid step number: {step_number}. Must be 1-8.") return self._steps[step_number - 1] def get_state(self) -> WizardState: @@ -389,9 +385,7 @@ def get_state(self) -> WizardState: # -- Step operations ---------------------------------------------------- - def complete_step( - self, step_number: int, config: dict[str, Any] - ) -> WizardStep: + def complete_step(self, step_number: int, config: dict[str, Any]) -> WizardStep: """Mark a step as completed with the given configuration. Args: @@ -426,9 +420,7 @@ def skip_step(self, step_number: int) -> WizardStep: """ step = self.get_step(step_number) if step_number == 8: - raise ValueError( - "Step 8 (Review & Launch) cannot be skipped." - ) + raise ValueError("Step 8 (Review & Launch) cannot be skipped.") step.status = WizardStepStatus.SKIPPED step.config = dict(step.defaults) step.completed_at = datetime.now(UTC).isoformat() diff --git a/src/ad_buyer/storage/audience_audit_log.py b/src/ad_buyer/storage/audience_audit_log.py index 3387d6f..0a95737 100644 --- a/src/ad_buyer/storage/audience_audit_log.py +++ b/src/ad_buyer/storage/audience_audit_log.py @@ -220,9 +220,7 @@ def log_event( """ if not plan_id: - logger.warning( - "audience_audit_log.log_event called with empty plan_id; skipping" - ) + logger.warning("audience_audit_log.log_event called with empty plan_id; skipping") return if event_type not in KNOWN_EVENT_TYPES: @@ -260,8 +258,7 @@ def log_event( conn.commit() except sqlite3.Error as exc: # noqa: BLE001 -- audit log is fail-open logger.warning( - "audience_audit_log.log_event: failed to insert " - "plan_id=%s event_type=%s: %s", + "audience_audit_log.log_event: failed to insert plan_id=%s event_type=%s: %s", plan_id, event_type, exc, diff --git a/src/ad_buyer/storage/base.py b/src/ad_buyer/storage/base.py index 27863f1..004a90a 100644 --- a/src/ad_buyer/storage/base.py +++ b/src/ad_buyer/storage/base.py @@ -139,9 +139,7 @@ async def get_session(self, session_id: str) -> dict | None: """Get a session by ID.""" return await self.get(f"session:{session_id}") - async def set_session( - self, session_id: str, data: dict, ttl: int | None = None - ) -> None: + async def set_session(self, session_id: str, data: dict, ttl: int | None = None) -> None: """Store a session with optional TTL.""" await self.set(f"session:{session_id}", data, ttl=ttl) @@ -184,7 +182,10 @@ async def list_conversions(self, filters: dict | None = None) -> list[dict]: if filters: if "deal_id" in filters and conversion.get("deal_id") != filters["deal_id"]: continue - if "campaign_id" in filters and conversion.get("campaign_id") != filters["campaign_id"]: # noqa: E501 + if ( + "campaign_id" in filters + and conversion.get("campaign_id") != filters["campaign_id"] + ): # noqa: E501 continue conversions.append(conversion) return conversions @@ -210,7 +211,10 @@ async def list_optimization_decisions(self, filters: dict | None = None) -> list if decision is None: continue if filters: - if "campaign_id" in filters and decision.get("campaign_id") != filters["campaign_id"]: # noqa: E501 + if ( + "campaign_id" in filters + and decision.get("campaign_id") != filters["campaign_id"] + ): # noqa: E501 continue decisions.append(decision) return decisions @@ -238,7 +242,10 @@ async def list_experiments(self, filters: dict | None = None) -> list[dict]: if experiment is None: continue if filters: - if "campaign_id" in filters and experiment.get("campaign_id") != filters["campaign_id"]: # noqa: E501 + if ( + "campaign_id" in filters + and experiment.get("campaign_id") != filters["campaign_id"] + ): # noqa: E501 continue if "status" in filters and experiment.get("status") != filters["status"]: continue diff --git a/src/ad_buyer/storage/deal_store.py b/src/ad_buyer/storage/deal_store.py index 272f0f5..a8e48e9 100644 --- a/src/ad_buyer/storage/deal_store.py +++ b/src/ad_buyer/storage/deal_store.py @@ -1545,7 +1545,24 @@ def delete_creative_asset(self, asset_id: str) -> bool: # Deal Templates (v5, Strategic Plan Section 6.3) # ------------------------------------------------------------------ - def save_deal_template(self, *, template_id: str | None = None, name: str, deal_type_pref: str | None = None, inventory_types: str | None = None, preferred_publishers: str | None = None, excluded_publishers: str | None = None, targeting_defaults: str | None = None, default_price: float | None = None, max_cpm: float | None = None, min_impressions: int | None = None, default_flight_days: int | None = None, supply_path_prefs: str | None = None, advertiser_id: str | None = None, agency_id: str | None = None) -> str: # noqa: E501 + def save_deal_template( + self, + *, + template_id: str | None = None, + name: str, + deal_type_pref: str | None = None, + inventory_types: str | None = None, + preferred_publishers: str | None = None, + excluded_publishers: str | None = None, + targeting_defaults: str | None = None, + default_price: float | None = None, + max_cpm: float | None = None, + min_impressions: int | None = None, + default_flight_days: int | None = None, + supply_path_prefs: str | None = None, + advertiser_id: str | None = None, + agency_id: str | None = None, + ) -> str: # noqa: E501 """Insert a new deal template. Returns the template ID.""" if template_id is None: template_id = str(uuid.uuid4()) @@ -1560,7 +1577,24 @@ def save_deal_template(self, *, template_id: str | None = None, name: str, deal_ supply_path_prefs, advertiser_id, agency_id, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (template_id, name, deal_type_pref, inventory_types, preferred_publishers, excluded_publishers, targeting_defaults, default_price, max_cpm, min_impressions, default_flight_days, supply_path_prefs, advertiser_id, agency_id, now, now), # noqa: E501 + ( + template_id, + name, + deal_type_pref, + inventory_types, + preferred_publishers, + excluded_publishers, + targeting_defaults, + default_price, + max_cpm, + min_impressions, + default_flight_days, + supply_path_prefs, + advertiser_id, + agency_id, + now, + now, + ), # noqa: E501 ) self._conn.commit() logger.info("Saved deal template %s: %s", template_id, name) @@ -1573,7 +1607,13 @@ def get_deal_template(self, template_id: str) -> dict[str, Any] | None: row = cursor.fetchone() return dict(row) if row else None - def list_deal_templates(self, *, advertiser_id: str | None = None, deal_type_pref: str | None = None, limit: int = 100) -> list[dict[str, Any]]: # noqa: E501 + def list_deal_templates( + self, + *, + advertiser_id: str | None = None, + deal_type_pref: str | None = None, + limit: int = 100, + ) -> list[dict[str, Any]]: # noqa: E501 """List deal templates with optional filters.""" conditions: list[str] = [] params: list[Any] = [] @@ -1595,7 +1635,21 @@ def update_deal_template(self, template_id: str, **kwargs: Any) -> bool: """Update fields on an existing deal template.""" if not kwargs: return False - allowed = {"name", "deal_type_pref", "inventory_types", "preferred_publishers", "excluded_publishers", "targeting_defaults", "default_price", "max_cpm", "min_impressions", "default_flight_days", "supply_path_prefs", "advertiser_id", "agency_id"} # noqa: E501 + allowed = { + "name", + "deal_type_pref", + "inventory_types", + "preferred_publishers", + "excluded_publishers", + "targeting_defaults", + "default_price", + "max_cpm", + "min_impressions", + "default_flight_days", + "supply_path_prefs", + "advertiser_id", + "agency_id", + } # noqa: E501 updates = {k: v for k, v in kwargs.items() if k in allowed} if not updates: return False @@ -1603,7 +1657,9 @@ def update_deal_template(self, template_id: str, **kwargs: Any) -> bool: set_clause = ", ".join(f"{col} = ?" for col in updates) values = list(updates.values()) + [template_id] with self._lock: - cursor = self._conn.execute(f"UPDATE deal_templates SET {set_clause} WHERE id = ?", values) # noqa: E501 + cursor = self._conn.execute( + f"UPDATE deal_templates SET {set_clause} WHERE id = ?", values + ) # noqa: E501 self._conn.commit() return cursor.rowcount > 0 @@ -1618,7 +1674,19 @@ def delete_deal_template(self, template_id: str) -> bool: # Supply Path Templates (v5, Strategic Plan Section 6.4) # ------------------------------------------------------------------ - def save_supply_path_template(self, *, template_id: str | None = None, name: str, scoring_weights: str | None = None, max_reseller_hops: int | None = None, require_sellers_json: int | None = None, preferred_ssps: str | None = None, blocked_ssps: str | None = None, preferred_curators: str | None = None, rules: str | None = None) -> str: # noqa: E501 + def save_supply_path_template( + self, + *, + template_id: str | None = None, + name: str, + scoring_weights: str | None = None, + max_reseller_hops: int | None = None, + require_sellers_json: int | None = None, + preferred_ssps: str | None = None, + blocked_ssps: str | None = None, + preferred_curators: str | None = None, + rules: str | None = None, + ) -> str: # noqa: E501 """Insert a new supply path template. Returns the template ID.""" if template_id is None: template_id = str(uuid.uuid4()) @@ -1630,7 +1698,19 @@ def save_supply_path_template(self, *, template_id: str | None = None, name: str require_sellers_json, preferred_ssps, blocked_ssps, preferred_curators, rules, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (template_id, name, scoring_weights, max_reseller_hops, require_sellers_json, preferred_ssps, blocked_ssps, preferred_curators, rules, now, now), # noqa: E501 + ( + template_id, + name, + scoring_weights, + max_reseller_hops, + require_sellers_json, + preferred_ssps, + blocked_ssps, + preferred_curators, + rules, + now, + now, + ), # noqa: E501 ) self._conn.commit() logger.info("Saved supply path template %s: %s", template_id, name) @@ -1639,14 +1719,18 @@ def save_supply_path_template(self, *, template_id: str | None = None, name: str def get_supply_path_template(self, template_id: str) -> dict[str, Any] | None: """Retrieve a supply path template by ID.""" with self._lock: - cursor = self._conn.execute("SELECT * FROM supply_path_templates WHERE id = ?", (template_id,)) # noqa: E501 + cursor = self._conn.execute( + "SELECT * FROM supply_path_templates WHERE id = ?", (template_id,) + ) # noqa: E501 row = cursor.fetchone() return dict(row) if row else None def list_supply_path_templates(self, *, limit: int = 100) -> list[dict[str, Any]]: """List supply path templates.""" with self._lock: - cursor = self._conn.execute("SELECT * FROM supply_path_templates ORDER BY created_at DESC LIMIT ?", (limit,)) # noqa: E501 + cursor = self._conn.execute( + "SELECT * FROM supply_path_templates ORDER BY created_at DESC LIMIT ?", (limit,) + ) # noqa: E501 rows = cursor.fetchall() return [dict(row) for row in rows] @@ -1654,7 +1738,16 @@ def update_supply_path_template(self, template_id: str, **kwargs: Any) -> bool: """Update fields on an existing supply path template.""" if not kwargs: return False - allowed = {"name", "scoring_weights", "max_reseller_hops", "require_sellers_json", "preferred_ssps", "blocked_ssps", "preferred_curators", "rules"} # noqa: E501 + allowed = { + "name", + "scoring_weights", + "max_reseller_hops", + "require_sellers_json", + "preferred_ssps", + "blocked_ssps", + "preferred_curators", + "rules", + } # noqa: E501 updates = {k: v for k, v in kwargs.items() if k in allowed} if not updates: return False @@ -1662,14 +1755,18 @@ def update_supply_path_template(self, template_id: str, **kwargs: Any) -> bool: set_clause = ", ".join(f"{col} = ?" for col in updates) values = list(updates.values()) + [template_id] with self._lock: - cursor = self._conn.execute(f"UPDATE supply_path_templates SET {set_clause} WHERE id = ?", values) # noqa: E501 + cursor = self._conn.execute( + f"UPDATE supply_path_templates SET {set_clause} WHERE id = ?", values + ) # noqa: E501 self._conn.commit() return cursor.rowcount > 0 def delete_supply_path_template(self, template_id: str) -> bool: """Delete a supply path template by ID.""" with self._lock: - cursor = self._conn.execute("DELETE FROM supply_path_templates WHERE id = ?", (template_id,)) # noqa: E501 + cursor = self._conn.execute( + "DELETE FROM supply_path_templates WHERE id = ?", (template_id,) + ) # noqa: E501 self._conn.commit() return cursor.rowcount > 0 diff --git a/src/ad_buyer/storage/factory.py b/src/ad_buyer/storage/factory.py index 5ff1493..7911278 100644 --- a/src/ad_buyer/storage/factory.py +++ b/src/ad_buyer/storage/factory.py @@ -7,7 +7,6 @@ Supports SQLite (default), Redis, and Hybrid backends. """ - from ad_buyer.storage.base import StorageBackend from ad_buyer.storage.sqlite_backend import SQLiteBackend diff --git a/src/ad_buyer/storage/order_store.py b/src/ad_buyer/storage/order_store.py index cb52de5..93e5045 100644 --- a/src/ad_buyer/storage/order_store.py +++ b/src/ad_buyer/storage/order_store.py @@ -71,9 +71,9 @@ def __init__(self, database_url: str) -> None: def _parse_url(url: str) -> str: """Extract the file path from a sqlite:/// URL.""" if url.startswith("sqlite:///"): - return url[len("sqlite:///"):] + return url[len("sqlite:///") :] if url.startswith("sqlite://"): - path = url[len("sqlite://"):] + path = url[len("sqlite://") :] return path if path else ":memory:" return url @@ -149,9 +149,7 @@ def get_order(self, order_id: str) -> dict[str, Any] | None: """ key = self._make_key(order_id) with self._lock: - cursor = self._conn.execute( - "SELECT data FROM orders WHERE key = ?", (key,) - ) + cursor = self._conn.execute("SELECT data FROM orders WHERE key = ?", (key,)) row = cursor.fetchone() if row is None: return None diff --git a/src/ad_buyer/storage/redis_backend.py b/src/ad_buyer/storage/redis_backend.py index 8f1a565..05360e7 100644 --- a/src/ad_buyer/storage/redis_backend.py +++ b/src/ad_buyer/storage/redis_backend.py @@ -55,7 +55,7 @@ def _prefixed_key(self, key: str) -> str: def _unprefixed_key(self, key: str) -> str: """Remove prefix from key.""" if key.startswith(self.key_prefix): - return key[len(self.key_prefix):] + return key[len(self.key_prefix) :] return key async def connect(self) -> None: diff --git a/src/ad_buyer/storage/sqlite_backend.py b/src/ad_buyer/storage/sqlite_backend.py index 1fd746b..1682db4 100644 --- a/src/ad_buyer/storage/sqlite_backend.py +++ b/src/ad_buyer/storage/sqlite_backend.py @@ -32,9 +32,9 @@ def __init__(self, database_url: str): database_url: SQLite connection string (e.g., sqlite:///./ad_buyer.db) """ if database_url.startswith("sqlite:///"): - self.db_path = database_url[len("sqlite:///"):] + self.db_path = database_url[len("sqlite:///") :] elif database_url.startswith("sqlite://"): - self.db_path = database_url[len("sqlite://"):] + self.db_path = database_url[len("sqlite://") :] else: self.db_path = database_url @@ -123,9 +123,7 @@ async def delete(self, key: str) -> bool: if not self._connection: raise RuntimeError("Storage not connected. Call connect() first.") - async with self._connection.execute( - "DELETE FROM kv_store WHERE key = ?", (key,) - ) as cursor: + async with self._connection.execute("DELETE FROM kv_store WHERE key = ?", (key,)) as cursor: await self._connection.commit() return cursor.rowcount > 0 diff --git a/src/ad_buyer/sync/order_sync.py b/src/ad_buyer/sync/order_sync.py index e73fe88..e6dc21f 100644 --- a/src/ad_buyer/sync/order_sync.py +++ b/src/ad_buyer/sync/order_sync.py @@ -55,9 +55,7 @@ async def sync_order(self, order_id: str) -> bool: """ seller_data = await self._client.get_order_status(order_id) if seller_data is None: - logger.info( - "Sync skipped for order %s: seller returned None", order_id - ) + logger.info("Sync skipped for order %s: seller returned None", order_id) return False # Update local store with seller's data diff --git a/src/ad_buyer/tools/audience/audience_discovery.py b/src/ad_buyer/tools/audience/audience_discovery.py index 8d35b97..a0adf48 100644 --- a/src/ad_buyer/tools/audience/audience_discovery.py +++ b/src/ad_buyer/tools/audience/audience_discovery.py @@ -3,7 +3,6 @@ """Audience Discovery Tool - Discover available audience signals from sellers.""" - import httpx from crewai.tools import BaseTool from pydantic import BaseModel, Field @@ -16,9 +15,7 @@ class AudienceDiscoveryInput(BaseModel): """Input schema for audience discovery tool.""" - seller_endpoint: str = Field( - description="Seller's capability discovery endpoint URL" - ) + seller_endpoint: str = Field(description="Seller's capability discovery endpoint URL") signal_types: list[str] | None = Field( default=None, description="Filter by signal types: identity, contextual, reinforcement", @@ -52,9 +49,7 @@ def _run( min_coverage: float | None = None, ) -> str: """Execute the audience discovery.""" - return run_async( - self._arun(seller_endpoint, signal_types, min_coverage) - ) + return run_async(self._arun(seller_endpoint, signal_types, min_coverage)) async def _arun( self, @@ -86,17 +81,11 @@ async def _arun( pass if valid_types: - capabilities = [ - cap for cap in capabilities - if cap.signal_type in valid_types - ] + capabilities = [cap for cap in capabilities if cap.signal_type in valid_types] # Filter by minimum coverage if min_coverage is not None: - capabilities = [ - cap for cap in capabilities - if cap.coverage_percentage >= min_coverage - ] + capabilities = [cap for cap in capabilities if cap.coverage_percentage >= min_coverage] return self._format_results(capabilities) diff --git a/src/ad_buyer/tools/audience/audience_matching.py b/src/ad_buyer/tools/audience/audience_matching.py index 58d5c89..4facd1f 100644 --- a/src/ad_buyer/tools/audience/audience_matching.py +++ b/src/ad_buyer/tools/audience/audience_matching.py @@ -17,9 +17,7 @@ class AudienceMatchingInput(BaseModel): """Input schema for audience matching tool.""" - seller_endpoint: str = Field( - description="Seller's UCP exchange endpoint URL" - ) + seller_endpoint: str = Field(description="Seller's UCP exchange endpoint URL") demographics: dict[str, Any] | None = Field( default=None, description="Demographic targeting (age, gender, income, etc.)", @@ -166,10 +164,12 @@ def _get_mock_validation(self, requirements: dict[str, Any]) -> Any: alternatives = [] if has_behaviors: gaps.append("behavioral_targeting") - alternatives.append({ - "gap": "behavioral_targeting", - "suggestion": "Use contextual signals with frequency capping as proxy", - }) + alternatives.append( + { + "gap": "behavioral_targeting", + "suggestion": "Use contextual signals with frequency capping as proxy", + } + ) return AudienceValidationResult( validation_status=status, @@ -177,7 +177,8 @@ def _get_mock_validation(self, requirements: dict[str, Any]) -> Any: matched_capabilities=[ "cap_ctx_categories", "cap_ctx_keywords", - ] + (["cap_demo_age", "cap_demo_gender"] if has_demographics else []), + ] + + (["cap_demo_age", "cap_demo_gender"] if has_demographics else []), gaps=gaps, alternatives=alternatives, ucp_similarity_score=score, diff --git a/src/ad_buyer/tools/audience/embedding_mint.py b/src/ad_buyer/tools/audience/embedding_mint.py index 6365243..aab3ac0 100644 --- a/src/ad_buyer/tools/audience/embedding_mint.py +++ b/src/ad_buyer/tools/audience/embedding_mint.py @@ -69,9 +69,7 @@ def embedding_mode_label() -> str: # (keeps test fixtures that patch settings simple). from ...config.settings import settings - return _EMBEDDING_MODE_LABELS.get( - settings.embedding_mode, EMBEDDING_MODE_LABEL_MOCK - ) + return _EMBEDDING_MODE_LABELS.get(settings.embedding_mode, EMBEDDING_MODE_LABEL_MOCK) class EmbeddingMintInput(BaseModel): @@ -92,16 +90,12 @@ class EmbeddingMintInput(BaseModel): ) jurisdiction: str = Field( default="GLOBAL", - description=( - "Jurisdiction code for the consent context " - "(e.g. 'US', 'EU', 'GLOBAL')." - ), + description=("Jurisdiction code for the consent context (e.g. 'US', 'EU', 'GLOBAL')."), ) consent_framework: str = Field( default="advertiser-1p", description=( - "Consent framework backing the mint: 'IAB-TCFv2', 'GPP', " - "'advertiser-1p', or 'none'." + "Consent framework backing the mint: 'IAB-TCFv2', 'GPP', 'advertiser-1p', or 'none'." ), ) diff --git a/src/ad_buyer/tools/audience/taxonomy_lookup.py b/src/ad_buyer/tools/audience/taxonomy_lookup.py index 5faef86..ffaa81c 100644 --- a/src/ad_buyer/tools/audience/taxonomy_lookup.py +++ b/src/ad_buyer/tools/audience/taxonomy_lookup.py @@ -8,7 +8,6 @@ step 1) to map raw `target_audience` strings into typed `AudienceRef`s. """ - from crewai.tools import BaseTool from pydantic import BaseModel, Field diff --git a/src/ad_buyer/tools/buyer_deals/discover_inventory.py b/src/ad_buyer/tools/buyer_deals/discover_inventory.py index da71c47..21264d1 100644 --- a/src/ad_buyer/tools/buyer_deals/discover_inventory.py +++ b/src/ad_buyer/tools/buyer_deals/discover_inventory.py @@ -181,9 +181,7 @@ async def _arun( approvals = await self._fetch_approvals(result.data) filtered, filter_summary = self._apply_enforcement(result.data, approvals) - return self._format_results( - filtered, identity_context, approvals, filter_summary - ) + return self._format_results(filtered, identity_context, approvals, filter_summary) except SGPClientError as e: # Reached only when enforcement is on; _fetch_approvals swallows @@ -223,9 +221,7 @@ def _approval_line( return f" SGP Approval: ✓ APPROVED — {normalized}" return f" SGP Approval: ✗ NOT APPROVED — {normalized}" - async def _fetch_approvals( - self, products: Any - ) -> dict[str, ApprovalRecord | None]: + async def _fetch_approvals(self, products: Any) -> dict[str, ApprovalRecord | None]: """Batch-check SGP approvals for the distinct seller domains in the result. Returns a dict keyed by normalized domain. Empty dict when no @@ -253,8 +249,7 @@ async def _fetch_approvals( if self._sgp_enforce: raise logger.warning( - "SGP approval lookup failed during discovery; " - "continuing without annotations", + "SGP approval lookup failed during discovery; continuing without annotations", exc_info=True, ) return {} diff --git a/src/ad_buyer/tools/buyer_deals/request_deal.py b/src/ad_buyer/tools/buyer_deals/request_deal.py index 64d6987..15dd9e1 100644 --- a/src/ad_buyer/tools/buyer_deals/request_deal.py +++ b/src/ad_buyer/tools/buyer_deals/request_deal.py @@ -253,9 +253,7 @@ async def _arun( except (OSError, ValueError, RuntimeError) as e: return f"Error requesting deal: {e}" - async def _check_sgp_approval( - self, product: dict - ) -> tuple[str | None, str | None]: + async def _check_sgp_approval(self, product: dict) -> tuple[str | None, str | None]: """Gate a deal request against IAB Diligence Platform approval. Returns ``(error_message, banner)``: @@ -284,7 +282,8 @@ async def _check_sgp_approval( approvals = await self._sgp_client.check_approvals([raw_domain]) except SGPClientError as exc: logger.warning( - "IAB Diligence Platform lookup failed for %s during deal request", domain, + "IAB Diligence Platform lookup failed for %s during deal request", + domain, exc_info=True, ) # Fail closed — enforcement is on, so we must not issue a Deal ID @@ -472,9 +471,7 @@ def _format_deal_response( # the human reviewer (and audit trail) a stable handle linking # buyer state to seller-side records (proposal §5.1 step 2). if audience_plan is not None: - output_lines.append( - f"Audience Plan ID: {audience_plan.audience_plan_id}" - ) + output_lines.append(f"Audience Plan ID: {audience_plan.audience_plan_id}") output_lines.extend( [ diff --git a/src/ad_buyer/tools/deal_library/connectors/index_exchange.py b/src/ad_buyer/tools/deal_library/connectors/index_exchange.py index 3e0b692..fd30721 100644 --- a/src/ad_buyer/tools/deal_library/connectors/index_exchange.py +++ b/src/ad_buyer/tools/deal_library/connectors/index_exchange.py @@ -61,7 +61,7 @@ _DEAL_TYPE_MAP: dict[str, str] = { "PG": "PG", "PD": "PD", - "PMP": "PA", # Private Marketplace → Private Auction + "PMP": "PA", # Private Marketplace → Private Auction "PA": "PA", # Lowercase aliases (defensive; IX docs show uppercase) "pg": "PG", @@ -436,9 +436,7 @@ def _fetch_page( try: response = self._client.get(url, params=params, headers=headers) except httpx.TransportError as exc: - raise SSPConnectionError( - f"Index Exchange API network error: {exc}" - ) from exc + raise SSPConnectionError(f"Index Exchange API network error: {exc}") from exc if response.status_code in (401, 403): raise SSPAuthError( @@ -462,8 +460,7 @@ def _fetch_page( if response.status_code >= 500: raise SSPConnectionError( - f"Index Exchange API server error (HTTP {response.status_code}): " - f"{response.text}", + f"Index Exchange API server error (HTTP {response.status_code}): {response.text}", status_code=response.status_code, ) diff --git a/src/ad_buyer/tools/deal_library/connectors/magnite.py b/src/ad_buyer/tools/deal_library/connectors/magnite.py index 8c769e8..93b8925 100644 --- a/src/ad_buyer/tools/deal_library/connectors/magnite.py +++ b/src/ad_buyer/tools/deal_library/connectors/magnite.py @@ -145,11 +145,7 @@ class MagniteConnector(SSPConnector): def __init__(self, *, platform: str | None = None) -> None: # Resolve platform: constructor arg > env var > default - resolved_platform = ( - platform - or os.environ.get("MAGNITE_PLATFORM", "") - or PLATFORM_STREAMING - ) + resolved_platform = platform or os.environ.get("MAGNITE_PLATFORM", "") or PLATFORM_STREAMING if resolved_platform not in _VALID_PLATFORMS: raise ValueError( @@ -256,7 +252,9 @@ def _login(self, client: httpx.Client) -> str: # The session cookie is set by the server; httpx stores it in the # client's cookie jar automatically. Return it for logging/debug. session_cookie = response.cookies.get("SESSION", "") - logger.debug("Magnite: authentication successful (session cookie present: %s)", bool(session_cookie)) # noqa: E501 + logger.debug( + "Magnite: authentication successful (session cookie present: %s)", bool(session_cookie) + ) # noqa: E501 return session_cookie # ------------------------------------------------------------------ diff --git a/src/ad_buyer/tools/deal_library/connectors/pubmatic.py b/src/ad_buyer/tools/deal_library/connectors/pubmatic.py index cc91efb..2de8497 100644 --- a/src/ad_buyer/tools/deal_library/connectors/pubmatic.py +++ b/src/ad_buyer/tools/deal_library/connectors/pubmatic.py @@ -369,9 +369,7 @@ def _fetch_page( try: response = self._client.get(url, params=params, headers=headers) except httpx.TransportError as exc: - raise SSPConnectionError( - f"PubMatic API network error: {exc}" - ) from exc + raise SSPConnectionError(f"PubMatic API network error: {exc}") from exc if response.status_code in (401, 403): raise SSPAuthError( @@ -395,8 +393,7 @@ def _fetch_page( if response.status_code >= 500: raise SSPConnectionError( - f"PubMatic API server error (HTTP {response.status_code}): " - f"{response.text}", + f"PubMatic API server error (HTTP {response.status_code}): {response.text}", status_code=response.status_code, ) diff --git a/src/ad_buyer/tools/deal_library/templates.py b/src/ad_buyer/tools/deal_library/templates.py index a593364..94354b1 100644 --- a/src/ad_buyer/tools/deal_library/templates.py +++ b/src/ad_buyer/tools/deal_library/templates.py @@ -48,8 +48,7 @@ class ManageDealTemplateInput(BaseModel): action: str = Field( ..., description=( - "The CRUD action to perform: 'create', 'read', 'list', " - "'update', or 'delete'." + "The CRUD action to perform: 'create', 'read', 'list', 'update', or 'delete'." ), ) params_json: str = Field( @@ -74,8 +73,7 @@ class ManageSupplyPathTemplateInput(BaseModel): action: str = Field( ..., description=( - "The CRUD action to perform: 'create', 'read', 'list', " - "'update', or 'delete'." + "The CRUD action to perform: 'create', 'read', 'list', 'update', or 'delete'." ), ) params_json: str = Field( @@ -261,9 +259,7 @@ def _format_supply_path_template(tmpl: dict[str, Any]) -> str: lines.append(f" Max Reseller Hops: {tmpl['max_reseller_hops']}") if tmpl.get("require_sellers_json") is not None: - lines.append( - f" Require sellers.json: {'Yes' if tmpl['require_sellers_json'] else 'No'}" - ) + lines.append(f" Require sellers.json: {'Yes' if tmpl['require_sellers_json'] else 'No'}") # JSON array fields for field_name, label in [ @@ -366,33 +362,21 @@ def _create(self, params: dict[str, Any]) -> str: name=name, deal_type_pref=params.get("deal_type_pref"), inventory_types=_serialize_list_field(params.get("inventory_types")), - preferred_publishers=_serialize_list_field( - params.get("preferred_publishers") - ), - excluded_publishers=_serialize_list_field( - params.get("excluded_publishers") - ), - targeting_defaults=_serialize_dict_field( - params.get("targeting_defaults") - ), + preferred_publishers=_serialize_list_field(params.get("preferred_publishers")), + excluded_publishers=_serialize_list_field(params.get("excluded_publishers")), + targeting_defaults=_serialize_dict_field(params.get("targeting_defaults")), default_price=params.get("default_price"), max_cpm=params.get("max_cpm"), min_impressions=params.get("min_impressions"), default_flight_days=params.get("default_flight_days"), - supply_path_prefs=_serialize_dict_field( - params.get("supply_path_prefs") - ), + supply_path_prefs=_serialize_dict_field(params.get("supply_path_prefs")), advertiser_id=params.get("advertiser_id"), agency_id=params.get("agency_id"), ) except Exception as exc: return f"Error creating deal template: {exc}" - return ( - f"Deal template created successfully.\n" - f" ID: {template_id}\n" - f" Name: {name}" - ) + return f"Deal template created successfully.\n ID: {template_id}\n Name: {name}" def _read(self, params: dict[str, Any]) -> str: """Read a deal template by ID.""" @@ -552,33 +536,22 @@ def _create(self, params: dict[str, Any]) -> str: name=name, scoring_weights=_serialize_dict_field(scoring_weights), max_reseller_hops=params.get("max_reseller_hops"), - require_sellers_json=( - 1 if params.get("require_sellers_json") else None - ), + require_sellers_json=(1 if params.get("require_sellers_json") else None), preferred_ssps=_serialize_list_field(params.get("preferred_ssps")), blocked_ssps=_serialize_list_field(params.get("blocked_ssps")), - preferred_curators=_serialize_list_field( - params.get("preferred_curators") - ), + preferred_curators=_serialize_list_field(params.get("preferred_curators")), rules=_serialize_list_field(params.get("rules")), ) except Exception as exc: return f"Error creating supply path template: {exc}" - return ( - f"Supply path template created successfully.\n" - f" ID: {template_id}\n" - f" Name: {name}" - ) + return f"Supply path template created successfully.\n ID: {template_id}\n Name: {name}" def _read(self, params: dict[str, Any]) -> str: """Read a supply path template by ID.""" template_id = params.get("template_id") if not template_id: - return ( - "Error: 'template_id' is required for reading a " - "supply path template." - ) + return "Error: 'template_id' is required for reading a supply path template." tmpl = self.deal_store.get_supply_path_template(template_id) if tmpl is None: @@ -608,9 +581,7 @@ def _list(self, params: dict[str, Any]) -> str: if weights_raw: try: weights = ( - json.loads(weights_raw) - if isinstance(weights_raw, str) - else weights_raw + json.loads(weights_raw) if isinstance(weights_raw, str) else weights_raw ) weight_parts = [ f"{k}={weights.get(k, 0):.1f}" @@ -628,10 +599,7 @@ def _update(self, params: dict[str, Any]) -> str: """Update a supply path template.""" template_id = params.pop("template_id", None) if not template_id: - return ( - "Error: 'template_id' is required for updating a " - "supply path template." - ) + return "Error: 'template_id' is required for updating a supply path template." # Validate scoring weights if being updated scoring_weights = params.get("scoring_weights") @@ -657,9 +625,7 @@ def _update(self, params: dict[str, Any]) -> str: else: update_kwargs[key] = val - result = self.deal_store.update_supply_path_template( - template_id, **update_kwargs - ) + result = self.deal_store.update_supply_path_template(template_id, **update_kwargs) if not result: return f"Supply path template not found: {template_id}" @@ -669,10 +635,7 @@ def _delete(self, params: dict[str, Any]) -> str: """Delete a supply path template.""" template_id = params.get("template_id") if not template_id: - return ( - "Error: 'template_id' is required for deleting a " - "supply path template." - ) + return "Error: 'template_id' is required for deleting a supply path template." result = self.deal_store.delete_supply_path_template(template_id) if not result: diff --git a/tests/integration/test_path_a_audience_e2e.py b/tests/integration/test_path_a_audience_e2e.py index 2d05310..5e75b67 100644 --- a/tests/integration/test_path_a_audience_e2e.py +++ b/tests/integration/test_path_a_audience_e2e.py @@ -105,9 +105,7 @@ def _three_type_plan_dict() -> dict[str, Any]: "extensions": [ { "type": "agentic", - "identifier": ( - "emb://buyer.example.com/audiences/auto-converters-q1" - ), + "identifier": ("emb://buyer.example.com/audiences/auto-converters-q1"), "taxonomy": "agentic-audiences", "version": "draft-2026-01", "source": "explicit", @@ -150,9 +148,7 @@ def _base_brief_dict(**overrides: Any) -> dict[str, Any]: def _three_type_brief() -> CampaignBrief: """Brief carrying an explicit 3-type AudiencePlan.""" - return parse_campaign_brief( - _base_brief_dict(target_audience=_three_type_plan_dict()) - ) + return parse_campaign_brief(_base_brief_dict(target_audience=_three_type_plan_dict())) # --------------------------------------------------------------------------- @@ -315,9 +311,7 @@ def test_happy_path_three_types_through_path_a( campaign_id = loop.run_until_complete( pipeline.ingest_brief(brief.model_dump(mode="json")) ) - campaign_plan = loop.run_until_complete( - pipeline.plan_campaign(campaign_id) - ) + campaign_plan = loop.run_until_complete(pipeline.plan_campaign(campaign_id)) loop.run_until_complete(pipeline.execute_booking(campaign_id)) finally: loop.close() @@ -361,29 +355,18 @@ def test_happy_path_three_types_through_path_a( # MUST survive plan -> seller (no in-flight mutation). This # is the §5.1 step-2 wire-format guarantee for the buyer # side of Path A. - assert ( - inv_req.audience_plan.audience_plan_id == post_planner_plan_id - ) - assert ( - deal_params.audience_plan.audience_plan_id - == post_planner_plan_id - ) + assert inv_req.audience_plan.audience_plan_id == post_planner_plan_id + assert deal_params.audience_plan.audience_plan_id == post_planner_plan_id # All three types still present at the seller boundary. assert inv_req.audience_plan.primary.type == "standard" assert inv_req.audience_plan.primary.identifier == "3-7" - assert any( - c.type == "contextual" for c in inv_req.audience_plan.constraints - ) - assert any( - e.type == "agentic" for e in inv_req.audience_plan.extensions - ) + assert any(c.type == "contextual" for c in inv_req.audience_plan.constraints) + assert any(e.type == "agentic" for e in inv_req.audience_plan.extensions) # Compliance context survives for the agentic extension -- # required by §5.2's consent-regime guarantee. - agentic = next( - e for e in inv_req.audience_plan.extensions if e.type == "agentic" - ) + agentic = next(e for e in inv_req.audience_plan.extensions if e.type == "agentic") assert isinstance(agentic.compliance_context, ComplianceContext) assert agentic.compliance_context.jurisdiction == "US" assert agentic.compliance_context.consent_framework == "IAB-TCFv2" @@ -488,16 +471,10 @@ def __init__(self, caps_by_url: dict[str, SellerAudienceCapabilities]): self._caps_by_url = caps_by_url self.calls: list[str] = [] - async def discover_capabilities( - self, seller_endpoint: str - ) -> CapabilityDiscoveryResult: + async def discover_capabilities(self, seller_endpoint: str) -> CapabilityDiscoveryResult: self.calls.append(seller_endpoint) - caps = self._caps_by_url.get( - seller_endpoint, SellerAudienceCapabilities.legacy_default() - ) - return CapabilityDiscoveryResult( - capabilities=caps, cache_status="miss", fetched_at=0.0 - ) + caps = self._caps_by_url.get(seller_endpoint, SellerAudienceCapabilities.legacy_default()) + return CapabilityDiscoveryResult(capabilities=caps, cache_status="miss", fetched_at=0.0) @pytest.fixture @@ -661,18 +638,12 @@ def test_capability_degradation_legacy_seller( # -- the pre-degradation id, by design (so a reviewer can correlate # the original plan with everything that happened to it downstream). events = audience_audit_log.get_events(original_plan_id) - assert events, ( - f"Expected audit events for original plan_id={original_plan_id!r}; " - f"got none" - ) + assert events, f"Expected audit events for original plan_id={original_plan_id!r}; got none" event_types = [e["event_type"] for e in events] assert audience_audit_log.EVENT_DEGRADATION in event_types # Find the degradation event and confirm it carries the seller and # the structured drop log. - deg_events = [ - e for e in events - if e["event_type"] == audience_audit_log.EVENT_DEGRADATION - ] + deg_events = [e for e in events if e["event_type"] == audience_audit_log.EVENT_DEGRADATION] assert len(deg_events) >= 1 deg_payload = deg_events[0]["payload"] assert deg_payload.get("phase") == "preflight" @@ -822,15 +793,9 @@ def test_cross_repo_audience_plan_json_round_trip(self) -> None: break if worktree_name is not None: sibling_worktree = ( - agent_range_root - / "ad_seller_system" - / ".worktrees" - / worktree_name - / "src" - ) - seller_src = str( - sibling_worktree if sibling_worktree.is_dir() else seller_main + agent_range_root / "ad_seller_system" / ".worktrees" / worktree_name / "src" ) + seller_src = str(sibling_worktree if sibling_worktree.is_dir() else seller_main) else: seller_src = str(seller_main) # Even with an ad_buyer_system-named ancestor, the sibling @@ -863,9 +828,7 @@ def _assert_ref_round_trips(ref_dict: dict[str, Any], where: str) -> None: buyer_canon = json.dumps(ref_dict, sort_keys=True) seller_canon = json.dumps(re_serialized, sort_keys=True) assert buyer_canon == seller_canon, ( - f"Schema drift at {where}:\n" - f" buyer: {buyer_canon}\n" - f" seller: {seller_canon}" + f"Schema drift at {where}:\n buyer: {buyer_canon}\n seller: {seller_canon}" ) # 3. Round-trip every ref slot. @@ -879,9 +842,7 @@ def _assert_ref_round_trips(ref_dict: dict[str, Any], where: str) -> None: # 4. Confirm the agentic compliance_context survived the round-trip # (it's the most failure-prone nested field). - agentic_dict = next( - r for r in buyer_dict["extensions"] if r["type"] == "agentic" - ) + agentic_dict = next(r for r in buyer_dict["extensions"] if r["type"] == "agentic") seller_agentic = SellerRef(**agentic_dict) assert seller_agentic.compliance_context is not None assert seller_agentic.compliance_context.jurisdiction == "US" diff --git a/tests/integration/test_path_b_audience_e2e.py b/tests/integration/test_path_b_audience_e2e.py index 319174c..81610c3 100644 --- a/tests/integration/test_path_b_audience_e2e.py +++ b/tests/integration/test_path_b_audience_e2e.py @@ -103,9 +103,7 @@ def _three_type_plan_dict() -> dict[str, Any]: "extensions": [ { "type": "agentic", - "identifier": ( - "emb://buyer.example.com/audiences/auto-converters-q1" - ), + "identifier": ("emb://buyer.example.com/audiences/auto-converters-q1"), "taxonomy": "agentic-audiences", "version": "draft-2026-01", "source": "explicit", @@ -148,9 +146,7 @@ def _base_brief_dict(**overrides: Any) -> dict[str, Any]: def _three_type_brief() -> CampaignBrief: """Brief carrying an explicit 3-type AudiencePlan.""" - return parse_campaign_brief( - _base_brief_dict(target_audience=_three_type_plan_dict()) - ) + return parse_campaign_brief(_base_brief_dict(target_audience=_three_type_plan_dict())) def _legacy_list_brief() -> CampaignBrief: @@ -228,9 +224,7 @@ def channel_brief() -> dict[str, Any]: class TestBuyerDealFlowThreeTypeHappyPath: """3-type plan (Standard + Contextual + Agentic) flows end to end.""" - def test_brief_yields_three_type_plan_on_state( - self, mock_unified_client: MagicMock - ) -> None: + def test_brief_yields_three_type_plan_on_state(self, mock_unified_client: MagicMock) -> None: """Brief -> planner runs -> 3-type plan attached to flow state.""" brief = _three_type_brief() @@ -297,9 +291,7 @@ def test_three_type_plan_threaded_into_dealrequest( # observe the plan that crosses the flow -> tool boundary. flow.state.selected_product_id = "ctv-pkg-pathb" flow._deal_tool = MagicMock() - flow._deal_tool._run = MagicMock( - return_value="DEAL CREATED: deal-pathb-3type-001" - ) + flow._deal_tool._run = MagicMock(return_value="DEAL CREATED: deal-pathb-3type-001") outcome = flow.request_deal_id({"status": "success"}) assert outcome["status"] == "success" @@ -366,9 +358,7 @@ def test_legacy_list_brief_propagates_source_inferred( # source=inferred reached the DealRequest payload. flow.state.selected_product_id = "ctv-pkg-legacy" flow._deal_tool = MagicMock() - flow._deal_tool._run = MagicMock( - return_value="DEAL CREATED: deal-pathb-legacy-001" - ) + flow._deal_tool._run = MagicMock(return_value="DEAL CREATED: deal-pathb-legacy-001") flow.request_deal_id({"status": "success"}) observed = flow._deal_tool._run.call_args.kwargs.get("audience_plan") @@ -376,8 +366,7 @@ def test_legacy_list_brief_propagates_source_inferred( assert observed.primary.source == "inferred" # Extension carries source=inferred too -- whole-plan provenance. assert any( - e.source == "inferred" and e.identifier == "luxury_buyers" - for e in observed.extensions + e.source == "inferred" and e.identifier == "luxury_buyers" for e in observed.extensions ) @@ -389,9 +378,7 @@ def test_legacy_list_brief_propagates_source_inferred( class TestBuyerDealFlowSerializationParity: """AudiencePlan survives JSON serialization at the flow -> seller boundary.""" - def test_dealrequest_roundtrip_preserves_plan_id( - self, mock_unified_client: MagicMock - ) -> None: + def test_dealrequest_roundtrip_preserves_plan_id(self, mock_unified_client: MagicMock) -> None: """Mock the seller, capture the payload, deserialize, compare. This is the §5.1 step-2 wire-format guarantee: the buyer's @@ -440,9 +427,7 @@ def test_dealrequest_roundtrip_preserves_plan_id( assert any(c.type == "contextual" for c in rebuilt.audience_plan.constraints) assert any(e.type == "agentic" for e in rebuilt.audience_plan.extensions) # Compliance context survives for agentic refs. - agentic = next( - e for e in rebuilt.audience_plan.extensions if e.type == "agentic" - ) + agentic = next(e for e in rebuilt.audience_plan.extensions if e.type == "agentic") assert agentic.compliance_context is not None assert agentic.compliance_context.jurisdiction == "US" @@ -480,9 +465,7 @@ def test_full_flow_to_seller_payload_preserves_plan_id( flow.state.selected_product_id = "ctv-pkg-pathb" flow._deal_tool = MagicMock() - flow._deal_tool._run = MagicMock( - return_value="DEAL CREATED: deal-pathb-roundtrip" - ) + flow._deal_tool._run = MagicMock(return_value="DEAL CREATED: deal-pathb-roundtrip") flow.request_deal_id({"status": "success"}) observed = flow._deal_tool._run.call_args.kwargs.get("audience_plan") @@ -552,13 +535,9 @@ def test_legacy_seller_capability_reachable_no_crash( # The agentic extension is still on the plan -- §12 will # decide whether to drop it. For §20 we just confirm the # extension is there for §12 to act on. - assert any( - e.type == "agentic" for e in flow.state.audience_plan.extensions - ) + assert any(e.type == "agentic" for e in flow.state.audience_plan.extensions) - def test_capability_degradation_seam_observable( - self, mock_unified_client: MagicMock - ) -> None: + def test_capability_degradation_seam_observable(self, mock_unified_client: MagicMock) -> None: """A capability response advertising no agentic is observable. Records the JSON shape §12 will consume: an audience_capabilities @@ -611,12 +590,11 @@ async def _fake_discover(endpoint: str) -> list[Any]: # the structure is JSON-serializable and carries the §5.7 # required fields, so §12's design has a concrete fixture. import json as _json + wire = _json.dumps(legacy_caps_payload) rebuilt = _json.loads(wire) assert rebuilt["audience_capabilities"]["agentic"]["supported"] is False - assert ( - rebuilt["audience_capabilities"]["supports_extensions"] is False - ) + assert rebuilt["audience_capabilities"]["supports_extensions"] is False # =========================================================================== @@ -627,9 +605,7 @@ async def _fake_discover(endpoint: str) -> list[Any]: class TestBuyerDealFlowPreSetPlanPrecedence: """Pre-seeded ``state.audience_plan`` must NOT be overwritten by the planner.""" - def test_preset_plan_skips_planner_run( - self, mock_unified_client: MagicMock - ) -> None: + def test_preset_plan_skips_planner_run(self, mock_unified_client: MagicMock) -> None: """When state.audience_plan is already set, the planner does not run. Used when a parent pipeline (e.g. CampaignPipeline / Path A) ran @@ -666,9 +642,7 @@ def test_preset_plan_skips_planner_run( # The planner did NOT run -- no cached planner result. assert flow.get_audience_planner_result() is None - def test_preset_plan_threaded_to_seller_payload( - self, mock_unified_client: MagicMock - ) -> None: + def test_preset_plan_threaded_to_seller_payload(self, mock_unified_client: MagicMock) -> None: """Pre-seeded plan must reach the seller-bound DealRequest unchanged. Closes the loop for parent-pipeline integrations: not only does @@ -701,9 +675,7 @@ def test_preset_plan_threaded_to_seller_payload( flow.state.selected_product_id = "ctv-pkg-preset" flow._deal_tool = MagicMock() - flow._deal_tool._run = MagicMock( - return_value="DEAL CREATED: deal-pathb-preset-001" - ) + flow._deal_tool._run = MagicMock(return_value="DEAL CREATED: deal-pathb-preset-001") flow.request_deal_id({"status": "success"}) observed = flow._deal_tool._run.call_args.kwargs.get("audience_plan") diff --git a/tests/integration/test_real_model_path_e2e.py b/tests/integration/test_real_model_path_e2e.py index 95a0ae3..ed8727f 100644 --- a/tests/integration/test_real_model_path_e2e.py +++ b/tests/integration/test_real_model_path_e2e.py @@ -54,9 +54,7 @@ class TestRealModelPath: - @pytest.mark.skipif( - not SBERT_AVAILABLE, reason="sentence-transformers not installed" - ) + @pytest.mark.skipif(not SBERT_AVAILABLE, reason="sentence-transformers not installed") def test_local_model_produces_384_dim_or_falls_back(self): with patch.object(settings, "embedding_mode", "local"): client = UCPClient() @@ -172,7 +170,4 @@ def test_full_plan_with_local_path_serializes(self): plan_json = plan.model_dump_json() reconstructed = AudiencePlan.model_validate_json(plan_json) assert reconstructed.audience_plan_id == plan.audience_plan_id - assert ( - reconstructed.extensions[0].compliance_context.embedding_provenance - == r.provenance - ) + assert reconstructed.extensions[0].compliance_context.embedding_provenance == r.provenance diff --git a/tests/smoke/test_mcp_e2e.py b/tests/smoke/test_mcp_e2e.py index b8ca004..5771cec 100644 --- a/tests/smoke/test_mcp_e2e.py +++ b/tests/smoke/test_mcp_e2e.py @@ -29,6 +29,7 @@ try: from mcp import ClientSession from mcp.client.sse import sse_client + MCP_AVAILABLE = True except ImportError: MCP_AVAILABLE = False @@ -48,6 +49,7 @@ # Shared helpers # ------------------------------------------------------------------------- + async def _call(session: "ClientSession", name: str, args: dict | None = None): """Call an MCP tool and return (is_error, data) tuple.""" result = await session.call_tool(name, arguments=args or {}) @@ -87,6 +89,7 @@ async def mcp_session(): # 1. Foundation (3 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_health_check(mcp_session): err, data = await _call(mcp_session, "health_check") @@ -118,6 +121,7 @@ async def test_get_config(mcp_session): # 2. Setup Wizard (4 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_run_setup_wizard(mcp_session): err, data = await _call(mcp_session, "run_setup_wizard") @@ -134,8 +138,7 @@ async def test_get_wizard_step(mcp_session): @pytest.mark.asyncio async def test_complete_wizard_step(mcp_session): - err, data = await _call(mcp_session, "complete_wizard_step", - {"step_number": 1, "config": "{}"}) + err, data = await _call(mcp_session, "complete_wizard_step", {"step_number": 1, "config": "{}"}) assert not err, f"complete_wizard_step raised: {data}" # Either succeeds or returns a structured error (step already done) assert "success" in data or "error" in data @@ -152,6 +155,7 @@ async def test_skip_wizard_step(mcp_session): # 3. Campaign Management (4 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_list_campaigns(mcp_session): err, data = await _call(mcp_session, "list_campaigns") @@ -163,16 +167,16 @@ async def test_list_campaigns(mcp_session): @pytest.mark.asyncio async def test_get_campaign_status_not_found(mcp_session): """Verify graceful error for missing campaign.""" - err, data = await _call(mcp_session, "get_campaign_status", - {"campaign_id": "nonexistent-uat-001"}) + err, data = await _call( + mcp_session, "get_campaign_status", {"campaign_id": "nonexistent-uat-001"} + ) assert not err, f"get_campaign_status raised: {data}" assert "error" in data @pytest.mark.asyncio async def test_check_pacing_not_found(mcp_session): - err, data = await _call(mcp_session, "check_pacing", - {"campaign_id": "nonexistent-uat-001"}) + err, data = await _call(mcp_session, "check_pacing", {"campaign_id": "nonexistent-uat-001"}) assert not err, f"check_pacing raised: {data}" assert "error" in data @@ -188,19 +192,26 @@ async def test_review_budgets(mcp_session): # 4. Deal Library (6 tools) + multi-step workflow # ------------------------------------------------------------------------- + @pytest.fixture(scope="module") def created_deal_id(mcp_session): """Create a test deal once for the module and return its ID.""" loop = asyncio.get_event_loop() - err, data = loop.run_until_complete(_call(mcp_session, "create_deal_manual", { - "display_name": "Quinn UAT Smoke Test Deal", - "seller_url": "http://smoke-test-seller.example.com", - "deal_type": "PD", - "price": 12.50, - "currency": "USD", - "media_type": "DIGITAL", - "description": "Created by MCP E2E smoke test — safe to delete", - })) + err, data = loop.run_until_complete( + _call( + mcp_session, + "create_deal_manual", + { + "display_name": "Quinn UAT Smoke Test Deal", + "seller_url": "http://smoke-test-seller.example.com", + "deal_type": "PD", + "price": 12.50, + "currency": "USD", + "media_type": "DIGITAL", + "description": "Created by MCP E2E smoke test — safe to delete", + }, + ) + ) assert not err and data.get("success"), f"Deal creation failed: {data}" return data["deal_id"] @@ -225,14 +236,18 @@ async def test_get_portfolio_summary(mcp_session): async def test_create_deal_manual(mcp_session): """Workflow: create deal → inspect → search → verify portfolio grows.""" # Create - err, data = await _call(mcp_session, "create_deal_manual", { - "display_name": "Quinn Workflow Test Deal", - "seller_url": "http://workflow-test.example.com", - "deal_type": "PG", - "price": 25.00, - "currency": "USD", - "media_type": "CTV", - }) + err, data = await _call( + mcp_session, + "create_deal_manual", + { + "display_name": "Quinn Workflow Test Deal", + "seller_url": "http://workflow-test.example.com", + "deal_type": "PG", + "price": 25.00, + "currency": "USD", + "media_type": "CTV", + }, + ) assert not err and data.get("success"), f"create failed: {data}" deal_id = data["deal_id"] @@ -244,8 +259,7 @@ async def test_create_deal_manual(mcp_session): assert inspect.get("display_name") == "Quinn Workflow Test Deal" # Search - err, search = await _call(mcp_session, "search_deals", - {"query": "Quinn Workflow Test"}) + err, search = await _call(mcp_session, "search_deals", {"query": "Quinn Workflow Test"}) assert not err, f"search_deals raised: {search}" assert "deals" in search found = any(d.get("deal_id") == deal_id for d in search["deals"]) @@ -262,8 +276,7 @@ async def test_search_deals(mcp_session): @pytest.mark.asyncio async def test_inspect_deal_not_found(mcp_session): - err, data = await _call(mcp_session, "inspect_deal", - {"deal_id": "does-not-exist-uat"}) + err, data = await _call(mcp_session, "inspect_deal", {"deal_id": "does-not-exist-uat"}) assert not err, f"inspect_deal raised: {data}" assert "error" in data @@ -276,8 +289,7 @@ async def test_import_deals_csv(mcp_session): "Quinn CSV Import Deal,http://csv-seller.example.com," "CSVPublisher,csv-seller.example.com,PD,9.50,USD" ) - err, data = await _call(mcp_session, "import_deals_csv", - {"csv_data": csv_data}) + err, data = await _call(mcp_session, "import_deals_csv", {"csv_data": csv_data}) assert not err, f"import_deals_csv raised: {data}" assert "total_rows" in data assert data.get("successful", 0) >= 1 or data.get("failed", 0) >= 0 # graceful @@ -287,6 +299,7 @@ async def test_import_deals_csv(mcp_session): # 5. Seller Discovery (3 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_discover_sellers(mcp_session): err, data = await _call(mcp_session, "discover_sellers") @@ -297,17 +310,20 @@ async def test_discover_sellers(mcp_session): @pytest.mark.asyncio async def test_get_seller_media_kit_unreachable(mcp_session): """Unreachable seller returns a structured error, not a crash.""" - err, data = await _call(mcp_session, "get_seller_media_kit", - {"seller_url": "http://127.0.0.1:19999"}) + err, data = await _call( + mcp_session, "get_seller_media_kit", {"seller_url": "http://127.0.0.1:19999"} + ) assert not err, f"get_seller_media_kit raised: {data}" assert "error" in data @pytest.mark.asyncio async def test_compare_sellers(mcp_session): - err, data = await _call(mcp_session, "compare_sellers", - {"seller_urls": ["http://s1.example.com", - "http://s2.example.com"]}) + err, data = await _call( + mcp_session, + "compare_sellers", + {"seller_urls": ["http://s1.example.com", "http://s2.example.com"]}, + ) assert not err, f"compare_sellers raised: {data}" assert "sellers_compared" in data or "sellers" in data @@ -316,6 +332,7 @@ async def test_compare_sellers(mcp_session): # 6. Negotiation (3 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_list_active_negotiations(mcp_session): err, data = await _call(mcp_session, "list_active_negotiations") @@ -326,12 +343,16 @@ async def test_list_active_negotiations(mcp_session): @pytest.mark.asyncio async def test_start_negotiation(mcp_session): - err, data = await _call(mcp_session, "start_negotiation", { - "seller_url": "http://neg-test-seller.example.com", - "product_id": "pkg-ctv-premium", - "product_name": "Quinn UAT Negotiation Test", - "initial_price": 18.00, - }) + err, data = await _call( + mcp_session, + "start_negotiation", + { + "seller_url": "http://neg-test-seller.example.com", + "product_id": "pkg-ctv-premium", + "product_name": "Quinn UAT Negotiation Test", + "initial_price": 18.00, + }, + ) assert not err, f"start_negotiation raised: {data}" assert "deal_id" in data assert data.get("status") == "negotiating" @@ -341,16 +362,19 @@ async def test_start_negotiation(mcp_session): @pytest.mark.asyncio async def test_get_negotiation_status(mcp_session): # Start one so we have a deal to check - _, neg = await _call(mcp_session, "start_negotiation", { - "seller_url": "http://neg-status-test.example.com", - "product_id": "pkg-test", - "initial_price": 10.00, - }) + _, neg = await _call( + mcp_session, + "start_negotiation", + { + "seller_url": "http://neg-status-test.example.com", + "product_id": "pkg-test", + "initial_price": 10.00, + }, + ) deal_id = neg.get("deal_id") assert deal_id, "start_negotiation did not return deal_id" - err, data = await _call(mcp_session, "get_negotiation_status", - {"deal_id": deal_id}) + err, data = await _call(mcp_session, "get_negotiation_status", {"deal_id": deal_id}) assert not err, f"get_negotiation_status raised: {data}" assert data.get("deal_id") == deal_id assert "status" in data @@ -360,6 +384,7 @@ async def test_get_negotiation_status(mcp_session): # 7. Orders (3 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_list_orders(mcp_session): err, data = await _call(mcp_session, "list_orders") @@ -369,19 +394,22 @@ async def test_list_orders(mcp_session): @pytest.mark.asyncio async def test_get_order_status_not_found(mcp_session): - err, data = await _call(mcp_session, "get_order_status", - {"order_id": "nonexistent-order-uat"}) + err, data = await _call(mcp_session, "get_order_status", {"order_id": "nonexistent-order-uat"}) assert not err, f"get_order_status raised: {data}" assert "error" in data @pytest.mark.asyncio async def test_transition_order_not_found(mcp_session): - err, data = await _call(mcp_session, "transition_order", { - "order_id": "nonexistent-order-uat", - "to_status": "confirmed", - "reason": "Quinn UAT test", - }) + err, data = await _call( + mcp_session, + "transition_order", + { + "order_id": "nonexistent-order-uat", + "to_status": "confirmed", + "reason": "Quinn UAT test", + }, + ) assert not err, f"transition_order raised: {data}" assert "error" in data @@ -390,6 +418,7 @@ async def test_transition_order_not_found(mcp_session): # 8. Approval Gate (2 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_list_pending_approvals(mcp_session): err, data = await _call(mcp_session, "list_pending_approvals") @@ -400,12 +429,16 @@ async def test_list_pending_approvals(mcp_session): @pytest.mark.asyncio async def test_approve_or_reject_not_found(mcp_session): - err, data = await _call(mcp_session, "approve_or_reject", { - "approval_request_id": "nonexistent-request-uat", - "decision": "approved", - "reviewer": "quinn-uat", - "reason": "UAT test approval", - }) + err, data = await _call( + mcp_session, + "approve_or_reject", + { + "approval_request_id": "nonexistent-request-uat", + "decision": "approved", + "reviewer": "quinn-uat", + "reason": "UAT test approval", + }, + ) assert not err, f"approve_or_reject raised: {data}" assert "error" in data @@ -414,6 +447,7 @@ async def test_approve_or_reject_not_found(mcp_session): # 9. API Keys (3 tools) + lifecycle workflow # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_api_key_lifecycle(mcp_session): """Full workflow: create → list (verify masked) → revoke → list (verify gone).""" @@ -421,10 +455,14 @@ async def test_api_key_lifecycle(mcp_session): test_key = "quinn-test-key-abcdefgh9999" # Create - err, create_data = await _call(mcp_session, "create_api_key", { - "seller_url": test_seller, - "api_key": test_key, - }) + err, create_data = await _call( + mcp_session, + "create_api_key", + { + "seller_url": test_seller, + "api_key": test_key, + }, + ) assert not err and create_data.get("created"), f"create_api_key failed: {create_data}" assert "masked_key" in create_data # Verify masking: real key should NOT appear in masked version @@ -440,15 +478,13 @@ async def test_api_key_lifecycle(mcp_session): assert "****" in matching[0]["masked_key"], "Key is not masked" # Revoke - err, revoke_data = await _call(mcp_session, "revoke_api_key", - {"seller_url": test_seller}) + err, revoke_data = await _call(mcp_session, "revoke_api_key", {"seller_url": test_seller}) assert not err and revoke_data.get("revoked"), f"revoke_api_key failed: {revoke_data}" # List — key should be gone err, list_after = await _call(mcp_session, "list_api_keys") assert not err, f"list_api_keys after revoke failed: {list_after}" - remaining = [k for k in list_after.get("keys", []) - if k.get("seller_url") == test_seller] + remaining = [k for k in list_after.get("keys", []) if k.get("seller_url") == test_seller] assert len(remaining) == 0, f"Key still present after revoke: {remaining}" @@ -456,6 +492,7 @@ async def test_api_key_lifecycle(mcp_session): # 10. Templates (3 tools) + workflow # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_list_templates(mcp_session): err, data = await _call(mcp_session, "list_templates") @@ -468,26 +505,36 @@ async def test_list_templates(mcp_session): async def test_template_workflow(mcp_session): """Create deal template → instantiate from it → verify deal created.""" # Create template - err, tmpl = await _call(mcp_session, "create_template", { - "template_type": "deal", - "name": "Quinn UAT CTV Deal Template", - "deal_type_pref": "PD", - "max_cpm": 30.0, - "default_price": 15.0, - }) + err, tmpl = await _call( + mcp_session, + "create_template", + { + "template_type": "deal", + "name": "Quinn UAT CTV Deal Template", + "deal_type_pref": "PD", + "max_cpm": 30.0, + "default_price": 15.0, + }, + ) assert not err, f"create_template raised: {tmpl}" assert "template_id" in tmpl, f"No template_id in response: {tmpl}" template_id = tmpl["template_id"] # Instantiate - overrides_json = json.dumps({ - "display_name": "Quinn UAT Instantiated Deal", - "seller_url": "http://template-seller.example.com", - }) - err, inst = await _call(mcp_session, "instantiate_from_template", { - "template_id": template_id, - "overrides": overrides_json, - }) + overrides_json = json.dumps( + { + "display_name": "Quinn UAT Instantiated Deal", + "seller_url": "http://template-seller.example.com", + } + ) + err, inst = await _call( + mcp_session, + "instantiate_from_template", + { + "template_id": template_id, + "overrides": overrides_json, + }, + ) assert not err, f"instantiate_from_template raised: {inst}" assert inst.get("success"), f"instantiate_from_template failed: {inst}" assert "deal_id" in inst @@ -504,19 +551,23 @@ async def test_template_workflow(mcp_session): # 11. Reporting (3 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_get_deal_performance(mcp_session): """Create a deal and verify performance report is returned for it.""" - err, deal = await _call(mcp_session, "create_deal_manual", { - "display_name": "Quinn Perf Report Test Deal", - "seller_url": "http://perf-test.example.com", - "deal_type": "PD", - }) + err, deal = await _call( + mcp_session, + "create_deal_manual", + { + "display_name": "Quinn Perf Report Test Deal", + "seller_url": "http://perf-test.example.com", + "deal_type": "PD", + }, + ) assert not err and deal.get("success"), f"create failed: {deal}" deal_id = deal["deal_id"] - err, data = await _call(mcp_session, "get_deal_performance", - {"deal_id": deal_id}) + err, data = await _call(mcp_session, "get_deal_performance", {"deal_id": deal_id}) assert not err, f"get_deal_performance raised: {data}" # Should have deal info even if no spend yet assert "deal_id" in data or "error" in data @@ -524,16 +575,18 @@ async def test_get_deal_performance(mcp_session): @pytest.mark.asyncio async def test_get_campaign_report_not_found(mcp_session): - err, data = await _call(mcp_session, "get_campaign_report", - {"campaign_id": "nonexistent-uat-001"}) + err, data = await _call( + mcp_session, "get_campaign_report", {"campaign_id": "nonexistent-uat-001"} + ) assert not err, f"get_campaign_report raised: {data}" assert "error" in data @pytest.mark.asyncio async def test_get_pacing_report_not_found(mcp_session): - err, data = await _call(mcp_session, "get_pacing_report", - {"campaign_id": "nonexistent-uat-001"}) + err, data = await _call( + mcp_session, "get_pacing_report", {"campaign_id": "nonexistent-uat-001"} + ) assert not err, f"get_pacing_report raised: {data}" assert "error" in data @@ -542,6 +595,7 @@ async def test_get_pacing_report_not_found(mcp_session): # 12. SSP Connectors (3 tools) # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_list_ssp_connectors(mcp_session): err, data = await _call(mcp_session, "list_ssp_connectors") @@ -558,8 +612,7 @@ async def test_list_ssp_connectors(mcp_session): @pytest.mark.asyncio async def test_import_deals_ssp_unconfigured(mcp_session): """Unconfigured SSP returns structured error, not crash.""" - err, data = await _call(mcp_session, "import_deals_ssp", - {"ssp_name": "pubmatic"}) + err, data = await _call(mcp_session, "import_deals_ssp", {"ssp_name": "pubmatic"}) assert not err, f"import_deals_ssp raised: {data}" assert "error" in data or "deals" in data # error if not configured @@ -567,8 +620,7 @@ async def test_import_deals_ssp_unconfigured(mcp_session): @pytest.mark.asyncio async def test_ssp_connection_test_unconfigured(mcp_session): """Test SSP connection for unconfigured connector.""" - err, data = await _call(mcp_session, "test_ssp_connection", - {"ssp_name": "index_exchange"}) + err, data = await _call(mcp_session, "test_ssp_connection", {"ssp_name": "index_exchange"}) assert not err, f"test_ssp_connection raised: {data}" assert "connected" in data or "error" in data @@ -577,6 +629,7 @@ async def test_ssp_connection_test_unconfigured(mcp_session): # 13. Edge Case Tests # ------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_search_deals_empty_query(mcp_session): """Empty string search is handled gracefully.""" @@ -596,10 +649,14 @@ async def test_inspect_deal_empty_id(mcp_session): @pytest.mark.asyncio async def test_create_deal_minimal_fields(mcp_session): """Minimal required fields only: display_name + seller_url.""" - err, data = await _call(mcp_session, "create_deal_manual", { - "display_name": "Quinn Minimal Deal", - "seller_url": "http://minimal.example.com", - }) + err, data = await _call( + mcp_session, + "create_deal_manual", + { + "display_name": "Quinn Minimal Deal", + "seller_url": "http://minimal.example.com", + }, + ) assert not err, f"create_deal_manual minimal raised: {data}" assert data.get("success"), f"Minimal deal creation failed: {data}" @@ -607,7 +664,6 @@ async def test_create_deal_minimal_fields(mcp_session): @pytest.mark.asyncio async def test_list_deals_with_status_filter(mcp_session): """Filtering list_deals by status is handled.""" - err, data = await _call(mcp_session, "list_deals", - {"status": "active"}) + err, data = await _call(mcp_session, "list_deals", {"status": "active"}) assert not err, f"list_deals with status filter raised: {data}" assert "deals" in data diff --git a/tests/smoke/test_mcp_streamable.py b/tests/smoke/test_mcp_streamable.py index 2c8289b..36d99a2 100644 --- a/tests/smoke/test_mcp_streamable.py +++ b/tests/smoke/test_mcp_streamable.py @@ -30,6 +30,7 @@ try: from mcp import ClientSession from mcp.client.streamable_http import streamable_http_client + MCP_HTTP_AVAILABLE = True except ImportError: try: @@ -38,6 +39,7 @@ from mcp.client.streamable_http import ( streamablehttp_client as streamable_http_client, # type: ignore[no-redef] ) + MCP_HTTP_AVAILABLE = True except ImportError: MCP_HTTP_AVAILABLE = False @@ -55,6 +57,7 @@ # Session helper — context manager, not a fixture, avoids AUTO-mode doubling # --------------------------------------------------------------------------- + @asynccontextmanager async def _mcp_session(): """Open a fresh Streamable HTTP MCP session for one test.""" @@ -93,6 +96,7 @@ async def _call(session: "ClientSession", name: str, args: dict | None = None): # Connection # --------------------------------------------------------------------------- + async def test_streamable_http_connection(): """/mcp must accept a session and initialize successfully.""" async with _mcp_session() as session: @@ -114,6 +118,7 @@ async def test_streamable_http_tool_list(): # Foundation tools # --------------------------------------------------------------------------- + async def test_health_check(): async with _mcp_session() as session: err, data = await _call(session, "health_check") @@ -143,6 +148,7 @@ async def test_get_config(): # Deal library # --------------------------------------------------------------------------- + async def test_list_deals(): async with _mcp_session() as session: err, data = await _call(session, "list_deals") @@ -154,13 +160,17 @@ async def test_list_deals(): async def test_create_and_inspect_deal(): """Create a deal via /mcp, then inspect it — verifies round-trip.""" async with _mcp_session() as session: - err, data = await _call(session, "create_deal_manual", { - "display_name": "Streamable HTTP Test Deal", - "seller_url": "http://mcp-http-test.example.com", - "deal_type": "PD", - "price": 18.0, - "currency": "USD", - }) + err, data = await _call( + session, + "create_deal_manual", + { + "display_name": "Streamable HTTP Test Deal", + "seller_url": "http://mcp-http-test.example.com", + "deal_type": "PD", + "price": 18.0, + "currency": "USD", + }, + ) assert not err and data.get("success"), f"create_deal_manual failed: {data}" deal_id = data["deal_id"] @@ -181,6 +191,7 @@ async def test_get_portfolio_summary(): # Seller discovery # --------------------------------------------------------------------------- + async def test_discover_sellers(): async with _mcp_session() as session: err, data = await _call(session, "discover_sellers") @@ -191,8 +202,9 @@ async def test_discover_sellers(): async def test_get_seller_media_kit_unreachable(): """Unreachable seller must return structured error, not crash.""" async with _mcp_session() as session: - err, data = await _call(session, "get_seller_media_kit", - {"seller_url": "http://127.0.0.1:19999"}) + err, data = await _call( + session, "get_seller_media_kit", {"seller_url": "http://127.0.0.1:19999"} + ) assert not err, f"get_seller_media_kit raised: {data}" assert "error" in data @@ -201,6 +213,7 @@ async def test_get_seller_media_kit_unreachable(): # Campaigns & Orders # --------------------------------------------------------------------------- + async def test_list_campaigns(): async with _mcp_session() as session: err, data = await _call(session, "list_campaigns") @@ -219,14 +232,16 @@ async def test_list_orders(): # API keys # --------------------------------------------------------------------------- + async def test_api_key_lifecycle(): """Full create → list → revoke lifecycle over /mcp.""" seller = "http://mcp-http-key-test.example.com" raw_key = "mcp-http-test-key-xyz999" async with _mcp_session() as session: - err, created = await _call(session, "create_api_key", - {"seller_url": seller, "api_key": raw_key}) + err, created = await _call( + session, "create_api_key", {"seller_url": seller, "api_key": raw_key} + ) assert not err and created.get("created"), f"create_api_key failed: {created}" assert raw_key not in created["masked_key"], "Raw key must be masked" diff --git a/tests/unit/test_audience_audit_log.py b/tests/unit/test_audience_audit_log.py index 3a5e552..15f3095 100644 --- a/tests/unit/test_audience_audit_log.py +++ b/tests/unit/test_audience_audit_log.py @@ -133,15 +133,9 @@ def test_known_event_types_includes_documented_types(self): class TestOrdering: def test_multiple_events_returned_in_insertion_order(self, temp_audit_db): - audience_audit_log.log_event( - "plan-multi", EVENT_CAPABILITY_REJECTION, {"step": 1} - ) - audience_audit_log.log_event( - "plan-multi", EVENT_DEGRADATION, {"step": 2} - ) - audience_audit_log.log_event( - "plan-multi", EVENT_DEGRADATION, {"step": 3} - ) + audience_audit_log.log_event("plan-multi", EVENT_CAPABILITY_REJECTION, {"step": 1}) + audience_audit_log.log_event("plan-multi", EVENT_DEGRADATION, {"step": 2}) + audience_audit_log.log_event("plan-multi", EVENT_DEGRADATION, {"step": 3}) events = audience_audit_log.get_events("plan-multi") assert len(events) == 3 @@ -189,9 +183,7 @@ def test_empty_plan_id_is_ignored(self, temp_audit_db): def test_unknown_event_type_still_writes(self, temp_audit_db): # Forward-compat: callers can experiment with new types ahead of # constants landing here. The helper logs a WARN but does NOT drop. - audience_audit_log.log_event( - "plan-fc", "future_event_type", {"x": 1} - ) + audience_audit_log.log_event("plan-fc", "future_event_type", {"x": 1}) events = audience_audit_log.get_events("plan-fc") assert len(events) == 1 assert events[0]["event_type"] == "future_event_type" @@ -221,17 +213,14 @@ def test_log_event_creates_table_on_existing_db_without_it(self, tmp_path): # NOT `audience_audit_log`. We use `deals` from the schema module so # the legacy DB is realistic. legacy = sqlite3.connect(str(db_path)) - legacy.execute( - "CREATE TABLE pretend_other_table (id INTEGER PRIMARY KEY)" - ) + legacy.execute("CREATE TABLE pretend_other_table (id INTEGER PRIMARY KEY)") legacy.commit() legacy.close() # Confirm the table is genuinely missing before the helper touches it. check = sqlite3.connect(str(db_path)) cursor = check.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND " - "name='audience_audit_log'" + "SELECT name FROM sqlite_master WHERE type='table' AND name='audience_audit_log'" ) assert cursor.fetchone() is None check.close() @@ -239,15 +228,12 @@ def test_log_event_creates_table_on_existing_db_without_it(self, tmp_path): # Point the helper at this legacy file and write an event. audience_audit_log.configure(f"sqlite:///{db_path}") try: - audience_audit_log.log_event( - "plan-legacy", EVENT_DEGRADATION, {"first": True} - ) + audience_audit_log.log_event("plan-legacy", EVENT_DEGRADATION, {"first": True}) # The table now exists. check2 = sqlite3.connect(str(db_path)) cursor = check2.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND " - "name='audience_audit_log'" + "SELECT name FROM sqlite_master WHERE type='table' AND name='audience_audit_log'" ) assert cursor.fetchone() is not None check2.close() @@ -336,9 +322,7 @@ async def book_deal(self, request: DealBookingRequest) -> DealResponse: class TestOrchestratorEmitsAuditEvents: @pytest.mark.asyncio - async def test_degrade_and_retry_emits_degradation_and_rejection( - self, temp_audit_db - ): + async def test_degrade_and_retry_emits_degradation_and_rejection(self, temp_audit_db): plan = _make_audience_plan() plan_id = plan.audience_plan_id assert plan_id # sanity: auto-computed @@ -374,9 +358,7 @@ async def test_degrade_and_retry_emits_degradation_and_rejection( assert EVENT_DEGRADATION in types # Capability-rejection event preserves the seller's structured list. - rejection = next( - e for e in events if e["event_type"] == EVENT_CAPABILITY_REJECTION - ) + rejection = next(e for e in events if e["event_type"] == EVENT_CAPABILITY_REJECTION) assert rejection["payload"]["seller_id"] == "seller-1" assert rejection["payload"]["unsupported"] == [ { @@ -386,9 +368,7 @@ async def test_degrade_and_retry_emits_degradation_and_rejection( ] # Degradation event captures what was stripped. - degradation = next( - e for e in events if e["event_type"] == EVENT_DEGRADATION - ) + degradation = next(e for e in events if e["event_type"] == EVENT_DEGRADATION) assert degradation["payload"]["seller_id"] == "seller-1" assert degradation["payload"]["deal_id"] == "deal-001" assert isinstance(degradation["payload"]["log"], list) diff --git a/tests/unit/test_audience_degradation.py b/tests/unit/test_audience_degradation.py index b8dda7e..bc6d5cc 100644 --- a/tests/unit/test_audience_degradation.py +++ b/tests/unit/test_audience_degradation.py @@ -55,9 +55,7 @@ def _standard(identifier: str = "3-7", version: str = "1.1") -> AudienceRef: ) -def _contextual( - identifier: str = "IAB1-2", version: str = "3.1" -) -> AudienceRef: +def _contextual(identifier: str = "IAB1-2", version: str = "3.1") -> AudienceRef: return AudienceRef( type="contextual", identifier=identifier, @@ -384,7 +382,8 @@ def test_constraints_trimmed_to_max(self): plan = AudiencePlan( primary=_standard(), constraints=[ - _contextual(f"IAB1-{i}") for i in range(1, 6) # 5 constraints + _contextual(f"IAB1-{i}") + for i in range(1, 6) # 5 constraints ], ) # Seller accepts only 2 constraints. diff --git a/tests/unit/test_audience_planner_reasoning.py b/tests/unit/test_audience_planner_reasoning.py index 7c09988..ab5df55 100644 --- a/tests/unit/test_audience_planner_reasoning.py +++ b/tests/unit/test_audience_planner_reasoning.py @@ -154,8 +154,7 @@ class TestContextualBrief: def test_content_adjacent_description_biases_contextual(self): brief = _make_brief( description=( - "ads next to automotive content on premium news sites; " - "contextual-led campaign" + "ads next to automotive content on premium news sites; contextual-led campaign" ), ) # No usable taxonomy candidates resolve from prose alone -- but @@ -202,13 +201,10 @@ class TestFirstPartyBrief: def test_first_party_description_mints_agentic_primary(self, mint_tool): brief = _make_brief( description=( - "lookalike of our converters from last campaign; " - "advertiser first-party data" + "lookalike of our converters from last campaign; advertiser first-party data" ), ) - result = run_audience_reasoning( - brief, embedding_mint_tool=mint_tool - ) + result = run_audience_reasoning(brief, embedding_mint_tool=mint_tool) assert result.plan is not None, " | ".join(result.rationale_lines) assert result.plan.primary.type == "agentic" @@ -250,9 +246,7 @@ def test_mixed_brief_with_strong_agentic_phrase(self, mint_tool): "campaign; show alongside automotive content" ), ) - result = run_audience_reasoning( - brief, embedding_mint_tool=mint_tool - ) + result = run_audience_reasoning(brief, embedding_mint_tool=mint_tool) # The strong "lookalike of our converters" phrase counts as # 2 agentic phrases (lookalike + our converters), beating @@ -264,8 +258,7 @@ def test_mixed_brief_with_strong_agentic_phrase(self, mint_tool): def test_mixed_brief_demographic_dominant(self): brief = _make_brief( description=( - "women, men, parents, kids, household income brief; " - "no first-party data this time" + "women, men, parents, kids, household income brief; no first-party data this time" ), ) # Many demographic phrases, no agentic, no contextual. With no @@ -298,10 +291,7 @@ def test_explicit_primary_with_cpa_kpi_picks_constraints_branch(self): brief_dict = _base_brief_dict( objective="CONVERSION", kpis=[{"metric": "ROAS", "target_value": 3.0}], - description=( - "Auto intenders; show on automotive content. " - "ROAS-driven optimization." - ), + description=("Auto intenders; show on automotive content. ROAS-driven optimization."), target_audience={ "primary": { "type": "standard", @@ -344,17 +334,14 @@ def test_explicit_primary_with_cpa_kpi_picks_constraints_branch(self): class TestExplicitPlanReachAddsExtensions: """Explicit primary preserved; reach KPI -> inferred extensions.""" - def test_explicit_primary_with_reach_objective_adds_agentic_extension( - self, mint_tool - ): + def test_explicit_primary_with_reach_objective_adds_agentic_extension(self, mint_tool): # REACH objective signals the reach branch; the planner mints an # Agentic extension from a "lookalike" seed in the description. brief_dict = _base_brief_dict( objective="REACH", kpis=[{"metric": "CPM", "target_value": 12.0}], description=( - "Big-reach awareness push; lookalike of our converters " - "for additional scale." + "Big-reach awareness push; lookalike of our converters for additional scale." ), target_audience={ "primary": { @@ -368,9 +355,7 @@ def test_explicit_primary_with_reach_objective_adds_agentic_extension( ) brief = parse_campaign_brief(brief_dict) # Inject the mint tool so the planner can produce an agentic ext. - result = run_audience_reasoning( - brief, embedding_mint_tool=mint_tool - ) + result = run_audience_reasoning(brief, embedding_mint_tool=mint_tool) assert result.plan is not None assert result.plan.primary.identifier == "243" @@ -384,9 +369,7 @@ def test_explicit_primary_with_reach_objective_adds_agentic_extension( # mint tool wired in and "lookalike" / "our converters" in the # description, the extension is Agentic. assert len(result.plan.extensions) >= 1 - agentic_exts = [ - e for e in result.plan.extensions if e.type == "agentic" - ] + agentic_exts = [e for e in result.plan.extensions if e.type == "agentic"] assert len(agentic_exts) >= 1 # Inferred provenance is the mark of agent-added refs. assert agentic_exts[0].source == "inferred" diff --git a/tests/unit/test_audience_planner_wiring.py b/tests/unit/test_audience_planner_wiring.py index e693e7d..f4e7eed 100644 --- a/tests/unit/test_audience_planner_wiring.py +++ b/tests/unit/test_audience_planner_wiring.py @@ -430,9 +430,7 @@ def _research_agent_tools(self, crew): for agent in crew.agents: if agent.role == research_role: return {type(t) for t in agent.tools} - raise AssertionError( - f"Could not find Research Agent (role={research_role!r}) in crew" - ) + raise AssertionError(f"Could not find Research Agent (role={research_role!r}) in crew") def test_branding_crew_research_agent_has_no_audience_tools( self, opendirect_client, channel_brief @@ -452,9 +450,7 @@ def test_mobile_crew_research_agent_has_no_audience_tools( assert AudienceMatchingTool not in types assert CoverageEstimationTool not in types - def test_ctv_crew_research_agent_has_no_audience_tools( - self, opendirect_client, channel_brief - ): + def test_ctv_crew_research_agent_has_no_audience_tools(self, opendirect_client, channel_brief): crew = create_ctv_crew(opendirect_client, channel_brief) types = self._research_agent_tools(crew) assert AudienceDiscoveryTool not in types diff --git a/tests/unit/test_audience_strictness.py b/tests/unit/test_audience_strictness.py index 763b7a1..4ab9b45 100644 --- a/tests/unit/test_audience_strictness.py +++ b/tests/unit/test_audience_strictness.py @@ -83,11 +83,7 @@ def test_brief_accepts_strictness_override_dict(): def test_brief_accepts_strictness_partial_override(): - brief = CampaignBrief( - **_minimal_brief( - audience_strictness={"agentic": "required"} - ) - ) + brief = CampaignBrief(**_minimal_brief(audience_strictness={"agentic": "required"})) s = brief.audience_strictness # Overridden field assert s.agentic == "required" diff --git a/tests/unit/test_buyer_deal_flow_audience.py b/tests/unit/test_buyer_deal_flow_audience.py index fc3029a..6c6d9f2 100644 --- a/tests/unit/test_buyer_deal_flow_audience.py +++ b/tests/unit/test_buyer_deal_flow_audience.py @@ -132,9 +132,7 @@ def _seed_request_state(flow: BuyerDealFlow) -> None: class TestPlannerRunsOnReceiveRequest: """When a brief is supplied, receive_request must run the planner.""" - def test_brief_yields_audience_plan_on_state( - self, mock_unified_client: MagicMock - ) -> None: + def test_brief_yields_audience_plan_on_state(self, mock_unified_client: MagicMock) -> None: brief = _make_brief() flow = BuyerDealFlow( client=mock_unified_client, @@ -154,9 +152,7 @@ def test_brief_yields_audience_plan_on_state( assert planner_result is not None assert planner_result.plan is flow.state.audience_plan - def test_no_brief_keeps_flow_audience_blind( - self, mock_unified_client: MagicMock - ) -> None: + def test_no_brief_keeps_flow_audience_blind(self, mock_unified_client: MagicMock) -> None: """Legacy callers (no brief) must keep the original audience-blind path.""" flow = BuyerDealFlow( @@ -180,9 +176,7 @@ def test_no_brief_keeps_flow_audience_blind( class TestExplicitBriefPreserved: """Explicit user-supplied AudiencePlans are NEVER mutated by the planner.""" - def test_explicit_primary_preserved_verbatim( - self, mock_unified_client: MagicMock - ) -> None: + def test_explicit_primary_preserved_verbatim(self, mock_unified_client: MagicMock) -> None: brief = _make_brief() # Capture the explicit plan as authored by the user. original = brief.target_audience @@ -223,9 +217,7 @@ def test_explicit_primary_preserved_verbatim( class TestLegacyBriefMigration: """Legacy `list[str]` audience field must round-trip through the flow.""" - def test_legacy_brief_yields_inferred_primary( - self, mock_unified_client: MagicMock - ) -> None: + def test_legacy_brief_yields_inferred_primary(self, mock_unified_client: MagicMock) -> None: brief = _make_legacy_brief() # Confirm the parser already migrated the list[str] to a typed plan # marked source=inferred (the contract from §4 / coerce_audience_field). @@ -272,9 +264,7 @@ def _make_minimal_plan(identifier: str = "3-7") -> AudiencePlan: class TestAudiencePlanCrossesSellerBoundary: """The plan threaded onto state must reach the seller-bound DealRequest.""" - def test_request_deal_id_threads_plan_into_tool( - self, mock_unified_client: MagicMock - ) -> None: + def test_request_deal_id_threads_plan_into_tool(self, mock_unified_client: MagicMock) -> None: """request_deal_id must call the deal tool with the AudiencePlan.""" flow = BuyerDealFlow( @@ -289,9 +279,7 @@ def test_request_deal_id_threads_plan_into_tool( # Mock the deal tool so we can inspect the call. flow._deal_tool = MagicMock() - flow._deal_tool._run = MagicMock( - return_value="DEAL CREATED: deal-test-001" - ) + flow._deal_tool._run = MagicMock(return_value="DEAL CREATED: deal-test-001") result = flow.request_deal_id({"status": "success"}) @@ -303,9 +291,7 @@ def test_request_deal_id_threads_plan_into_tool( # Deal type / impressions / flights came along too. assert call_kwargs.get("product_id") == "ctv-pkg-1" - def test_request_deal_payload_carries_plan( - self, mock_unified_client: MagicMock - ) -> None: + def test_request_deal_payload_carries_plan(self, mock_unified_client: MagicMock) -> None: """The seller-bound DealRequest payload must carry the plan.""" from ad_buyer.tools.buyer_deals.request_deal import RequestDealTool @@ -336,9 +322,7 @@ def test_request_deal_payload_carries_plan( assert rebuilt.audience_plan is not None assert rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id - def test_legacy_payload_still_works_without_plan( - self, mock_unified_client: MagicMock - ) -> None: + def test_legacy_payload_still_works_without_plan(self, mock_unified_client: MagicMock) -> None: """No plan supplied -> DealRequest carries audience_plan=None.""" from ad_buyer.tools.buyer_deals.request_deal import RequestDealTool @@ -388,22 +372,16 @@ def test_plan_id_preserved_from_brief_to_tool_call( # observe the plan that crosses the boundary. flow.state.selected_product_id = "ctv-pkg-1" flow._deal_tool = MagicMock() - flow._deal_tool._run = MagicMock( - return_value="DEAL CREATED: deal-test-002" - ) + flow._deal_tool._run = MagicMock(return_value="DEAL CREATED: deal-test-002") flow.request_deal_id({"status": "success"}) - observed_plan = flow._deal_tool._run.call_args.kwargs.get( - "audience_plan" - ) + observed_plan = flow._deal_tool._run.call_args.kwargs.get("audience_plan") assert observed_plan is not None # Same audience_plan_id from brief through state to tool kwargs. assert observed_plan.audience_plan_id == plan_id_after_receive - def test_plan_id_surfaced_on_status( - self, mock_unified_client: MagicMock - ) -> None: + def test_plan_id_surfaced_on_status(self, mock_unified_client: MagicMock) -> None: """get_status() exposes audience_plan_id once the planner has run.""" brief = _make_brief() diff --git a/tests/unit/test_buyer_preflight.py b/tests/unit/test_buyer_preflight.py index 065cc95..93362c9 100644 --- a/tests/unit/test_buyer_preflight.py +++ b/tests/unit/test_buyer_preflight.py @@ -167,9 +167,7 @@ def _make_deal_response( ) -def _ranked_quote( - quote_id: str = "q-1", seller_id: str = "seller-a" -) -> NormalizedQuote: +def _ranked_quote(quote_id: str = "q-1", seller_id: str = "seller-a") -> NormalizedQuote: return NormalizedQuote( seller_id=seller_id, quote_id=quote_id, @@ -476,13 +474,9 @@ def __init__(self, caps_by_url: dict[str, SellerAudienceCapabilities]): self._caps_by_url = caps_by_url self.calls: list[str] = [] - async def discover_capabilities( - self, seller_endpoint: str - ) -> CapabilityDiscoveryResult: + async def discover_capabilities(self, seller_endpoint: str) -> CapabilityDiscoveryResult: self.calls.append(seller_endpoint) - caps = self._caps_by_url.get( - seller_endpoint, SellerAudienceCapabilities.legacy_default() - ) + caps = self._caps_by_url.get(seller_endpoint, SellerAudienceCapabilities.legacy_default()) return CapabilityDiscoveryResult( capabilities=caps, cache_status="miss", @@ -578,9 +572,7 @@ class TestPreflightStrictnessGate: """Tests 7a / 7b / 7c: pre-flight degrades plan, applies strictness.""" @pytest.mark.asyncio - async def test_primary_required_with_version_mismatch_skips_seller( - self, deals_client_factory - ): + async def test_primary_required_with_version_mismatch_skips_seller(self, deals_client_factory): """primary=required + standard taxonomy version mismatch -> seller skipped. The seller advertises only Audience Taxonomy v2.0 (which the buyer's @@ -620,10 +612,7 @@ async def test_primary_required_with_version_mismatch_skips_seller( assert "seller-a" in selection.incompatible_sellers assert selection.booked_deals == [] assert len(selection.failed_bookings) == 1 - assert ( - selection.failed_bookings[0]["error_code"] - == "audience_plan_unsupported" - ) + assert selection.failed_bookings[0]["error_code"] == "audience_plan_unsupported" @pytest.mark.asyncio async def test_extensions_optional_dropped_proceeds(self, deals_client_factory): @@ -703,9 +692,7 @@ async def test_constraints_preferred_dropped_proceeds(self, deals_client_factory assert booking_arg.audience_plan.primary.identifier == "3-7" @pytest.mark.asyncio - async def test_constraints_required_dropped_skips_seller( - self, deals_client_factory - ): + async def test_constraints_required_dropped_skips_seller(self, deals_client_factory): """constraints=required + dropped -> seller skipped. Promotes the optional-by-default constraint policy to required; the @@ -816,9 +803,7 @@ async def test_stale_cache_seller_rejects_retry_fires(self, deals_client_factory assert retry_request.audience_plan.extensions == [] @pytest.mark.asyncio - async def test_preflight_dropped_extensions_no_retry_needed( - self, deals_client_factory - ): + async def test_preflight_dropped_extensions_no_retry_needed(self, deals_client_factory): """When pre-flight already strips ext, the seller never sees them.""" seller_url = "https://seller-no-ext.example.com" diff --git a/tests/unit/test_campaign_brief_migration.py b/tests/unit/test_campaign_brief_migration.py index e5692b0..dc94963 100644 --- a/tests/unit/test_campaign_brief_migration.py +++ b/tests/unit/test_campaign_brief_migration.py @@ -117,12 +117,8 @@ def test_migrate_empty_list_raises(): def test_migrate_emits_structured_log(caplog): caplog.set_level(logging.INFO, logger="ad_buyer.audience.migration") - plan = migrate_legacy_audience_list( - ["3-7", "3-8"], source_context="test_emits" - ) - records = [ - r for r in caplog.records if r.name == "ad_buyer.audience.migration" - ] + plan = migrate_legacy_audience_list(["3-7", "3-8"], source_context="test_emits") + records = [r for r in caplog.records if r.name == "ad_buyer.audience.migration"] assert len(records) == 1 payload = getattr(records[0], "audience_migration", None) assert payload is not None @@ -236,9 +232,7 @@ def test_brief_omitting_target_audience_yields_none(): def test_brief_logs_legacy_conversion(caplog): caplog.set_level(logging.INFO, logger="ad_buyer.audience.migration") CampaignBrief(**_minimal_brief()) - records = [ - r for r in caplog.records if r.name == "ad_buyer.audience.migration" - ] + records = [r for r in caplog.records if r.name == "ad_buyer.audience.migration"] assert len(records) >= 1 payload = getattr(records[0], "audience_migration", None) assert payload is not None diff --git a/tests/unit/test_channel_crew_audience_invocation.py b/tests/unit/test_channel_crew_audience_invocation.py index 36ae8c6..6974da8 100644 --- a/tests/unit/test_channel_crew_audience_invocation.py +++ b/tests/unit/test_channel_crew_audience_invocation.py @@ -170,9 +170,7 @@ def test_typed_plan_uses_typed_renderer(self, typed_plan: AudiencePlan) -> None: assert "Plan ID:" in result assert "Primary:" in result - def test_legacy_dict_uses_legacy_renderer( - self, legacy_dict_plan: dict[str, Any] - ) -> None: + def test_legacy_dict_uses_legacy_renderer(self, legacy_dict_plan: dict[str, Any]) -> None: result = _format_audience_context(legacy_dict_plan) # Legacy renderer header (no "typed" qualifier). assert "Audience Plan Context:" in result @@ -220,9 +218,7 @@ def test_primary_carries_type_tag(self, typed_plan: AudiencePlan) -> None: assert "iab-audience" in result assert "version=1.1" in result - def test_contextual_constraint_carries_type_tag( - self, typed_plan: AudiencePlan - ) -> None: + def test_contextual_constraint_carries_type_tag(self, typed_plan: AudiencePlan) -> None: result = _format_typed_audience_plan(typed_plan) assert "[contextual]" in result assert "IAB1-2" in result @@ -231,9 +227,7 @@ def test_contextual_constraint_carries_type_tag( # Resolved ref -> confidence rendered. assert "confidence=0.92" in result - def test_agentic_extension_carries_compliance( - self, typed_plan: AudiencePlan - ) -> None: + def test_agentic_extension_carries_compliance(self, typed_plan: AudiencePlan) -> None: result = _format_typed_audience_plan(typed_plan) assert "[agentic]" in result assert "emb://buyer.example.com/audiences/auto-converters-q1" in result diff --git a/tests/unit/test_cpm_fallback_removal.py b/tests/unit/test_cpm_fallback_removal.py index 0212a3f..83790ab 100644 --- a/tests/unit/test_cpm_fallback_removal.py +++ b/tests/unit/test_cpm_fallback_removal.py @@ -71,9 +71,7 @@ async def test_no_base_price_returns_error(self, mock_client, agency_context): "channel": "ctv", # No basePrice, no price } - mock_client.get_product.return_value = MagicMock( - success=True, data=product_no_price - ) + mock_client.get_product.return_value = MagicMock(success=True, data=product_no_price) tool = RequestDealTool(client=mock_client, buyer_context=agency_context) result = await tool._arun(product_id="prod-001") @@ -93,9 +91,7 @@ async def test_non_numeric_base_price_returns_error(self, mock_client, agency_co "name": "Premium Display", "basePrice": "contact_sales", } - mock_client.get_product.return_value = MagicMock( - success=True, data=product_bad_price - ) + mock_client.get_product.return_value = MagicMock(success=True, data=product_bad_price) tool = RequestDealTool(client=mock_client, buyer_context=agency_context) result = await tool._arun(product_id="prod-002") @@ -113,9 +109,7 @@ async def test_null_base_price_returns_error(self, mock_client, agency_context): "name": "Premium Audio", "basePrice": None, } - mock_client.get_product.return_value = MagicMock( - success=True, data=product_null_price - ) + mock_client.get_product.return_value = MagicMock(success=True, data=product_null_price) tool = RequestDealTool(client=mock_client, buyer_context=agency_context) result = await tool._arun(product_id="prod-003") @@ -132,9 +126,7 @@ async def test_valid_price_still_works(self, mock_client, agency_context): "name": "Premium Display", "basePrice": 25.0, } - mock_client.get_product.return_value = MagicMock( - success=True, data=product_with_price - ) + mock_client.get_product.return_value = MagicMock(success=True, data=product_with_price) tool = RequestDealTool(client=mock_client, buyer_context=agency_context) result = await tool._arun(product_id="prod-004") @@ -161,9 +153,7 @@ async def test_no_base_price_shows_unavailable(self, mock_client, agency_context "availableImpressions": 5_000_000, # No basePrice, no price } - mock_client.list_products.return_value = MagicMock( - success=True, data=[product_no_price] - ) + mock_client.list_products.return_value = MagicMock(success=True, data=[product_no_price]) tool = DiscoverInventoryTool(client=mock_client, buyer_context=agency_context) result = await tool._arun() @@ -183,9 +173,7 @@ async def test_null_base_price_shows_unavailable(self, mock_client, agency_conte "channel": "display", "availableImpressions": 3_000_000, } - mock_client.list_products.return_value = MagicMock( - success=True, data=[product_null_price] - ) + mock_client.list_products.return_value = MagicMock(success=True, data=[product_null_price]) tool = DiscoverInventoryTool(client=mock_client, buyer_context=agency_context) result = await tool._arun() @@ -202,9 +190,7 @@ async def test_valid_price_still_displays(self, mock_client, agency_context): "channel": "display", "availableImpressions": 5_000_000, } - mock_client.list_products.return_value = MagicMock( - success=True, data=[product_with_price] - ) + mock_client.list_products.return_value = MagicMock(success=True, data=[product_with_price]) tool = DiscoverInventoryTool(client=mock_client, buyer_context=agency_context) result = await tool._arun() @@ -340,9 +326,7 @@ def test_estimate_impressions_requires_cpm(self): def test_estimate_impressions_with_explicit_cpm(self): """_estimate_impressions with an explicit CPM should still work.""" - result = CampaignPipeline._estimate_impressions( - budget=60_000.0, assumed_cpm=20.0 - ) + result = CampaignPipeline._estimate_impressions(budget=60_000.0, assumed_cpm=20.0) # (60000 / 20) * 1000 = 3,000,000 assert result == 3_000_000 diff --git a/tests/unit/test_deal_library_templates.py b/tests/unit/test_deal_library_templates.py index 1a37211..007b699 100644 --- a/tests/unit/test_deal_library_templates.py +++ b/tests/unit/test_deal_library_templates.py @@ -58,13 +58,15 @@ def test_create_deal_template_returns_success(self, deal_template_tool): """Creating a deal template returns a success message with template ID.""" result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Standard Sports Video PG", - "deal_type_pref": "PG", - "inventory_types": ["DIGITAL", "CTV"], - "preferred_publishers": ["espn.com", "nfl.com"], - "max_cpm": 25.00, - }), + params_json=json.dumps( + { + "name": "Standard Sports Video PG", + "deal_type_pref": "PG", + "inventory_types": ["DIGITAL", "CTV"], + "preferred_publishers": ["espn.com", "nfl.com"], + "max_cpm": 25.00, + } + ), ) assert "successfully" in result.lower() or "created" in result.lower() assert "Standard Sports Video PG" in result @@ -73,20 +75,22 @@ def test_create_deal_template_with_all_fields(self, deal_template_tool): """Creating a template with all fields stores them correctly.""" result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Full Template", - "deal_type_pref": "PD", - "inventory_types": ["DIGITAL"], - "preferred_publishers": ["nyt.com"], - "excluded_publishers": ["sketchy.com"], - "targeting_defaults": {"geo": ["US"], "audience": ["sports"]}, - "max_cpm": 18.50, - "min_impressions": 100000, - "default_flight_days": 30, - "supply_path_prefs": {"max_hops": 2}, - "advertiser_id": "adv-001", - "agency_id": "agency-001", - }), + params_json=json.dumps( + { + "name": "Full Template", + "deal_type_pref": "PD", + "inventory_types": ["DIGITAL"], + "preferred_publishers": ["nyt.com"], + "excluded_publishers": ["sketchy.com"], + "targeting_defaults": {"geo": ["US"], "audience": ["sports"]}, + "max_cpm": 18.50, + "min_impressions": 100000, + "default_flight_days": 30, + "supply_path_prefs": {"max_hops": 2}, + "advertiser_id": "adv-001", + "agency_id": "agency-001", + } + ), ) assert "created" in result.lower() or "successfully" in result.lower() @@ -94,10 +98,12 @@ def test_create_deal_template_agency_wide(self, deal_template_tool): """Creating a template without advertiser_id makes it agency-wide.""" result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Agency-Wide Template", - "deal_type_pref": "PG", - }), + params_json=json.dumps( + { + "name": "Agency-Wide Template", + "deal_type_pref": "PG", + } + ), ) assert "created" in result.lower() or "successfully" in result.lower() # Verify it's retrievable and has no advertiser_id @@ -111,11 +117,13 @@ def test_create_deal_template_advertiser_scoped(self, deal_template_tool): """Creating a template with advertiser_id scopes it to that advertiser.""" result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Advertiser Template", - "deal_type_pref": "PD", - "advertiser_id": "adv-nike", - }), + params_json=json.dumps( + { + "name": "Advertiser Template", + "deal_type_pref": "PD", + "advertiser_id": "adv-nike", + } + ), ) assert "created" in result.lower() or "successfully" in result.lower() @@ -123,9 +131,11 @@ def test_create_deal_template_requires_name(self, deal_template_tool): """Creating a template without a name returns an error.""" result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "deal_type_pref": "PG", - }), + params_json=json.dumps( + { + "deal_type_pref": "PG", + } + ), ) assert "error" in result.lower() @@ -151,11 +161,13 @@ def test_read_deal_template_by_id(self, deal_template_tool): # Create first create_result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Readable Template", - "deal_type_pref": "PG", - "max_cpm": 20.00, - }), + params_json=json.dumps( + { + "name": "Readable Template", + "deal_type_pref": "PG", + "max_cpm": 20.00, + } + ), ) # Extract template ID from result template_id = _extract_template_id(create_result) @@ -216,26 +228,32 @@ def test_list_deal_templates_filter_by_advertiser(self, deal_template_tool): """Listing templates with advertiser_id filter returns only matching.""" deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Nike Template", - "deal_type_pref": "PG", - "advertiser_id": "adv-nike", - }), + params_json=json.dumps( + { + "name": "Nike Template", + "deal_type_pref": "PG", + "advertiser_id": "adv-nike", + } + ), ) deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Agency Wide", - "deal_type_pref": "PD", - }), + params_json=json.dumps( + { + "name": "Agency Wide", + "deal_type_pref": "PD", + } + ), ) deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Adidas Template", - "deal_type_pref": "PG", - "advertiser_id": "adv-adidas", - }), + params_json=json.dumps( + { + "name": "Adidas Template", + "deal_type_pref": "PG", + "advertiser_id": "adv-adidas", + } + ), ) result = deal_template_tool._run( @@ -276,21 +294,25 @@ def test_update_deal_template(self, deal_template_tool): """Updating a template changes the specified fields.""" create_result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Old Name", - "deal_type_pref": "PG", - "max_cpm": 15.00, - }), + params_json=json.dumps( + { + "name": "Old Name", + "deal_type_pref": "PG", + "max_cpm": 15.00, + } + ), ) template_id = _extract_template_id(create_result) update_result = deal_template_tool._run( action="update", - params_json=json.dumps({ - "template_id": template_id, - "name": "New Name", - "max_cpm": 20.00, - }), + params_json=json.dumps( + { + "template_id": template_id, + "name": "New Name", + "max_cpm": 20.00, + } + ), ) assert "updated" in update_result.lower() @@ -305,11 +327,13 @@ def test_update_deal_template_default_price(self, deal_template_tool): """Updating default_price succeeds and persists the new value.""" create_result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Price Update Test", - "deal_type_pref": "PG", - "default_price": 15.00, - }), + params_json=json.dumps( + { + "name": "Price Update Test", + "deal_type_pref": "PG", + "default_price": 15.00, + } + ), ) template_id = _extract_template_id(create_result) assert template_id is not None @@ -317,10 +341,12 @@ def test_update_deal_template_default_price(self, deal_template_tool): # Update the default_price update_result = deal_template_tool._run( action="update", - params_json=json.dumps({ - "template_id": template_id, - "default_price": 32.00, - }), + params_json=json.dumps( + { + "template_id": template_id, + "default_price": 32.00, + } + ), ) assert "updated" in update_result.lower(), ( f"Expected 'updated' in result but got: {update_result}" @@ -339,10 +365,12 @@ def test_update_deal_template_not_found(self, deal_template_tool): """Updating a nonexistent template returns a not-found message.""" result = deal_template_tool._run( action="update", - params_json=json.dumps({ - "template_id": "nonexistent", - "name": "Won't Work", - }), + params_json=json.dumps( + { + "template_id": "nonexistent", + "name": "Won't Work", + } + ), ) assert "not found" in result.lower() @@ -359,10 +387,12 @@ def test_delete_deal_template(self, deal_template_tool): """Deleting a template removes it.""" create_result = deal_template_tool._run( action="create", - params_json=json.dumps({ - "name": "Doomed Template", - "deal_type_pref": "PG", - }), + params_json=json.dumps( + { + "name": "Doomed Template", + "deal_type_pref": "PG", + } + ), ) template_id = _extract_template_id(create_result) @@ -417,36 +447,38 @@ def test_create_supply_path_template_success(self, supply_path_tool): """Creating a supply path template with valid weights returns success.""" result = supply_path_tool._run( action="create", - params_json=json.dumps({ - "name": "Low Fee Direct Paths", - "scoring_weights": { - "transparency": 0.3, - "fee": 0.4, - "trust": 0.2, - "performance": 0.1, - }, - "max_reseller_hops": 2, - "preferred_ssps": ["index", "pubmatic"], - "blocked_ssps": ["shady-exchange"], - }), + params_json=json.dumps( + { + "name": "Low Fee Direct Paths", + "scoring_weights": { + "transparency": 0.3, + "fee": 0.4, + "trust": 0.2, + "performance": 0.1, + }, + "max_reseller_hops": 2, + "preferred_ssps": ["index", "pubmatic"], + "blocked_ssps": ["shady-exchange"], + } + ), ) assert "created" in result.lower() or "successfully" in result.lower() - def test_create_supply_path_template_weights_must_sum_to_one( - self, supply_path_tool - ): + def test_create_supply_path_template_weights_must_sum_to_one(self, supply_path_tool): """Weights that don't sum to 1.0 are rejected.""" result = supply_path_tool._run( action="create", - params_json=json.dumps({ - "name": "Bad Weights", - "scoring_weights": { - "transparency": 0.3, - "fee": 0.3, - "trust": 0.3, - "performance": 0.3, - }, - }), + params_json=json.dumps( + { + "name": "Bad Weights", + "scoring_weights": { + "transparency": 0.3, + "fee": 0.3, + "trust": 0.3, + "performance": 0.3, + }, + } + ), ) assert "error" in result.lower() assert "sum" in result.lower() or "1.0" in result @@ -455,14 +487,16 @@ def test_create_supply_path_template_requires_name(self, supply_path_tool): """Creating without a name returns an error.""" result = supply_path_tool._run( action="create", - params_json=json.dumps({ - "scoring_weights": { - "transparency": 0.25, - "fee": 0.25, - "trust": 0.25, - "performance": 0.25, - }, - }), + params_json=json.dumps( + { + "scoring_weights": { + "transparency": 0.25, + "fee": 0.25, + "trust": 0.25, + "performance": 0.25, + }, + } + ), ) assert "error" in result.lower() @@ -479,16 +513,18 @@ def test_read_supply_path_template_by_id(self, supply_path_tool): """Reading a template by ID returns its details.""" create_result = supply_path_tool._run( action="create", - params_json=json.dumps({ - "name": "Readable SPO Template", - "scoring_weights": { - "transparency": 0.25, - "fee": 0.25, - "trust": 0.25, - "performance": 0.25, - }, - "max_reseller_hops": 2, - }), + params_json=json.dumps( + { + "name": "Readable SPO Template", + "scoring_weights": { + "transparency": 0.25, + "fee": 0.25, + "trust": 0.25, + "performance": 0.25, + }, + "max_reseller_hops": 2, + } + ), ) template_id = _extract_template_id(create_result) assert template_id is not None @@ -516,29 +552,33 @@ def test_update_supply_path_template_revalidates_weights(self, supply_path_tool) """Updating scoring_weights still validates sum = 1.0.""" create_result = supply_path_tool._run( action="create", - params_json=json.dumps({ - "name": "Weight Test", - "scoring_weights": { - "transparency": 0.25, - "fee": 0.25, - "trust": 0.25, - "performance": 0.25, - }, - }), + params_json=json.dumps( + { + "name": "Weight Test", + "scoring_weights": { + "transparency": 0.25, + "fee": 0.25, + "trust": 0.25, + "performance": 0.25, + }, + } + ), ) template_id = _extract_template_id(create_result) update_result = supply_path_tool._run( action="update", - params_json=json.dumps({ - "template_id": template_id, - "scoring_weights": { - "transparency": 0.5, - "fee": 0.5, - "trust": 0.5, - "performance": 0.5, - }, - }), + params_json=json.dumps( + { + "template_id": template_id, + "scoring_weights": { + "transparency": 0.5, + "fee": 0.5, + "trust": 0.5, + "performance": 0.5, + }, + } + ), ) assert "error" in update_result.lower() @@ -550,15 +590,17 @@ def test_delete_supply_path_template(self, supply_path_tool): """Deleting a template removes it.""" create_result = supply_path_tool._run( action="create", - params_json=json.dumps({ - "name": "Doomed SPO Template", - "scoring_weights": { - "transparency": 0.25, - "fee": 0.25, - "trust": 0.25, - "performance": 0.25, - }, - }), + params_json=json.dumps( + { + "name": "Doomed SPO Template", + "scoring_weights": { + "transparency": 0.25, + "fee": 0.25, + "trust": 0.25, + "performance": 0.25, + }, + } + ), ) template_id = _extract_template_id(create_result) diff --git a/tests/unit/test_deals_client_dual_content_type.py b/tests/unit/test_deals_client_dual_content_type.py index e84a420..cc137b6 100644 --- a/tests/unit/test_deals_client_dual_content_type.py +++ b/tests/unit/test_deals_client_dual_content_type.py @@ -194,9 +194,7 @@ async def test_logs_audience_plan_id_at_info(self, caplog): await c.close() # Exactly one record on the booking logger; carries the canonical id. - records = [ - r for r in caplog.records if r.name == "ad_buyer.audience.booking" - ] + records = [r for r in caplog.records if r.name == "ad_buyer.audience.booking"] assert len(records) == 1 msg = records[0].getMessage() assert plan.audience_plan_id in msg @@ -216,9 +214,7 @@ async def test_no_audience_plan_does_not_log(self, caplog): await c.book_deal(booking) await c.close() - assert [ - r for r in caplog.records if r.name == "ad_buyer.audience.booking" - ] == [] + assert [r for r in caplog.records if r.name == "ad_buyer.audience.booking"] == [] class TestSnapshotResponseParsing: diff --git a/tests/unit/test_iac_templates.py b/tests/unit/test_iac_templates.py index 5d94012..7782a12 100644 --- a/tests/unit/test_iac_templates.py +++ b/tests/unit/test_iac_templates.py @@ -28,19 +28,35 @@ class CfnLoader(yaml.SafeLoader): # Register handlers for all common CloudFormation tags _CFN_TAGS = [ - "!Ref", "!Sub", "!GetAtt", "!Select", "!GetAZs", "!Join", - "!If", "!Not", "!Equals", "!And", "!Or", "!FindInMap", - "!Base64", "!Cidr", "!ImportValue", "!Split", "!Transform", + "!Ref", + "!Sub", + "!GetAtt", + "!Select", + "!GetAZs", + "!Join", + "!If", + "!Not", + "!Equals", + "!And", + "!Or", + "!FindInMap", + "!Base64", + "!Cidr", + "!ImportValue", + "!Split", + "!Transform", ] for _tag in _CFN_TAGS: CfnLoader.add_multi_constructor( _tag, - lambda loader, suffix, node: loader.construct_mapping(node) - if isinstance(node, yaml.MappingNode) - else loader.construct_sequence(node) - if isinstance(node, yaml.SequenceNode) - else loader.construct_scalar(node), + lambda loader, suffix, node: ( + loader.construct_mapping(node) + if isinstance(node, yaml.MappingNode) + else loader.construct_sequence(node) + if isinstance(node, yaml.SequenceNode) + else loader.construct_scalar(node) + ), ) @@ -61,9 +77,7 @@ def cfn_template(self, request): def test_has_format_version(self, cfn_template): name, template = cfn_template - assert "AWSTemplateFormatVersion" in template, ( - f"{name}: missing AWSTemplateFormatVersion" - ) + assert "AWSTemplateFormatVersion" in template, f"{name}: missing AWSTemplateFormatVersion" def test_has_description(self, cfn_template): name, template = cfn_template @@ -83,9 +97,7 @@ def test_storage_has_redis(self): assert "RedisReplicationGroup" in resources, ( "storage.yaml should define a Redis replication group" ) - assert "RedisSubnetGroup" in resources, ( - "storage.yaml should define a Redis subnet group" - ) + assert "RedisSubnetGroup" in resources, "storage.yaml should define a Redis subnet group" def test_network_has_redis_security_group(self): template = load_cfn_yaml(CFN_DIR / "network.yaml") @@ -97,21 +109,15 @@ def test_network_has_redis_security_group(self): def test_main_wires_storage_stack(self): template = load_cfn_yaml(CFN_DIR / "main.yaml") resources = template["Resources"] - assert "StorageStack" in resources, ( - "main.yaml should include a StorageStack" - ) + assert "StorageStack" in resources, "main.yaml should include a StorageStack" assert "NetworkStack" in resources assert "ComputeStack" in resources def test_compute_accepts_redis_params(self): template = load_cfn_yaml(CFN_DIR / "compute.yaml") params = template["Parameters"] - assert "RedisEndpoint" in params, ( - "compute.yaml should accept RedisEndpoint parameter" - ) - assert "RedisPort" in params, ( - "compute.yaml should accept RedisPort parameter" - ) + assert "RedisEndpoint" in params, "compute.yaml should accept RedisEndpoint parameter" + assert "RedisPort" in params, "compute.yaml should accept RedisPort parameter" class TestTerraformModules: @@ -139,9 +145,7 @@ def test_module_has_required_files(self, module_name): def test_root_main_references_storage_module(self): content = (TF_DIR / "main.tf").read_text() - assert 'module "storage"' in content, ( - "Root main.tf should reference the storage module" - ) + assert 'module "storage"' in content, "Root main.tf should reference the storage module" assert "./modules/storage" in content def test_root_outputs_include_redis(self): diff --git a/tests/unit/test_mcp_approval_apikey.py b/tests/unit/test_mcp_approval_apikey.py index b39ba39..62b4b76 100644 --- a/tests/unit/test_mcp_approval_apikey.py +++ b/tests/unit/test_mcp_approval_apikey.py @@ -54,10 +54,12 @@ def _make_api_key_store(): def _reconnecting(store): """Return a lambda that reconnects the store before returning it.""" + def _get(): if hasattr(store, "_conn") and store._conn is None: store.connect() return store + return _get @@ -198,9 +200,7 @@ async def test_filter_by_campaign_id(self, monkeypatch): _reconnecting(store), ) - result = await mcp.call_tool( - "list_pending_approvals", {"campaign_id": "camp-001"} - ) + result = await mcp.call_tool("list_pending_approvals", {"campaign_id": "camp-001"}) data = json.loads(_extract_text(result)) assert data["total"] == 1 assert data["pending"][0]["campaign_id"] == "camp-001" @@ -238,11 +238,14 @@ async def test_approve_pending_request(self, monkeypatch): _reconnecting(store), ) - result = await mcp.call_tool("approve_or_reject", { - "approval_request_id": "req-001", - "decision": "approved", - "reviewer": "test-user", - }) + result = await mcp.call_tool( + "approve_or_reject", + { + "approval_request_id": "req-001", + "decision": "approved", + "reviewer": "test-user", + }, + ) data = json.loads(_extract_text(result)) assert data["approval_request_id"] == "req-001" assert data["new_status"] == "approved" @@ -258,12 +261,15 @@ async def test_reject_pending_request(self, monkeypatch): _reconnecting(store), ) - result = await mcp.call_tool("approve_or_reject", { - "approval_request_id": "req-001", - "decision": "rejected", - "reviewer": "test-user", - "reason": "Budget too high", - }) + result = await mcp.call_tool( + "approve_or_reject", + { + "approval_request_id": "req-001", + "decision": "rejected", + "reviewer": "test-user", + "reason": "Budget too high", + }, + ) data = json.loads(_extract_text(result)) assert data["new_status"] == "rejected" assert data["reason"] == "Budget too high" @@ -278,11 +284,14 @@ async def test_request_not_found(self, monkeypatch): _reconnecting(store), ) - result = await mcp.call_tool("approve_or_reject", { - "approval_request_id": "nonexistent", - "decision": "approved", - "reviewer": "test-user", - }) + result = await mcp.call_tool( + "approve_or_reject", + { + "approval_request_id": "nonexistent", + "decision": "approved", + "reviewer": "test-user", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data @@ -296,11 +305,14 @@ async def test_already_decided_returns_error(self, monkeypatch): _reconnecting(store), ) - result = await mcp.call_tool("approve_or_reject", { - "approval_request_id": "req-001", - "decision": "rejected", - "reviewer": "test-user", - }) + result = await mcp.call_tool( + "approve_or_reject", + { + "approval_request_id": "req-001", + "decision": "rejected", + "reviewer": "test-user", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data @@ -314,11 +326,14 @@ async def test_includes_timestamp(self, monkeypatch): _reconnecting(store), ) - result = await mcp.call_tool("approve_or_reject", { - "approval_request_id": "req-001", - "decision": "approved", - "reviewer": "test-user", - }) + result = await mcp.call_tool( + "approve_or_reject", + { + "approval_request_id": "req-001", + "decision": "approved", + "reviewer": "test-user", + }, + ) data = json.loads(_extract_text(result)) assert "timestamp" in data @@ -408,10 +423,13 @@ async def test_creates_key_for_seller(self, monkeypatch): lambda: key_store, ) - result = await mcp.call_tool("create_api_key", { - "seller_url": "http://seller-a.com", - "api_key": "new-secret-key-123", - }) + result = await mcp.call_tool( + "create_api_key", + { + "seller_url": "http://seller-a.com", + "api_key": "new-secret-key-123", + }, + ) data = json.loads(_extract_text(result)) assert data["seller_url"] == "http://seller-a.com" assert data["created"] is True @@ -429,10 +447,13 @@ async def test_replaces_existing_key(self, monkeypatch): lambda: key_store, ) - result = await mcp.call_tool("create_api_key", { - "seller_url": "http://seller-a.com", - "api_key": "new-key", - }) + result = await mcp.call_tool( + "create_api_key", + { + "seller_url": "http://seller-a.com", + "api_key": "new-key", + }, + ) data = json.loads(_extract_text(result)) assert data["created"] is True assert key_store.get_key("http://seller-a.com") == "new-key" @@ -446,10 +467,13 @@ async def test_does_not_expose_full_key(self, monkeypatch): lambda: key_store, ) - result = await mcp.call_tool("create_api_key", { - "seller_url": "http://seller-a.com", - "api_key": "supersecretvalue12345", - }) + result = await mcp.call_tool( + "create_api_key", + { + "seller_url": "http://seller-a.com", + "api_key": "supersecretvalue12345", + }, + ) text = _extract_text(result) assert "supersecretvalue12345" not in text @@ -462,10 +486,13 @@ async def test_includes_timestamp(self, monkeypatch): lambda: key_store, ) - result = await mcp.call_tool("create_api_key", { - "seller_url": "http://seller-a.com", - "api_key": "test-key", - }) + result = await mcp.call_tool( + "create_api_key", + { + "seller_url": "http://seller-a.com", + "api_key": "test-key", + }, + ) data = json.loads(_extract_text(result)) assert "timestamp" in data @@ -488,9 +515,12 @@ async def test_revokes_existing_key(self, monkeypatch): lambda: key_store, ) - result = await mcp.call_tool("revoke_api_key", { - "seller_url": "http://seller-a.com", - }) + result = await mcp.call_tool( + "revoke_api_key", + { + "seller_url": "http://seller-a.com", + }, + ) data = json.loads(_extract_text(result)) assert data["revoked"] is True assert data["seller_url"] == "http://seller-a.com" @@ -507,9 +537,12 @@ async def test_revoke_nonexistent_returns_false(self, monkeypatch): lambda: key_store, ) - result = await mcp.call_tool("revoke_api_key", { - "seller_url": "http://nonexistent.com", - }) + result = await mcp.call_tool( + "revoke_api_key", + { + "seller_url": "http://nonexistent.com", + }, + ) data = json.loads(_extract_text(result)) assert data["revoked"] is False @@ -522,8 +555,11 @@ async def test_includes_timestamp(self, monkeypatch): lambda: key_store, ) - result = await mcp.call_tool("revoke_api_key", { - "seller_url": "http://seller-a.com", - }) + result = await mcp.call_tool( + "revoke_api_key", + { + "seller_url": "http://seller-a.com", + }, + ) data = json.loads(_extract_text(result)) assert "timestamp" in data diff --git a/tests/unit/test_mcp_campaign_tools.py b/tests/unit/test_mcp_campaign_tools.py index 1fa770f..f445a1c 100644 --- a/tests/unit/test_mcp_campaign_tools.py +++ b/tests/unit/test_mcp_campaign_tools.py @@ -59,10 +59,12 @@ def _seed_campaign(store: CampaignStore, **overrides) -> str: "currency": "USD", "flight_start": "2026-03-01", "flight_end": "2026-03-31", - "channels": json.dumps([ - {"channel": "CTV", "budget_pct": 0.6}, - {"channel": "DISPLAY", "budget_pct": 0.4}, - ]), + "channels": json.dumps( + [ + {"channel": "CTV", "budget_pct": 0.6}, + {"channel": "DISPLAY", "budget_pct": 0.4}, + ] + ), } defaults.update(overrides) return store.save_campaign(**defaults) @@ -152,9 +154,7 @@ class TestListCampaigns: async def test_no_campaigns_returns_empty(self, monkeypatch): """list_campaigns should return empty list when no campaigns exist.""" store = _make_campaign_store() - monkeypatch.setattr( - "ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store - ) + monkeypatch.setattr("ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store) result = await mcp.call_tool("list_campaigns", {}) data = json.loads(_extract_text(result)) @@ -167,9 +167,7 @@ async def test_lists_all_campaigns(self, monkeypatch): store = _make_campaign_store() _seed_campaign(store, campaign_name="Campaign A", status="ACTIVE") _seed_campaign(store, campaign_name="Campaign B", status="DRAFT") - monkeypatch.setattr( - "ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store - ) + monkeypatch.setattr("ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store) result = await mcp.call_tool("list_campaigns", {}) data = json.loads(_extract_text(result)) @@ -184,9 +182,7 @@ async def test_filter_by_status(self, monkeypatch): _seed_campaign(store, campaign_name="Active One", status="ACTIVE") _seed_campaign(store, campaign_name="Draft One", status="DRAFT") _seed_campaign(store, campaign_name="Active Two", status="ACTIVE") - monkeypatch.setattr( - "ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store - ) + monkeypatch.setattr("ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store) result = await mcp.call_tool("list_campaigns", {"status": "ACTIVE"}) data = json.loads(_extract_text(result)) @@ -199,9 +195,7 @@ async def test_filter_returns_empty_for_unmatched(self, monkeypatch): """Filtering by a status with no matches should return empty.""" store = _make_campaign_store() _seed_campaign(store, status="ACTIVE") - monkeypatch.setattr( - "ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store - ) + monkeypatch.setattr("ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store) result = await mcp.call_tool("list_campaigns", {"status": "COMPLETED"}) data = json.loads(_extract_text(result)) @@ -213,17 +207,19 @@ async def test_campaign_fields_included(self, monkeypatch): """Each campaign in the list should include key fields.""" store = _make_campaign_store() _seed_campaign(store) - monkeypatch.setattr( - "ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store - ) + monkeypatch.setattr("ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store) result = await mcp.call_tool("list_campaigns", {}) data = json.loads(_extract_text(result)) campaign = data["campaigns"][0] required_fields = [ - "campaign_id", "campaign_name", "status", - "total_budget", "flight_start", "flight_end", + "campaign_id", + "campaign_name", + "status", + "total_budget", + "flight_start", + "flight_end", ] for field in required_fields: assert field in campaign, f"Missing field: {field}" @@ -232,9 +228,7 @@ async def test_campaign_fields_included(self, monkeypatch): async def test_returns_valid_json(self, monkeypatch): """list_campaigns should return valid JSON.""" store = _make_campaign_store() - monkeypatch.setattr( - "ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store - ) + monkeypatch.setattr("ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store) result = await mcp.call_tool("list_campaigns", {}) data = json.loads(_extract_text(result)) @@ -281,9 +275,12 @@ async def test_includes_pacing_data(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, - total_spend=30000.0, expected_spend=50000.0, - pacing_pct=60.0, deviation_pct=-40.0, + pacing_store, + cid, + total_spend=30000.0, + expected_spend=50000.0, + pacing_pct=60.0, + deviation_pct=-40.0, ) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_campaign_store", @@ -335,9 +332,7 @@ async def test_campaign_not_found(self, monkeypatch): lambda: pacing_store, ) - result = await mcp.call_tool( - "get_campaign_status", {"campaign_id": "nonexistent-id"} - ) + result = await mcp.call_tool("get_campaign_status", {"campaign_id": "nonexistent-id"}) data = json.loads(_extract_text(result)) assert "error" in data @@ -376,9 +371,12 @@ async def test_on_track_pacing(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, - total_spend=50000.0, expected_spend=50000.0, - pacing_pct=100.0, deviation_pct=0.0, + pacing_store, + cid, + total_spend=50000.0, + expected_spend=50000.0, + pacing_pct=100.0, + deviation_pct=0.0, ) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_campaign_store", @@ -402,9 +400,12 @@ async def test_behind_pacing(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, - total_spend=30000.0, expected_spend=50000.0, - pacing_pct=60.0, deviation_pct=-40.0, + pacing_store, + cid, + total_spend=30000.0, + expected_spend=50000.0, + pacing_pct=60.0, + deviation_pct=-40.0, ) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_campaign_store", @@ -428,9 +429,12 @@ async def test_ahead_pacing(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, - total_spend=70000.0, expected_spend=50000.0, - pacing_pct=140.0, deviation_pct=40.0, + pacing_store, + cid, + total_spend=70000.0, + expected_spend=50000.0, + pacing_pct=140.0, + deviation_pct=40.0, ) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_campaign_store", @@ -481,9 +485,7 @@ async def test_campaign_not_found(self, monkeypatch): lambda: pacing_store, ) - result = await mcp.call_tool( - "check_pacing", {"campaign_id": "nonexistent-id"} - ) + result = await mcp.call_tool("check_pacing", {"campaign_id": "nonexistent-id"}) data = json.loads(_extract_text(result)) assert "error" in data @@ -494,9 +496,13 @@ async def test_includes_budget_info(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store, total_budget=200000.0) _seed_pacing_snapshot( - pacing_store, cid, - total_budget=200000.0, total_spend=90000.0, - expected_spend=100000.0, pacing_pct=90.0, deviation_pct=-10.0, + pacing_store, + cid, + total_budget=200000.0, + total_spend=90000.0, + expected_spend=100000.0, + pacing_pct=90.0, + deviation_pct=-10.0, ) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_campaign_store", @@ -521,15 +527,20 @@ async def test_includes_channel_pacing(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, + pacing_store, + cid, channel_snapshots=[ ChannelSnapshot( - channel="CTV", allocated_budget=60000.0, - spend=30000.0, pacing_pct=100.0, + channel="CTV", + allocated_budget=60000.0, + spend=30000.0, + pacing_pct=100.0, ), ChannelSnapshot( - channel="DISPLAY", allocated_budget=40000.0, - spend=15000.0, pacing_pct=75.0, + channel="DISPLAY", + allocated_budget=40000.0, + spend=15000.0, + pacing_pct=75.0, ), ], ) @@ -584,22 +595,34 @@ async def test_aggregates_budgets(self, monkeypatch): campaign_store = _make_campaign_store() pacing_store = _make_pacing_store() cid1 = _seed_campaign( - campaign_store, campaign_name="Campaign A", - total_budget=100000.0, status="ACTIVE", + campaign_store, + campaign_name="Campaign A", + total_budget=100000.0, + status="ACTIVE", ) cid2 = _seed_campaign( - campaign_store, campaign_name="Campaign B", - total_budget=50000.0, status="ACTIVE", + campaign_store, + campaign_name="Campaign B", + total_budget=50000.0, + status="ACTIVE", ) _seed_pacing_snapshot( - pacing_store, cid1, - total_budget=100000.0, total_spend=40000.0, - expected_spend=50000.0, pacing_pct=80.0, deviation_pct=-20.0, + pacing_store, + cid1, + total_budget=100000.0, + total_spend=40000.0, + expected_spend=50000.0, + pacing_pct=80.0, + deviation_pct=-20.0, ) _seed_pacing_snapshot( - pacing_store, cid2, - total_budget=50000.0, total_spend=30000.0, - expected_spend=25000.0, pacing_pct=120.0, deviation_pct=20.0, + pacing_store, + cid2, + total_budget=50000.0, + total_spend=30000.0, + expected_spend=25000.0, + pacing_pct=120.0, + deviation_pct=20.0, ) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_campaign_store", @@ -623,13 +646,18 @@ async def test_per_campaign_budget_info(self, monkeypatch): campaign_store = _make_campaign_store() pacing_store = _make_pacing_store() cid = _seed_campaign( - campaign_store, campaign_name="My Campaign", + campaign_store, + campaign_name="My Campaign", total_budget=100000.0, ) _seed_pacing_snapshot( - pacing_store, cid, - total_budget=100000.0, total_spend=45000.0, - expected_spend=50000.0, pacing_pct=90.0, deviation_pct=-10.0, + pacing_store, + cid, + total_budget=100000.0, + total_spend=45000.0, + expected_spend=50000.0, + pacing_pct=90.0, + deviation_pct=-10.0, ) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_campaign_store", @@ -655,7 +683,8 @@ async def test_campaign_without_pacing_shows_zero_spend(self, monkeypatch): campaign_store = _make_campaign_store() pacing_store = _make_pacing_store() _seed_campaign( - campaign_store, campaign_name="No Pacing", + campaign_store, + campaign_name="No Pacing", total_budget=50000.0, ) monkeypatch.setattr( diff --git a/tests/unit/test_mcp_deal_library.py b/tests/unit/test_mcp_deal_library.py index 5b2108f..3cf20fe 100644 --- a/tests/unit/test_mcp_deal_library.py +++ b/tests/unit/test_mcp_deal_library.py @@ -195,7 +195,9 @@ async def test_filter_by_media_type(self): """list_deals should filter by media_type when provided.""" store = _make_deal_store() _seed_deal(store, display_name="CTV Deal", media_type="CTV", seller_deal_id="CTV-001") - _seed_deal(store, display_name="Digital Deal", media_type="DIGITAL", seller_deal_id="DIG-001") # noqa: E501 + _seed_deal( + store, display_name="Digital Deal", media_type="DIGITAL", seller_deal_id="DIG-001" + ) # noqa: E501 _set_deal_store(store) result = await mcp.call_tool("list_deals", {"media_type": "CTV"}) @@ -215,8 +217,13 @@ async def test_deal_fields_included(self): deal = data["deals"][0] required_fields = [ - "deal_id", "display_name", "status", "deal_type", - "seller_org", "media_type", "price", + "deal_id", + "display_name", + "status", + "deal_type", + "seller_org", + "media_type", + "price", ] for field in required_fields: assert field in deal, f"Missing field: {field}" @@ -272,12 +279,16 @@ async def test_search_by_seller_org(self): """search_deals should find deals by seller_org.""" store = _make_deal_store() _seed_deal( - store, display_name="Premium Display", - seller_org="NBCUniversal", seller_deal_id="NBC-001", + store, + display_name="Premium Display", + seller_org="NBCUniversal", + seller_deal_id="NBC-001", ) _seed_deal( - store, display_name="Sports Package", - seller_org="Disney", seller_deal_id="DIS-001", + store, + display_name="Sports Package", + seller_org="Disney", + seller_deal_id="DIS-001", ) _set_deal_store(store) @@ -400,8 +411,14 @@ async def test_inspect_includes_core_fields(self): data = json.loads(_extract_text(result)) required_fields = [ - "deal_id", "display_name", "status", "deal_type", - "seller_url", "price", "flight_start", "flight_end", + "deal_id", + "display_name", + "status", + "deal_type", + "seller_url", + "price", + "flight_start", + "flight_end", ] for field in required_fields: assert field in data, f"Missing field: {field}" @@ -435,7 +452,7 @@ async def test_import_valid_csv(self): csv_data = ( "deal_name,publisher,seller_domain,deal_type,cpm,impressions\n" - "ESPN Sports PMP,ESPN,espn.com,PG,$15.00,\"1,000,000\"\n" + 'ESPN Sports PMP,ESPN,espn.com,PG,$15.00,"1,000,000"\n' "CNN News PMP,CNN,cnn.com,PD,$10.00,500000\n" ) @@ -453,11 +470,7 @@ async def test_import_csv_with_errors(self): _set_deal_store(store) # Row missing both deal_id and name (required by parser) - csv_data = ( - "deal_name,publisher,seller_domain\n" - "Good Deal,ESPN,espn.com\n" - ",,,\n" - ) + csv_data = "deal_name,publisher,seller_domain\nGood Deal,ESPN,espn.com\n,,,\n" result = await mcp.call_tool("import_deals_csv", {"csv_data": csv_data}) data = json.loads(_extract_text(result)) @@ -506,15 +519,15 @@ async def test_import_csv_with_custom_seller_url(self): store = _make_deal_store() _set_deal_store(store) - csv_data = ( - "deal_name,publisher,seller_domain\n" - "Test Deal,TestPub,testpub.com\n" - ) + csv_data = "deal_name,publisher,seller_domain\nTest Deal,TestPub,testpub.com\n" - result = await mcp.call_tool("import_deals_csv", { - "csv_data": csv_data, - "default_seller_url": "https://custom-seller.example.com", - }) + result = await mcp.call_tool( + "import_deals_csv", + { + "csv_data": csv_data, + "default_seller_url": "https://custom-seller.example.com", + }, + ) data = json.loads(_extract_text(result)) assert data["successful"] == 1 @@ -566,10 +579,13 @@ async def test_create_minimal_deal(self): store = _make_deal_store() _set_deal_store(store) - result = await mcp.call_tool("create_deal_manual", { - "display_name": "ESPN Sports PMP", - "seller_url": "https://espn.seller.example.com", - }) + result = await mcp.call_tool( + "create_deal_manual", + { + "display_name": "ESPN Sports PMP", + "seller_url": "https://espn.seller.example.com", + }, + ) data = json.loads(_extract_text(result)) assert data["success"] is True @@ -581,19 +597,22 @@ async def test_create_deal_with_all_fields(self): store = _make_deal_store() _set_deal_store(store) - result = await mcp.call_tool("create_deal_manual", { - "display_name": "Premium CTV Package", - "seller_url": "https://seller.example.com", - "deal_type": "PG", - "media_type": "CTV", - "price": 25.00, - "impressions": 2000000, - "flight_start": "2026-04-01", - "flight_end": "2026-06-30", - "seller_org": "NBCUniversal", - "description": "Premium CTV inventory package", - "tags": ["premium", "ctv"], - }) + result = await mcp.call_tool( + "create_deal_manual", + { + "display_name": "Premium CTV Package", + "seller_url": "https://seller.example.com", + "deal_type": "PG", + "media_type": "CTV", + "price": 25.00, + "impressions": 2000000, + "flight_start": "2026-04-01", + "flight_end": "2026-06-30", + "seller_org": "NBCUniversal", + "description": "Premium CTV inventory package", + "tags": ["premium", "ctv"], + }, + ) data = json.loads(_extract_text(result)) assert data["success"] is True @@ -605,10 +624,13 @@ async def test_create_deal_persists(self): store = _make_deal_store() _set_deal_store(store) - result = await mcp.call_tool("create_deal_manual", { - "display_name": "Test Persistence", - "seller_url": "https://seller.example.com", - }) + result = await mcp.call_tool( + "create_deal_manual", + { + "display_name": "Test Persistence", + "seller_url": "https://seller.example.com", + }, + ) data = json.loads(_extract_text(result)) deal_id = data["deal_id"] @@ -624,11 +646,14 @@ async def test_create_deal_invalid_deal_type(self): store = _make_deal_store() _set_deal_store(store) - result = await mcp.call_tool("create_deal_manual", { - "display_name": "Bad Deal", - "seller_url": "https://seller.example.com", - "deal_type": "INVALID", - }) + result = await mcp.call_tool( + "create_deal_manual", + { + "display_name": "Bad Deal", + "seller_url": "https://seller.example.com", + "deal_type": "INVALID", + }, + ) data = json.loads(_extract_text(result)) assert data["success"] is False @@ -641,10 +666,13 @@ async def test_create_deal_missing_display_name(self): store = _make_deal_store() _set_deal_store(store) - result = await mcp.call_tool("create_deal_manual", { - "display_name": "", - "seller_url": "https://seller.example.com", - }) + result = await mcp.call_tool( + "create_deal_manual", + { + "display_name": "", + "seller_url": "https://seller.example.com", + }, + ) data = json.loads(_extract_text(result)) assert data["success"] is False @@ -655,10 +683,13 @@ async def test_create_deal_returns_valid_json(self): store = _make_deal_store() _set_deal_store(store) - result = await mcp.call_tool("create_deal_manual", { - "display_name": "Test JSON", - "seller_url": "https://seller.example.com", - }) + result = await mcp.call_tool( + "create_deal_manual", + { + "display_name": "Test JSON", + "seller_url": "https://seller.example.com", + }, + ) data = json.loads(_extract_text(result)) assert isinstance(data, dict) assert "timestamp" in data @@ -669,12 +700,15 @@ async def test_create_deal_saves_portfolio_metadata(self): store = _make_deal_store() _set_deal_store(store) - result = await mcp.call_tool("create_deal_manual", { - "display_name": "Test Metadata", - "seller_url": "https://seller.example.com", - "advertiser_id": "adv-001", - "tags": ["premium"], - }) + result = await mcp.call_tool( + "create_deal_manual", + { + "display_name": "Test Metadata", + "seller_url": "https://seller.example.com", + "advertiser_id": "adv-001", + "tags": ["premium"], + }, + ) data = json.loads(_extract_text(result)) deal_id = data["deal_id"] @@ -753,7 +787,9 @@ async def test_summary_by_media_type(self): """get_portfolio_summary should break down deals by media_type.""" store = _make_deal_store() _seed_deal(store, display_name="CTV Deal", media_type="CTV", seller_deal_id="CTV-001") - _seed_deal(store, display_name="Digital Deal", media_type="DIGITAL", seller_deal_id="DIG-001") # noqa: E501 + _seed_deal( + store, display_name="Digital Deal", media_type="DIGITAL", seller_deal_id="DIG-001" + ) # noqa: E501 _set_deal_store(store) result = await mcp.call_tool("get_portfolio_summary", {}) @@ -769,14 +805,18 @@ async def test_summary_portfolio_value(self): store = _make_deal_store() # price=10 CPM, impressions=1M -> value = 10 * 1M / 1000 = $10,000 _seed_deal( - store, display_name="Deal A", - price=10.0, impressions=1000000, + store, + display_name="Deal A", + price=10.0, + impressions=1000000, seller_deal_id="A-001", ) # price=20 CPM, impressions=500K -> value = 20 * 500K / 1000 = $10,000 _seed_deal( - store, display_name="Deal B", - price=20.0, impressions=500000, + store, + display_name="Deal B", + price=20.0, + impressions=500000, seller_deal_id="B-001", ) _set_deal_store(store) @@ -830,27 +870,30 @@ async def test_v2_fields_survive_roundtrip(self): store = _make_deal_store() _set_deal_store(store) - create_result = await mcp.call_tool("create_deal_manual", { - "display_name": "Premium Video PG", - "seller_url": "https://nbcu.seller.example.com", - "deal_type": "PG", - "status": "active", - "seller_org": "NBCUniversal", - "seller_domain": "nbcuniversal.com", - "seller_type": "PUBLISHER", - "buyer_org": "MediaCo Agency", - "buyer_id": "buyer-mediaco-001", - "price": 15.50, - "fixed_price_cpm": 15.50, - "bid_floor_cpm": 12.00, - "price_model": "CPM", - "currency": "EUR", - "media_type": "CTV", - "impressions": 5000000, - "flight_start": "2026-04-01", - "flight_end": "2026-06-30", - "description": "Premium CTV video inventory for Q2 campaign", - }) + create_result = await mcp.call_tool( + "create_deal_manual", + { + "display_name": "Premium Video PG", + "seller_url": "https://nbcu.seller.example.com", + "deal_type": "PG", + "status": "active", + "seller_org": "NBCUniversal", + "seller_domain": "nbcuniversal.com", + "seller_type": "PUBLISHER", + "buyer_org": "MediaCo Agency", + "buyer_id": "buyer-mediaco-001", + "price": 15.50, + "fixed_price_cpm": 15.50, + "bid_floor_cpm": 12.00, + "price_model": "CPM", + "currency": "EUR", + "media_type": "CTV", + "impressions": 5000000, + "flight_start": "2026-04-01", + "flight_end": "2026-06-30", + "description": "Premium CTV video inventory for Q2 campaign", + }, + ) create_data = json.loads(_extract_text(create_result)) assert create_data["success"] is True deal_id = create_data["deal_id"] @@ -880,11 +923,14 @@ async def test_filter_by_media_type_after_create(self): store = _make_deal_store() _set_deal_store(store) - await mcp.call_tool("create_deal_manual", { - "display_name": "CTV Deal", - "seller_url": "https://seller.example.com", - "media_type": "CTV", - }) + await mcp.call_tool( + "create_deal_manual", + { + "display_name": "CTV Deal", + "seller_url": "https://seller.example.com", + "media_type": "CTV", + }, + ) result = await mcp.call_tool("list_deals", {"media_type": "CTV"}) data = json.loads(_extract_text(result)) @@ -897,11 +943,14 @@ async def test_search_by_seller_org_after_create(self): store = _make_deal_store() _set_deal_store(store) - await mcp.call_tool("create_deal_manual", { - "display_name": "NBC Deal", - "seller_url": "https://nbc.example.com", - "seller_org": "NBCUniversal", - }) + await mcp.call_tool( + "create_deal_manual", + { + "display_name": "NBC Deal", + "seller_url": "https://nbc.example.com", + "seller_org": "NBCUniversal", + }, + ) result = await mcp.call_tool("search_deals", {"query": "NBCUniversal"}) data = json.loads(_extract_text(result)) diff --git a/tests/unit/test_mcp_negotiation_orders.py b/tests/unit/test_mcp_negotiation_orders.py index 7f88022..8e9004b 100644 --- a/tests/unit/test_mcp_negotiation_orders.py +++ b/tests/unit/test_mcp_negotiation_orders.py @@ -38,6 +38,7 @@ def _make_deal_store() -> DealStore: """ import os import tempfile + fd, path = tempfile.mkstemp(suffix=".db") os.close(fd) store = DealStore(f"sqlite:///{path}") @@ -53,6 +54,7 @@ def _make_order_store() -> OrderStore: """ import os import tempfile + fd, path = tempfile.mkstemp(suffix=".db") os.close(fd) store = OrderStore(f"sqlite:///{path}") @@ -66,10 +68,12 @@ def _reconnecting(store): MCP tools call store.disconnect() in their finally blocks. For multi-call tests, the store needs to be reconnected on each access. """ + def _get(): if store._conn is None: store.connect() return store + return _get @@ -186,12 +190,15 @@ async def test_start_negotiation_creates_session(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_deal_store", _reconnecting(deal_store) ) - result = await mcp.call_tool("start_negotiation", { - "seller_url": "http://localhost:8001", - "product_id": "pkg-001", - "product_name": "Premium CTV", - "initial_price": 20.0, - }) + result = await mcp.call_tool( + "start_negotiation", + { + "seller_url": "http://localhost:8001", + "product_id": "pkg-001", + "product_name": "Premium CTV", + "initial_price": 20.0, + }, + ) data = json.loads(_extract_text(result)) assert "deal_id" in data @@ -206,19 +213,20 @@ async def test_start_negotiation_records_first_round(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_deal_store", _reconnecting(deal_store) ) - result = await mcp.call_tool("start_negotiation", { - "seller_url": "http://localhost:8001", - "product_id": "pkg-001", - "product_name": "Premium CTV", - "initial_price": 20.0, - }) + result = await mcp.call_tool( + "start_negotiation", + { + "seller_url": "http://localhost:8001", + "product_id": "pkg-001", + "product_name": "Premium CTV", + "initial_price": 20.0, + }, + ) data = json.loads(_extract_text(result)) deal_id = data["deal_id"] # Verify round was recorded via the get_negotiation_status tool - status_result = await mcp.call_tool( - "get_negotiation_status", {"deal_id": deal_id} - ) + status_result = await mcp.call_tool("get_negotiation_status", {"deal_id": deal_id}) status_data = json.loads(_extract_text(status_result)) assert status_data["rounds_count"] == 1 @@ -233,12 +241,15 @@ async def test_start_negotiation_returns_json(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_deal_store", _reconnecting(deal_store) ) - result = await mcp.call_tool("start_negotiation", { - "seller_url": "http://localhost:8001", - "product_id": "pkg-001", - "product_name": "Premium CTV", - "initial_price": 20.0, - }) + result = await mcp.call_tool( + "start_negotiation", + { + "seller_url": "http://localhost:8001", + "product_id": "pkg-001", + "product_name": "Premium CTV", + "initial_price": 20.0, + }, + ) data = json.loads(_extract_text(result)) assert isinstance(data, dict) assert "timestamp" in data @@ -260,9 +271,7 @@ async def test_returns_deal_not_found(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_deal_store", _reconnecting(deal_store) ) - result = await mcp.call_tool("get_negotiation_status", { - "deal_id": "nonexistent" - }) + result = await mcp.call_tool("get_negotiation_status", {"deal_id": "nonexistent"}) data = json.loads(_extract_text(result)) assert "error" in data @@ -276,9 +285,7 @@ async def test_returns_deal_with_rounds(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_deal_store", _reconnecting(deal_store) ) - result = await mcp.call_tool("get_negotiation_status", { - "deal_id": deal_id - }) + result = await mcp.call_tool("get_negotiation_status", {"deal_id": deal_id}) data = json.loads(_extract_text(result)) assert data["deal_id"] == deal_id @@ -295,9 +302,7 @@ async def test_returns_deal_with_no_rounds(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_deal_store", _reconnecting(deal_store) ) - result = await mcp.call_tool("get_negotiation_status", { - "deal_id": deal_id - }) + result = await mcp.call_tool("get_negotiation_status", {"deal_id": deal_id}) data = json.loads(_extract_text(result)) assert data["deal_id"] == deal_id @@ -441,9 +446,7 @@ async def test_returns_order_not_found(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_order_store", _reconnecting(order_store) ) - result = await mcp.call_tool("get_order_status", { - "order_id": "nonexistent" - }) + result = await mcp.call_tool("get_order_status", {"order_id": "nonexistent"}) data = json.loads(_extract_text(result)) assert "error" in data @@ -461,9 +464,7 @@ async def test_returns_order_details(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_order_store", _reconnecting(order_store) ) - result = await mcp.call_tool("get_order_status", { - "order_id": "order-001" - }) + result = await mcp.call_tool("get_order_status", {"order_id": "order-001"}) data = json.loads(_extract_text(result)) assert data["order_id"] == "order-001" @@ -479,9 +480,7 @@ async def test_includes_timestamp(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_order_store", _reconnecting(order_store) ) - result = await mcp.call_tool("get_order_status", { - "order_id": "order-001" - }) + result = await mcp.call_tool("get_order_status", {"order_id": "order-001"}) data = json.loads(_extract_text(result)) assert "timestamp" in data @@ -502,10 +501,13 @@ async def test_returns_order_not_found(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_order_store", _reconnecting(order_store) ) - result = await mcp.call_tool("transition_order", { - "order_id": "nonexistent", - "to_status": "booked", - }) + result = await mcp.call_tool( + "transition_order", + { + "order_id": "nonexistent", + "to_status": "booked", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data @@ -518,10 +520,13 @@ async def test_successful_transition(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_order_store", _reconnecting(order_store) ) - result = await mcp.call_tool("transition_order", { - "order_id": "order-001", - "to_status": "booked", - }) + result = await mcp.call_tool( + "transition_order", + { + "order_id": "order-001", + "to_status": "booked", + }, + ) data = json.loads(_extract_text(result)) assert data["order_id"] == "order-001" @@ -529,9 +534,7 @@ async def test_successful_transition(self, monkeypatch): assert data["previous_status"] == "pending" # Verify the order was actually updated via get_order_status - status_result = await mcp.call_tool( - "get_order_status", {"order_id": "order-001"} - ) + status_result = await mcp.call_tool("get_order_status", {"order_id": "order-001"}) status_data = json.loads(_extract_text(status_result)) assert status_data["status"] == "booked" @@ -544,11 +547,14 @@ async def test_transition_with_reason(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_order_store", _reconnecting(order_store) ) - result = await mcp.call_tool("transition_order", { - "order_id": "order-001", - "to_status": "booked", - "reason": "Seller confirmed booking", - }) + result = await mcp.call_tool( + "transition_order", + { + "order_id": "order-001", + "to_status": "booked", + "reason": "Seller confirmed booking", + }, + ) data = json.loads(_extract_text(result)) assert data["reason"] == "Seller confirmed booking" @@ -562,9 +568,12 @@ async def test_transition_includes_timestamp(self, monkeypatch): "ad_buyer.interfaces.mcp_server._get_order_store", _reconnecting(order_store) ) - result = await mcp.call_tool("transition_order", { - "order_id": "order-001", - "to_status": "booked", - }) + result = await mcp.call_tool( + "transition_order", + { + "order_id": "order-001", + "to_status": "booked", + }, + ) data = json.loads(_extract_text(result)) assert "timestamp" in data diff --git a/tests/unit/test_mcp_seller_discovery.py b/tests/unit/test_mcp_seller_discovery.py index 3aaf381..bb2dddf 100644 --- a/tests/unit/test_mcp_seller_discovery.py +++ b/tests/unit/test_mcp_seller_discovery.py @@ -152,8 +152,10 @@ async def test_filter_by_capability(self, monkeypatch): ctv_cap = AgentCapability(name="ctv", description="CTV inventory") sellers = [ _make_agent_card( - agent_id="s1", name="CTV Publisher", - url="http://ctv.com", capabilities=[ctv_cap], + agent_id="s1", + name="CTV Publisher", + url="http://ctv.com", + capabilities=[ctv_cap], ), ] mock_client = AsyncMock() @@ -163,9 +165,7 @@ async def test_filter_by_capability(self, monkeypatch): lambda: mock_client, ) - result = await mcp.call_tool( - "discover_sellers", {"capability": "ctv"} - ) + result = await mcp.call_tool("discover_sellers", {"capability": "ctv"}) data = json.loads(_extract_text(result)) assert data["total"] == 1 @@ -319,8 +319,11 @@ async def test_package_fields_included(self, monkeypatch): pkg = data["packages"][0] required_fields = [ - "package_id", "name", "ad_formats", - "price_range", "rate_type", + "package_id", + "name", + "ad_formats", + "price_range", + "rate_type", ] for field in required_fields: assert field in pkg, f"Missing field: {field}" @@ -502,9 +505,7 @@ async def test_compare_empty_list(self, monkeypatch): lambda: mock_client, ) - result = await mcp.call_tool( - "compare_sellers", {"seller_urls": []} - ) + result = await mcp.call_tool("compare_sellers", {"seller_urls": []}) data = json.loads(_extract_text(result)) assert data["sellers_compared"] == 0 @@ -549,9 +550,7 @@ async def test_returns_valid_json(self, monkeypatch): lambda: mock_client, ) - result = await mcp.call_tool( - "compare_sellers", {"seller_urls": []} - ) + result = await mcp.call_tool("compare_sellers", {"seller_urls": []}) data = json.loads(_extract_text(result)) assert isinstance(data, dict) assert "timestamp" in data @@ -564,7 +563,8 @@ async def test_compare_sellers_ad_format_breakdown(self, monkeypatch): seller_name="Publisher A", packages=[ _make_package_summary( - name="Display", ad_formats=["display", "native"], + name="Display", + ad_formats=["display", "native"], seller_url="http://a.example.com", ), ], @@ -574,7 +574,8 @@ async def test_compare_sellers_ad_format_breakdown(self, monkeypatch): seller_name="Publisher B", packages=[ _make_package_summary( - name="Video", ad_formats=["video"], + name="Video", + ad_formats=["video"], seller_url="http://b.example.com", ), ], diff --git a/tests/unit/test_mcp_server.py b/tests/unit/test_mcp_server.py index 122386c..364ff84 100644 --- a/tests/unit/test_mcp_server.py +++ b/tests/unit/test_mcp_server.py @@ -29,21 +29,25 @@ class TestMCPServerInitialization: def test_mcp_server_exists(self): """The mcp_server module should be importable.""" from ad_buyer.interfaces.mcp_server import mcp + assert mcp is not None def test_mcp_server_is_fastmcp_instance(self): """The mcp object should be a FastMCP instance.""" from ad_buyer.interfaces.mcp_server import mcp + assert isinstance(mcp, FastMCP) def test_mcp_server_name(self): """The MCP server should identify as the buyer agent.""" from ad_buyer.interfaces.mcp_server import mcp + assert mcp.name == "ad-buyer-agent" def test_mcp_server_has_instructions(self): """The MCP server should have instructions describing the buyer agent.""" from ad_buyer.interfaces.mcp_server import mcp + assert mcp.instructions is not None assert len(mcp.instructions) > 0 @@ -54,6 +58,7 @@ class TestMCPMounting: def test_mount_mcp_function_exists(self): """A mount_mcp function should exist for integrating with FastAPI.""" from ad_buyer.interfaces.mcp_server import mount_mcp + assert callable(mount_mcp) def test_mount_mcp_adds_route(self): @@ -90,7 +95,10 @@ def test_buyer_api_app_has_mcp_mounted(self): route_paths.append(route.path) # Streamable HTTP transport (canonical) - assert any("/mcp" == str(p) or (str(p).startswith("/mcp") and not str(p).startswith("/mcp-sse")) for p in route_paths), ( # noqa: E501 + assert any( + "/mcp" == str(p) or (str(p).startswith("/mcp") and not str(p).startswith("/mcp-sse")) + for p in route_paths + ), ( # noqa: E501 f"Expected /mcp (Streamable HTTP) in buyer API app routes, got: {route_paths}" ) # Legacy SSE transport @@ -110,15 +118,9 @@ async def test_list_tools_includes_foundation_tools(self): tools_result = await mcp.list_tools() tool_names = [t.name for t in tools_result] - assert "get_setup_status" in tool_names, ( - f"get_setup_status not in tools: {tool_names}" - ) - assert "health_check" in tool_names, ( - f"health_check not in tools: {tool_names}" - ) - assert "get_config" in tool_names, ( - f"get_config not in tools: {tool_names}" - ) + assert "get_setup_status" in tool_names, f"get_setup_status not in tools: {tool_names}" + assert "health_check" in tool_names, f"health_check not in tools: {tool_names}" + assert "get_config" in tool_names, f"get_config not in tools: {tool_names}" @pytest.mark.asyncio async def test_foundation_tools_are_present(self): @@ -429,9 +431,7 @@ async def test_all_prompts_registered(self): prompt_names = [p.name for p in prompts_result] for name in self.EXPECTED_PROMPTS: - assert name in prompt_names, ( - f"Prompt '{name}' not registered. Found: {prompt_names}" - ) + assert name in prompt_names, f"Prompt '{name}' not registered. Found: {prompt_names}" @pytest.mark.asyncio async def test_prompt_count(self): @@ -440,8 +440,7 @@ async def test_prompt_count(self): prompts_result = await mcp.list_prompts() assert len(prompts_result) == 10, ( - f"Expected 10 prompts, got {len(prompts_result)}: " - f"{[p.name for p in prompts_result]}" + f"Expected 10 prompts, got {len(prompts_result)}: {[p.name for p in prompts_result]}" ) @pytest.mark.asyncio @@ -451,9 +450,7 @@ async def test_each_prompt_has_description(self): prompts_result = await mcp.list_prompts() for prompt in prompts_result: - assert prompt.description, ( - f"Prompt '{prompt.name}' has no description" - ) + assert prompt.description, f"Prompt '{prompt.name}' has no description" @pytest.mark.asyncio async def test_each_prompt_returns_messages(self): @@ -463,12 +460,8 @@ async def test_each_prompt_returns_messages(self): prompts_result = await mcp.list_prompts() for prompt in prompts_result: result = await mcp.get_prompt(prompt.name) - assert result is not None, ( - f"Prompt '{prompt.name}' returned None" - ) - assert len(result.messages) > 0, ( - f"Prompt '{prompt.name}' returned no messages" - ) + assert result is not None, f"Prompt '{prompt.name}' returned None" + assert len(result.messages) > 0, f"Prompt '{prompt.name}' returned no messages" @pytest.mark.asyncio async def test_each_prompt_has_user_role(self): @@ -480,6 +473,5 @@ async def test_each_prompt_has_user_role(self): result = await mcp.get_prompt(prompt.name) for msg in result.messages: assert msg.role == "user", ( - f"Prompt '{prompt.name}' has role '{msg.role}', " - f"expected 'user'" + f"Prompt '{prompt.name}' has role '{msg.role}', expected 'user'" ) diff --git a/tests/unit/test_mcp_ssp_tools.py b/tests/unit/test_mcp_ssp_tools.py index 2819395..162eb5c 100644 --- a/tests/unit/test_mcp_ssp_tools.py +++ b/tests/unit/test_mcp_ssp_tools.py @@ -246,9 +246,7 @@ async def test_successful_pubmatic_import(self, monkeypatch): ], ) - with patch( - "ad_buyer.interfaces.mcp_server.PubMaticConnector" - ) as MockConnector: + with patch("ad_buyer.interfaces.mcp_server.PubMaticConnector") as MockConnector: instance = MagicMock() instance.is_configured.return_value = True instance.fetch_deals.return_value = fake_result @@ -294,9 +292,7 @@ async def test_successful_magnite_import(self, monkeypatch): ], ) - with patch( - "ad_buyer.interfaces.mcp_server.MagniteConnector" - ) as MockConnector: + with patch("ad_buyer.interfaces.mcp_server.MagniteConnector") as MockConnector: instance = MagicMock() instance.is_configured.return_value = True instance.fetch_deals.return_value = fake_result @@ -328,9 +324,7 @@ async def test_import_result_has_standard_fields(self, monkeypatch): deals=[], ) - with patch( - "ad_buyer.interfaces.mcp_server.PubMaticConnector" - ) as MockConnector: + with patch("ad_buyer.interfaces.mcp_server.PubMaticConnector") as MockConnector: instance = MagicMock() instance.is_configured.return_value = True instance.fetch_deals.return_value = fake_result @@ -341,7 +335,15 @@ async def test_import_result_has_standard_fields(self, monkeypatch): data = json.loads(_extract_text(result)) # Must have same structure as import_deals_csv - for field in ("total_rows", "successful", "failed", "skipped", "errors", "deal_ids", "timestamp"): # noqa: E501 + for field in ( + "total_rows", + "successful", + "failed", + "skipped", + "errors", + "deal_ids", + "timestamp", + ): # noqa: E501 assert field in data, f"Missing field: {field}" @pytest.mark.asyncio @@ -355,9 +357,7 @@ async def test_import_ssp_name_case_insensitive(self, monkeypatch): fake_result = SSPFetchResult(ssp_name="PubMatic", total_fetched=0, successful=0, deals=[]) - with patch( - "ad_buyer.interfaces.mcp_server.PubMaticConnector" - ) as MockConnector: + with patch("ad_buyer.interfaces.mcp_server.PubMaticConnector") as MockConnector: instance = MagicMock() instance.is_configured.return_value = True instance.fetch_deals.return_value = fake_result diff --git a/tests/unit/test_mcp_streamable_http.py b/tests/unit/test_mcp_streamable_http.py index 5790f43..cfe09c9 100644 --- a/tests/unit/test_mcp_streamable_http.py +++ b/tests/unit/test_mcp_streamable_http.py @@ -36,9 +36,7 @@ def test_streamable_http_route_present(self): """POST /mcp (Streamable HTTP) should be mounted in the buyer app.""" from ad_buyer.interfaces.api.main import app - route_paths = [ - getattr(route, "path", "") for route in app.routes - ] + route_paths = [getattr(route, "path", "") for route in app.routes] assert any( p == "/mcp" or (p.startswith("/mcp") and not p.startswith("/mcp-sse")) for p in route_paths @@ -48,9 +46,7 @@ def test_legacy_sse_route_present(self): """GET /mcp-sse/sse (legacy SSE fallback) should be mounted in the buyer app.""" from ad_buyer.interfaces.api.main import app - route_paths = [ - getattr(route, "path", "") for route in app.routes - ] + route_paths = [getattr(route, "path", "") for route in app.routes] assert any("/mcp-sse" in p for p in route_paths), ( f"Expected /mcp-sse (legacy SSE) mount, got: {route_paths}" ) @@ -104,9 +100,7 @@ async def test_streamable_http_initialize_handshake(): } transport = ASGITransport(app=app) # lifespan handled by context manager - async with httpx.AsyncClient( - transport=transport, base_url="http://testserver" - ) as client: + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: response = await client.post( "/mcp", content=json.dumps(initialize_payload), @@ -120,16 +114,12 @@ async def test_streamable_http_initialize_handshake(): assert response.status_code != 404, ( "POST /mcp returned 404 — Streamable HTTP transport not mounted" ) - assert response.status_code != 405, ( - "POST /mcp returned 405 — wrong method for Streamable HTTP" - ) + assert response.status_code != 405, "POST /mcp returned 405 — wrong method for Streamable HTTP" # Happy path: 200 with MCP initialize response if response.status_code == 200: body = response.text # Response may be JSON or SSE-wrapped JSON; either way, check for MCP fields - assert any( - field in body for field in ("protocolVersion", "serverInfo", "capabilities") - ), ( + assert any(field in body for field in ("protocolVersion", "serverInfo", "capabilities")), ( f"POST /mcp returned 200 but body missing MCP negotiation fields: {body[:500]}" ) diff --git a/tests/unit/test_mcp_template_reporting.py b/tests/unit/test_mcp_template_reporting.py index 8482bd0..23b0a3b 100644 --- a/tests/unit/test_mcp_template_reporting.py +++ b/tests/unit/test_mcp_template_reporting.py @@ -69,10 +69,12 @@ def _seed_campaign(store: CampaignStore, **overrides) -> str: "currency": "USD", "flight_start": "2026-03-01", "flight_end": "2026-03-31", - "channels": json.dumps([ - {"channel": "CTV", "budget_pct": 0.6}, - {"channel": "DISPLAY", "budget_pct": 0.4}, - ]), + "channels": json.dumps( + [ + {"channel": "CTV", "budget_pct": 0.6}, + {"channel": "DISPLAY", "budget_pct": 0.4}, + ] + ), } defaults.update(overrides) return store.save_campaign(**defaults) @@ -281,12 +283,15 @@ async def test_create_deal_template(self): store = _make_deal_store() _set_deal_store(store) try: - result = await mcp.call_tool("create_template", { - "template_type": "deal", - "name": "Sports PG", - "deal_type_pref": "PG", - "max_cpm": 25.0, - }) + result = await mcp.call_tool( + "create_template", + { + "template_type": "deal", + "name": "Sports PG", + "deal_type_pref": "PG", + "max_cpm": 25.0, + }, + ) data = json.loads(_extract_text(result)) assert "template_id" in data assert data["name"] == "Sports PG" @@ -306,11 +311,14 @@ async def test_create_supply_path_template(self): store = _make_deal_store() _set_deal_store(store) try: - result = await mcp.call_tool("create_template", { - "template_type": "supply_path", - "name": "Direct Paths", - "max_reseller_hops": 2, - }) + result = await mcp.call_tool( + "create_template", + { + "template_type": "supply_path", + "name": "Direct Paths", + "max_reseller_hops": 2, + }, + ) data = json.loads(_extract_text(result)) assert "template_id" in data assert data["template_type"] == "supply_path" @@ -329,9 +337,12 @@ async def test_create_requires_name(self): store = _make_deal_store() _set_deal_store(store) try: - result = await mcp.call_tool("create_template", { - "template_type": "deal", - }) + result = await mcp.call_tool( + "create_template", + { + "template_type": "deal", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data finally: @@ -343,9 +354,12 @@ async def test_create_requires_template_type(self): store = _make_deal_store() _set_deal_store(store) try: - result = await mcp.call_tool("create_template", { - "name": "Test", - }) + result = await mcp.call_tool( + "create_template", + { + "name": "Test", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data finally: @@ -357,10 +371,13 @@ async def test_create_invalid_template_type(self): store = _make_deal_store() _set_deal_store(store) try: - result = await mcp.call_tool("create_template", { - "template_type": "unknown", - "name": "Test", - }) + result = await mcp.call_tool( + "create_template", + { + "template_type": "unknown", + "name": "Test", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data finally: @@ -388,9 +405,12 @@ async def test_instantiate_creates_deal_from_template(self): ) _set_deal_store(store) try: - result = await mcp.call_tool("instantiate_from_template", { - "template_id": tmpl_id, - }) + result = await mcp.call_tool( + "instantiate_from_template", + { + "template_id": tmpl_id, + }, + ) data = json.loads(_extract_text(result)) assert "deal_id" in data assert data["template_id"] == tmpl_id @@ -423,6 +443,7 @@ async def test_instantiate_with_overrides(self): try: # Call the function directly to bypass MCP's pre-parse from ad_buyer.interfaces.mcp_server import instantiate_from_template + result_str = instantiate_from_template( template_id=tmpl_id, overrides='{"price": 25.0}', @@ -443,9 +464,12 @@ async def test_instantiate_template_not_found(self): store = _make_deal_store() _set_deal_store(store) try: - result = await mcp.call_tool("instantiate_from_template", { - "template_id": "nonexistent-id", - }) + result = await mcp.call_tool( + "instantiate_from_template", + { + "template_id": "nonexistent-id", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data finally: @@ -478,9 +502,12 @@ async def test_deal_not_found(self): store = _make_deal_store() _set_deal_store(store) try: - result = await mcp.call_tool("get_deal_performance", { - "deal_id": "nonexistent", - }) + result = await mcp.call_tool( + "get_deal_performance", + { + "deal_id": "nonexistent", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data finally: @@ -499,9 +526,12 @@ async def test_deal_performance_basic(self): ) _set_deal_store(store) try: - result = await mcp.call_tool("get_deal_performance", { - "deal_id": deal_id, - }) + result = await mcp.call_tool( + "get_deal_performance", + { + "deal_id": deal_id, + }, + ) data = json.loads(_extract_text(result)) assert data["deal_id"] == deal_id assert data["product_name"] == "Premium CTV" @@ -523,9 +553,12 @@ async def test_returns_valid_json(self): ) _set_deal_store(store) try: - result = await mcp.call_tool("get_deal_performance", { - "deal_id": deal_id, - }) + result = await mcp.call_tool( + "get_deal_performance", + { + "deal_id": deal_id, + }, + ) data = json.loads(_extract_text(result)) assert isinstance(data, dict) finally: @@ -545,15 +578,16 @@ async def test_campaign_not_found(self, monkeypatch): """get_campaign_report should error when campaign not found.""" store = _make_campaign_store() pacing_store = _make_pacing_store() - monkeypatch.setattr( - "ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store - ) + monkeypatch.setattr("ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_pacing_store", lambda: pacing_store ) - result = await mcp.call_tool("get_campaign_report", { - "campaign_id": "nonexistent", - }) + result = await mcp.call_tool( + "get_campaign_report", + { + "campaign_id": "nonexistent", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data @@ -564,19 +598,28 @@ async def test_campaign_report_basic(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, + pacing_store, + cid, channel_snapshots=[ ChannelSnapshot( - channel="CTV", allocated_budget=60000.0, - spend=30000.0, pacing_pct=100.0, - impressions=2000000, effective_cpm=15.0, fill_rate=0.85, + channel="CTV", + allocated_budget=60000.0, + spend=30000.0, + pacing_pct=100.0, + impressions=2000000, + effective_cpm=15.0, + fill_rate=0.85, ), ], deal_snapshots=[ DealSnapshot( - deal_id="deal-001", allocated_budget=40000.0, - spend=20000.0, impressions=1000000, - effective_cpm=20.0, fill_rate=0.9, win_rate=0.3, + deal_id="deal-001", + allocated_budget=40000.0, + spend=20000.0, + impressions=1000000, + effective_cpm=20.0, + fill_rate=0.9, + win_rate=0.3, ), ], ) @@ -586,9 +629,12 @@ async def test_campaign_report_basic(self, monkeypatch): monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_pacing_store", lambda: pacing_store ) - result = await mcp.call_tool("get_campaign_report", { - "campaign_id": cid, - }) + result = await mcp.call_tool( + "get_campaign_report", + { + "campaign_id": cid, + }, + ) data = json.loads(_extract_text(result)) assert data["campaign_id"] == cid assert "status_summary" in data @@ -609,9 +655,12 @@ async def test_campaign_report_no_pacing_data(self, monkeypatch): monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_pacing_store", lambda: pacing_store ) - result = await mcp.call_tool("get_campaign_report", { - "campaign_id": cid, - }) + result = await mcp.call_tool( + "get_campaign_report", + { + "campaign_id": cid, + }, + ) data = json.loads(_extract_text(result)) assert data["campaign_id"] == cid assert "status_summary" in data @@ -630,15 +679,16 @@ async def test_campaign_not_found(self, monkeypatch): """get_pacing_report should error when campaign not found.""" store = _make_campaign_store() pacing_store = _make_pacing_store() - monkeypatch.setattr( - "ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store - ) + monkeypatch.setattr("ad_buyer.interfaces.mcp_server._get_campaign_store", lambda: store) monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_pacing_store", lambda: pacing_store ) - result = await mcp.call_tool("get_pacing_report", { - "campaign_id": "nonexistent", - }) + result = await mcp.call_tool( + "get_pacing_report", + { + "campaign_id": "nonexistent", + }, + ) data = json.loads(_extract_text(result)) assert "error" in data @@ -649,21 +699,30 @@ async def test_pacing_report_with_data(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, + pacing_store, + cid, total_spend=35000.0, expected_spend=50000.0, pacing_pct=70.0, deviation_pct=-30.0, channel_snapshots=[ ChannelSnapshot( - channel="CTV", allocated_budget=60000.0, - spend=18000.0, pacing_pct=60.0, - impressions=1200000, effective_cpm=15.0, fill_rate=0.85, + channel="CTV", + allocated_budget=60000.0, + spend=18000.0, + pacing_pct=60.0, + impressions=1200000, + effective_cpm=15.0, + fill_rate=0.85, ), ChannelSnapshot( - channel="DISPLAY", allocated_budget=40000.0, - spend=17000.0, pacing_pct=85.0, - impressions=850000, effective_cpm=20.0, fill_rate=0.7, + channel="DISPLAY", + allocated_budget=40000.0, + spend=17000.0, + pacing_pct=85.0, + impressions=850000, + effective_cpm=20.0, + fill_rate=0.7, ), ], ) @@ -673,9 +732,12 @@ async def test_pacing_report_with_data(self, monkeypatch): monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_pacing_store", lambda: pacing_store ) - result = await mcp.call_tool("get_pacing_report", { - "campaign_id": cid, - }) + result = await mcp.call_tool( + "get_pacing_report", + { + "campaign_id": cid, + }, + ) data = json.loads(_extract_text(result)) assert data["campaign_id"] == cid assert data["pacing_status"] in ("behind", "on_track", "ahead", "no_data") @@ -700,9 +762,12 @@ async def test_pacing_report_no_data(self, monkeypatch): monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_pacing_store", lambda: pacing_store ) - result = await mcp.call_tool("get_pacing_report", { - "campaign_id": cid, - }) + result = await mcp.call_tool( + "get_pacing_report", + { + "campaign_id": cid, + }, + ) data = json.loads(_extract_text(result)) assert data["pacing_status"] == "no_data" @@ -713,7 +778,8 @@ async def test_pacing_report_behind(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, + pacing_store, + cid, total_spend=30000.0, expected_spend=50000.0, pacing_pct=60.0, @@ -725,9 +791,12 @@ async def test_pacing_report_behind(self, monkeypatch): monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_pacing_store", lambda: pacing_store ) - result = await mcp.call_tool("get_pacing_report", { - "campaign_id": cid, - }) + result = await mcp.call_tool( + "get_pacing_report", + { + "campaign_id": cid, + }, + ) data = json.loads(_extract_text(result)) assert data["pacing_status"] == "behind" @@ -738,7 +807,8 @@ async def test_pacing_report_ahead(self, monkeypatch): pacing_store = _make_pacing_store() cid = _seed_campaign(campaign_store) _seed_pacing_snapshot( - pacing_store, cid, + pacing_store, + cid, total_spend=65000.0, expected_spend=50000.0, pacing_pct=130.0, @@ -750,8 +820,11 @@ async def test_pacing_report_ahead(self, monkeypatch): monkeypatch.setattr( "ad_buyer.interfaces.mcp_server._get_pacing_store", lambda: pacing_store ) - result = await mcp.call_tool("get_pacing_report", { - "campaign_id": cid, - }) + result = await mcp.call_tool( + "get_pacing_report", + { + "campaign_id": cid, + }, + ) data = json.loads(_extract_text(result)) assert data["pacing_status"] == "ahead" diff --git a/tests/unit/test_openrtb_builder.py b/tests/unit/test_openrtb_builder.py index 0b8cd2c..7a3a515 100644 --- a/tests/unit/test_openrtb_builder.py +++ b/tests/unit/test_openrtb_builder.py @@ -174,11 +174,14 @@ def test_agentic_extension_dropped_when_flag_disabled( # Agentic extension NOT emitted. assert "ext" not in fragment["user"] # Warning logged citing the flag. - warning_messages = [r.message for r in caplog.records # type: ignore[attr-defined] - if r.levelno >= logging.WARNING] - assert any( - "enable_agentic_openrtb_ext" in m for m in warning_messages - ), f"expected flag-disabled warning, got: {warning_messages}" + warning_messages = [ + r.message + for r in caplog.records # type: ignore[attr-defined] + if r.levelno >= logging.WARNING + ] + assert any("enable_agentic_openrtb_ext" in m for m in warning_messages), ( + f"expected flag-disabled warning, got: {warning_messages}" + ) def test_agentic_only_plan_with_flag_off_returns_empty_user_block() -> None: @@ -255,11 +258,14 @@ def test_contextual_exclusions_dropped_with_warning( # Only the positive contextual ref appears. assert fragment["site"]["cat"] == ["IAB1-2"] - warning_messages = [r.message for r in caplog.records # type: ignore[attr-defined] - if r.levelno >= logging.WARNING] - assert any( - "site.cat" in m or "exclusion" in m.lower() for m in warning_messages - ), f"expected dropped-contextual-exclusion warning, got: {warning_messages}" + warning_messages = [ + r.message + for r in caplog.records # type: ignore[attr-defined] + if r.levelno >= logging.WARNING + ] + assert any("site.cat" in m or "exclusion" in m.lower() for m in warning_messages), ( + f"expected dropped-contextual-exclusion warning, got: {warning_messages}" + ) # --------------------------------------------------------------------------- diff --git a/tests/unit/test_orchestrator_audience_plan.py b/tests/unit/test_orchestrator_audience_plan.py index 400f3f9..7c7a802 100644 --- a/tests/unit/test_orchestrator_audience_plan.py +++ b/tests/unit/test_orchestrator_audience_plan.py @@ -93,9 +93,7 @@ class TestAcceptsAudiencePlan: def test_inventory_requirements_accepts_plan(self) -> None: plan = _build_minimal_plan() - ir = InventoryRequirements( - media_type="ctv", deal_types=["PD"], audience_plan=plan - ) + ir = InventoryRequirements(media_type="ctv", deal_types=["PD"], audience_plan=plan) assert ir.audience_plan is plan assert ir.audience_plan.primary.identifier == "3-7" @@ -140,9 +138,7 @@ def test_quote_request_round_trips(self) -> None: rebuilt = QuoteRequest(**data) assert rebuilt.audience_plan is not None - assert ( - rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id - ) + assert rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id assert rebuilt.audience_plan.primary.identifier == "3-7" def test_deal_booking_request_round_trips(self) -> None: @@ -153,16 +149,12 @@ def test_deal_booking_request_round_trips(self) -> None: rebuilt = DealBookingRequest(**data) assert rebuilt.audience_plan is not None - assert ( - rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id - ) + assert rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id def test_inventory_requirements_round_trips_via_asdict(self) -> None: # InventoryRequirements is a dataclass; round-trip via asdict + ctor. plan = _build_minimal_plan() - ir = InventoryRequirements( - media_type="ctv", deal_types=["PD"], audience_plan=plan - ) + ir = InventoryRequirements(media_type="ctv", deal_types=["PD"], audience_plan=plan) # asdict recursively converts the AudiencePlan to a dict; rebuilding # requires re-validating the plan dict back into an AudiencePlan. @@ -179,9 +171,7 @@ def test_inventory_requirements_round_trips_via_asdict(self) -> None: ) assert rebuilt.audience_plan is not None - assert ( - rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id - ) + assert rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id assert rebuilt.audience_plan.primary.identifier == "3-7" def test_deal_params_round_trips_via_asdict(self) -> None: @@ -209,9 +199,7 @@ def test_deal_params_round_trips_via_asdict(self) -> None: ) assert rebuilt.audience_plan is not None - assert ( - rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id - ) + assert rebuilt.audience_plan.audience_plan_id == plan.audience_plan_id # --------------------------------------------------------------------------- diff --git a/tests/unit/test_pricing_provenance.py b/tests/unit/test_pricing_provenance.py index 524f980..9859149 100644 --- a/tests/unit/test_pricing_provenance.py +++ b/tests/unit/test_pricing_provenance.py @@ -257,10 +257,18 @@ def test_compare_quotes_filters_unpriced(self): both gracefully — unpriced quotes are excluded from ranking.""" normalizer = QuoteNormalizer() quotes = [ - (_make_quote(quote_id="q-priced", seller_id="seller-a", - final_cpm=10.0, base_cpm=12.0), "PD"), - (_make_quote(quote_id="q-unpriced", seller_id="seller-b", - final_cpm=None, base_cpm=None), "PD"), + ( + _make_quote( + quote_id="q-priced", seller_id="seller-a", final_cpm=10.0, base_cpm=12.0 + ), + "PD", + ), + ( + _make_quote( + quote_id="q-unpriced", seller_id="seller-b", final_cpm=None, base_cpm=None + ), + "PD", + ), ] ranked = normalizer.compare_quotes(quotes) # Unpriced quotes should be separated from the ranked list diff --git a/tests/unit/test_random_seed_and_cors.py b/tests/unit/test_random_seed_and_cors.py index d2e6ad1..350858d 100644 --- a/tests/unit/test_random_seed_and_cors.py +++ b/tests/unit/test_random_seed_and_cors.py @@ -68,9 +68,7 @@ def test_settings_default_cors_origins(self): ) origins = s.get_cors_origins() # MCP server default: wildcard — auth is enforced via X-API-Key, not origin - assert origins == ["*"], ( - f"Expected wildcard default for MCP server, got: {origins}" - ) + assert origins == ["*"], f"Expected wildcard default for MCP server, got: {origins}" def test_settings_custom_cors_origins(self): """CORS origins should be configurable.""" diff --git a/tests/unit/test_real_embedding_model.py b/tests/unit/test_real_embedding_model.py index c3efc31..c9ca4e3 100644 --- a/tests/unit/test_real_embedding_model.py +++ b/tests/unit/test_real_embedding_model.py @@ -45,9 +45,7 @@ def test_advertiser_mode_uses_supplied_vector(self): sample = [0.1] * 384 with patch.object(settings, "embedding_mode", "advertiser"): client = UCPClient() - r = client.create_query_embedding_with_provenance( - REQS, advertiser_vector=sample - ) + r = client.create_query_embedding_with_provenance(REQS, advertiser_vector=sample) assert r.provenance == "advertiser_supplied" assert r.embedding.vector == sample assert r.dimension == 384 @@ -57,9 +55,7 @@ def test_advertiser_dim_out_of_range_falls_back(self): bad = [0.5] * 100 with patch.object(settings, "embedding_mode", "advertiser"): client = UCPClient() - r = client.create_query_embedding_with_provenance( - REQS, advertiser_vector=bad - ) + r = client.create_query_embedding_with_provenance(REQS, advertiser_vector=bad) # Out-of-range advertiser vector skipped, mock used (mode=advertiser # has no local fallback configured, so mock is the safe default). assert r.provenance == "mock" @@ -69,9 +65,7 @@ def test_hybrid_mode_advertiser_wins(self): sample = [0.2] * 384 with patch.object(settings, "embedding_mode", "hybrid"): client = UCPClient() - r = client.create_query_embedding_with_provenance( - REQS, advertiser_vector=sample - ) + r = client.create_query_embedding_with_provenance(REQS, advertiser_vector=sample) assert r.provenance == "advertiser_supplied" assert r.embedding.vector == sample @@ -83,9 +77,7 @@ def test_hybrid_mode_no_advertiser_falls_to_local_or_mock(self): assert r.provenance in ("local_buyer", "mock") assert 256 <= r.dimension <= 1024 or r.dimension == 384 - @pytest.mark.skipif( - not SBERT_AVAILABLE, reason="sentence-transformers not installed" - ) + @pytest.mark.skipif(not SBERT_AVAILABLE, reason="sentence-transformers not installed") def test_local_mode_loads_real_model(self): # Best-effort: model download may be blocked in CI. Either way the # function returns a well-formed result. @@ -168,4 +160,5 @@ def test_mint_tool_format_uses_dynamic_label(self): def test_backward_compat_static_constant(self): from ad_buyer.tools.audience import EMBEDDING_MODE_LABEL_MOCK + assert "MOCK" in EMBEDDING_MODE_LABEL_MOCK diff --git a/tests/unit/test_seller_retry_on_rejection.py b/tests/unit/test_seller_retry_on_rejection.py index e02a273..0252da5 100644 --- a/tests/unit/test_seller_retry_on_rejection.py +++ b/tests/unit/test_seller_retry_on_rejection.py @@ -262,16 +262,12 @@ async def test_retry_drops_unsupported_extension_and_succeeds( assert any("extensions" in e.path for e in log) @pytest.mark.asyncio - async def test_retry_succeeds_clean_first_try_no_log( - self, orchestrator, deals_client_factory - ): + async def test_retry_succeeds_clean_first_try_no_log(self, orchestrator, deals_client_factory): """When the first booking succeeds, no retry, no degradation log.""" seller_url = "http://seller-a.example.com" client = deals_client_factory(seller_url) - client.book_deal.return_value = _make_deal_response( - deal_id="deal-1", quote_id="q-1" - ) + client.book_deal.return_value = _make_deal_response(deal_id="deal-1", quote_id="q-1") selection = await orchestrator.select_and_book( ranked_quotes=[_ranked_quote()], @@ -325,10 +321,7 @@ async def test_second_rejection_marks_seller_incompatible( assert selection.booked_deals == [] assert "seller-a" in selection.incompatible_sellers assert len(selection.failed_bookings) == 1 - assert ( - selection.failed_bookings[0]["error_code"] - == "audience_plan_unsupported" - ) + assert selection.failed_bookings[0]["error_code"] == "audience_plan_unsupported" assert selection.failed_bookings[0]["seller_id"] == "seller-a" @pytest.mark.asyncio @@ -534,10 +527,7 @@ def test_flat_error_shape_still_parses(self): from ad_buyer.clients.deals_client import DealsClient - body = ( - b'{"error": "product_not_found", ' - b'"detail": "Product bad-id does not exist"}' - ) + body = b'{"error": "product_not_found", "detail": "Product bad-id does not exist"}' response = httpx.Response( status_code=404, content=body, diff --git a/tests/unit/test_setup_wizard.py b/tests/unit/test_setup_wizard.py index c4cdc29..713e477 100644 --- a/tests/unit/test_setup_wizard.py +++ b/tests/unit/test_setup_wizard.py @@ -319,9 +319,7 @@ class TestAutoDetection: def test_detect_step_1_when_api_key_set(self): """Step 1 should auto-detect as done when API key is configured.""" wizard = SetupWizard() - with patch( - "ad_buyer.services.setup_wizard._get_settings" - ) as mock_settings: + with patch("ad_buyer.services.setup_wizard._get_settings") as mock_settings: mock_settings.return_value.api_key = "test-key-123" mock_settings.return_value.database_url = "sqlite:///./ad_buyer.db" mock_settings.return_value.environment = "development" @@ -332,9 +330,7 @@ def test_detect_step_1_when_api_key_set(self): def test_detect_step_1_not_detected_when_no_key(self): """Step 1 should not auto-detect without API key.""" wizard = SetupWizard() - with patch( - "ad_buyer.services.setup_wizard._get_settings" - ) as mock_settings: + with patch("ad_buyer.services.setup_wizard._get_settings") as mock_settings: mock_settings.return_value.api_key = "" mock_settings.return_value.database_url = "sqlite:///./ad_buyer.db" mock_settings.return_value.environment = "development" @@ -347,9 +343,7 @@ def test_detect_step_1_not_detected_when_no_key(self): def test_detect_step_2_when_sellers_configured(self): """Step 2 should auto-detect when seller endpoints are configured.""" wizard = SetupWizard() - with patch( - "ad_buyer.services.setup_wizard._get_settings" - ) as mock_settings: + with patch("ad_buyer.services.setup_wizard._get_settings") as mock_settings: mock_settings.return_value.api_key = "" mock_settings.return_value.database_url = "sqlite:///./ad_buyer.db" mock_settings.return_value.environment = "development" @@ -365,9 +359,7 @@ def test_auto_detect_does_not_override_completed(self): """Auto-detection should not override a manually completed step.""" wizard = SetupWizard() wizard.complete_step(1, {"deployment_target": "local"}) - with patch( - "ad_buyer.services.setup_wizard._get_settings" - ) as mock_settings: + with patch("ad_buyer.services.setup_wizard._get_settings") as mock_settings: mock_settings.return_value.api_key = "test-key" mock_settings.return_value.database_url = "sqlite:///./ad_buyer.db" mock_settings.return_value.environment = "development" @@ -408,9 +400,7 @@ def test_progress_includes_skipped(self): def test_progress_includes_auto_detected(self): """Auto-detected steps count toward progress.""" wizard = SetupWizard() - with patch( - "ad_buyer.services.setup_wizard._get_settings" - ) as mock_settings: + with patch("ad_buyer.services.setup_wizard._get_settings") as mock_settings: mock_settings.return_value.api_key = "test-key" mock_settings.return_value.database_url = "sqlite:///./ad_buyer.db" mock_settings.return_value.environment = "development" @@ -487,9 +477,7 @@ def test_run_wizard_returns_state(self): def test_run_wizard_auto_detects(self): """run_wizard should auto-detect completed steps first.""" wizard = SetupWizard() - with patch( - "ad_buyer.services.setup_wizard._get_settings" - ) as mock_settings: + with patch("ad_buyer.services.setup_wizard._get_settings") as mock_settings: mock_settings.return_value.api_key = "test-key" mock_settings.return_value.database_url = "sqlite:///./ad_buyer.db" mock_settings.return_value.environment = "development" diff --git a/tests/unit/test_sgp_gate.py b/tests/unit/test_sgp_gate.py index 3e431f3..afa37ba 100644 --- a/tests/unit/test_sgp_gate.py +++ b/tests/unit/test_sgp_gate.py @@ -326,12 +326,8 @@ def discovery_client() -> MagicMock: _product("p2", "denied.example.com"), _product("p3", "unknown.example.com"), ] - client.search_products = AsyncMock( - return_value=MagicMock(success=True, data=products) - ) - client.list_products = AsyncMock( - return_value=MagicMock(success=True, data=products) - ) + client.search_products = AsyncMock(return_value=MagicMock(success=True, data=products)) + client.list_products = AsyncMock(return_value=MagicMock(success=True, data=products)) return client @@ -355,9 +351,7 @@ def _discovery_sgp_mock() -> MagicMock: @pytest.mark.asyncio -async def test_discovery_enforce_filters_not_approved( - discovery_client, agency_context -): +async def test_discovery_enforce_filters_not_approved(discovery_client, agency_context): """When enforcing, NOT APPROVED rows are dropped before formatting.""" from ad_buyer.tools.buyer_deals import DiscoverInventoryTool @@ -377,9 +371,7 @@ async def test_discovery_enforce_filters_not_approved( @pytest.mark.asyncio -async def test_discovery_enforce_warn_keeps_unknowns( - discovery_client, agency_context -): +async def test_discovery_enforce_warn_keeps_unknowns(discovery_client, agency_context): """warn policy keeps unknowns in the result and emits a warning line.""" from ad_buyer.tools.buyer_deals import DiscoverInventoryTool @@ -399,9 +391,7 @@ async def test_discovery_enforce_warn_keeps_unknowns( @pytest.mark.asyncio -async def test_discovery_enforce_allow_keeps_unknowns_silently( - discovery_client, agency_context -): +async def test_discovery_enforce_allow_keeps_unknowns_silently(discovery_client, agency_context): """allow policy keeps unknowns and suppresses the per-row annotation.""" from ad_buyer.tools.buyer_deals import DiscoverInventoryTool @@ -462,9 +452,7 @@ async def test_discovery_fails_closed_when_sgp_unreachable_and_enforcing( @pytest.mark.asyncio -async def test_discovery_no_enforce_swallows_sgp_error( - discovery_client, agency_context, caplog -): +async def test_discovery_no_enforce_swallows_sgp_error(discovery_client, agency_context, caplog): """Without enforcement, transport error returns unannotated results.""" from ad_buyer.tools.buyer_deals import DiscoverInventoryTool @@ -486,9 +474,7 @@ async def test_discovery_no_enforce_swallows_sgp_error( @pytest.mark.asyncio -async def test_discovery_no_sgp_client_pass_through( - discovery_client, agency_context -): +async def test_discovery_no_sgp_client_pass_through(discovery_client, agency_context): """Without an SGP client, discovery behaves as before — no annotations, no filter.""" from ad_buyer.tools.buyer_deals import DiscoverInventoryTool diff --git a/tests/unit/test_ssp_index_exchange.py b/tests/unit/test_ssp_index_exchange.py index d25377f..470d408 100644 --- a/tests/unit/test_ssp_index_exchange.py +++ b/tests/unit/test_ssp_index_exchange.py @@ -574,8 +574,8 @@ def test_duplicate_deal_ids_skipped(self): connector = _connector_with_transport(transport) result = connector.fetch_deals() - assert result.successful == 4 # 4 unique deals - assert result.skipped == 1 # 1 duplicate skipped + assert result.successful == 4 # 4 unique deals + assert result.skipped == 1 # 1 duplicate skipped assert len(result.deals) == 4 diff --git a/tests/unit/test_ssp_magnite.py b/tests/unit/test_ssp_magnite.py index ba7af54..6134f61 100644 --- a/tests/unit/test_ssp_magnite.py +++ b/tests/unit/test_ssp_magnite.py @@ -293,9 +293,7 @@ def test_normalize_open_auction_no_fixed_price( assert result["fixed_price_cpm"] is None assert result["bid_floor_cpm"] == 8.00 - def test_normalize_null_dates_become_none( - self, streaming_connector, raw_open_auction_deal - ): + def test_normalize_null_dates_become_none(self, streaming_connector, raw_open_auction_deal): """Null dates in the Magnite response map to None.""" result = streaming_connector._normalize_deal(raw_open_auction_deal) assert result["flight_start"] is None @@ -371,9 +369,7 @@ def test_normalize_pd_deal(self, streaming_connector): result = streaming_connector._normalize_deal(raw) assert result["deal_type"] == "PD" - def test_normalized_deal_has_all_required_fields( - self, streaming_connector, raw_ctv_deal - ): + def test_normalized_deal_has_all_required_fields(self, streaming_connector, raw_ctv_deal): """Normalized deal contains all fields required by DealStore.save_deal().""" result = streaming_connector._normalize_deal(raw_ctv_deal) required_fields = [ @@ -414,9 +410,7 @@ def _make_deals_response(self, fixture_data: dict) -> MagicMock: mock_resp.raise_for_status = MagicMock() return mock_resp - def test_fetch_deals_returns_ssp_fetch_result( - self, streaming_connector, magnite_api_response - ): + def test_fetch_deals_returns_ssp_fetch_result(self, streaming_connector, magnite_api_response): """fetch_deals() returns an SSPFetchResult instance.""" with patch("httpx.Client") as mock_client_cls: mock_client = MagicMock() @@ -429,9 +423,7 @@ def test_fetch_deals_returns_ssp_fetch_result( assert isinstance(result, SSPFetchResult) - def test_fetch_deals_returns_correct_count( - self, streaming_connector, magnite_api_response - ): + def test_fetch_deals_returns_correct_count(self, streaming_connector, magnite_api_response): """fetch_deals() returns all 3 deals from the fixture.""" with patch("httpx.Client") as mock_client_cls: mock_client = MagicMock() @@ -447,9 +439,7 @@ def test_fetch_deals_returns_correct_count( assert result.total_fetched == 3 assert result.raw_response_count == 3 - def test_fetch_deals_sets_ssp_name( - self, streaming_connector, magnite_api_response - ): + def test_fetch_deals_sets_ssp_name(self, streaming_connector, magnite_api_response): """fetch_deals() sets ssp_name to 'Magnite' in the result.""" with patch("httpx.Client") as mock_client_cls: mock_client = MagicMock() @@ -462,9 +452,7 @@ def test_fetch_deals_sets_ssp_name( assert result.ssp_name == "Magnite" - def test_fetch_deals_normalizes_all_deals( - self, streaming_connector, magnite_api_response - ): + def test_fetch_deals_normalizes_all_deals(self, streaming_connector, magnite_api_response): """fetch_deals() normalizes all deals in the response.""" with patch("httpx.Client") as mock_client_cls: mock_client = MagicMock() @@ -480,9 +468,7 @@ def test_fetch_deals_normalizes_all_deals( assert "MAG-CTV-002" in deal_ids assert "MAG-CTV-003" in deal_ids - def test_fetch_deals_posts_to_login_endpoint( - self, streaming_connector, magnite_api_response - ): + def test_fetch_deals_posts_to_login_endpoint(self, streaming_connector, magnite_api_response): """fetch_deals() calls the Magnite login endpoint for session auth.""" with patch("httpx.Client") as mock_client_cls: mock_client = MagicMock() @@ -516,9 +502,7 @@ def test_fetch_deals_sends_credentials_in_login( body = post_kwargs[1].get("json", {}) assert "access-key" in body or "access_key" in body or "accessKey" in body - def test_fetch_deals_uses_seat_id_in_url( - self, streaming_connector, magnite_api_response - ): + def test_fetch_deals_uses_seat_id_in_url(self, streaming_connector, magnite_api_response): """fetch_deals() uses the seat ID in the deals endpoint URL.""" with patch("httpx.Client") as mock_client_cls: mock_client = MagicMock() @@ -576,9 +560,7 @@ def test_fetch_deals_rate_limit_raises_ssp_rate_limit_error( with pytest.raises(SSPRateLimitError): streaming_connector.fetch_deals() - def test_fetch_deals_connection_error_raises_ssp_connection_error( - self, streaming_connector - ): + def test_fetch_deals_connection_error_raises_ssp_connection_error(self, streaming_connector): """fetch_deals() raises SSPConnectionError on network failure.""" import httpx @@ -591,21 +573,32 @@ def test_fetch_deals_connection_error_raises_ssp_connection_error( with pytest.raises(SSPConnectionError): streaming_connector.fetch_deals() - def test_fetch_deals_bad_deal_captured_as_error( - self, streaming_connector - ): + def test_fetch_deals_bad_deal_captured_as_error(self, streaming_connector): """fetch_deals() captures normalization errors without crashing.""" bad_response = { "data": { "deals": [ - {"id": "VALID-001", "name": "Good Deal", "dealType": "PD", - "publisherName": "Test Publisher", "publisherDomain": "test.com", - "currency": "USD", "price": {"type": "floor"}, "floor": 5.0, - "mediaType": "CTV", "status": "active", "seatId": "seat-12345", - "impressions": None, "startDate": None, "endDate": None, - "description": None, - "targeting": {"geo": [], "contentCategories": [], "audiences": []}, - "formats": [], "publisherId": "pub-001", "buyerSeatId": "bseat"}, + { + "id": "VALID-001", + "name": "Good Deal", + "dealType": "PD", + "publisherName": "Test Publisher", + "publisherDomain": "test.com", + "currency": "USD", + "price": {"type": "floor"}, + "floor": 5.0, + "mediaType": "CTV", + "status": "active", + "seatId": "seat-12345", + "impressions": None, + "startDate": None, + "endDate": None, + "description": None, + "targeting": {"geo": [], "contentCategories": [], "audiences": []}, + "formats": [], + "publisherId": "pub-001", + "buyerSeatId": "bseat", + }, {"name": "Missing ID", "dealType": "PD"}, # Missing required id ], "totalCount": 2, @@ -633,9 +626,7 @@ def test_fetch_deals_bad_deal_captured_as_error( def test_fetch_deals_empty_response(self, streaming_connector): """fetch_deals() handles an empty deals list gracefully.""" - empty_response = { - "data": {"deals": [], "totalCount": 0, "page": 1, "pageSize": 100} - } + empty_response = {"data": {"deals": [], "totalCount": 0, "page": 1, "pageSize": 100}} with patch("httpx.Client") as mock_client_cls: mock_client = MagicMock() mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client) @@ -744,9 +735,7 @@ def test_all_fixture_deals_normalize_successfully( assert errors == [], f"Normalization errors: {errors}" assert len(normalized) == 3 - def test_fixture_pg_deal_has_fixed_price( - self, streaming_connector, magnite_deals_list - ): + def test_fixture_pg_deal_has_fixed_price(self, streaming_connector, magnite_deals_list): """The PG deal in the fixture has fixed_price_cpm set.""" pg_deal = next(d for d in magnite_deals_list if d["dealType"] == "PG") normalized = streaming_connector._normalize_deal(pg_deal) @@ -760,25 +749,19 @@ def test_fixture_open_auction_deal_has_no_fixed_price( normalized = streaming_connector._normalize_deal(oa_deal) assert normalized["fixed_price_cpm"] is None - def test_fixture_all_deals_have_ctv_media_type( - self, streaming_connector, magnite_deals_list - ): + def test_fixture_all_deals_have_ctv_media_type(self, streaming_connector, magnite_deals_list): """All fixture deals normalize to CTV media type.""" for deal in magnite_deals_list: normalized = streaming_connector._normalize_deal(deal) assert normalized["media_type"] == "CTV" - def test_fixture_all_deals_have_seller_type_ssp( - self, streaming_connector, magnite_deals_list - ): + def test_fixture_all_deals_have_seller_type_ssp(self, streaming_connector, magnite_deals_list): """All fixture deals have seller_type set to SSP.""" for deal in magnite_deals_list: normalized = streaming_connector._normalize_deal(deal) assert normalized["seller_type"] == "SSP" - def test_fixture_fetch_full_flow( - self, streaming_connector, magnite_api_response - ): + def test_fixture_fetch_full_flow(self, streaming_connector, magnite_api_response): """Full fetch flow using fixture returns 3 successful deals.""" with patch("httpx.Client") as mock_client_cls: mock_client = MagicMock() @@ -825,4 +808,5 @@ def test_magnite_connector_importable_from_connectors_package(self): def test_magnite_is_ssp_connector_subclass(self): """MagniteConnector is a subclass of SSPConnector.""" from ad_buyer.tools.deal_library.ssp_connector_base import SSPConnector + assert issubclass(MagniteConnector, SSPConnector) diff --git a/tests/unit/test_ssp_pubmatic.py b/tests/unit/test_ssp_pubmatic.py index 46ca0fe..632f2ab 100644 --- a/tests/unit/test_ssp_pubmatic.py +++ b/tests/unit/test_ssp_pubmatic.py @@ -386,7 +386,10 @@ def test_null_impressions_is_none(self): def test_notes_mapped_to_description(self): """notes → description.""" normalized = self.connector._normalize_deal(self._pg_deal()) - assert normalized["description"] == "Premium sports programming package with guaranteed delivery" # noqa: E501 + assert ( + normalized["description"] + == "Premium sports programming package with guaranteed delivery" + ) # noqa: E501 def test_null_notes_is_none(self): """null notes → description is None.""" @@ -463,7 +466,9 @@ def test_fetch_deals_deal_fields_correct(self): def test_fetch_deals_empty_response(self): """fetch_deals() handles empty deals list.""" transport = _MockTransport( - _make_response(200, {"status": "success", "total": 0, "page": 1, "page_size": 100, "deals": []}) # noqa: E501 + _make_response( + 200, {"status": "success", "total": 0, "page": 1, "page_size": 100, "deals": []} + ) # noqa: E501 ) connector = _connector_with_transport(transport) result = connector.fetch_deals() diff --git a/tests/unit/test_tool_return_type_hints.py b/tests/unit/test_tool_return_type_hints.py index 3434bdd..60c91ec 100644 --- a/tests/unit/test_tool_return_type_hints.py +++ b/tests/unit/test_tool_return_type_hints.py @@ -41,9 +41,7 @@ def _discover_tool_methods() -> list[tuple[str, str]]: method = getattr(cls, method_name) if not callable(method): continue - discovered.append( - (f"{cls.__module__}.{cls_name}.{method_name}", method_name) - ) + discovered.append((f"{cls.__module__}.{cls_name}.{method_name}", method_name)) return discovered diff --git a/tests/unit/test_tool_to_natural_language.py b/tests/unit/test_tool_to_natural_language.py index 58af840..35b2691 100644 --- a/tests/unit/test_tool_to_natural_language.py +++ b/tests/unit/test_tool_to_natural_language.py @@ -27,6 +27,7 @@ def client() -> UnifiedClient: # Registry coverage: every registered tool returns a non-empty string. # --------------------------------------------------------------------------- + def test_every_registered_tool_has_non_empty_mapping(client: UnifiedClient): """Each entry in _TOOL_NL_REGISTRY must produce a non-empty string.""" sample_args = { @@ -55,15 +56,14 @@ def test_every_registered_tool_has_non_empty_mapping(client: UnifiedClient): ) # Sanity: callable-backed entries should differ from no-arg fallback if callable(entry): - assert with_arg_msg != "", ( - f"Callable entry returned empty for {tool_name!r}" - ) + assert with_arg_msg != "", f"Callable entry returned empty for {tool_name!r}" # --------------------------------------------------------------------------- # Unknown tool fallback: never empty, never raises. # --------------------------------------------------------------------------- + def test_unknown_tool_returns_generic_fallback(client: UnifiedClient): msg = client._tool_to_natural_language("totally_made_up_tool", {}) assert msg @@ -72,9 +72,7 @@ def test_unknown_tool_returns_generic_fallback(client: UnifiedClient): def test_unknown_tool_with_args_includes_args(client: UnifiedClient): - msg = client._tool_to_natural_language( - "totally_made_up_tool", {"foo": "bar", "n": 3} - ) + msg = client._tool_to_natural_language("totally_made_up_tool", {"foo": "bar", "n": 3}) assert "totally_made_up_tool" in msg assert "foo=bar" in msg assert "n=3" in msg @@ -106,6 +104,7 @@ def test_none_args_is_safe(client: UnifiedClient): # Case insensitivity. # --------------------------------------------------------------------------- + @pytest.mark.parametrize( "name", ["list_products", "LIST_PRODUCTS", "List_Products", " list_products "], @@ -119,6 +118,7 @@ def test_case_insensitive_lookup(client: UnifiedClient, name: str): # Backward-compatible behavior with the previous test suite expectations. # --------------------------------------------------------------------------- + def test_create_account_includes_name_and_type(client: UnifiedClient): msg = client._tool_to_natural_language( "create_account", {"name": "TestCo", "type": "advertiser"} @@ -163,10 +163,9 @@ def test_get_by_id_tools_render_id(client: UnifiedClient): # preserved) but still returns a useful, non-empty message. # --------------------------------------------------------------------------- + def test_listing_tool_with_args_falls_through_to_generic(client: UnifiedClient): - msg = client._tool_to_natural_language( - "list_orders", {"accountId": "acct-9"} - ) + msg = client._tool_to_natural_language("list_orders", {"accountId": "acct-9"}) assert msg # Should mention either the tool name or the arg; generic renderer does both assert "list_orders" in msg