From b206eaa1fbeeabef869e2493bcf5b9313fd343b6 Mon Sep 17 00:00:00 2001 From: Marcin Zieba Date: Mon, 25 May 2026 11:00:33 +0200 Subject: [PATCH] feat: per-vendor fetch, export-diff, verify-images, remove-unmanaged-types New flags: - --export-diff / --export-diff-dir / --force-export-overwrite: compare NetBox device types against a local DTL repo and export any that differ or are absent, writing them as YAML files - --verify-images: physically check images are present on the NetBox server (not just in DB) and re-upload if missing or changed - --remove-unmanaged-types: remove device/module/rack types from NetBox that are no longer present in the repo (requires --remove-components) - --slugs fast-path: resolve matching files via upstream pickle indexes before scanning, narrowing the vendor list for faster runs Core additions: - core/export.py: Exporter class with parallel component fetch, repo YAML comparison, order-insensitive diff, manufacturer slug normalisation, deterministic sorted YAML loading, duplicate-key detection, and --force-export-overwrite guard - core/export_manifest.py: manifest helpers for tracking exported files - core/nb_serializer.py: NetBox-to-DTL YAML serialiser with legacy front-port fallback and numeric-string coercion - core/graphql_client.py: last_updated field; get_module_type_image* endpoints; GraphQLCountMismatchError no longer swallowed; per-thread sessions closed after ThreadPoolExecutor; slug stripping - core/netbox_api.py: atomic image-hash cache writes (tempfile+fsync+ os.replace); _try_delete_stale_attachment helper; verify-images HTTP check with same-host auth guard; _fmt_connection_error helper - core/repo.py: _vendor_slugs_from_pickle returns None when unavailable (distinguishes 'no matches' from 'hint missing') - nb-dt-import.py: _parse_vendor_types / _parse_vendor_racks helpers; _apply_slug_fast_path with None-safe module/rack vendor hint guards; _validate_argument_combinations rejects import-only flags with --export-diff; per-vendor Vendors progress counter; NoPulseBarColumn; cumulative task registry; deduplicated GraphQL page-size warning Security / robustness: - Upstream pickle loads are guarded (security) - Image-hash cache migrated from pickle to JSON - Mixed str/float name lists sorted safely - UnicodeDecodeError caught in load_manifest - ConnectionError caught mid-run - YAML loader except narrowed to (yaml.YAMLError, OSError) so duplicate- key ValueError propagates Tests: 929 tests, all passing --- .gitignore | 11 + .markdownlint.json | 3 +- .pre-commit-config.yaml | 4 - CHANGELOG.md | 2 +- README.md | 37 +- core/change_detector.py | 27 +- core/export.py | 908 ++++++++++ core/export_manifest.py | 51 + core/graphql_client.py | 406 +++-- core/nb_serializer.py | 298 ++++ core/netbox_api.py | 843 ++++++++-- core/repo.py | 191 +++ nb-dt-import.py | 829 ++++++--- tests/conftest.py | 10 + tests/test_change_detector.py | 65 +- tests/test_export_manifest.py | 127 ++ tests/test_exporter.py | 1395 ++++++++++++++++ tests/test_graphql_client.py | 664 +++++++- tests/test_nb_dt_import.py | 995 ++++++++--- tests/test_nb_serializer.py | 679 ++++++++ tests/test_netbox_api.py | 2231 +++++++++++++++++++++---- tests/test_outcomes.py | 16 + tests/test_repo.py | 459 ++++- tests/test_schema_reader.py | 7 + tests/test_update_failure_resolver.py | 11 + 25 files changed, 9251 insertions(+), 1018 deletions(-) create mode 100644 core/export.py create mode 100644 core/export_manifest.py create mode 100644 core/nb_serializer.py create mode 100644 tests/test_export_manifest.py create mode 100644 tests/test_exporter.py create mode 100644 tests/test_nb_serializer.py diff --git a/.gitignore b/.gitignore index 1b1766a18..67573fca1 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,14 @@ CLAUDE.md CODE_REVIEW_PLAN.md repo + +# Image hash cache for --verify-images (local state, not for commit) +tests/known-image-hashes.pickle +tests/known-image-hashes.json + +# Superpowers skill/spec files (local tooling, not for commit) +docs/superpowers/ + +# Export-diff output directory and manifest (local state, not for commit) +extra/ +**/.export-manifest.json diff --git a/.markdownlint.json b/.markdownlint.json index 38f1ed817..e6342f41b 100644 --- a/.markdownlint.json +++ b/.markdownlint.json @@ -1,6 +1,7 @@ { "MD013": { - "line_length": 120 + "line_length": 120, + "tables": false }, "MD024": { "siblings_only": true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4f29299e..9d3046ff7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,4 @@ repos: - - repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.47.0 - hooks: - - id: markdownlint - repo: local hooks: - id: ruff-check diff --git a/CHANGELOG.md b/CHANGELOG.md index 19eb22c72..7023d795f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,7 +50,7 @@ ### Features -- Schema-driven property comparison for device/module types (#64, [`2af54a0`](https://github.com/marcinpsk/Device-Type-Library-Import/commit/2af54a09255225f75457d159bbb6c5afbdf0f1e7)) +- `--verify-images`: verify image physical presence and content hash, re-upload if missing or changed - Full component comparison for module types with description/color/rf_role coverage (#64, [`2af54a0`](https://github.com/marcinpsk/Device-Type-Library-Import/commit/2af54a09255225f75457d159bbb6c5afbdf0f1e7)) - Validate graphql component fetch counts against rest api (#64, [`2af54a0`](https://github.com/marcinpsk/Device-Type-Library-Import/commit/2af54a09255225f75457d159bbb6c5afbdf0f1e7)) - Graphql count mismatch retry logic, 100% docstring coverage (#64, [`2af54a0`](https://github.com/marcinpsk/Device-Type-Library-Import/commit/2af54a09255225f75457d159bbb6c5afbdf0f1e7)) diff --git a/README.md b/README.md index 7b4f8359b..b42c5ab5f 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,9 @@ uv run nb-dt-import.py --vendors "Palo Alto" --slugs 440 | `--only-new` | off | Only create new types, skip all existing ones (mutually exclusive with `--update`) | | `--update` | off | Update existing types with changes from the repo (mutually exclusive with `--only-new`) | | `--remove-components` | off | Delete components missing from YAML when used with `--update`. **Destructive.** | +| `--remove-unmanaged-types` | off | Also delete components whose entire YAML section is missing (e.g. NetBox has interfaces but YAML defines none). Requires `--remove-components`. **Aggressive.** | | `--force-resolve-conflicts` | off | Automatically resolve NetBox constraint failures during `--update`. **Destructive.** See below. | +| `--verify-images` | off | Verify images recorded in NetBox are physically present on the server. Uses an HTTP presence check per image and a local SHA-256 cache to detect local file changes (does not hash the remote file). Re-uploads any image that is missing on the server or whose local file has changed. Useful after recreating a devcontainer or updating local image files. **Makes one HTTP request per image.** | #### Update Mode @@ -162,6 +164,15 @@ no longer present in the YAML definition. - Components attached to actual device instances may prevent deletion - Review the change detection report before enabling component removal - Test on a staging NetBox instance first if possible +- By default, `--remove-components` only removes components from YAML sections that are + *present but no longer list a given component*. If a YAML omits an entire section + (for example, a chassis with no `interfaces:` key), pre-existing NetBox interfaces are + left untouched. Add `--remove-unmanaged-types` to treat a missing section the same as an + empty list and remove every component of that type from NetBox. + +```shell +uv run nb-dt-import.py --update --remove-components --remove-unmanaged-types +``` #### Conflict Resolution (Use with Caution) @@ -196,7 +207,31 @@ uv run nb-dt-import.py --update --force-resolve-conflicts - After converting device types from parent to child (or vice versa) - When the script reports constraint failures that block property updates -## Contributing +#### Image Verification (`--verify-images`) + +By default, the script skips uploading images that already have a URL recorded in the NetBox +database. This means physically missing images (e.g. after recreating a devcontainer) or updated +local image files are not re-uploaded. Use `--verify-images` to re-check: + +```shell +uv run nb-dt-import.py --vendors nokia --verify-images +``` + +**What it does**: + +- For each device type / module type whose image is already recorded in NetBox, issues an HTTP + GET to verify the file is physically accessible on the server +- Compares the local file's SHA-256 hash against a persistent local cache (the remote file is + **not** downloaded or hashed; a 2xx HTTP response is treated as "present") +- Re-uploads the image if it is **missing** (server returned a non-2xx response) or + **changed** (the local file's hash differs from the cached value recorded at last upload) + +**When to use**: + +- After recreating a devcontainer or restoring NetBox without its media volume — the database + still knows about images, but the files are gone +- After replacing a local image file with a higher-quality version and wanting NetBox to pick + it up We're happy about any pull requests! diff --git a/core/change_detector.py b/core/change_detector.py index 6fa273df3..c819b9364 100644 --- a/core/change_detector.py +++ b/core/change_detector.py @@ -166,15 +166,20 @@ def get_device_type_properties(): class ChangeDetector: """Detects changes between YAML device types and NetBox cached data.""" - def __init__(self, device_types_instance, handle): + def __init__(self, device_types_instance, handle, remove_unmanaged_types: bool = False): """Initialize the change detector. Args: device_types_instance: DeviceTypes instance with cached data handle: LogHandler for logging + remove_unmanaged_types: When True, propose removal of components whose entire + YAML section is missing (not just those listed in an empty/partial section). + Only honoured when callers also pass ``remove_components=True`` to the + applier; this flag controls *detection*, not application. """ self.device_types = device_types_instance self.handle = handle + self.remove_unmanaged_types = remove_unmanaged_types def detect_changes(self, device_types: List[dict], progress=None) -> ChangeReport: """Analyze all device types and generate a change report. @@ -333,8 +338,11 @@ def _compare_components( # Check for removed components (exist in NetBox but not in YAML) # Only flag removals when the YAML explicitly defines this component type; - # a missing key means the YAML doesn't manage this type at all. - if yaml_key in yaml_data: + # a missing key normally means the YAML doesn't manage this type at all. + # When remove_unmanaged_types is True, the missing-key case is treated the + # same as an empty list so chassis YAMLs that omit (e.g.) interfaces can + # still drive cleanup of stale templates in NetBox. + if yaml_key in yaml_data or self.remove_unmanaged_types: for existing_name in existing_components.keys(): if existing_name not in yaml_component_names: changes.append( @@ -571,7 +579,18 @@ def _log_modified_device_details(self, dt: DeviceTypeChange): self.handle.verbose_log(f" - {comp.component_type}: {comp.component_name}") def log_change_report(self, report: ChangeReport): - """Log the change report in a clear, readable format.""" + """Log the change report in a clear, readable format. + + Suppresses the full banner when there are no new or modified types — + emits a single verbose-level summary instead to avoid flooding the + terminal with empty reports during a multi-vendor run. + """ + has_changes = report.new_device_types or report.modified_device_types + if not has_changes: + if report.unchanged_count: + self.handle.verbose_log(f"No device type changes ({report.unchanged_count} unchanged).") + return + self.handle.log("=" * 60) self.handle.log("CHANGE DETECTION REPORT") self.handle.log("=" * 60) diff --git a/core/export.py b/core/export.py new file mode 100644 index 000000000..064aae557 --- /dev/null +++ b/core/export.py @@ -0,0 +1,908 @@ +"""Export-diff feature: export NetBox types absent from or differing vs. the local repo. + +Entry point: ``Exporter(settings, handle, export_dir, force_overwrite, vendor_slugs).run()`` +""" + +import hashlib +import os +import re +import threading +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional + +import requests +import yaml + +from core.export_manifest import ( + is_entry_fresh, + load_manifest, + save_manifest, + update_entry, +) +from core.graphql_client import NetBoxGraphQLClient +from core.nb_serializer import ( + COMPONENT_ENDPOINTS, + serialize_device_type, + serialize_module_type, + serialize_rack_type, +) +from core.netbox_api import IMAGE_EXTENSIONS, _build_auth_header + +_SKIP = object() # sentinel: image already exists, no download needed + +# Maps Content-Type to a canonical extension for extension-less attachments. +_CONTENT_TYPE_EXT = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "image/tiff": ".tiff", + "image/svg+xml": ".svg", +} + + +def _canon_mfr_slug(mfr: Any) -> str: + """Return a canonical lowercase slug for a manufacturer value from a repo YAML. + + Handles three forms produced by ``yaml.safe_load()``: + - ``{"slug": "nokia"}`` → ``"nokia"`` + - ``{"name": "Nokia", "slug": "nokia"}`` → ``"nokia"`` + - ``{"name": "Nokia"}`` → ``"nokia"`` (derived from name) + - ``"Nokia"`` → ``"nokia"`` (plain string) + """ + if isinstance(mfr, dict): + raw = mfr.get("slug") or mfr.get("name") or "" + elif isinstance(mfr, str): + raw = mfr + else: + return "" + return re.sub(r"[^a-z0-9]+", "-", raw.lower()).strip("-") + + +def _sanitize_attachment_filename(att_name: str, url_path: str, content_type: str) -> str: + """Return a safe, extension-bearing filename for a NetBox image attachment. + + 1. Strips any directory components (prevents path traversal). + 2. If the name already carries a recognised image extension, returns it as-is. + 3. Otherwise derives an extension from *content_type* (preferred) or from + the URL suffix, falling back to ``.bin`` if nothing matches. + """ + safe = Path(att_name).name # strip leading "../" or "subdir/" + ext = Path(safe).suffix.lower() + if ext in IMAGE_EXTENSIONS: + return safe + # Try to derive extension from Content-Type + ct_base = content_type.split(";")[0].strip().lower() + derived = _CONTENT_TYPE_EXT.get(ct_base) + if derived is None: + # Fall back to URL path suffix + url_ext = Path(url_path.split("?")[0]).suffix.lower() + derived = url_ext if url_ext in IMAGE_EXTENSIONS else ".bin" + return safe + derived if safe else f"attachment{derived}" + + +def _make_filename(model: str) -> str: + """Sanitize *model* into a valid flat filename (no path separators, no spaces). + + Replaces spaces, forward/back-slashes with dashes and collapses duplicate dashes. + Preserves original casing to stay consistent with DTL conventions. + """ + name = re.sub(r"[ /\\]", "-", model) + name = re.sub(r"-{2,}", "-", name) + return name.strip("-") + + +@dataclass +class ExportItem: + """A single type that should be written to the export directory.""" + + kind: str # "device-type" | "module-type" | "rack-type" + nb_record: Any + repo_yaml: Optional[dict] # None when absent from repo + serialized: dict # What we will write + reason: str # "absent" | "differs" | "images-missing" + mfr_name: str + filename: str # e.g. "nokia-7750-sr-7s.yaml" + manifest_key: str # e.g. "Nokia/nokia-7750-sr-7s" + + +def _normalize_for_compare(obj: Any) -> Any: + """Recursively normalize a dict/list for equality comparison. + + - float with integer value → int (handles u_height=1.0 vs u_height=1) + - empty string → None + - lists of named components → sorted by ``name`` (DTL ordering is cosmetic) + """ + if isinstance(obj, dict): + return {k: _normalize_for_compare(v) for k, v in obj.items()} + if isinstance(obj, list): + normalized = [_normalize_for_compare(item) for item in obj] + if normalized and all(isinstance(i, dict) and "name" in i for i in normalized): + normalized.sort(key=lambda d: str(d["name"])) + return normalized + if isinstance(obj, float) and not isinstance(obj, bool) and obj.is_integer(): + return int(obj) + if isinstance(obj, str) and obj == "": + return None + return obj + + +def _yaml_equal(a: dict, b: dict) -> bool: + """Return True when *a* and *b* are semantically equal YAML dicts.""" + return _normalize_for_compare(a) == _normalize_for_compare(b) + + +def _repo_supersedes(repo_yaml: dict, nb_serialized: dict) -> bool: + """Return True when *repo_yaml* already contains every field NetBox would write. + + Used to suppress exports where the repo is a strict superset of NetBox: if + every field/value in ``nb_serialized`` is also present (and equal, after + normalization) in ``repo_yaml``, the export would only delete information + (e.g. drop the repo-only ``profile`` field). In that case the export is + skipped — the repo is considered the better source of truth. + + Lists of named components are matched element-wise by ``name``: each NB + component must be present in the repo with the same fields/values, but the + repo may carry additional components or extra per-component fields. + """ + + # Normalize manufacturer to canonical slug so dict-form repo values + # (e.g. {name: Nokia, slug: Nokia}) compare equal to NB's plain slug string. + def _norm_mfr(d: dict) -> dict: + if "manufacturer" not in d: + return d + return {**d, "manufacturer": _canon_mfr_slug(d["manufacturer"])} + + nrepo = _normalize_for_compare(_norm_mfr(repo_yaml)) + nnb = _normalize_for_compare(_norm_mfr(nb_serialized)) + return _is_subset(nnb, nrepo) + + +def _is_subset(sub: Any, sup: Any) -> bool: + """Return True if every leaf value in *sub* is present and equal in *sup*.""" + if isinstance(sub, dict): + if not isinstance(sup, dict): + return False + for k, v in sub.items(): + if k not in sup: + return False + if not _is_subset(v, sup[k]): + return False + return True + if isinstance(sub, list): + if not isinstance(sup, list): + return False + # If items are named components, match by name; otherwise require + # exact equality (positional list). + if sub and all(isinstance(i, dict) and "name" in i for i in sub): + sup_by_name = {i["name"]: i for i in sup if isinstance(i, dict) and "name" in i} + for item in sub: + other = sup_by_name.get(item["name"]) + if other is None or not _is_subset(item, other): + return False + return True + return sub == sup + return sub == sup + + +class Exporter: + """Exports NetBox device/module/rack types to a local directory in DTL format.""" + + def __init__(self, settings, handle, export_dir: str, force_overwrite: bool, vendor_slugs: Optional[List[str]]): + """Initialize the Exporter with settings and configuration.""" + self.settings = settings + self.handle = handle + self.export_dir = Path(export_dir) + self.force_overwrite = force_overwrite + self.vendor_slugs = vendor_slugs # None means all vendors + self.repo_path = Path(settings.REPO_PATH) + self.base_url = settings.NETBOX_URL.rstrip("/") + self.token = settings.NETBOX_TOKEN + self.ignore_ssl = settings.IGNORE_SSL_ERRORS + self.graphql = NetBoxGraphQLClient( + url=settings.NETBOX_URL, + token=settings.NETBOX_TOKEN, + ignore_ssl=settings.IGNORE_SSL_ERRORS, + ) + self._module_image_details: Optional[dict] = None + + def _get_module_image_details(self) -> dict: + """Return module type image details, fetching from NetBox at most once per run.""" + if self._module_image_details is None: + self._module_image_details = self.graphql.get_module_type_image_details() + return self._module_image_details + + def run(self, progress=None) -> None: + """Run the export-diff workflow.""" + self._module_image_details = None + self._verify_export_dir_writable() + manifest_path = self.export_dir / ".export-manifest.json" + manifest = load_manifest(manifest_path) + + scope = ( + f" for {len(self.vendor_slugs)} vendor(s): {', '.join(self.vendor_slugs)}" + if self.vendor_slugs + else " (all vendors)" + ) + self.handle.log(f"Export-diff: fetching NetBox device/module/rack types{scope}") + + # ── Fetch all types from NetBox ────────────────────────────────────── + by_model, by_slug = self.graphql.get_device_types( + manufacturer_slugs=self.vendor_slugs if self.vendor_slugs else None + ) + all_mt = self.graphql.get_module_types(manufacturer_slugs=self.vendor_slugs if self.vendor_slugs else None) + all_rt = self.graphql.get_rack_types(manufacturer_slugs=self.vendor_slugs if self.vendor_slugs else None) + + total_dt = len(by_model) + total_mt = sum(len(v) for v in all_mt.values()) + total_rt = sum(len(v) for v in all_rt.values()) + self.handle.log( + f"Fetched type metadata: {total_dt} device-types, " + f"{total_mt} module-types, {total_rt} rack-types. " + f"Component templates fetched per vendor below " + f"({len(COMPONENT_ENDPOINTS)} endpoints/vendor)." + ) + + # ── Load repo YAML dicts ───────────────────────────────────────────── + repo_dt_by_slug = self._load_repo_device_types() + repo_mt_by_key = self._load_repo_module_types() + repo_rt_by_key = self._load_repo_rack_types() + self.handle.verbose_log( + f"Loaded repo: {len(repo_dt_by_slug)} device-types, " + f"{len(repo_mt_by_key)} module-types, {len(repo_rt_by_key)} rack-types" + ) + + # Collect all vendor slugs that have device types OR module types + dt_by_vendor: dict = {} + for (mfr_slug, _model), record in by_model.items(): + dt_by_vendor.setdefault(mfr_slug, []).append(record) + + mt_by_vendor: dict = {} + for mfr_slug, models in all_mt.items(): + mt_by_vendor[mfr_slug] = list(models.values()) + + all_vendor_slugs = sorted(set(dt_by_vendor) | set(mt_by_vendor)) + items, skipped_fresh = self._compare_vendors_to_items( + all_vendor_slugs=all_vendor_slugs, + dt_by_vendor=dt_by_vendor, + mt_by_vendor=mt_by_vendor, + manifest=manifest, + repo_dt_by_slug=repo_dt_by_slug, + repo_mt_by_key=repo_mt_by_key, + progress=progress, + ) + rack_items, rack_skipped_fresh = self._compare_racks_to_items( + all_rt=all_rt, + manifest=manifest, + repo_rt_by_key=repo_rt_by_key, + progress=progress, + ) + items.extend(rack_items) + skipped_fresh += rack_skipped_fresh + + if skipped_fresh: + self.handle.verbose_log(f"Skipped {skipped_fresh} record(s) unchanged since last export (manifest fresh)") + + if not items: + self.handle.log( + "Nothing to export: every NetBox type is already represented in the repo " + "(or fresh in manifest). Use --force-export-overwrite or delete the manifest " + "to re-check." + ) + save_manifest(manifest_path, manifest) + return + + self._write_export_items(items, manifest, manifest_path, progress) + + def _compare_vendors_to_items( + self, + all_vendor_slugs, + dt_by_vendor, + mt_by_vendor, + manifest, + repo_dt_by_slug, + repo_mt_by_key, + progress, + ) -> tuple[List[ExportItem], int]: + """Compare stale device/module types per vendor and return export items.""" + items: List[ExportItem] = [] + skipped_fresh = 0 + compare_task = ( + progress.add_task("Comparing vendors", total=len(all_vendor_slugs)) + if progress is not None and all_vendor_slugs + else None + ) + + for mfr_slug in all_vendor_slugs: + stale_dts = [] + for record in dt_by_vendor.get(mfr_slug, []): + if is_entry_fresh( + manifest, + "device-types", + f"{record.manufacturer.name}/{record.slug}", + record.last_updated, + ): + skipped_fresh += 1 + continue + stale_dts.append(record) + + stale_mts = [] + for record in mt_by_vendor.get(mfr_slug, []): + if is_entry_fresh( + manifest, + "module-types", + f"{record.manufacturer.name}/{record.model}", + record.last_updated, + ): + skipped_fresh += 1 + continue + stale_mts.append(record) + + if not stale_dts and not stale_mts: + if compare_task is not None: + progress.advance(compare_task) + continue + + self.handle.verbose_log( + f" {mfr_slug}: {len(stale_dts)} device-type(s), " + f"{len(stale_mts)} module-type(s) to compare; " + f"fetching {len(COMPONENT_ENDPOINTS)} component-template endpoints…" + ) + dt_components, mt_components = self._fetch_vendor_components(mfr_slug) + + for record in stale_dts: + items.extend( + self._determine_export_set_for_device_types( + nb_records=[record], + repo_dt_by_slug=repo_dt_by_slug, + components_by_dt_id=dt_components, + ) + ) + + for record in stale_mts: + items.extend( + self._determine_export_set_for_module_types( + nb_records=[record], + repo_mt_by_key=repo_mt_by_key, + components_by_mt_id=mt_components, + ) + ) + + if compare_task is not None: + progress.advance(compare_task) + + return items, skipped_fresh + + def _compare_racks_to_items(self, all_rt, manifest, repo_rt_by_key, progress) -> tuple[List[ExportItem], int]: + """Compare stale rack types and return export items.""" + items: List[ExportItem] = [] + skipped_fresh = 0 + rack_records = [record for models in all_rt.values() for record in models.values()] + rack_task = ( + progress.add_task("Comparing rack types", total=len(rack_records)) + if progress is not None and rack_records + else None + ) + + for record in rack_records: + if is_entry_fresh( + manifest, + "rack-types", + f"{record.manufacturer.name}/{record.model}", + record.last_updated, + ): + skipped_fresh += 1 + else: + items.extend( + self._determine_export_set_for_rack_types( + nb_records=[record], + repo_rt_by_key=repo_rt_by_key, + ) + ) + + if rack_task is not None: + progress.advance(rack_task) + + return items, skipped_fresh + + def _write_export_items(self, items, manifest, manifest_path, progress) -> None: + """Write export items, update the manifest, and log the final summary.""" + absent = sum(1 for item in items if item.reason == "absent") + differs = sum(1 for item in items if item.reason == "differs") + img_missing = sum(1 for item in items if item.reason == "images-missing") + self.handle.log( + f"Will export {len(items)} item(s) to {self.export_dir}: " + f"{absent} absent, {differs} differs, {img_missing} images-missing" + ) + + write_task = progress.add_task("Writing exports", total=len(items)) if progress is not None else None + written_count = 0 + skipped_overwrite = 0 + for item in items: + self.handle.verbose_log(f"Export [{item.reason}] {item.kind}: {item.mfr_name}/{item.filename}") + subdir = { + "device-type": "device-types", + "module-type": "module-types", + "rack-type": "rack-types", + }[item.kind] + dest = self.export_dir / subdir / item.mfr_name / item.filename + # For "differs" items that have a repo counterpart, preserve repo-only + # top-level fields (e.g. comments, profile) that NetBox does not return + # in its serialized output. Component lists are left as NB authoritative. + to_write = item.serialized + if item.reason == "differs" and item.repo_yaml: + # Only preserve scalar/metadata repo fields not present in the NB output. + # Exclude list-valued keys (component sections such as interfaces, power-ports, + # console-ports, etc.) so that NB remains authoritative for all components. + extra = { + k: v for k, v in item.repo_yaml.items() if k not in item.serialized and not isinstance(v, list) + } + if extra: + to_write = {**item.serialized, **extra} + written = self._write_yaml(dest, to_write) + if not written: + skipped_overwrite += 1 + self.handle.log( + f"[yellow]Skipped (overwrite guard): {dest}. Use --force-export-overwrite to overwrite.[/yellow]" + ) + if write_task is not None: + progress.advance(write_task) + continue + + written_count += 1 + images_ok = self._download_type_images(item) + if images_ok: + update_entry(manifest, f"{item.kind}s", item.manifest_key, item.nb_record.last_updated) + if write_task is not None: + progress.advance(write_task) + + save_manifest(manifest_path, manifest) + self.handle.log( + f"Export-diff complete: wrote {written_count} file(s)" + + (f", skipped {skipped_overwrite} (overwrite guard)" if skipped_overwrite else "") + ) + + # ── Directory helpers ──────────────────────────────────────────────────── + + def _verify_export_dir_writable(self) -> None: + """Raise PermissionError if export dir cannot be created or written to.""" + self.export_dir.mkdir(parents=True, exist_ok=True) + if not os.access(self.export_dir, os.W_OK): + raise PermissionError(f"Export directory {self.export_dir} is not writable") + + # ── Repo loading ───────────────────────────────────────────────────────── + + def _vendor_dirs(self, root: Path): + """Yield child dirs of *root*, optionally filtered by ``self.vendor_slugs``. + + Matches against the directory name converted to a slug (lowercase, + non-alphanumeric runs replaced with ``-``) so that directories like + ``Extreme Networks`` match the CLI slug ``extreme-networks``. + """ + if not root.exists(): + return + + def _to_slug(name: str) -> str: + return re.sub(r"[^a-z0-9]+", "-", name.lower()).strip("-") + + if self.vendor_slugs: + wanted = {_to_slug(v) for v in self.vendor_slugs} + for d in root.iterdir(): + if d.is_dir() and _to_slug(d.name) in wanted: + yield d + else: + for d in root.iterdir(): + if d.is_dir(): + yield d + + def _load_repo_device_types(self) -> dict: + """Return ``{(mfr_slug, slug): yaml_dict}`` for repo device types (filtered by vendor).""" + result: dict = {} + seen_files: dict = {} # key -> Path that produced it + for vdir in self._vendor_dirs(self.repo_path / "device-types"): + mfr_slug = re.sub(r"[^a-z0-9]+", "-", vdir.name.lower()).strip("-") + yaml_files = sorted(set(vdir.rglob("*.yaml")) | set(vdir.rglob("*.yml"))) + for yaml_file in yaml_files: + try: + data = yaml.safe_load(yaml_file.read_text(encoding="utf-8")) + except (yaml.YAMLError, OSError) as exc: + self.handle.verbose_log(f"[yellow]Skipping malformed YAML {yaml_file}: {exc}[/yellow]") + continue + if isinstance(data, dict) and "slug" in data: + key = (mfr_slug, data["slug"]) + if key in result: + raise ValueError(f"Duplicate repo device-type key {key!r}: {seen_files[key]} and {yaml_file}") + result[key] = data + seen_files[key] = yaml_file + return result + + def _load_repo_module_types(self) -> dict: + """Return ``{(mfr_slug, model): yaml_dict}`` for repo module types (filtered by vendor).""" + result: dict = {} + seen_files: dict = {} + for vdir in self._vendor_dirs(self.repo_path / "module-types"): + yaml_files = sorted(set(vdir.rglob("*.yaml")) | set(vdir.rglob("*.yml"))) + for yaml_file in yaml_files: + try: + data = yaml.safe_load(yaml_file.read_text(encoding="utf-8")) + except (yaml.YAMLError, OSError) as exc: + self.handle.verbose_log(f"[yellow]Skipping malformed YAML {yaml_file}: {exc}[/yellow]") + continue + if isinstance(data, dict) and "model" in data and "manufacturer" in data: + mfr_slug = _canon_mfr_slug(data["manufacturer"]) + if mfr_slug: + key = (mfr_slug, data["model"]) + if key in result: + raise ValueError( + f"Duplicate repo module-type key {key!r}: {seen_files[key]} and {yaml_file}" + ) + result[key] = data + seen_files[key] = yaml_file + return result + + def _load_repo_rack_types(self) -> dict: + """Return ``{(mfr_slug, model): yaml_dict}`` for repo rack types (filtered by vendor).""" + result: dict = {} + seen_files: dict = {} + for vdir in self._vendor_dirs(self.repo_path / "rack-types"): + yaml_files = sorted(set(vdir.rglob("*.yaml")) | set(vdir.rglob("*.yml"))) + for yaml_file in yaml_files: + try: + data = yaml.safe_load(yaml_file.read_text(encoding="utf-8")) + except (yaml.YAMLError, OSError) as exc: + self.handle.verbose_log(f"[yellow]Skipping malformed YAML {yaml_file}: {exc}[/yellow]") + continue + if isinstance(data, dict) and "model" in data and "manufacturer" in data: + mfr_slug = _canon_mfr_slug(data["manufacturer"]) + if mfr_slug: + key = (mfr_slug, data["model"]) + if key in result: + raise ValueError(f"Duplicate repo rack-type key {key!r}: {seen_files[key]} and {yaml_file}") + result[key] = data + seen_files[key] = yaml_file + return result + + # ── Component fetching ─────────────────────────────────────────────────── + + def _fetch_vendor_components(self, mfr_slug: str) -> tuple: + """Fetch component templates for *mfr_slug* and group by type id. + + Returns ``(dt_components, mt_components)`` where each is + ``{type_id: {endpoint_name: [records]}}``. + + Keeps device-type and module-type IDs in separate dicts to prevent + collisions (both use PostgreSQL auto-increment, so id=5 can exist in + both dcim_devicetype and dcim_moduletype simultaneously). + + The 9 endpoint queries are issued concurrently — for a vendor with + many records (e.g. Juniper) this turns ~40s of sequential paginated + GraphQL calls into ~5s. + """ + from concurrent.futures import ThreadPoolExecutor + + dt_result: dict = {} + mt_result: dict = {} + + # Each worker gets its own GraphQL client to avoid sharing a single + # requests.Session across threads (Session is not thread-safe). + _thread_local = threading.local() + _clients: list = [] + _clients_lock = threading.Lock() + + def _fetch_one(endpoint_name): + if not getattr(_thread_local, "graphql", None): + client = NetBoxGraphQLClient( + self.graphql.url, + self.graphql.token, + self.graphql.ignore_ssl, + self.graphql._log_handler, + self.graphql.DEFAULT_PAGE_SIZE, + ) + _thread_local.graphql = client + with _clients_lock: + _clients.append(client) + return endpoint_name, _thread_local.graphql.get_component_templates( + endpoint_name, manufacturer_slug=mfr_slug + ) + + try: + with ThreadPoolExecutor(max_workers=len(COMPONENT_ENDPOINTS)) as pool: + results = list(pool.map(_fetch_one, [ep_name for _, ep_name in COMPONENT_ENDPOINTS])) + finally: + for client in _clients: + try: + client.close() + except Exception: + pass + + for endpoint_name, records in results: + for rec in records: + dt = getattr(rec, "device_type", None) + mt = getattr(rec, "module_type", None) + if dt and getattr(dt, "id", None): + dt_result.setdefault(dt.id, {}).setdefault(endpoint_name, []).append(rec) + if mt and getattr(mt, "id", None): + mt_result.setdefault(mt.id, {}).setdefault(endpoint_name, []).append(rec) + return dt_result, mt_result + + # ── Export set determination ───────────────────────────────────────────── + + def _determine_export_set_for_device_types( + self, nb_records: list, repo_dt_by_slug: dict, components_by_dt_id: dict + ) -> List[ExportItem]: + items = [] + for rec in nb_records: + serialized = serialize_device_type(rec, components_by_dt_id) + mfr_name = rec.manufacturer.name + mfr_slug = rec.manufacturer.slug + filename = f"{_make_filename(rec.model)}.yaml" + manifest_key = f"{mfr_name}/{rec.slug}" + + repo_yaml = repo_dt_by_slug.get((mfr_slug, rec.slug)) + if repo_yaml is None: + reason = "absent" + elif _repo_supersedes(repo_yaml, serialized): + reason = self._check_missing_images(rec.front_image, rec.rear_image, mfr_name, rec.slug) + if reason is None: + continue + else: + reason = "differs" + + items.append( + ExportItem( + kind="device-type", + nb_record=rec, + repo_yaml=repo_yaml, + serialized=serialized, + reason=reason, + mfr_name=mfr_name, + filename=filename, + manifest_key=manifest_key, + ) + ) + return items + + def _determine_export_set_for_module_types( + self, nb_records: list, repo_mt_by_key: dict, components_by_mt_id: dict + ) -> List[ExportItem]: + items = [] + for rec in nb_records: + serialized = serialize_module_type(rec, components_by_mt_id) + mfr_name = rec.manufacturer.name + mfr_slug = rec.manufacturer.slug + filename = f"{_make_filename(rec.model)}.yaml" + manifest_key = f"{mfr_name}/{rec.model}" + + repo_yaml = repo_mt_by_key.get((mfr_slug, rec.model)) + if repo_yaml is None: + reason = "absent" + elif _repo_supersedes(repo_yaml, serialized): + continue + else: + reason = "differs" + + items.append( + ExportItem( + kind="module-type", + nb_record=rec, + repo_yaml=repo_yaml, + serialized=serialized, + reason=reason, + mfr_name=mfr_name, + filename=filename, + manifest_key=manifest_key, + ) + ) + return items + + def _determine_export_set_for_rack_types(self, nb_records: list, repo_rt_by_key: dict) -> List[ExportItem]: + items = [] + for rec in nb_records: + serialized = serialize_rack_type(rec) + mfr_name = rec.manufacturer.name + mfr_slug = rec.manufacturer.slug + filename = f"{_make_filename(rec.model)}.yaml" + manifest_key = f"{mfr_name}/{rec.model}" + + repo_yaml = repo_rt_by_key.get((mfr_slug, rec.model)) + if repo_yaml is None: + reason = "absent" + elif _repo_supersedes(repo_yaml, serialized): + continue + else: + reason = "differs" + + items.append( + ExportItem( + kind="rack-type", + nb_record=rec, + repo_yaml=repo_yaml, + serialized=serialized, + reason=reason, + mfr_name=mfr_name, + filename=filename, + manifest_key=manifest_key, + ) + ) + return items + + def _check_missing_images(self, front_url, rear_url, mfr_name: str, slug: str) -> Optional[str]: + """Return ``'images-missing'`` if any expected local image is absent; else None. + + DTL stores images under ``elevation-images//.{front,rear}.{png,jpg,jpeg,gif}`` + so we accept any of those extensions when probing the repo. + """ + img_dir = self.repo_path / "elevation-images" / mfr_name + exts = tuple(ext.lstrip(".") for ext in IMAGE_EXTENSIONS) + if front_url and not any((img_dir / f"{slug}.front.{e}").exists() for e in exts): + return "images-missing" + if rear_url and not any((img_dir / f"{slug}.rear.{e}").exists() for e in exts): + return "images-missing" + return None + + # ── File writing ───────────────────────────────────────────────────────── + + def _write_yaml(self, dest: Path, data: dict) -> bool: + """Write *data* as YAML to *dest* with overwrite guard. + + Returns True on write success, False when the overwrite guard blocked. + """ + dest.parent.mkdir(parents=True, exist_ok=True) + content = yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False) + if dest.exists(): + try: + existing = dest.read_text(encoding="utf-8") + except (UnicodeDecodeError, OSError): + existing = None # treat as different content + if existing == content: + return True # same content — no need to overwrite + if not self.force_overwrite: + return False # blocked by overwrite guard + dest.write_text(content, encoding="utf-8") + return True + + # ── Image downloading ──────────────────────────────────────────────────── + + def _download_type_images(self, item: ExportItem) -> bool: + """Download images for *item*. Returns True if all downloads succeeded.""" + if item.kind == "device-type": + return self._download_device_type_images(item) + elif item.kind == "module-type": + return self._download_module_type_images(item) + return True # rack types have no images + + def _download_device_type_images(self, item: ExportItem) -> bool: + img_dir = self.export_dir / "elevation-images" / item.mfr_name + ok = True + for suffix, url_path in ( + ("front", item.nb_record.front_image), + ("rear", item.nb_record.rear_image), + ): + if not url_path: + continue + url_ext = Path(url_path.split("?")[0]).suffix.lower() + ext = url_ext if url_ext in IMAGE_EXTENSIONS else ".png" + content_type_out: list[str] = [] + dest = img_dir / f"{item.nb_record.slug}.{suffix}{ext}" + result = self._download_image(url_path, dest, content_type_out=content_type_out) + # If the URL carried no recognised extension, try to rename based on Content-Type. + if result not in (None, _SKIP) and url_ext not in IMAGE_EXTENSIONS: + ct = content_type_out[0] if content_type_out else "" + new_name = _sanitize_attachment_filename(f"{item.nb_record.slug}.{suffix}", url_path, ct) + new_dest = img_dir / new_name + if new_dest != dest: + try: + if new_dest.exists() and not self.force_overwrite: + dest.unlink(missing_ok=True) + else: + dest.replace(new_dest) + except OSError as exc: + self.handle.verbose_log(f"Could not rename {dest.name!r} → {new_dest.name!r}: {exc}") + ok = False + if result is None: # actual failure + ok = False + return ok + + def _download_module_type_images(self, item: ExportItem) -> bool: + """Download image attachments for a module type. + + Sanitizes each attachment filename to: + - Strip directory components (prevents path traversal). + - Ensure a recognised image extension (derived from the URL suffix or + the response Content-Type header when the name carries none). + """ + try: + details = self._get_module_image_details() + except Exception as exc: + self.handle.log(f"[yellow]Could not fetch module image details: {exc}[/yellow]") + return False + type_images = details.get(item.nb_record.id, {}) + img_dir = self.export_dir / "module-images" / item.mfr_name + ok = True + for att_name, att in type_images.items(): + url_path = att.get("url") if isinstance(att, dict) else getattr(att, "url", None) + if not url_path: + continue + + # First pass: sanitize using URL suffix (no extra HTTP request). + content_type_out: list[str] = [] + safe_name = _sanitize_attachment_filename(att_name, url_path, "") + dest = img_dir / safe_name + + # Path-escape guard: resolved dest must remain under img_dir. + try: + dest.resolve().relative_to(img_dir.resolve()) + except ValueError: + self.handle.log(f"[yellow]Skipping attachment with unsafe path: {att_name!r}[/yellow]") + ok = False + continue + + result = self._download_image(url_path, dest, content_type_out=content_type_out) + + # If extension was unknown and we now have a Content-Type from the response, + # rename the written file to the correct extension. + if result not in (None, _SKIP) and Path(safe_name).suffix.lower() not in IMAGE_EXTENSIONS: + content_type = content_type_out[0] if content_type_out else "" + new_name = _sanitize_attachment_filename(att_name, url_path, content_type) + if new_name != safe_name: + new_dest = img_dir / new_name + try: + new_dest.resolve().relative_to(img_dir.resolve()) + if new_dest.exists() and not self.force_overwrite: + # Respect the overwrite guard; discard the provisional file. + dest.unlink(missing_ok=True) + else: + dest.replace(new_dest) + except (ValueError, OSError) as exc: + self.handle.verbose_log(f"Could not rename {safe_name!r} → {new_name!r}: {exc}") + ok = False + + if result is None: # actual failure + ok = False + return ok + + def _download_image(self, url_path: str, dest: Path, content_type_out: "Optional[list]" = None) -> "Optional[str]": + """Download an image from NetBox and write to *dest*. + + Returns SHA-256 hex digest on success, None on failure, or the module-level + ``_SKIP`` sentinel when the destination already exists and ``--force-export-overwrite`` + is not set (callers must compare ``result is None`` to distinguish failure from skip). + Respects overwrite guard: skips if dest exists and --force-export-overwrite is not set. + """ + if dest.exists() and not self.force_overwrite: + return _SKIP + + full_url = self.base_url + url_path if not url_path.startswith("http") else url_path + # Only send the auth header when the effective URL resolves to the same + # host as base_url — prevents credential leakage to off-host storage + # backends (e.g. S3 redirect, custom CDN). + from urllib.parse import urlparse + + base = urlparse(self.base_url) + target = urlparse(full_url) + headers = {} + if (base.scheme, base.netloc) == (target.scheme, target.netloc): + headers["Authorization"] = _build_auth_header(self.token) + try: + resp = requests.get( + full_url, + headers=headers, + verify=not self.ignore_ssl, + timeout=30, + ) + except requests.RequestException as exc: + self.handle.log(f"[yellow]Image download failed {full_url}: {exc}[/yellow]") + return None + + content_type = resp.headers.get("Content-Type", "") + if not resp.ok or "text" in content_type or "json" in content_type: + self.handle.log(f"[yellow]Image not available at {full_url} (status {resp.status_code})[/yellow]") + return None + + if content_type_out is not None: + content_type_out.append(content_type) + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(resp.content) + return hashlib.sha256(resp.content).hexdigest() diff --git a/core/export_manifest.py b/core/export_manifest.py new file mode 100644 index 000000000..9d632420e --- /dev/null +++ b/core/export_manifest.py @@ -0,0 +1,51 @@ +"""Manifest helpers for the --export-diff feature. + +The manifest (`.export-manifest.json`) lives in the export directory and +records the NetBox ``last_updated`` timestamp for each exported type so that +repeat runs can skip re-exporting unchanged types. +""" + +import json +import os +from pathlib import Path + +_EMPTY = {"device-types": {}, "module-types": {}, "rack-types": {}} + + +def load_manifest(path: Path) -> dict: + """Load manifest from *path*. Returns an empty manifest on any error.""" + try: + loaded = json.loads(Path(path).read_text(encoding="utf-8")) + if not isinstance(loaded, dict): + return {k: {} for k in _EMPTY} + # Ensure each expected section exists and is a dict. + return {k: (loaded[k] if isinstance(loaded.get(k), dict) else {}) for k in _EMPTY} + except (OSError, json.JSONDecodeError, ValueError, UnicodeDecodeError): + return {k: {} for k in _EMPTY} + + +def save_manifest(path: Path, data: dict) -> None: + """Atomically write *data* to *path* (write-then-rename).""" + path = Path(path) + tmp = path.with_suffix(".tmp") + tmp.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8") + os.replace(tmp, path) + + +def is_entry_fresh(manifest: dict, kind: str, key: str, last_updated: str) -> bool: + """Return True if the manifest entry for *key* matches *last_updated*.""" + section = manifest.get(kind) + if not isinstance(section, dict): + return False + entry = section.get(key) + if not isinstance(entry, dict): + return False + return entry.get("last_updated") == last_updated + + +def update_entry(manifest: dict, kind: str, key: str, last_updated: str) -> None: + """Write (or overwrite) the manifest entry for *key*.""" + section = manifest.get(kind) + if not isinstance(section, dict): + manifest[kind] = {} + manifest[kind][key] = {"last_updated": last_updated} diff --git a/core/graphql_client.py b/core/graphql_client.py index 201d931c0..0faadc4ff 100644 --- a/core/graphql_client.py +++ b/core/graphql_client.py @@ -10,17 +10,22 @@ import requests +# Module-level dedup: tracks (url, requested_page_size) pairs that have already +# emitted the page-size clamping warning so the message appears at most once +# even when multiple client instances share the same server. +_CLAMPING_WARNED: set = set() +_CLAMPING_WARNED_LOCK = threading.Lock() + class GraphQLError(Exception): """Raised when a GraphQL query fails (HTTP error or GraphQL-level errors).""" class GraphQLCountMismatchError(GraphQLError): - """Raised when the number of records returned by GraphQL does not match the REST count. + """Raised when a GraphQL-cached component count differs from the REST API count. - This indicates a silent truncation in the GraphQL response — e.g. the server - returned far fewer records than it reports via the REST API. The run is aborted - to prevent processing an incomplete cache. + Indicates that GraphQL silently truncated results, which would lead to + incomplete data being imported. The import should be aborted and retried. """ @@ -176,8 +181,6 @@ def __init__(self, url, token, ignore_ssl=False, log_handler=None, page_size=500 self.token = token self.ignore_ssl = ignore_ssl self._log_handler = log_handler - self._page_size_clamping_warned = False - self._page_size_clamping_lock = threading.Lock() self._session = requests.Session() # v2 tokens start with "nbt_" prefix (format: nbt_.); @@ -318,9 +321,10 @@ def query_all(self, graphql_query, list_key, page_size=None, variables=None, on_ elif n > 0 and effective_page_size < page_size: # Second page arrived and the first page was smaller than # requested — clamping confirmed. - with self._page_size_clamping_lock: - if not self._page_size_clamping_warned: - self._page_size_clamping_warned = True + _key = (self.url, page_size) + with _CLAMPING_WARNED_LOCK: + if _key not in _CLAMPING_WARNED: + _CLAMPING_WARNED.add(_key) msg = ( f"WARNING: NetBox capped the GraphQL page size at " f"{effective_page_size} (requested {page_size}). " @@ -359,17 +363,61 @@ def get_manufacturers(self): items = self.query_all(query, list_key="manufacturer_list") return {item["name"]: _to_dotdict(item) for item in items} - def get_device_types(self): + def _build_manufacturer_filter(self, slugs): + """Return ``(var_decl, filter_fragment, extra_variables)`` for manufacturer filtering. + + ``var_decl`` is appended to the query's variable list (e.g. + ``', $manufacturer_slug: String!'``). ``filter_fragment`` is placed + before the ``pagination`` argument in the list field call. + ``extra_variables`` is the dict to pass as ``variables`` to + :meth:`query_all`. + + Args: + slugs: ``None`` or a non-empty list of manufacturer slug strings. + + Returns: + tuple[str, str, dict] + """ + if not slugs: + return "", "", {} + if not isinstance(slugs, list) or any(not isinstance(s, str) or not s.strip() for s in slugs): + raise ValueError("manufacturer_slugs must be None or a non-empty list of non-empty strings") + slugs = [s.strip() for s in slugs] + if len(slugs) == 1: + return ( + ", $manufacturer_slug: String!", + "filters: {manufacturer: {slug: {exact: $manufacturer_slug}}}, ", + {"manufacturer_slug": slugs[0]}, + ) + return ( + ", $manufacturer_slugs: [String!]!", + "filters: {manufacturer: {slug: {in_list: $manufacturer_slugs}}}, ", + {"manufacturer_slugs": slugs}, + ) + + def get_device_types(self, manufacturer_slugs=None): """Fetch all device types and return two lookup indexes. + Args: + manufacturer_slugs: Optional list of manufacturer slugs to filter by. + When provided, only device types from the specified manufacturers are returned. + Returns: tuple[dict, dict]: - ``by_model``: ``{(manufacturer_slug, model): record}`` - ``by_slug``: ``{(manufacturer_slug, slug): record}`` + + Raises: + ValueError: If *manufacturer_slugs* is an empty list. """ - query = """ - query($pagination: OffsetPaginationInput) { - device_type_list(pagination: $pagination) { + if manufacturer_slugs is not None and len(manufacturer_slugs) == 0: + raise ValueError("manufacturer_slugs must be None or a non-empty list") + + var_decl, filter_fragment, extra_vars = self._build_manufacturer_filter(manufacturer_slugs) + + query = f""" + query($pagination: OffsetPaginationInput{var_decl}) {{ + device_type_list({filter_fragment}pagination: $pagination) {{ id model slug @@ -382,17 +430,18 @@ def get_device_types(self): weight_unit description comments - front_image { url } - rear_image { url } - manufacturer { + last_updated + front_image {{ url }} + rear_image {{ url }} + manufacturer {{ id name slug - } - } - } + }} + }} + }} """ - items = self.query_all(query, list_key="device_type_list") + items = self.query_all(query, list_key="device_type_list", variables=extra_vars or None) by_model = {} by_slug = {} @@ -409,15 +458,27 @@ def get_device_types(self): return by_model, by_slug - def get_module_types(self): + def get_module_types(self, manufacturer_slugs=None): """Fetch all module types and return them indexed by manufacturer slug and model. + Args: + manufacturer_slugs: Optional list of manufacturer slugs to filter by. + When provided, only module types from the specified manufacturers are returned. + Returns: dict: ``{manufacturer_slug: {model: record}}`` + + Raises: + ValueError: If *manufacturer_slugs* is an empty list. """ - query = """ - query($pagination: OffsetPaginationInput) { - module_type_list(pagination: $pagination) { + if manufacturer_slugs is not None and len(manufacturer_slugs) == 0: + raise ValueError("manufacturer_slugs must be None or a non-empty list") + + var_decl, filter_fragment, extra_vars = self._build_manufacturer_filter(manufacturer_slugs) + + query = f""" + query($pagination: OffsetPaginationInput{var_decl}) {{ + module_type_list({filter_fragment}pagination: $pagination) {{ id model part_number @@ -426,15 +487,16 @@ def get_module_types(self): comments weight weight_unit - manufacturer { + last_updated + manufacturer {{ id name slug - } - } - } + }} + }} + }} """ - items = self.query_all(query, list_key="module_type_list") + items = self.query_all(query, list_key="module_type_list", variables=extra_vars or None) result = {} for item in items: @@ -444,15 +506,27 @@ def get_module_types(self): return result - def get_rack_types(self): + def get_rack_types(self, manufacturer_slugs=None): """Fetch all rack types and return them indexed by manufacturer slug and model. + Args: + manufacturer_slugs: Optional list of manufacturer slugs to filter by. + When provided, only rack types from the specified manufacturers are returned. + Returns: dict: ``{manufacturer_slug: {model: record}}`` + + Raises: + ValueError: If *manufacturer_slugs* is an empty list. """ - query = """ - query($pagination: OffsetPaginationInput) { - rack_type_list(pagination: $pagination) { + if manufacturer_slugs is not None and len(manufacturer_slugs) == 0: + raise ValueError("manufacturer_slugs must be None or a non-empty list") + + var_decl, filter_fragment, extra_vars = self._build_manufacturer_filter(manufacturer_slugs) + + query = f""" + query($pagination: OffsetPaginationInput{var_decl}) {{ + rack_type_list({filter_fragment}pagination: $pagination) {{ id model slug @@ -471,15 +545,16 @@ def get_rack_types(self): desc_units comments description - manufacturer { + last_updated + manufacturer {{ id name slug - } - } - } + }} + }} + }} """ - items = self.query_all(query, list_key="rack_type_list") + items = self.query_all(query, list_key="rack_type_list", variables=extra_vars or None) result = {} for item in items: record = _to_dotdict(item) @@ -513,7 +588,9 @@ def get_module_type_images(self): """ try: items = self.query_all(query, list_key="image_attachment_list") - except GraphQLError: + except GraphQLError as e: + if isinstance(e, GraphQLCountMismatchError): + raise # Fallback: fetch all attachments and filter in Python fallback_query = """ query($pagination: OffsetPaginationInput) { @@ -548,11 +625,159 @@ def get_module_type_images(self): return result - def get_component_templates(self, endpoint_name, on_page=None): + def get_module_type_image_details(self): + """Fetch image attachments for module types including attachment IDs and URLs. + + Used by the ``--verify-images`` code path to obtain the information needed to + check physical presence, compare content hashes, and delete stale attachments. + + Returns: + dict: ``{module_type_id: {attachment_name: {"att_id": id, "url": url}}}`` + """ + query = """ + query($pagination: OffsetPaginationInput) { + image_attachment_list( + pagination: $pagination, + filters: {object_type: {app_label: {exact: "dcim"}, model: {exact: "moduletype"}}} + ) { + id + name + object_id + image { url } + } + } + """ + try: + items = self.query_all(query, list_key="image_attachment_list") + except GraphQLError as e: + if isinstance(e, GraphQLCountMismatchError): + raise + fallback_query = """ + query($pagination: OffsetPaginationInput) { + image_attachment_list(pagination: $pagination) { + id + name + object_id + image { url } + object_type { app_label model } + } + } + """ + all_items = self.query_all(fallback_query, list_key="image_attachment_list") + items = [ + i + for i in all_items + if (i.get("object_type") or {}).get("app_label") == "dcim" + and (i.get("object_type") or {}).get("model") == "moduletype" + ] + + result = {} + for item in items: + name = item.get("name") + if not name: + continue + obj_id = item["object_id"] + if isinstance(obj_id, str): + try: + obj_id = int(obj_id) + except ValueError: + continue + att_id = item.get("id") + if isinstance(att_id, str): + try: + att_id = int(att_id) + except ValueError: + att_id = None + image_field = item.get("image") or {} + url = image_field.get("url", "") if isinstance(image_field, dict) else str(image_field or "") + result.setdefault(obj_id, {})[name] = {"att_id": att_id, "url": url} + + return result + + @staticmethod + def _front_port_field_variants(fields): + """Yield successive field-list tiers for the front_port_templates fallback. + + Tier 1: mappings block (NetBox 4.5+) + Tier 2: rear_port_position scalar (<4.5) + Tier 3: neither (field removed entirely) + """ + yield fields + fallback = [] + for f in fields: + if "mappings" in f: + fallback.extend(["rear_port_position", "rear_port { id name }"]) + else: + fallback.append(f) + yield fallback + stripped = [f for f in fallback if f != "rear_port_position" and "rear_port" not in f] + if len(stripped) < len(fallback): + yield stripped + + def _query_component_endpoint( + self, list_key, filter_clause, endpoint_name, fields, parent_fields, on_page, var_decl="", extra_variables=None + ): + """Query a component template endpoint with fallback logic for front_port_templates. + + Args: + list_key: GraphQL list key (e.g., "interface_template_list"). + filter_clause: GraphQL filter clause string (empty or with trailing comma/space). + endpoint_name: Endpoint name for fallback handling. + fields: List of field strings. + parent_fields: String with device_type/module_type fields. + on_page: Optional callable for progress reporting. + var_decl: Optional variable declaration string appended to the query signature. + extra_variables: Optional dict of extra variables to pass to query_all. + + Returns: + list: Raw items (not DotDict-wrapped). + """ + + def _build_query(field_list): + field_block = "\n ".join(field_list) + return f""" + query($pagination: OffsetPaginationInput{var_decl}) {{ + {list_key}({filter_clause}pagination: $pagination) {{ + {field_block} + {parent_fields} + }} + }} + """ + + is_front_port_with_mappings = endpoint_name == "front_port_templates" and any("mappings" in f for f in fields) + field_variants = list(self._front_port_field_variants(fields)) if is_front_port_with_mappings else [fields] + + # Three-tier fallback for front_port_templates: + # Tier 1: mappings { ... } (NetBox 4.5+) + # Tier 2: rear_port_position (<4.5 direct scalar field) + # Tier 3: neither (future: field removed entirely) + original_exc = last_exc = None + for variant in field_variants: + try: + return self.query_all( + _build_query(variant), list_key=list_key, on_page=on_page, variables=extra_variables + ) + except GraphQLError as exc: + if isinstance(exc, GraphQLCountMismatchError): + raise + last_exc = exc + if original_exc is None: + original_exc = exc + + if original_exc is last_exc: + raise last_exc + raise last_exc from original_exc + + def get_component_templates(self, endpoint_name, manufacturer_slug=None, on_page=None): """Fetch component template records for the given endpoint. Args: endpoint_name: Endpoint name as used by DeviceTypes (e.g. ``"interface_templates"``). + manufacturer_slug: Optional manufacturer slug to filter by. When provided, + fetches templates for both device types and module types from that manufacturer. + This is intentionally a single slug (not a list) because component preloading + operates one vendor at a time, unlike :meth:`get_device_types` / + :meth:`get_module_types` which accept a ``manufacturer_slugs`` list. on_page: Optional callable passed to :meth:`query_all` to receive the item count after each page is fetched. @@ -561,71 +786,62 @@ def get_component_templates(self, endpoint_name, on_page=None): Raises: ValueError: If *endpoint_name* is not a recognized component template endpoint. + ValueError: If *manufacturer_slug* is an empty string. """ if endpoint_name not in COMPONENT_TEMPLATE_FIELDS or endpoint_name not in ENDPOINT_TO_LIST_KEY: raise ValueError(f"Unknown component endpoint: {endpoint_name}") + if manufacturer_slug is not None and len(manufacturer_slug) == 0: + raise ValueError("manufacturer_slug must be None or a non-empty string") + fields = COMPONENT_TEMPLATE_FIELDS[endpoint_name] list_key = ENDPOINT_TO_LIST_KEY[endpoint_name] - field_block = "\n ".join(fields) parent_fields = "device_type { id }" if endpoint_name not in _NO_MODULE_TYPE: parent_fields += "\n module_type { id }" - query = f""" - query($pagination: OffsetPaginationInput) {{ - {list_key}(pagination: $pagination) {{ - {field_block} - {parent_fields} - }} - }} - """ - - try: - items = self.query_all(query, list_key=list_key, on_page=on_page) - except GraphQLError as original_exc: - if endpoint_name == "front_port_templates": - # Three-tier fallback for front_port_templates: - # 1. Primary: mappings { ... } (NetBox 4.5+) - # 2. First fallback: rear_port_position (<4.5 direct scalar field) - # 3. Second fallback: neither (future: field removed entirely) - has_mappings = any("mappings" in f for f in fields) - if not has_mappings: - raise - - # First fallback: replace the mappings block with the scalar rear_port_position - fallback_fields = ["rear_port_position" if "mappings" in f else f for f in fields] - field_block = "\n ".join(fallback_fields) - fallback_query = f""" - query($pagination: OffsetPaginationInput) {{ - {list_key}(pagination: $pagination) {{ - {field_block} - {parent_fields} - }} - }} - """ - try: - items = self.query_all(fallback_query, list_key=list_key, on_page=on_page) - except GraphQLError as fallback_exc: - # Second fallback: strip rear_port_position too - second_fallback_fields = [f for f in fallback_fields if f != "rear_port_position"] - if len(second_fallback_fields) == len(fallback_fields): - # rear_port_position wasn't in fallback_fields — nothing more to try - raise fallback_exc from original_exc - field_block = "\n ".join(second_fallback_fields) - second_fallback_query = f""" - query($pagination: OffsetPaginationInput) {{ - {list_key}(pagination: $pagination) {{ - {field_block} - {parent_fields} - }} - }} - """ - try: - items = self.query_all(second_fallback_query, list_key=list_key, on_page=on_page) - except GraphQLError as second_exc: - raise second_exc from original_exc + if manufacturer_slug is None: + # Unfiltered query (original behavior) + items = self._query_component_endpoint( + list_key=list_key, + filter_clause="", + endpoint_name=endpoint_name, + fields=fields, + parent_fields=parent_fields, + on_page=on_page, + ) + else: + var_decl = ", $manufacturer_slug: String!" + extra_vars = {"manufacturer_slug": manufacturer_slug} + # Vendor-scoped: query device-type-filtered templates + device_filter = "filters: {device_type: {manufacturer: {slug: {exact: $manufacturer_slug}}}}, " + device_items = self._query_component_endpoint( + list_key=list_key, + filter_clause=device_filter, + endpoint_name=endpoint_name, + fields=fields, + parent_fields=parent_fields, + on_page=on_page, + var_decl=var_decl, + extra_variables=extra_vars, + ) + + # If endpoint supports module_type, also query module-type-filtered templates + if endpoint_name not in _NO_MODULE_TYPE: + module_filter = "filters: {module_type: {manufacturer: {slug: {exact: $manufacturer_slug}}}}, " + module_items = self._query_component_endpoint( + list_key=list_key, + filter_clause=module_filter, + endpoint_name=endpoint_name, + fields=fields, + parent_fields=parent_fields, + on_page=on_page, + var_decl=var_decl, + extra_variables=extra_vars, + ) + items = device_items + module_items else: - raise + items = device_items + return [_to_dotdict(item) for item in items] diff --git a/core/nb_serializer.py b/core/nb_serializer.py new file mode 100644 index 000000000..003608df1 --- /dev/null +++ b/core/nb_serializer.py @@ -0,0 +1,298 @@ +"""Serialize NetBox API records to DTL-compatible YAML dicts (export-diff feature). + +Direction: NetBox record → Python dict suitable for ``yaml.dump()`` and +comparison against existing repo YAML files. +""" + +import warnings +from typing import Any + +# Maps YAML component key → NetBox endpoint name. +# Order defines the output key order in the serialized YAML. +COMPONENT_ENDPOINTS = [ + ("interfaces", "interface_templates"), + ("power-ports", "power_port_templates"), + ("console-ports", "console_port_templates"), + ("power-outlets", "power_outlet_templates"), + ("console-server-ports", "console_server_port_templates"), + ("rear-ports", "rear_port_templates"), + ("front-ports", "front_port_templates"), + ("device-bays", "device_bay_templates"), + ("module-bays", "module_bay_templates"), +] + +# The DTL endpoint names (for use with get_component_templates()) +COMPONENT_ENDPOINT_NAMES = [ep_name for _, ep_name in COMPONENT_ENDPOINTS] + +# Field lists per component type (same fields as graphql_client.COMPONENT_TEMPLATE_FIELDS, +# minus "id" and parent fields). +_IFACE_FIELDS = ["name", "type", "label", "description", "mgmt_only", "enabled", "poe_mode", "poe_type", "rf_role"] +_POWER_PORT_FIELDS = ["name", "type", "label", "description", "maximum_draw", "allocated_draw"] +_CONSOLE_FIELDS = ["name", "type", "label", "description"] +_POWER_OUTLET_FIELDS = ["name", "type", "label", "description", "feed_leg"] +_REAR_PORT_FIELDS = ["name", "type", "label", "description", "positions", "color"] +_FRONT_PORT_FIELDS = ["name", "type", "label", "description", "color"] # rear_port handled separately +_DEVICE_BAY_FIELDS = ["name", "label", "description"] +_MODULE_BAY_FIELDS = ["name", "position", "label", "description"] + +_COMPONENT_FIELDS = { + "interface_templates": _IFACE_FIELDS, + "power_port_templates": _POWER_PORT_FIELDS, + "console_port_templates": _CONSOLE_FIELDS, + "console_server_port_templates": _CONSOLE_FIELDS, + "power_outlet_templates": _POWER_OUTLET_FIELDS, + "rear_port_templates": _REAR_PORT_FIELDS, + "front_port_templates": _FRONT_PORT_FIELDS, + "device_bay_templates": _DEVICE_BAY_FIELDS, + "module_bay_templates": _MODULE_BAY_FIELDS, +} + +# Values that are defaults — omit from output to keep YAML clean. +_OMIT_IF_EQUAL = { + "label": "", + "description": "", + "comments": "", + "mgmt_only": False, + "enabled": True, # True is the interface default; include only when False + "color": "", + "poe_mode": None, + "poe_type": None, + "rf_role": None, + "feed_leg": None, + "maximum_draw": None, + "allocated_draw": None, + "positions": 1, # rear port default; include only when > 1 +} + +# Device type scalar field order for output. +_DT_SCALAR_FIELDS = [ + "manufacturer", + "model", + "slug", + "part_number", + "u_height", + "is_full_depth", + "airflow", + "weight", + "weight_unit", + "description", + "comments", +] + +# Module type scalar field order for output. +_MT_SCALAR_FIELDS = [ + "manufacturer", + "model", + "part_number", + "airflow", + "weight", + "weight_unit", + "description", + "comments", +] + +# Rack type scalar field order for output. +_RT_SCALAR_FIELDS = [ + "manufacturer", + "model", + "slug", + "form_factor", + "description", + "width", + "u_height", + "starting_unit", + "outer_width", + "outer_height", + "outer_depth", + "outer_unit", + "mounting_depth", + "weight", + "max_weight", + "weight_unit", + "desc_units", + "comments", +] + + +def _coerce_numeric(val: Any) -> Any: + """Coerce float-with-integer-value or numeric string to a Python numeric type. + + - ``1.0`` → ``1`` (float integer → int) + - ``'12.0'`` → ``12`` (string integer → int) + - ``'13.60'`` → ``13.6`` (string float → float, trailing zeros dropped) + """ + if isinstance(val, float) and not isinstance(val, bool) and val.is_integer(): + return int(val) + # Only coerce strings that look like decimals (contain '.') — NetBox + # DecimalField values come back as e.g. '13.60' or '1.0'. Plain integer + # strings like '1' are preserved (they may belong to CharField columns + # such as ``position`` where DTL convention keeps them quoted). + if isinstance(val, str) and "." in val: + try: + f = float(val) + if f.is_integer(): + return int(f) + return f + except (ValueError, TypeError): + pass + return val + + +def _should_include(field: str, val: Any) -> bool: + """Return True when *val* should be written to the YAML output.""" + if val is None: + return False + if isinstance(val, str) and val == "": + return False + if field in _OMIT_IF_EQUAL and val == _OMIT_IF_EQUAL[field]: + return False + return True + + +def _serialize_component(record: Any, fields: list) -> dict: + """Serialize a single component template record to a YAML-ready dict.""" + result = {} + for field in fields: + val = getattr(record, field, None) + val = _coerce_numeric(val) + if _should_include(field, val): + result[field] = val + return result + + +def _serialize_front_port(record: Any) -> dict: + """Serialize a front port template, including rear_port mapping.""" + result = _serialize_component(record, _FRONT_PORT_FIELDS) + mappings = getattr(record, "mappings", None) or [] + if mappings: + if len(mappings) > 1: + port_name = getattr(record, "name", "") + warnings.warn( + f"Front port '{port_name}' has {len(mappings)} mappings; " + "only the first will be exported. " + "Full multi-mapping support requires DTL schema update (see issue #78).", + UserWarning, + stacklevel=4, + ) + m = mappings[0] + rear_port = getattr(m, "rear_port", None) + if rear_port: + result["rear_port"] = rear_port.name + rear_pos = getattr(m, "rear_port_position", None) + rear_pos = _coerce_numeric(rear_pos) + if rear_pos is not None and rear_pos > 1: + result["rear_port_position"] = rear_pos + else: + # Legacy: pre-4.5 NetBox returns rear_port / rear_port_position as direct scalar fields + legacy_rp = getattr(record, "rear_port", None) + if legacy_rp: + result["rear_port"] = legacy_rp.name + legacy_pos = getattr(record, "rear_port_position", None) + legacy_pos = _coerce_numeric(legacy_pos) + if legacy_pos is not None and legacy_pos > 1: + result["rear_port_position"] = legacy_pos + return result + + +def _serialize_component_list(endpoint_name: str, records: list) -> list: + """Serialize a list of component template records for a given endpoint.""" + out = [] + for record in sorted(records, key=lambda r: str(getattr(r, "name", "") or "")): + if endpoint_name == "front_port_templates": + out.append(_serialize_front_port(record)) + else: + out.append(_serialize_component(record, _COMPONENT_FIELDS[endpoint_name])) + return out + + +def _add_components(result: dict, type_id: int, components_by_id: dict) -> None: + """Append serialized component lists to *result* for a given type id.""" + type_components = components_by_id.get(type_id, {}) + for yaml_key, endpoint_name in COMPONENT_ENDPOINTS: + records = type_components.get(endpoint_name, []) + if records: + result[yaml_key] = _serialize_component_list(endpoint_name, records) + + +def serialize_device_type(nb_record: Any, components_by_dt_id: dict) -> dict: + """Convert a NetBox device type record to a DTL-compatible YAML dict. + + Args: + nb_record: DotDict returned by ``NetBoxGraphQLClient.get_device_types()``. + components_by_dt_id: ``{device_type_id: {endpoint_name: [records]}}``. + + Returns: + Ordered dict suitable for ``yaml.dump()``. + """ + result = {} + for field in _DT_SCALAR_FIELDS: + if field == "manufacturer": + mfr = getattr(nb_record, "manufacturer", None) + if mfr is not None: + result["manufacturer"] = mfr.name + continue + val = getattr(nb_record, field, None) + val = _coerce_numeric(val) + if field in ("u_height", "is_full_depth"): + # Always include — commonly explicit in DTL files + if val is not None: + result[field] = val + elif _should_include(field, val): + result[field] = val + + if getattr(nb_record, "front_image", None): + result["front_image"] = True + if getattr(nb_record, "rear_image", None): + result["rear_image"] = True + + _add_components(result, nb_record.id, components_by_dt_id) + return result + + +def serialize_module_type(nb_record: Any, components_by_mt_id: dict) -> dict: + """Convert a NetBox module type record to a DTL-compatible YAML dict. + + Args: + nb_record: DotDict returned by ``NetBoxGraphQLClient.get_module_types()``. + components_by_mt_id: ``{module_type_id: {endpoint_name: [records]}}``. + + Returns: + Ordered dict suitable for ``yaml.dump()``. + """ + result = {} + for field in _MT_SCALAR_FIELDS: + if field == "manufacturer": + mfr = getattr(nb_record, "manufacturer", None) + if mfr is not None: + result["manufacturer"] = mfr.name + continue + val = getattr(nb_record, field, None) + val = _coerce_numeric(val) + if _should_include(field, val): + result[field] = val + + _add_components(result, nb_record.id, components_by_mt_id) + return result + + +def serialize_rack_type(nb_record: Any) -> dict: + """Convert a NetBox rack type record to a DTL-compatible YAML dict. + + Rack types have no component templates. + """ + result = {} + for field in _RT_SCALAR_FIELDS: + if field == "manufacturer": + mfr = getattr(nb_record, "manufacturer", None) + if mfr is not None: + result["manufacturer"] = mfr.name + continue + val = getattr(nb_record, field, None) + val = _coerce_numeric(val) + # desc_units is bool — include regardless of value (explicit design choice) + if field == "desc_units": + if val is not None: + result[field] = val + elif _should_include(field, val): + result[field] = val + return result diff --git a/core/netbox_api.py b/core/netbox_api.py index bc06427eb..258e2bdca 100644 --- a/core/netbox_api.py +++ b/core/netbox_api.py @@ -3,9 +3,12 @@ from collections import Counter import concurrent.futures from functools import lru_cache +import hashlib import itertools +import json import queue import re +import tempfile import time import pynetbox import requests @@ -15,7 +18,12 @@ from pathlib import Path from core.change_detector import ChangeDetector, ChangeType -from core.compat import device_type_filter_kwargs, module_type_filter_kwargs, module_type_filter_key +from core.compat import ( + device_type_filter_key, + device_type_filter_kwargs, + module_type_filter_key, + module_type_filter_kwargs, +) from core.formatting import log_property_diffs from core.graphql_client import GraphQLCountMismatchError, GraphQLError, NetBoxGraphQLClient from core.normalization import values_equal @@ -33,6 +41,27 @@ def _build_auth_header(token): return f"{scheme} {token}" +def _fmt_connection_error(url: str, exc: Exception) -> str: + """Return a human-friendly message for a connection-level network error. + + Used wherever a ``requests.exceptions.ConnectionError`` (which wraps + ``urllib3`` ``ProtocolError`` / ``RemoteDisconnected`` etc.) is caught, so + that the message format is consistent across all call sites. + + Args: + url: The NetBox base URL that was being contacted. + exc: The caught exception. + + Returns: + A single multi-line string suitable for printing to stderr or a log. + """ + return ( + f"Connection error while contacting NetBox at {url}: {exc}\n" + "The remote end closed the connection unexpectedly. " + "Verify that NetBox is running, reachable, and not being restarted." + ) + + # Transient connection errors that warrant a retry _RETRYABLE_EXCEPTIONS = (requests.exceptions.ConnectionError, requests.exceptions.Timeout) @@ -44,6 +73,177 @@ def _build_auth_header(token): _UNKNOWN_SRC = "Unknown" +def _check_image_url( + base_url: str, + image_url_path: str, + ignore_ssl: bool, + token: str = "", + log_fn=None, +) -> str: + """Check whether a remote image URL is physically accessible. + + Issues an authenticated HTTP GET and reports whether the image exists on the server. + Content/byte comparison is intentionally omitted: NetBox re-encodes images on + upload so remote bytes never match the originals. Use + :func:`_is_image_hash_changed` for local-file change detection instead. + + Returns "ok" only when the server returns a 2xx response *and* the Content-Type + indicates an actual image. A 2xx with a non-image Content-Type (e.g. ``text/html`` + from a login-redirect) is treated as "missing" so that files absent from the + filesystem but still recorded in the database are re-uploaded. + + Returns: + "missing": the server returned a non-2xx response, or a 2xx but with a + non-image Content-Type (image not physically present / auth redirect) + "ok": image exists (2xx with image Content-Type) or a network error + occurred (conservative — avoids spurious re-uploads on transient + failures; network error is logged at verbose level when *log_fn* + is provided so operators can spot degraded runs) + + Args: + base_url: NetBox base URL (e.g. "https://netbox.example.com"). + image_url_path: Relative path from NetBox (e.g. "/media/devicetype-images/foo.png") + or a full URL starting with "http". + ignore_ssl: When True, SSL certificate verification is skipped. + token: NetBox API token. When non-empty, sent using the same + ``Authorization`` scheme as ``_build_auth_header`` (``Bearer`` for + ``nbt_…`` tokens, ``Token`` otherwise) to support all NetBox token + types. Auth is only sent when the URL resolves to the same host as + *base_url*, preventing credential leakage to off-host URLs. + log_fn: Optional callable ``(msg: str) -> None`` invoked at verbose level + when a network error is swallowed. Pass ``handle.verbose_log``. + """ + full_url = image_url_path if image_url_path.startswith("http") else base_url.rstrip("/") + image_url_path + headers = {} + if token: + # Only send auth header when the effective URL is on the same host as base_url. + from urllib.parse import urlparse + + base_host = urlparse(base_url).netloc + target_host = urlparse(full_url).netloc + if base_host == target_host: + headers["Authorization"] = _build_auth_header(token) + try: + response = requests.get(full_url, headers=headers, verify=(not ignore_ssl), timeout=30) + except requests.RequestException as exc: + if log_fn is not None: + log_fn( + f"[yellow]Network error checking image {full_url}: {exc} " + f"— treating as present to avoid spurious re-upload[/yellow]" + ) + return "ok" + if not response.ok: + return "missing" + content_type = response.headers.get("Content-Type", "") + if content_type.startswith("text/") or content_type.startswith("application/json"): + return "missing" + return "ok" + + +def _is_image_hash_changed(local_path: str, hash_cache: dict) -> bool: + """Return True if the local file's SHA-256 hash differs from the cached value. + + The cache maps local file paths to the SHA-256 hex-digest recorded at the time + the file was last uploaded. Comparing local-to-local (rather than local-to-remote) + avoids the unreliability caused by NetBox re-encoding images on upload. + + Returns False when *local_path* is absent from *hash_cache* (conservative: avoids + re-uploading images that have never been tracked). + + Args: + local_path: Absolute filesystem path to the local image file. + hash_cache: Dict mapping local path strings to SHA-256 hex-digests. + """ + cached = hash_cache.get(local_path) + if cached is None: + return False + try: + with open(local_path, "rb") as fh: + current = hashlib.sha256(fh.read()).hexdigest() + except OSError: + return False + return current != cached + + +def _load_image_hash_cache(path: str) -> dict: + """Load the image-hash cache from *path* (JSON). Returns an empty dict on any error.""" + try: + with open(path, encoding="utf-8") as fh: + data = json.load(fh) + return data if isinstance(data, dict) else {} + except Exception: + return {} + + +def _save_image_hash_cache(path: str, cache: dict) -> bool: + """Persist *cache* to *path* as a JSON file, written atomically. + + Writes to a temporary file in the same directory, fsyncs it, then + replaces *path* with ``os.replace`` so callers never see a truncated file. + + Returns True on success, False on I/O failure. Callers should warn when + False is returned: a missing cache entry causes ``_is_image_hash_changed`` + to report "unchanged", which would suppress re-uploads for locally edited + images on the next run. + """ + dir_ = os.path.dirname(os.path.abspath(path)) + tmp_path = None + try: + fd, tmp_path = tempfile.mkstemp(dir=dir_, suffix=".tmp") + with os.fdopen(fd, "w", encoding="utf-8") as fh: + json.dump(cache, fh) + fh.flush() + os.fsync(fh.fileno()) + os.replace(tmp_path, path) + tmp_path = None # successfully replaced; skip cleanup + return True + except Exception: + if tmp_path is not None: + try: + os.unlink(tmp_path) + except OSError: + pass + return False + + +def _store_image_hashes(cache: dict, images: dict) -> None: + """Compute and store SHA-256 hashes for each local image path in *images*. + + *images* maps arbitrary string keys to local file paths. Entries that cannot + be read are silently skipped. Updates *cache* in-place. + """ + for path in images.values(): + try: + with open(path, "rb") as fh: + cache[path] = hashlib.sha256(fh.read()).hexdigest() + except OSError: + pass + + +def _delete_image_attachment(base_url: str, token: str, att_id: int, ignore_ssl: bool, handle) -> bool: + """Delete a NetBox image attachment by ID via DELETE /api/extras/image-attachments/{id}/. + + Args: + base_url: NetBox base URL. + token: API token used for the Authorization header. + att_id: Numeric ID of the image attachment to delete. + ignore_ssl: When True, skip SSL certificate verification. + handle: Log handler with a ``log`` method for error reporting. + + Returns: + bool: True on success, False on any HTTP or network error. + """ + url = f"{base_url}/api/extras/image-attachments/{att_id}/" + headers = {"Authorization": _build_auth_header(token)} + try: + response = requests.delete(url, headers=headers, verify=(not ignore_ssl), timeout=30) + response.raise_for_status() + return True + except requests.RequestException as e: + handle.log(f"Error deleting image attachment {att_id}: {e}") + return False + + def _retry_on_connection_error(func, *args, **kwargs): """Call *func* with retries on transient connection errors. @@ -207,6 +407,25 @@ def __init__(self, settings, handle): self.m2m_front_ports = False # True for NetBox >= 4.5 (M2M port mappings) self.rack_types = False self.force_resolve_conflicts = False + self.remove_unmanaged_types = False + self.verify_images = False + self._module_image_details: dict = {} # populated by _fetch_module_type_existing_images in verify mode + # Image hash cache: local file path -> SHA-256 hex-digest at last upload time. + # Used by --verify-images to detect whether the local file changed since last upload, + # avoiding the unreliability of comparing local bytes to NetBox-served bytes (NetBox + # re-encodes images). Stored under ~/.cache/nb-dt-import/ (XDG_CACHE_HOME respected). + _cache_dir = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) / "nb-dt-import" + try: + _cache_dir.mkdir(parents=True, exist_ok=True) + self._image_hash_cache_path = str(_cache_dir / "image-hashes.json") + except OSError: + self.handle.verbose_log( + "[yellow]Warning: could not create image hash cache directory " + f"({_cache_dir}); hash-based re-upload detection will be disabled " + "for this run.[/yellow]" + ) + self._image_hash_cache_path = None + self._image_hash_cache: dict = _load_image_hash_cache(self._image_hash_cache_path) self.connect_api() self.verify_compatibility() self.graphql = NetBoxGraphQLClient( @@ -231,17 +450,45 @@ def __init__(self, settings, handle): m2m_front_ports=self.m2m_front_ports, max_threads=settings.PRELOAD_THREADS, ) - except GraphQLError as e: - system_exit(f"GraphQL error fetching device types: {e}") + except Exception as e: + system_exit(f"Error initializing device types: {e}") self._change_detector: ChangeDetector | None = None @property def change_detector(self) -> "ChangeDetector": """Lazily initialised, reused :class:`ChangeDetector` instance.""" if self._change_detector is None: - self._change_detector = ChangeDetector(self.device_types, self.handle) + self._change_detector = ChangeDetector( + self.device_types, + self.handle, + remove_unmanaged_types=self.remove_unmanaged_types, + ) return self._change_detector + def load_vendor(self, manufacturer_slug: str): + """Load device types for *manufacturer_slug* and reset per-vendor state. + + Delegates to :meth:`DeviceTypes.load_for_vendor` to populate the device + type lookup indexes, then clears the cached :class:`ChangeDetector` so + that the next access constructs a fresh instance against the new data. + + Args: + manufacturer_slug (str): Manufacturer slug to load. + """ + self.device_types.load_for_vendor(manufacturer_slug) + self._change_detector = None + self._module_image_details = {} # stale module entries must not bleed across vendors + + def _persist_hash_cache(self) -> None: + """Save the image hash cache and warn once if the write fails.""" + if self._image_hash_cache_path is None: + return + if not _save_image_hash_cache(self._image_hash_cache_path, self._image_hash_cache): + self.handle.verbose_log( + "[yellow]Warning: failed to persist image hash cache; " + "local image edits may not be detected on the next run.[/yellow]" + ) + def connect_api(self): """Connect to the NetBox API using the stored URL and token credentials.""" try: @@ -281,10 +528,7 @@ def verify_compatibility(self): f"(both with and without brackets for IPv6, e.g. '::1,[::1]')." ) except requests.exceptions.ConnectionError as e: - system_exit( - f"Connection error while connecting to NetBox at {self.url}: {e}\n" - f"Hint: Verify that NetBox is running and reachable at {self.url}." - ) + system_exit(_fmt_connection_error(self.url, e)) except pynetbox.core.query.RequestError as e: endpoint = getattr(e, "base", self.url) status = getattr(e.req, "status_code", "?") if hasattr(e, "req") else "?" @@ -594,6 +838,57 @@ def _log_device_type_change_outcome( "No property or component changes applied." ) + def _filter_images_for_upload(self, dt, saved_images): + """Remove from *saved_images* any image that does not need uploading. + + For each image kind present in *saved_images* that already has a record in NetBox, + either removes the entry unconditionally (default mode) or verifies physical + presence and local-file hash (``--verify-images`` mode) before deciding. + + In ``--verify-images`` mode two independent checks are run: + + 1. **HTTP accessibility** — an HTTP GET confirms the file exists on the server. + A non-2xx response means the file is physically missing. + 2. **Local-file hash** — the current SHA-256 of the local image is compared to + the hash recorded in the image-hash cache at the time of the last upload. + A mismatch means the local source file was updated since the last import. + + NetBox re-encodes images on upload so comparing local bytes to remote bytes is + unreliable; the local-hash cache approach is used instead. + + Args: + dt: pynetbox device type record for the existing device type. + saved_images (dict): Mapping of image kind to local file path; modified in-place. + """ + for image_kind in ("front_image", "rear_image"): + if image_kind not in saved_images: + continue + db_url = getattr(dt, image_kind, None) + if not db_url: + continue # no record in NetBox yet → keep for upload + label = image_kind.replace("_", " ").capitalize() + if not self.verify_images: + self.handle.verbose_log(f"{label} already exists for {dt.model}, skipping upload.") + del saved_images[image_kind] + continue + # --verify-images: Step 1 — check physical presence via HTTP + status = _check_image_url(self.url, db_url, self.ignore_ssl, self.token, log_fn=self.handle.verbose_log) + if status == "missing": + self.handle.verbose_log(f"{label} is missing on server for {dt.model}, will re-upload.") + continue # keep in saved_images for upload + # --verify-images: Step 2 — check if local file changed since last upload + if _is_image_hash_changed(saved_images[image_kind], self._image_hash_cache): + self.handle.verbose_log(f"{label} content has changed for {dt.model}, will re-upload.") + continue # keep in saved_images for upload + # Both checks passed — image is present and unchanged; + # seed hash cache so future local edits will be detected. + local_path = saved_images[image_kind] + if local_path not in self._image_hash_cache: + _store_image_hashes(self._image_hash_cache, {image_kind: local_path}) + self._persist_hash_cache() + self.handle.verbose_log(f"{label} verified OK for {dt.model}, skipping upload.") + del saved_images[image_kind] + def _handle_existing_device_type( self, dt, @@ -621,16 +916,11 @@ def _handle_existing_device_type( absent from the YAML. """ if saved_images: - if "front_image" in saved_images and getattr(dt, "front_image", None): - self.handle.verbose_log(f"Front image already exists for {dt.model}, skipping upload.") - del saved_images["front_image"] - - if "rear_image" in saved_images and getattr(dt, "rear_image", None): - self.handle.verbose_log(f"Rear image already exists for {dt.model}, skipping upload.") - del saved_images["rear_image"] - + self._filter_images_for_upload(dt, saved_images) if saved_images: self.device_types.upload_images(self.url, self.token, saved_images, dt.id) + _store_image_hashes(self._image_hash_cache, saved_images) + self._persist_hash_cache() if only_new: self.handle.verbose_log( @@ -777,6 +1067,8 @@ def _create_device_type_components(self, device_type, dt_id, src_file, saved_ima self.device_types.create_module_bays(device_type["module-bays"], dt_id) if saved_images: self.device_types.upload_images(self.url, self.token, saved_images, dt_id) + _store_image_hashes(self._image_hash_cache, saved_images) + self._persist_hash_cache() def create_device_types( self, @@ -1105,6 +1397,12 @@ def filter_actionable_module_types(self, module_types, all_module_types, only_ne image_changed = any( os.path.splitext(os.path.basename(path))[0] not in existing_images for path in image_files ) + # With --verify-images, images whose names already exist in NetBox also need + # to be re-examined for physical presence and local-file hash changes. + # _upload_module_type_images contains all the probe + decision logic; we just + # need to ensure this module type is considered actionable so it reaches that path. + if not image_changed and self.verify_images and image_files and existing_images: + image_changed = True changed_fields_info = [] for f in _load_module_type_properties(): @@ -1149,10 +1447,21 @@ def log_module_type_changes(self, changed_property_log): def _fetch_module_type_existing_images(self): """Query NetBox for all image attachments on module types via GraphQL and return a mapping. + When ``self.verify_images`` is True the richer attachment metadata (ID + URL) is fetched + via :meth:`~core.graphql_client.NetBoxGraphQLClient.get_module_type_image_details` and + stored on ``self._module_image_details`` for use by + :meth:`_upload_module_type_images`. + Returns: dict: ``{module_type_id: set_of_attachment_names}`` """ - module_type_existing_images = self.graphql.get_module_type_images() + if self.verify_images: + details = self.graphql.get_module_type_image_details() + self._module_image_details = details + module_type_existing_images = {obj_id: set(names.keys()) for obj_id, names in details.items()} + else: + self._module_image_details = {} + module_type_existing_images = self.graphql.get_module_type_images() self.handle.verbose_log( f"Found {len(module_type_existing_images)} module type(s) with existing image attachments." ) @@ -1456,8 +1765,9 @@ def count_device_type_images(self, device_types_to_add): for i in ["front_image", "rear_image"]: if device_type.get(i): - # Skip if existing device type already has this image - if dt is not None and getattr(dt, i, None): + # Skip if existing device type already has this image, unless verify_images + # is active (in that case we may re-upload even existing images so count them). + if not self.verify_images and dt is not None and getattr(dt, i, None): continue image_glob = f"{image_base}/{device_slug}.{i.split('_')[0]}.*" if glob.glob(image_glob, recursive=False): @@ -1538,6 +1848,31 @@ def _discover_module_image_files(src_file): image_files = glob.glob(str(image_dir / f"{src_path.stem}.*")) return [f for f in image_files if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS] + def _try_delete_stale_attachment(self, detail, img_path, module_type_res, existing, img_name) -> bool: + """Delete the stale attachment for *img_name* so a fresh upload can follow. + + Returns True when the attachment was successfully deleted (caller should + proceed to re-upload). Returns False when deletion is skipped or fails + (caller should ``continue`` without re-uploading to avoid duplicates). + """ + att_id = detail.get("att_id") if isinstance(detail, dict) else None + if not isinstance(att_id, int): + self.handle.verbose_log( + f"Cannot delete stale attachment for " + f"'{os.path.basename(img_path)}' on {module_type_res.model}: " + "missing or invalid att_id, skipping upload to avoid duplicates." + ) + return False + if not _delete_image_attachment(self.url, self.token, att_id, self.ignore_ssl, self.handle): + self.handle.verbose_log( + f"Failed to delete stale attachment for " + f"'{os.path.basename(img_path)}' on {module_type_res.model}, " + "skipping upload to avoid duplicates." + ) + return False + existing.discard(img_name) + return True + def _upload_module_type_images(self, module_type_res, src_file, module_type_existing_images): """Discover and upload images for a module type, skipping duplicates. @@ -1548,6 +1883,11 @@ def _upload_module_type_images(self, module_type_res, src_file, module_type_exis (basename without extension) is not already present in module_type_existing_images for this module type. + When ``self.verify_images`` is True, existing attachments are verified via + HTTP GET. If an attachment is missing on the server or its content differs + from the local file, the stale attachment is deleted and the image is + re-uploaded. + Args: module_type_res: pynetbox Record for the module type. src_file (str): Source YAML file path used to derive the image directory. @@ -1561,14 +1901,70 @@ def _upload_module_type_images(self, module_type_res, src_file, module_type_exis for img_path in image_files: img_name = os.path.splitext(os.path.basename(img_path))[0] if img_name in existing: - self.handle.verbose_log( - f"Image '{os.path.basename(img_path)}' already exists for {module_type_res.model}, skipping." - ) - continue + if self.verify_images: + detail = self._module_image_details.get(module_type_res.id, {}).get(img_name) + if detail: + img_url = detail.get("url", "") + full_url = img_url if img_url.startswith("http") else self.url.rstrip("/") + img_url + # Step 1: HTTP accessibility check + status = _check_image_url( + self.url, + full_url, + self.ignore_ssl, + self.token, + log_fn=self.handle.verbose_log, + ) + if status == "missing": + self.handle.verbose_log( + f"Image '{os.path.basename(img_path)}' missing on server for " + f"{module_type_res.model}, re-uploading." + ) + deleted = self._try_delete_stale_attachment( + detail, img_path, module_type_res, existing, img_name + ) + if not deleted: + continue + # Step 2: local-file hash check + elif _is_image_hash_changed(img_path, self._image_hash_cache): + self.handle.verbose_log( + f"Image '{os.path.basename(img_path)}' content has changed for " + f"{module_type_res.model}, re-uploading." + ) + deleted = self._try_delete_stale_attachment( + detail, img_path, module_type_res, existing, img_name + ) + if not deleted: + continue + else: + # Verify OK: image present and hash unchanged. + # Seed hash cache so future local edits will be detected. + if img_path not in self._image_hash_cache: + _store_image_hashes(self._image_hash_cache, {"image": img_path}) + self._persist_hash_cache() + self.handle.verbose_log( + f"Image '{os.path.basename(img_path)}' verified OK for " + f"{module_type_res.model}, skipping." + ) + continue + else: + # If no detail available, skip upload to avoid creating duplicate attachments. + self.handle.verbose_log( + f"Image '{os.path.basename(img_path)}' already exists for " + f"{module_type_res.model} but detail is unavailable; " + "skipping upload to avoid duplicates." + ) + continue + else: + self.handle.verbose_log( + f"Image '{os.path.basename(img_path)}' already exists for {module_type_res.model}, skipping." + ) + continue if self.device_types.upload_image_attachment( self.url, self.token, img_path, "dcim.moduletype", module_type_res.id ): existing.add(img_name) + _store_image_hashes(self._image_hash_cache, {"image": img_path}) + self._persist_hash_cache() # Component type -> (dcim endpoint attribute name, cache key name). @@ -1673,7 +2069,10 @@ def __init__( m2m_front_ports=False, max_threads=8, ): - """Initialize the DeviceTypes cache and load all existing device types from NetBox. + """Initialize empty DeviceTypes cache structures; no data is fetched at construction time. + + Device type data is loaded lazily via :meth:`load_for_vendor` on a per-vendor + basis rather than eagerly at startup. Args: netbox: Connected pynetbox API instance. @@ -1696,7 +2095,8 @@ def __init__( self.cached_components = {} self._global_preload_done = False self._image_progress = None - self.existing_device_types, self.existing_device_types_by_slug = self.get_device_types() + self.existing_device_types = {} + self.existing_device_types_by_slug = {} def get_device_types(self): """Fetch all device types from NetBox via GraphQL and build two lookup indexes. @@ -1708,12 +2108,31 @@ def get_device_types(self): """ return self.graphql.get_device_types() + def load_for_vendor(self, manufacturer_slug: str): + """Fetch device types for a single vendor and populate the lookup indexes. + + Replaces any previously loaded data so that state from a prior vendor + does not bleed into the current one. + + Args: + manufacturer_slug (str): Manufacturer slug to load device types for. + """ + self.cached_components = {} + self._global_preload_done = False + by_model, by_slug = self.graphql.get_device_types(manufacturer_slugs=[manufacturer_slug]) + self.existing_device_types = by_model + self.existing_device_types_by_slug = by_slug + # Endpoints whose GraphQL schema is missing fields required for accurate # change detection and where the REST API provides the missing data. # Add endpoint names here if a future NetBox version drops a field from # GraphQL but keeps it in REST (or vice-versa). REST_ONLY_ENDPOINTS: frozenset = frozenset() + # Endpoints that only apply to device types (no module-type path). + # Matches _NO_MODULE_TYPE in graphql_client.py. + _NO_MODULE_TYPE_ENDPOINTS: frozenset = frozenset({"device_bay_templates"}) + @staticmethod def _component_preload_targets(): """Return the list of ``(endpoint_attr, display_label)`` pairs used for component preloading.""" @@ -1729,79 +2148,41 @@ def _component_preload_targets(): ("module_bay_templates", "Module Bays"), ] - def _get_rest_component_count(self, endpoint_name): - """Return the REST API count for *endpoint_name*, or ``None`` on failure. - - Issues a single lightweight ``?limit=1`` REST call that returns just the - total record count — no item data is transferred. Used to validate that - subsequent GraphQL fetches returned the expected number of records. - - Args: - endpoint_name (str): Component template endpoint name (e.g. ``"interface_templates"``). - - Returns: - int | None: Total record count, or ``None`` if the request fails. - """ - try: - return getattr(self.netbox.dcim, endpoint_name).count() - except Exception as exc: - self.handle.verbose_log( - f"REST count unavailable for {endpoint_name}; skipping GraphQL count verification: {exc}" - ) - return None - - def _get_endpoint_totals(self, components): - """Fetch REST record counts for each component endpoint in parallel. - - Issues one lightweight ``?limit=1`` REST call per endpoint to obtain the - expected total before the GraphQL fetch begins. These totals are later - used to detect silent truncation in :meth:`_fetch_global_endpoint_records`. - - REST-only endpoints (see :attr:`REST_ONLY_ENDPOINTS`) are excluded because - their counts will be determined by the REST fetch itself. + def start_component_preload( + self, + progress=None, + manufacturer_slug: str | None = None, + task_registry: dict | None = None, + ): + """Start concurrent component prefetch and return a preload job handle. Args: - components: Iterable of ``(endpoint_name, label)`` tuples. - - Returns: - dict: ``{endpoint_name: count_or_None}`` for graphql endpoints (``None`` - preserves the "count unavailable" sentinel from - :meth:`_get_rest_component_count`), and ``0`` for REST-only endpoints - (whose authoritative count is established by the REST fetch itself). + progress: Optional Rich Progress instance for task tracking. + manufacturer_slug (str | None): When provided, fetch only component templates + belonging to this manufacturer's device types and module types. + task_registry (dict | None): When provided, task_ids are looked up or created + in this shared registry so they persist and accumulate counts across all + vendors rather than appearing and disappearing per vendor. """ - graphql_endpoints = [ep for ep, _label in components if ep not in self.REST_ONLY_ENDPOINTS] - totals = {ep: (0 if ep in self.REST_ONLY_ENDPOINTS else None) for ep, _label in components} - - max_workers = max(1, min(len(graphql_endpoints), self.max_threads)) - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_ep = {executor.submit(self._get_rest_component_count, ep): ep for ep in graphql_endpoints} - for future in concurrent.futures.as_completed(future_to_ep): - ep = future_to_ep[future] - result = future.result() - if isinstance(result, int) and not isinstance(result, bool): - totals[ep] = result - - return totals - - def start_component_preload(self, progress=None): - """Start concurrent component prefetch and return a preload job handle.""" components = self._component_preload_targets() max_workers = max(1, min(len(components), self.max_threads)) executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) try: - endpoint_totals = self._get_endpoint_totals(components) + endpoint_totals = {endpoint_name: None for endpoint_name, _label in components} progress_updates = queue.Queue() task_ids = None if progress is not None: - task_ids = { - endpoint_name: progress.add_task( - f"Caching {label}", - total=endpoint_totals.get(endpoint_name), - ) - for endpoint_name, label in components - } + task_ids = {} + for endpoint_name, label in components: + desc = f"Caching {label}" + if task_registry is not None: + if desc not in task_registry: + task_registry[desc] = progress.add_task(desc, total=None) + task_ids[endpoint_name] = task_registry[desc] + else: + task_ids[endpoint_name] = progress.add_task(desc, total=None) def update_progress(endpoint_name, advance): """Put a progress update onto the queue for the main thread to consume.""" @@ -1812,7 +2193,7 @@ def update_progress(endpoint_name, advance): self._fetch_global_endpoint_records, endpoint_name, update_progress, - endpoint_totals.get(endpoint_name), + manufacturer_slug, ) for endpoint_name, _label in components } @@ -1825,17 +2206,22 @@ def update_progress(endpoint_name, advance): "task_ids": task_ids, "finished_endpoints": set(), "executor": executor, + # When task_registry is provided the caller owns the tasks; + # this job must not stop or remove them on completion. + "owns_tasks": task_registry is None, } except Exception: executor.shutdown(wait=False, cancel_futures=True) raise @staticmethod - def stop_component_preload(preload_job): + def stop_component_preload(preload_job, progress=None): """Cancel any pending futures in *preload_job* and shut down its executor. Args: preload_job (dict | None): Preload job returned by :meth:`start_component_preload`; no-op if None. + progress: Optional Rich Progress instance; if provided, any remaining progress + tasks in the job are removed from the display. """ if not preload_job: return @@ -1850,6 +2236,17 @@ def stop_component_preload(preload_job): executor.shutdown(wait=False, cancel_futures=True) preload_job["executor"] = None + if progress is not None: + task_ids = preload_job.get("task_ids") or {} + owns_tasks = preload_job.get("owns_tasks", True) + if owns_tasks: + for task_id in task_ids.values(): + try: + progress.stop_task(task_id) + progress.remove_task(task_id) + except Exception: + pass + @staticmethod def _apply_progress_updates(progress_updates, progress, task_ids, allowed_endpoints=None): """Drain the progress queue and advance the corresponding Rich progress tasks. @@ -1923,6 +2320,7 @@ def pump_preload_progress(self, preload_job, progress): ) task_ids = preload_job.get("task_ids") or {} + owns_tasks = preload_job.get("owns_tasks", True) for endpoint_name in pending_endpoints: future = futures.get(endpoint_name) if future is None or not future.done(): @@ -1933,8 +2331,10 @@ def pump_preload_progress(self, preload_job, progress): final_total = max(len(records), 1) except Exception: final_total = 1 - progress.update(task_ids[endpoint_name], total=final_total, completed=final_total) - progress.stop_task(task_ids[endpoint_name]) + if owns_tasks: + progress.update(task_ids[endpoint_name], total=final_total, completed=final_total) + progress.stop_task(task_ids[endpoint_name]) + progress.remove_task(task_ids[endpoint_name]) finished_endpoints.add(endpoint_name) advanced = True @@ -1945,17 +2345,20 @@ def preload_all_components( progress_wrapper=None, preload_job=None, progress=None, + manufacturer_slug: str | None = None, + task_registry: dict | None = None, ): """Pre-fetch component templates to avoid N+1 queries during updates. - Always fetches all components globally via GraphQL — fast enough - that vendor/device-type scoping is unnecessary. - Args: - progress_wrapper: Optional callable to wrap iterables with progress bars - preload_job: Optional preload job from start_component_preload(). + progress_wrapper: Optional callable to wrap iterables with progress bars. + preload_job: Optional preload job from :meth:`start_component_preload`. progress: Optional shared Rich Progress instance used to render all caching tasks inside a single progress panel. + manufacturer_slug (str | None): When provided, only fetch component templates + for device/module types belonging to this manufacturer. + task_registry (dict | None): Shared registry for cumulative progress tasks. + When provided, "Caching X" tasks persist across all vendors. """ components = self._component_preload_targets() @@ -1965,11 +2368,31 @@ def preload_all_components( progress_wrapper, preload_job=preload_job, progress=progress, + task_registry=task_registry, + ) + else: + self._preload_global( + components, + progress_wrapper, + progress=progress, + manufacturer_slug=manufacturer_slug, + task_registry=task_registry, ) - self._global_preload_done = True - return - self._preload_global(components, progress_wrapper, progress=progress) + if manufacturer_slug is not None: + try: + vendor_dt_ids = {record.id for record in self.existing_device_types.values()} + vendor_mt_data = self.graphql.get_module_types(manufacturer_slugs=[manufacturer_slug]) + vendor_mt_ids = {record.id for models in vendor_mt_data.values() for record in models.values()} + except Exception as exc: + self.handle.log(f"WARNING: Component cache integrity check skipped: {exc}") + else: + self._verify_component_cache_integrity(vendor_dt_ids, vendor_mt_ids) + # Count check is intentionally outside the warning try/except above: + # a mismatch means GraphQL silently truncated results and the import + # must not proceed with incomplete data. + self._check_component_counts_against_rest(vendor_dt_ids, vendor_mt_ids) + self._global_preload_done = True def _preload_track_progress( @@ -1981,6 +2404,7 @@ def _preload_track_progress( preload_job, progress_updates, endpoint_totals, + owns_tasks=True, ): """Collect preload results and advance progress tasks as each endpoint future completes. @@ -1995,6 +2419,7 @@ def _preload_track_progress( preload_job (dict | None): Shared preload-job state dict, or None. progress_updates (queue.Queue | None): Queue carrying ``(endpoint_name, advance)`` tuples. endpoint_totals (dict): Expected record count per endpoint. + owns_tasks (bool): Whether this call owns the progress tasks and should stop/remove them. Returns: dict: ``{endpoint_name: [records]}`` populated as futures complete. @@ -2008,8 +2433,6 @@ def _preload_track_progress( for endpoint_name in already_done: try: records_by_endpoint[endpoint_name] = future_map[endpoint_name].result() - except GraphQLCountMismatchError: - raise except Exception as exc: self.handle.log(f"Preload failed for {endpoint_name}: {exc}") raise @@ -2020,12 +2443,14 @@ def _preload_track_progress( len(records_by_endpoint[endpoint_name]), 1, ) - progress.update( - task_ids[endpoint_name], - total=final_total, - completed=final_total, - ) - progress.stop_task(task_ids[endpoint_name]) + if owns_tasks: + progress.update( + task_ids[endpoint_name], + total=final_total, + completed=final_total, + ) + progress.stop_task(task_ids[endpoint_name]) + progress.remove_task(task_ids[endpoint_name]) except Exception: pass # Exclude from pending to avoid double stop_task. @@ -2038,6 +2463,7 @@ def _preload_track_progress( progress_updates, endpoint_totals, records_by_endpoint, + owns_tasks=owns_tasks, ) return records_by_endpoint @@ -2050,6 +2476,7 @@ def _drain_pending( progress_updates, endpoint_totals, records_by_endpoint, + owns_tasks=True, ): """Wait for pending endpoint futures to complete, collecting results and updating progress. @@ -2064,6 +2491,7 @@ def _drain_pending( progress_updates (queue.Queue | None): Queue of ``(endpoint_name, advance)`` tuples. endpoint_totals (dict): Expected record count per endpoint. records_by_endpoint (dict): Accumulator dict updated in-place with results. + owns_tasks (bool): Whether this call owns the progress tasks and should stop/remove them. """ while pending: had_updates = self._apply_progress_updates( @@ -2077,8 +2505,6 @@ def _drain_pending( pending.remove(endpoint_name) try: records_by_endpoint[endpoint_name] = future_map[endpoint_name].result() - except GraphQLCountMismatchError: - raise except Exception as exc: self.handle.log(f"Preload failed for {endpoint_name}: {exc}") raise @@ -2088,9 +2514,10 @@ def _drain_pending( 1, ) task_id = task_ids.get(endpoint_name) - if task_id is not None: + if task_id is not None and owns_tasks: progress.update(task_id, total=final_total, completed=final_total) progress.stop_task(task_id) + progress.remove_task(task_id) if pending and not had_updates: if progress_updates is not None: try: @@ -2137,15 +2564,21 @@ def _preload_no_progress(self, components, futures): self.handle.verbose_log(f"Pre-fetching {label}...") try: records_by_endpoint[endpoint] = futures[endpoint].result() - except GraphQLCountMismatchError: - raise except Exception as exc: self.handle.log(f"Preload failed for {label}: {exc}") raise return records_by_endpoint - def _preload_global(self, components, progress_wrapper=None, preload_job=None, progress=None): - """Fetch all component templates globally (no vendor/device filter).""" + def _preload_global( + self, + components, + progress_wrapper=None, + preload_job=None, + progress=None, + manufacturer_slug=None, + task_registry=None, + ): + """Fetch all component templates, optionally scoped to a single manufacturer.""" own_executor = preload_job is None if preload_job: executor = preload_job.get("executor") @@ -2154,7 +2587,7 @@ def _preload_global(self, components, progress_wrapper=None, preload_job=None, p endpoint_totals = preload_job.get("endpoint_totals", {}) else: max_workers = max(1, min(len(components), self.max_threads)) - endpoint_totals = self._get_endpoint_totals(components) + endpoint_totals = {endpoint_name: None for endpoint_name, _label in components} executor = None futures = {} progress_updates = None @@ -2174,7 +2607,7 @@ def update_progress(endpoint_name, advance): self._fetch_global_endpoint_records, endpoint, update_progress, - endpoint_totals.get(endpoint), + manufacturer_slug, ) for endpoint, _label in components } @@ -2184,20 +2617,23 @@ def update_progress(endpoint_name, advance): self._fetch_global_endpoint_records, endpoint, None, - endpoint_totals.get(endpoint), + manufacturer_slug, ) for endpoint, _label in components } if progress is not None: task_ids = preload_job.get("task_ids") if preload_job else None + owns_tasks = preload_job.get("owns_tasks", True) if preload_job else (task_registry is None) if not task_ids: - task_ids = { - endpoint: progress.add_task( - f"Caching {label}", - total=endpoint_totals.get(endpoint), - ) - for endpoint, label in components - } + task_ids = {} + for endpoint, label in components: + desc = f"Caching {label}" + if task_registry is not None: + if desc not in task_registry: + task_registry[desc] = progress.add_task(desc, total=None) + task_ids[endpoint] = task_registry[desc] + else: + task_ids[endpoint] = progress.add_task(desc, total=None) records_by_endpoint = self._preload_track_progress( components, futures, @@ -2206,6 +2642,7 @@ def update_progress(endpoint_name, advance): preload_job, progress_updates, endpoint_totals, + owns_tasks=owns_tasks, ) else: records_by_endpoint = self._preload_no_progress(components, futures) @@ -2223,7 +2660,7 @@ def update_progress(endpoint_name, advance): executor.shutdown(wait=True) preload_job["executor"] = None - def _fetch_global_endpoint_records(self, endpoint_name, progress_callback=None, expected_total=None): + def _fetch_global_endpoint_records(self, endpoint_name, progress_callback=None, manufacturer_slug=None): """Fetch all records for *endpoint_name* from NetBox. Most endpoints are fetched via GraphQL for speed. Endpoints listed in @@ -2242,12 +2679,9 @@ def _fetch_global_endpoint_records(self, endpoint_name, progress_callback=None, progress_callback (callable | None): Called with ``(endpoint_name, advance)`` once per page during the GraphQL fetch (or once after the batch fetch completes for REST endpoints). *advance* is a positive integer equal to - the number of records on that page. On a count-mismatch retry the - callback is invoked once with a **negative** advance to rewind the - progress bar by the same amount, keeping the display consistent across - attempts. - expected_total (int | None): Expected record count obtained from the REST API before the - GraphQL fetch. If provided and the fetched count differs, a warning is logged. + the number of records on that page. + manufacturer_slug (str | None): When provided, only templates belonging to + device types or module types of this manufacturer are fetched. Returns: list: All component template records. @@ -2261,45 +2695,16 @@ def _fetch_global_endpoint_records(self, endpoint_name, progress_callback=None, progress_callback(endpoint_name, len(records)) return records - for attempt in range(_MAX_RETRIES + 1): - # Forward per-page advances LIVE so the progress bar moves while a - # large endpoint (e.g. interfaces, 100k+ records) is fetching. Track - # fetched_this_attempt so that on a count-mismatch retry we can - # rewind by the same amount, keeping the bar consistent. - fetched_this_attempt = 0 - - def _live_advance(n): - nonlocal fetched_this_attempt - fetched_this_attempt += n - if progress_callback is not None and n: - progress_callback(endpoint_name, n) - - on_page = _live_advance if progress_callback is not None else None - records = self.graphql.get_component_templates(endpoint_name, on_page=on_page) - if endpoint_name == "front_port_templates": - records = [_FrontPortRecordWithMappings(r) for r in records] - - if expected_total is not None and len(records) != expected_total: - if attempt < _MAX_RETRIES: - backoff = _RETRY_BACKOFF[attempt] - self.handle.log( - f"WARNING: GraphQL returned {len(records)} {endpoint_name} " - f"but REST API expected {expected_total}. " - f"Retrying in {backoff}s (attempt {attempt + 1}/{_MAX_RETRIES})…" - ) - if progress_callback is not None and fetched_this_attempt: - # Rewind the bar so the next attempt's live advances do - # not double-count. - progress_callback(endpoint_name, -fetched_this_attempt) - time.sleep(backoff) - continue - raise GraphQLCountMismatchError( - f"GraphQL returned {len(records)} {endpoint_name} " - f"but REST API expected {expected_total} " - f"after {_MAX_RETRIES} retries. " - "Run aborted to prevent processing an incomplete component cache." - ) - break + def _live_advance(n): + if progress_callback is not None and n: + progress_callback(endpoint_name, n) + + on_page = _live_advance if progress_callback is not None else None + records = self.graphql.get_component_templates( + endpoint_name, manufacturer_slug=manufacturer_slug, on_page=on_page + ) + if endpoint_name == "front_port_templates": + records = [_FrontPortRecordWithMappings(r) for r in records] return records @@ -2338,6 +2743,106 @@ def _build_component_cache(items): return cache, count + def _verify_component_cache_integrity(self, vendor_dt_ids: set, vendor_mt_ids: set) -> bool: + """Check that cached component records belong to the current vendor. + + For each endpoint in :attr:`cached_components`, verifies that at least one + record has a parent ID (``device_type.id`` or ``module_type.id``) that + appears in *vendor_dt_ids* or *vendor_mt_ids* respectively. A non-empty + endpoint whose records contain **no** matching IDs is treated as garbage + data and cleared. + + Args: + vendor_dt_ids (set): Device type IDs belonging to the current vendor. + vendor_mt_ids (set): Module type IDs belonging to the current vendor. + + Returns: + bool: ``True`` if all non-empty endpoints passed the check, + ``False`` if any were cleared. + """ + all_ok = True + for endpoint_name, entries in list(self.cached_components.items()): + if not entries: + continue + has_valid = any( + (parent_type == "device" and parent_id in vendor_dt_ids) + or (parent_type == "module" and parent_id in vendor_mt_ids) + for (parent_type, parent_id) in entries + ) + if not has_valid: + self.handle.log( + f"ERROR: Cached {endpoint_name} contains no records matching the current vendor — " + "clearing to prevent cross-vendor contamination." + ) + self.cached_components[endpoint_name] = {} + all_ok = False + return all_ok + + def _rest_count_chunked(self, rest_endpoint, filter_key, ids, chunk_size=100): + """Return REST count for *ids* using *filter_key*, chunked to avoid URL-length limits. + + Args: + rest_endpoint: pynetbox endpoint object (e.g. ``self.netbox.dcim.interface_templates``). + filter_key (str): Filter parameter name (e.g. ``"device_type_id"``). + ids (list): List of integer IDs to filter by. + chunk_size (int): Maximum IDs per REST request. + + Returns: + int: Total count across all chunks. + """ + total = 0 + for i in range(0, len(ids), chunk_size): + chunk = ids[i : i + chunk_size] + total += rest_endpoint.count(**{filter_key: chunk}) + return total + + def _check_component_counts_against_rest(self, vendor_dt_ids: set, vendor_mt_ids: set): + """Verify that GraphQL-cached component counts match REST API counts for this vendor. + + For each preloaded component endpoint, counts cached records belonging to the + current vendor and compares with pynetbox REST counts. A discrepancy means + GraphQL silently truncated the fetch and the import should not proceed. + + Args: + vendor_dt_ids: Device type IDs for the current vendor. + vendor_mt_ids: Module type IDs for the current vendor. + + Raises: + GraphQLCountMismatchError: If any endpoint's cached count differs from REST. + """ + dt_filter_key = device_type_filter_key(self.new_filters) + mt_filter_key = module_type_filter_key(self.new_filters) + dt_ids = list(vendor_dt_ids) + mt_ids = list(vendor_mt_ids) + + for endpoint_name, _label in self._component_preload_targets(): + if endpoint_name in self.REST_ONLY_ENDPOINTS: + # REST-only endpoints are fetched via REST already — comparing REST + # count to REST count is tautological and adds no value. + continue + + endpoint_cache = self.cached_components.get(endpoint_name, {}) + cached_count = sum( + len(records) + for (parent_type, parent_id), records in endpoint_cache.items() + if (parent_type == "device" and parent_id in vendor_dt_ids) + or (parent_type == "module" and parent_id in vendor_mt_ids) + ) + + rest_ep = getattr(self.netbox.dcim, endpoint_name) + rest_count = 0 + if dt_ids: + rest_count += self._rest_count_chunked(rest_ep, dt_filter_key, dt_ids) + if mt_ids and endpoint_name not in self._NO_MODULE_TYPE_ENDPOINTS: + rest_count += self._rest_count_chunked(rest_ep, mt_filter_key, mt_ids) + + if cached_count != rest_count: + raise GraphQLCountMismatchError( + f"{endpoint_name}: GraphQL returned {cached_count} records " + f"but REST reports {rest_count} — " + "GraphQL may have silently truncated the result set." + ) + def _get_filter_kwargs(self, parent_id, parent_type="device"): """Build endpoint filter keyword arguments for the given parent type and ID. diff --git a/core/repo.py b/core/repo.py index 0bb6eb495..fa3c5380e 100644 --- a/core/repo.py +++ b/core/repo.py @@ -1,14 +1,93 @@ """Git repository helpers for cloning, updating, and parsing the device-type library.""" import os +import pickle from glob import glob from re import sub as re_sub +from typing import Optional from urllib.parse import urlparse from git import Repo, exc import yaml import concurrent.futures +class _RestrictedUnpickler(pickle.Unpickler): + """Unpickler that refuses to instantiate any class. + + The DTL upstream pickle files (``tests/known-*.pickle``) contain only + plain sets of (str, str) tuples and require no GLOBAL opcodes. This + subclass overrides ``find_class`` so that if a crafted/malicious pickle + were ever substituted it could not import or execute arbitrary code. + """ + + def find_class(self, module, name): + raise pickle.UnpicklingError( + f"DTL pickle safety: loading class '{module}.{name}' is not permitted. " + "The known-*.pickle files must contain only sets of string tuples." + ) + + +_PICKLE_MAX_BYTES = 10 * 1024 * 1024 # 10 MiB — DTL pickles are typically <500 KiB + + +def _vendor_slugs_from_pickle( + pickle_path: str, slugs_lower: list, slug_format, subdir_filter: "Optional[str]" = None +) -> "Optional[set]": + """Load a (model_name, vendor_dir) pickle and return the set of vendor slugs matching *slugs_lower*. + + *subdir_filter*, if given, requires the vendor_dir to contain that substring. + Returns ``None`` when the pickle is unavailable (missing or unreadable), so callers + can distinguish "no matches" (empty set) from "hint unavailable" (None). + """ + if not os.path.exists(pickle_path): + return None + try: + entries = _safe_pickle_load(pickle_path) + except Exception: + return None + result = set() + for model_name, vendor_dir in entries: + if subdir_filter and subdir_filter not in vendor_dir.replace("\\", "/"): + continue + if not any(s in model_name.casefold() for s in slugs_lower): + continue + vendor_name = vendor_dir.replace("\\", "/").split("/")[-1] + result.add(slug_format(vendor_name)) + return result + + +def _safe_abs_path(repo_root: str, relpath: str) -> "Optional[str]": + """Return the absolute path for *relpath* inside *repo_root*, or None if it escapes the root.""" + abs_path = os.path.normpath(os.path.join(repo_root, *relpath.replace("\\", "/").split("/"))) + return abs_path if abs_path.startswith(os.path.normpath(repo_root) + os.sep) else None + + +def _safe_pickle_load(path: str): + """Load a DTL upstream pickle using the restricted unpickler. + + Enforces a hard size cap before unpickling and validates the loaded object + is a set/list of (str, str) 2-tuples so malformed/oversized pickles cannot + cause resource exhaustion. Returns the loaded set on success or raises + ``ValueError`` on shape violations (callers should catch and fall back). + """ + size = os.path.getsize(path) + if size > _PICKLE_MAX_BYTES: + raise ValueError(f"Pickle file {path!r} is {size} bytes (limit {_PICKLE_MAX_BYTES}); refusing to load.") + with open(path, "rb") as fh: + data = _RestrictedUnpickler(fh).load() + if not isinstance(data, (set, list, frozenset)): + raise ValueError(f"Unexpected pickle root type {type(data).__name__!r}; expected set/list.") + for item in data: + if ( + not isinstance(item, tuple) + or len(item) != 2 + or not isinstance(item[0], str) + or not isinstance(item[1], str) + ): + raise ValueError(f"Unexpected item shape in pickle: {item!r}") + return data + + def validate_git_url(url): """Determine whether a Git remote URL is allowed (HTTPS, SSH, or file://). @@ -403,6 +482,118 @@ def get_devices(self, base_path, vendors: list = None): files.extend(glob(os.path.join(base_path, folder, f"*.{extension}"))) return files, discovered_vendors + def resolve_slug_files(self, slugs): + """Use the upstream pickle indexes to resolve YAML file paths for slug/model matches. + + The DTL repo ships three pickle files under ``tests/``: + + * ``known-slugs.pickle`` — set of ``(manufacturer_prefixed_slug, relpath)`` for + device types. ``relpath`` is relative to the repo root, e.g. + ``device-types/Nokia/7750-SR-7s.yaml``. + * ``known-modules.pickle`` — set of ``(model_name, vendor_dir)`` for module + types. Only the vendor directory is stored, not the file name. + * ``known-racks.pickle`` — same format as known-modules. + + Matching uses a **case-insensitive substring** check identical to the runtime + :meth:`parse_files` filter so that partial slug/model searches work the same way. + + Args: + slugs (list[str]): User-supplied slug/model substrings (``--slugs``). + + Returns: + dict or None: ``None`` when the device pickle is unavailable (caller falls back + to the normal glob path). Otherwise a dict with the keys: + + ``"device_files"`` + ``{vendor_slug: [abs_path, ...]}`` for devices resolved via pickle. + ``"module_vendors"`` + ``{vendor_slug}`` — set of vendor slugs that may contain matching module + types, or ``None`` when the module pickle was unavailable (caller should + fall back to full glob+parse instead of skipping). + ``"rack_vendors"`` + ``{vendor_slug}`` — same for rack types; ``None`` means unavailable. + """ + repo_root = self.get_absolute_path() + device_pickle = os.path.join(repo_root, "tests", "known-slugs.pickle") + module_pickle = os.path.join(repo_root, "tests", "known-modules.pickle") + rack_pickle = os.path.join(repo_root, "tests", "known-racks.pickle") + + if not os.path.exists(device_pickle): + return None + + slugs_lower = [s.casefold() for s in slugs] + + # --- device types -------------------------------------------------- + device_files = {} # vendor_slug -> [abs_path] + try: + known_slugs = _safe_pickle_load(device_pickle) + except Exception: + return None + + for entry_slug, relpath in known_slugs: + if not any(s in entry_slug.casefold() for s in slugs_lower): + continue + parts = relpath.replace("\\", "/").split("/") + if len(parts) < 3: + continue + vendor_name = parts[1] + vendor_slug = self.slug_format(vendor_name) + abs_path = _safe_abs_path(repo_root, relpath) + if abs_path is None: + continue + device_files.setdefault(vendor_slug, []).append(abs_path) + + # --- module types -------------------------------------------------- + module_vendors = _vendor_slugs_from_pickle(module_pickle, slugs_lower, self.slug_format) + + # --- rack types ---------------------------------------------------- + rack_vendors = _vendor_slugs_from_pickle(rack_pickle, slugs_lower, self.slug_format, subdir_filter="rack-types") + + return { + "device_files": device_files, + "module_vendors": module_vendors, + "rack_vendors": rack_vendors, + } + + def discover_vendors(self, devices_path, modules_path, racks_path): + """Discover all vendor directories across device-types/, module-types/, and rack-types/. + + Args: + devices_path (str): Path to device-types directory. + modules_path (str): Path to module-types directory. + racks_path (str): Path to rack-types directory. + + Returns: + list: Sorted list of unique vendor dictionaries with keys 'name' (str) and 'slug' (str). + Vendors are deduplicated across all three paths and the "testing" folder is excluded. + """ + vendors_dict = {} # Use dict to deduplicate by slug + + for path in [devices_path, modules_path, racks_path]: + if not os.path.exists(path): + continue + + try: + vendor_dirs = sorted(os.listdir(path)) + except OSError: + continue + + for folder in vendor_dirs: + if folder.casefold() == "testing": + continue + + full_path = os.path.join(path, folder) + if not os.path.isdir(full_path): + continue + + slug = self.slug_format(folder) + # Only add if we haven't seen this slug before + if slug not in vendors_dict: + vendors_dict[slug] = {"name": folder, "slug": slug} + + # Return sorted list by slug + return sorted(vendors_dict.values(), key=lambda v: v["slug"]) + def parse_files(self, files: list, slugs: list = None, progress=None): """Parse YAML device files into device type dicts, optionally filtering and tracking progress. diff --git a/nb-dt-import.py b/nb-dt-import.py index e644e0590..9c9e5c866 100644 --- a/nb-dt-import.py +++ b/nb-dt-import.py @@ -2,18 +2,19 @@ """Entry-point script for importing NetBox device and module types from the community library.""" from datetime import datetime -import concurrent.futures import os from argparse import ArgumentParser from contextlib import contextmanager from core import settings -from core.netbox_api import NetBox +from core.netbox_api import NetBox, _fmt_connection_error from core.log_handler import LogHandler from core.repo import DTLRepo from core.change_detector import ChangeDetector, ChangeType, IMAGE_PROPERTIES from core.graphql_client import GraphQLError from pynetbox.core.query import RequestError as NetBoxRequestError +import re +import requests import sys @@ -22,6 +23,7 @@ BarColumn, MofNCompleteColumn, Progress, + ProgressBar, ProgressColumn, SpinnerColumn, TaskProgressColumn, @@ -35,6 +37,43 @@ _PROGRESS_DESC_WIDTH = 28 # Longest: "Caching Console Server Ports" +class NoPulseBarColumn(BarColumn): + """BarColumn that never pulses — shows a static empty bar when total is unknown. + + Rich's default BarColumn sets ``pulse=True`` whenever ``task.total is None``, + which produces a scrolling rainbow-gradient animation. That causes a + continuous stream of ANSI color codes on every render frame, creating a + distracting "disco" effect on the terminal. This subclass always passes + ``pulse=False`` so the bar stays static when total is unknown. + """ + + def render(self, task): + """Render a static progress bar (no pulsing gradient). + + Rich's ProgressBar triggers pulse when ``total is None`` regardless of + the ``pulse`` flag (``should_pulse = self.pulse or self.total is None``). + When total is unknown we substitute total=1, completed=0 to get a plain + static empty bar instead of the scrolling rainbow gradient. + """ + if task.total is None: + total: float = 1.0 + completed: float = 0.0 + else: + total = max(0.0, task.total) + completed = max(0.0, task.completed) + return ProgressBar( + total=total, + completed=completed, + width=None if self.bar_width is None else max(1, self.bar_width), + pulse=False, + animation_time=task.get_time(), + style=self.style, + complete_style=self.complete_style, + finished_style=self.finished_style, + pulse_style=self.pulse_style, + ) + + class MyProgress(Progress): """Rich Progress subclass that renders each task table inside a bordered Panel.""" @@ -86,7 +125,7 @@ def get_progress_panel(show_remaining_time=False): columns = [ SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - BarColumn(), + NoPulseBarColumn(), TaskProgressColumn(), MofNCompleteColumn(), TimeElapsedColumn(), @@ -97,11 +136,12 @@ def get_progress_panel(show_remaining_time=False): with MyProgress( *columns, + refresh_per_second=4, ) as progress: yield progress -def get_progress_wrapper(progress, iterable, desc=None, total=None, on_step=None): +def get_progress_wrapper(progress, iterable, desc=None, total=None, on_step=None, task_registry=None): """Wrap *iterable* with a Rich progress task if *progress* is provided, otherwise return *iterable* unchanged. Args: @@ -110,6 +150,9 @@ def get_progress_wrapper(progress, iterable, desc=None, total=None, on_step=None desc (str | None): Task description shown in the progress bar. total (int | None): Total number of items; inferred from ``len(iterable)`` if omitted. on_step (callable | None): Optional callback invoked after each item and at the end. + task_registry (dict | None): When provided, tasks are created once per description and + reused across calls — counts accumulate rather than resetting per vendor. The caller + is responsible for stopping/removing tasks at the end of the run. Returns: The original iterable if *progress* is None, otherwise a generator that advances @@ -125,7 +168,15 @@ def get_progress_wrapper(progress, iterable, desc=None, total=None, on_step=None except TypeError: total = None - task_id = progress.add_task(description, total=total) + if task_registry is not None: + # Cumulative mode: create the task once, reuse it across vendors. + # NOTE: `total` is intentionally ignored here — the final count is + # unknown at task-creation time and is resolved during finalization. + if description not in task_registry: + task_registry[description] = progress.add_task(description, total=None) + task_id = task_registry[description] + else: + task_id = progress.add_task(description, total=total) def iterator(): """Yield items from *iterable* while advancing the progress task.""" @@ -138,30 +189,18 @@ def iterator(): if on_step: on_step() finally: - if total is None: - progress.update(task_id, total=max(count, 1), completed=count) - progress.stop_task(task_id) + if task_registry is None: + # Non-cumulative: finalize and clean up this vendor's task. + if total is None: + progress.update(task_id, total=max(count, 1), completed=count) + progress.stop_task(task_id) + progress.remove_task(task_id) if on_step: on_step() return iterator() -def filter_vendors_for_parsed_types(discovered_vendors, parsed_types): - """Return only the vendors referenced in *parsed_types* and the set of their slugs. - - Args: - discovered_vendors (list[dict]): All vendors discovered in the repo (each has a "slug" key). - parsed_types (list[dict]): Parsed device-type dicts; each must have a ``manufacturer.slug`` entry. - - Returns: - tuple[list[dict], set[str]]: Filtered vendor list and the corresponding slug set. - """ - selected_vendor_slugs = {item["manufacturer"]["slug"] for item in parsed_types} - filtered_vendors = [vendor for vendor in discovered_vendors if vendor["slug"] in selected_vendor_slugs] - return filtered_vendors, selected_vendor_slugs - - def filter_new_device_types(device_types, existing_by_model, existing_by_slug): """Return device types that do not already exist in NetBox. @@ -230,14 +269,34 @@ def filter_device_types_by_change_keys(device_types, change_keys): return [device_type for device_type in device_types if device_type_key(device_type) in change_keys] -def select_device_types_for_default_mode(device_types, change_report): +def _device_types_with_images_keys(device_types): + """Return the set of change-detection keys for device types that declare images in their YAML. + + Used by ``--verify-images`` to ensure image-bearing device types are processed even + when the change detector reports them as unchanged (e.g. image exists in the DB but the + physical file is gone from the server). + + Args: + device_types (list[dict]): All parsed device-type dicts. + + Returns: + set: Keys of device types that have ``front_image`` or ``rear_image`` set to True. + """ + return {device_type_key(dt) for dt in device_types if dt.get("front_image") or dt.get("rear_image")} + + +def select_device_types_for_default_mode(device_types, change_report, verify_images=False): """Select device types to process in default (non-update) mode. Includes newly discovered device types and existing ones with missing images. + When *verify_images* is True, also includes all device types that declare images + so their physical presence can be verified. Args: device_types (list[dict]): All parsed device-type dicts. change_report (ChangeReport | None): Change detection results; if None returns []. + verify_images (bool): When True, include image-bearing device types for physical + verification even if no DB-level change is detected. Returns: list[dict]: Device types that are new or have missing images. @@ -251,17 +310,23 @@ def select_device_types_for_default_mode(device_types, change_report): for change in change_report.modified_device_types if any(property_change.property_name in IMAGE_PROPERTIES for property_change in change.property_changes) } - return filter_device_types_by_change_keys(device_types, new_keys | image_change_keys) + all_keys = new_keys | image_change_keys + if verify_images: + all_keys |= _device_types_with_images_keys(device_types) + return filter_device_types_by_change_keys(device_types, all_keys) -def select_device_types_for_update_mode(device_types, change_report): +def select_device_types_for_update_mode(device_types, change_report, verify_images=False): """Select device types to process in update (``--update``) mode. - Includes all new and modified device types. + Includes all new and modified device types. When *verify_images* is True, also + includes device types that declare images so their physical presence can be checked. Args: device_types (list[dict]): All parsed device-type dicts. change_report (ChangeReport | None): Change detection results; if None returns []. + verify_images (bool): When True, include image-bearing device types for physical + verification even if no DB-level change is detected. Returns: list[dict]: Device types that are either new or have detected changes. @@ -271,6 +336,8 @@ def select_device_types_for_update_mode(device_types, change_report): actionable_keys = {change_entry_key(change) for change in change_report.new_device_types} actionable_keys.update(change_entry_key(change) for change in change_report.modified_device_types) + if verify_images: + actionable_keys |= _device_types_with_images_keys(device_types) return filter_device_types_by_change_keys(device_types, actionable_keys) @@ -304,6 +371,11 @@ def log_run_mode(handle, args): handle.log("Mode: --update enabled; changed properties and components on existing models will be updated.") if args.remove_components: handle.log("Mode: --remove-components enabled; missing components will be removed from existing models.") + if getattr(args, "remove_unmanaged_types", False): + handle.log( + "Mode: --remove-unmanaged-types enabled; components whose entire YAML section is missing " + "will also be removed from existing models." + ) else: handle.log( "Mode: will not remove components from existing models; use --remove-components with " @@ -316,6 +388,11 @@ def log_run_mode(handle, args): ) else: handle.log("Mode: --update not set; changed properties/components will not be applied (use --update).") + if getattr(args, "verify_images", False): + handle.log( + "Mode: --verify-images enabled; images already recorded in NetBox will be verified via HTTP " + "and re-uploaded if missing or content has changed." + ) def should_only_create_new_modules(args): @@ -329,13 +406,16 @@ def _image_progress_scope(progress, device_types, total=0): Creates a progress task (if *progress* is not None and *total* > 0), assigns the advance callback to ``device_types._image_progress``, and - always resets it to ``None`` on exit — even on exception. + always resets it to ``None`` on exit — even on exception. The task is + removed from the progress display on exit so completed upload bars do not + accumulate when multiple vendors are processed in sequence. Args: progress: Rich Progress instance, or None. device_types: ``DeviceTypes`` helper whose ``_image_progress`` callback is set. total (int): Pre-counted number of images to upload. If 0, no progress bar is shown. """ + _img_task = None if progress is not None and total > 0: _img_task = progress.add_task("Uploading Images", total=total) @@ -348,6 +428,9 @@ def _adv_img(count=1): yield finally: device_types._image_progress = None + if progress is not None and _img_task is not None: + progress.stop_task(_img_task) + progress.remove_task(_img_task) def _check_env_vars(handle): @@ -381,26 +464,16 @@ def _log_import_filters(handle, args): handle.log(f"Filtering by slugs: {', '.join(args.slugs)}") -def _bg_parse_module_types(dtl_repo, module_vendor_filter, slugs): - """Discover and parse module-type YAML files; designed for background execution. - - Args: - dtl_repo (DTLRepo): Repository helper for file discovery and YAML parsing. - module_vendor_filter (list[str]): Vendor slugs used to scope file discovery. - slugs (list[str]): Device-type slug filters passed to ``parse_files``. - - Returns: - tuple[list, list, list]: ``(files, discovered_vendors, module_types)``. - *module_types* is an empty list when no files are discovered. - """ - bg_files, bg_vendors = dtl_repo.get_devices(dtl_repo.get_modules_path(), module_vendor_filter) - if not bg_files: - return [], bg_vendors, [] - bg_module_types = dtl_repo.parse_files(bg_files, slugs=slugs) - return bg_files, bg_vendors, bg_module_types - - -def _process_device_types(args, netbox, handle, progress, device_types, cache_preload_job): +def _process_device_types( + args, + netbox, + handle, + progress, + device_types, + cache_preload_job, + vendor_slug=None, + task_registry=None, +): """Process device types according to the active run mode. Handles *only_new*, *update*, and default mode device-type processing, @@ -414,6 +487,8 @@ def _process_device_types(args, netbox, handle, progress, device_types, cache_pr progress: Rich Progress instance for progress display, or None. device_types (list[dict]): Parsed device-type dicts to process. cache_preload_job: Background component-cache preload job, or None. + vendor_slug (str | None): Manufacturer slug for scoped integrity checks. + task_registry (dict | None): Shared cumulative progress task registry. Returns: The updated *cache_preload_job*: ``None`` if the job was consumed by @@ -430,33 +505,49 @@ def _process_device_types(args, netbox, handle, progress, device_types, cache_pr with _image_progress_scope(progress, netbox.device_types, total=image_total): netbox.create_device_types( new_device_types, - progress=get_progress_wrapper(progress, new_device_types, desc="Creating Device Types"), + progress=get_progress_wrapper( + progress, + new_device_types, + desc="Creating Device Types", + task_registry=task_registry, + ), only_new=True, ) else: handle.verbose_log("No new device types to create.") return cache_preload_job - # Non-only_new path: preload cache then detect changes. - if device_types: - handle.verbose_log("Caching NetBox data for comparison (concurrent API requests started during parsing)...") + # Non-only_new path: always consume the preload job if one was started. + # This is required even when device_types is empty (e.g. module-type-only vendor) + # so that _global_preload_done is set before _process_module_types runs. + if cache_preload_job is not None: + handle.verbose_log("Caching NetBox data for comparison (concurrent API requests started after parsing)...") netbox.device_types.preload_all_components( progress=progress, preload_job=cache_preload_job, + manufacturer_slug=vendor_slug, + task_registry=task_registry, ) cache_preload_job = None - else: - handle.log("No device types matched filters. Skipping NetBox cache preload.") - detector = ChangeDetector(netbox.device_types, handle) + if not device_types: + handle.verbose_log("No device types matched filters.") + + detector = ChangeDetector( + netbox.device_types, + handle, + remove_unmanaged_types=args.remove_unmanaged_types, + ) change_report = detector.detect_changes( device_types, - progress=get_progress_wrapper(progress, device_types, desc="Detecting Changes"), + progress=get_progress_wrapper(progress, device_types, desc="Detecting Changes", task_registry=task_registry), ) detector.log_change_report(change_report) if args.update: - device_types_to_process = select_device_types_for_update_mode(device_types, change_report) + device_types_to_process = select_device_types_for_update_mode( + device_types, change_report, verify_images=getattr(args, "verify_images", False) + ) if device_types_to_process: image_total = netbox.count_device_type_images(device_types_to_process) with _image_progress_scope(progress, netbox.device_types, total=image_total): @@ -466,6 +557,7 @@ def _process_device_types(args, netbox, handle, progress, device_types, cache_pr progress, device_types_to_process, desc="Processing Device Types", + task_registry=task_registry, ), only_new=False, update=True, @@ -475,13 +567,20 @@ def _process_device_types(args, netbox, handle, progress, device_types, cache_pr else: handle.verbose_log("No device type changes to process.") else: - device_types_to_process = select_device_types_for_default_mode(device_types, change_report) + device_types_to_process = select_device_types_for_default_mode( + device_types, change_report, verify_images=getattr(args, "verify_images", False) + ) if device_types_to_process: image_total = netbox.count_device_type_images(device_types_to_process) with _image_progress_scope(progress, netbox.device_types, total=image_total): netbox.create_device_types( device_types_to_process, - progress=get_progress_wrapper(progress, device_types_to_process, desc="Creating Device Types"), + progress=get_progress_wrapper( + progress, + device_types_to_process, + desc="Creating Device Types", + task_registry=task_registry, + ), only_new=True, ) else: @@ -490,56 +589,21 @@ def _process_device_types(args, netbox, handle, progress, device_types, cache_pr return cache_preload_job -def _process_module_types( - args, - netbox, - dtl_repo, - handle, - progress, - selected_vendor_slugs, - *, - module_parse_future=None, - module_parse_executor=None, -): - """Process module types, retrieving data from a background future or parsing synchronously. +def _process_module_types(args, netbox, handle, progress, module_types, task_registry=None): + """Process module types for a single vendor. Args: - args: Parsed CLI arguments; inspects ``vendors``, ``slugs``, ``only_new``, - and ``update``. + args: Parsed CLI arguments; inspects ``only_new``, ``update``, and + ``remove_components``. netbox (NetBox): NetBox API wrapper instance. - dtl_repo (DTLRepo): Repository helper for file discovery and YAML parsing; - used only when *module_parse_future* is ``None``. handle (LogHandler): Logging handler used to emit progress messages. progress: Rich Progress instance for progress display, or None. - selected_vendor_slugs (set[str]): Vendor slugs derived from parsed device - types, used to scope module discovery when ``--slugs`` is set. - module_parse_future: Background ``concurrent.futures.Future`` that returns - ``(files, vendors, module_types)``, or ``None`` for synchronous parsing. - module_parse_executor: ``ThreadPoolExecutor`` used to start the background - parse; shut down after result retrieval. Ignored when - *module_parse_future* is ``None``. + module_types (list[dict]): Pre-parsed module-type dicts for this vendor. + task_registry (dict | None): Shared cumulative progress task registry. """ - if module_parse_future is not None: - module_files, discovered_module_vendors, module_types = module_parse_future.result() - if module_parse_executor is not None: - module_parse_executor.shutdown(wait=False) - if not module_files: - module_types = [] - else: - module_vendor_filter = args.vendors - if args.slugs and not args.vendors: - module_vendor_filter = sorted(selected_vendor_slugs) - module_files, discovered_module_vendors = dtl_repo.get_devices( - dtl_repo.get_modules_path(), module_vendor_filter - ) - if not module_files: - module_types = [] - else: - module_parse_progress = get_progress_wrapper(progress, module_files, desc="Parsing Module Types") - module_types = dtl_repo.parse_files(module_files, slugs=args.slugs, progress=module_parse_progress) + if not module_types: + return - module_vendors, _ = filter_vendors_for_parsed_types(discovered_module_vendors, module_types) - handle.verbose_log(f"{len(module_vendors)} Module Vendors Found") handle.verbose_log(f"{len(module_types)} Module-Types Found") module_only_new = should_only_create_new_modules(args) @@ -566,43 +630,53 @@ def _process_module_types( if removed_in_group: pending_removal_modules += 1 pending_removal_components += len(removed_in_group) - handle.log("============================================================") - handle.log("MODULE TYPE CHANGE DETECTION") - handle.log("============================================================") - if args.only_new: - handle.log(f"New module types: {new_module_count}") - else: - module_changed_count = len(changed_property_log) - module_unchanged_count = len(module_types) - len(module_types_to_process) - # Modules with only missing image attachments — handled in default mode, so - # they are NOT included in the "modified" count and do NOT trigger the - # `--update` hint. - image_only_count = max(0, len(module_types_to_process) - new_module_count - module_changed_count) - handle.log(f"New module types: {new_module_count}") - handle.log(f"Unchanged module types: {module_unchanged_count}") - handle.log(f"Modified module types: {module_changed_count}") - if image_only_count: - handle.log(f"Image-only updates: {image_only_count}") - if module_changed_count and not args.update: - handle.log(" (Run with --update to apply changes to existing module types)") - if pending_removal_modules and not args.remove_components: - remove_hint = "--remove-components" if args.update else "--update --remove-components" - handle.log( - f" (Run with {remove_hint} to remove {pending_removal_components} stale " - f"component(s) across {pending_removal_modules} module type(s))" - ) - handle.log("------------------------------------------------------------") - netbox.log_module_type_changes(changed_property_log) + + module_changed_count = len(changed_property_log) + module_unchanged_count = len(module_types) - len(module_types_to_process) if not args.only_new else 0 + + has_module_changes = new_module_count > 0 or module_changed_count > 0 or pending_removal_modules > 0 + if has_module_changes: + handle.log("============================================================") + handle.log("MODULE TYPE CHANGE DETECTION") + handle.log("============================================================") + if args.only_new: + handle.log(f"New module types: {new_module_count}") + else: + # Modules with only missing image attachments — handled in default mode, so + # they are NOT included in the "modified" count and do NOT trigger the + # `--update` hint. + image_only_count = max(0, len(module_types_to_process) - new_module_count - module_changed_count) + handle.log(f"New module types: {new_module_count}") + handle.log(f"Unchanged module types: {module_unchanged_count}") + handle.log(f"Modified module types: {module_changed_count}") + if image_only_count: + handle.log(f"Image-only updates: {image_only_count}") + if module_changed_count and not args.update: + handle.log(" (Run with --update to apply changes to existing module types)") + if pending_removal_modules and not args.remove_components: + remove_hint = "--remove-components" if args.update else "--update --remove-components" + handle.log( + f" (Run with {remove_hint} to remove {pending_removal_components} stale " + f"component(s) across {pending_removal_modules} module type(s))" + ) + handle.log("------------------------------------------------------------") + netbox.log_module_type_changes(changed_property_log) + elif module_unchanged_count: + handle.verbose_log(f"No module type changes ({module_unchanged_count} unchanged).") if module_types_to_process: - netbox.create_manufacturers(module_vendors) module_image_total = netbox.count_module_type_images( module_types_to_process, existing_module_types, module_type_existing_images ) with _image_progress_scope(progress, netbox.device_types, total=module_image_total): netbox.create_module_types( module_types_to_process, - progress=get_progress_wrapper(progress, module_types_to_process, desc="Processing Module Types"), + progress=get_progress_wrapper( + progress, + module_types_to_process, + desc="Processing Module Types", + task_registry=task_registry, + ), only_new=module_only_new, all_module_types=existing_module_types, module_type_existing_images=module_type_existing_images, @@ -612,44 +686,26 @@ def _process_module_types( handle.verbose_log("No module type changes to process.") -def _process_rack_types(args, netbox, dtl_repo, handle, progress, selected_vendor_slugs): - """Discover, parse, and import rack types from the repository into NetBox. +def _process_rack_types(args, netbox, handle, progress, rack_types, task_registry=None): + """Process rack types for a single vendor. Soft-skips with a warning when the connected NetBox instance is older than 4.1. - Honors ``--vendors`` and ``--slugs`` filters. Args: - args: Parsed CLI arguments; inspects ``vendors``, ``slugs``, and ``only_new``. + args: Parsed CLI arguments; inspects ``only_new``. netbox (NetBox): NetBox API wrapper instance. - dtl_repo (DTLRepo): Repository helper for file discovery and YAML parsing. handle (LogHandler): Logging handler used to emit progress messages. progress: Rich Progress instance for progress display, or None. - selected_vendor_slugs (set[str]): Vendor slugs derived from parsed device - types, used to scope rack-type discovery when ``--slugs`` is set. + rack_types (list[dict]): Pre-parsed rack-type dicts for this vendor. + task_registry (dict | None): Shared cumulative progress task registry. """ - if not netbox.rack_types: - handle.log("Rack types require NetBox >= 4.1. Skipping rack type import.") - return - - racks_path = dtl_repo.get_racks_path() - if not os.path.isdir(racks_path): - handle.verbose_log("No rack-types directory found in repository. Skipping.") + if not rack_types: return - rack_vendor_filter = args.vendors - if args.slugs and not args.vendors: - rack_vendor_filter = sorted(selected_vendor_slugs) - - rack_files, discovered_rack_vendors = dtl_repo.get_devices(racks_path, rack_vendor_filter) - if not rack_files: - handle.verbose_log("No rack-type files found for the selected vendors/slugs.") + if not netbox.rack_types: + handle.log("Rack types require NetBox >= 4.1. Skipping rack type import.") return - rack_parse_progress = get_progress_wrapper(progress, rack_files, desc="Parsing Rack Types") - rack_types = dtl_repo.parse_files(rack_files, slugs=args.slugs, progress=rack_parse_progress) - - rack_vendors, _ = filter_vendors_for_parsed_types(discovered_rack_vendors, rack_types) - handle.verbose_log(f"{len(rack_vendors)} Rack Vendors Found") handle.verbose_log(f"{len(rack_types)} Rack-Types Found") all_rack_types = netbox.get_existing_rack_types() @@ -660,16 +716,23 @@ def _process_rack_types(args, netbox, dtl_repo, handle, progress, selected_vendo ) existing_count = len(rack_types) - new_count - handle.log("============================================================") - handle.log(f"New rack types: {new_count}") - handle.log(f"Existing rack types: {existing_count}") - handle.log("============================================================") + if new_count == 0: + handle.verbose_log(f"No new rack types ({existing_count} unchanged).") + else: + handle.log("============================================================") + handle.log(f"New rack types: {new_count}") + handle.log(f"Existing rack types: {existing_count}") + handle.log("============================================================") if rack_types: - netbox.create_manufacturers(rack_vendors) netbox.create_rack_types( rack_types, - progress=get_progress_wrapper(progress, rack_types, desc="Processing Rack Types"), + progress=get_progress_wrapper( + progress, + rack_types, + desc="Processing Rack Types", + task_registry=task_registry, + ), only_new=args.only_new, all_rack_types=all_rack_types, ) @@ -732,15 +795,116 @@ def _log_run_summary(handle, netbox, start_time, dtl_repo=None): handle.log("These duplicates would otherwise oscillate on every run. Please report/fix them upstream.") -def main(): - """Orchestrate importing device- and module-types from a Git repository into NetBox. +def _parse_vendor_racks(dtl_repo, racks_path, vendor_name, slugs): + """Parse rack-type YAML files for *vendor_name*, returning an empty list when *racks_path* is absent. - Parses CLI arguments, validates environment variables, clones/pulls the DTL repo, - parses YAML files, and creates manufacturers, device types, and module types in NetBox. - Reports progress and summary counters. + Args: + dtl_repo (DTLRepo): Repository helper used for file discovery and parsing. + racks_path (str): Base path for rack-type YAML files. + vendor_name (str): Vendor directory name to filter files by. + slugs (list[str]): Optional rack-type slug filter. + + Returns: + list[dict]: Parsed rack-type records (may be empty). """ - startTime = datetime.now() + if not os.path.isdir(racks_path): + return [] + rack_files, _ = dtl_repo.get_devices(racks_path, [vendor_name.casefold()]) + return dtl_repo.parse_files(rack_files, slugs=slugs) + + +def _finalize_task_registry(progress, task_registry): + """Resolve unknown totals and stop spinners for all cumulative registry tasks. + + Args: + progress: Rich Progress instance, or None. + task_registry (dict | None): Mapping of description → task ID. + """ + if not progress or not task_registry: + return + for task_id in task_registry.values(): + task = next((t for t in progress.tasks if t.id == task_id), None) + if task is None: + continue + if task.total is None: + progress.update(task_id, total=max(task.completed, 0)) + progress.stop_task(task_id) + + +def _validate_argument_combinations(parser, args): + """Apply mutual-dependency checks for CLI flags and exit via parser.error on violation.""" + if args.export_diff and (args.update or args.only_new): + parser.error("--export-diff cannot be used with --update or --only-new") + if args.export_diff and args.remove_components: + parser.error("--export-diff cannot be used with --remove-components") + if args.export_diff and getattr(args, "remove_unmanaged_types", False): + parser.error("--remove-unmanaged-types is an import-only flag and cannot be used with --export-diff") + if args.export_diff and getattr(args, "slugs", None): + parser.error("--slugs is an import-only flag and cannot be used with --export-diff") + if args.export_diff and getattr(args, "verify_images", False): + parser.error("--verify-images is an import-only flag and cannot be used with --export-diff") + if args.export_diff and getattr(args, "force_resolve_conflicts", False): + parser.error("--force-resolve-conflicts is an import-only flag and cannot be used with --export-diff") + if args.remove_components and not args.update: + parser.error("--remove-components requires --update") + if args.remove_unmanaged_types and not args.remove_components: + parser.error("--remove-unmanaged-types requires --remove-components") + if args.force_resolve_conflicts and not args.update: + parser.error("--force-resolve-conflicts requires --update") + +def _apply_slug_fast_path(dtl_repo, args, vendors_to_process, handle): + """Use upstream pickle indexes to pre-resolve files and narrow the vendor list. + + When ``args.slugs`` is set and the DTL pickle files are present, resolves + exactly which device-type files match the requested slugs and restricts + ``vendors_to_process`` to only those vendors. Returns a ``(vendors, resolved)`` + pair where *resolved* is the dict returned by :meth:`DTLRepo.resolve_slug_files` + (or ``None`` when the pickle is absent/unavailable). + """ + if not args.slugs: + return vendors_to_process, None + + slug_resolved = dtl_repo.resolve_slug_files(args.slugs) + if slug_resolved is None: + handle.verbose_log("Slug pickle unavailable; falling back to full file scan.") + return vendors_to_process, None + + matched_vendor_slugs = ( + set(slug_resolved["device_files"]) + | (slug_resolved["module_vendors"] or set()) + | (slug_resolved["rack_vendors"] or set()) + ) + if not matched_vendor_slugs: + handle.verbose_log("Slug pickle returned no matches; falling back to full file scan.") + return vendors_to_process, None + narrowed = [v for v in vendors_to_process if v["slug"] in matched_vendor_slugs] + handle.verbose_log( + f"Slug pickle resolved {sum(len(f) for f in slug_resolved['device_files'].values())} " + f"device file(s) across {len(matched_vendor_slugs)} vendor(s)." + ) + return narrowed, slug_resolved + + +def _run_export_diff(settings, handle, args): + """Run the export-diff pipeline and return.""" + from core.export import Exporter + + exporter = Exporter( + settings=settings, + handle=handle, + export_dir=args.export_diff_dir, + force_overwrite=args.force_export_overwrite, + vendor_slugs=args.vendors if args.vendors else None, + ) + with get_progress_panel(args.show_remaining_time) as progress: + if progress is not None: + handle.set_console(progress.console) + exporter.run(progress=progress) + + +def _build_argument_parser() -> ArgumentParser: + """Build and return the CLI argument parser.""" parser = ArgumentParser(description="Import Netbox Device Types", allow_abbrev=False) parser.add_argument( "--vendors", @@ -790,6 +954,16 @@ def main(): help="Remove components from NetBox that no longer exist in YAML (use with --update). " "WARNING: May affect existing device instances.", ) + parser.add_argument( + "--remove-unmanaged-types", + action="store_true", + default=False, + help=( + "Also remove components whose entire YAML section is missing (e.g. NetBox has interfaces " + "but the YAML defines no 'interfaces:' key at all). Requires --remove-components. " + "WARNING: Aggressive; will delete components on every type whose YAML omits that section." + ), + ) parser.add_argument( "--force-resolve-conflicts", action="store_true", @@ -800,22 +974,203 @@ def main(): "Only applied when no live device references the type. WARNING: Destructive." ), ) + parser.add_argument( + "--verify-images", + action="store_true", + default=False, + help=( + "Verify that images recorded in the NetBox database are physically present on the server. " + "Uses an HTTP presence check per image and a local SHA-256 cache to detect local file " + "changes (does not hash or download the remote file). Re-uploads any image that is " + "missing on the server or whose local file has changed since the last upload. " + "Useful after recreating a devcontainer (media files gone but DB intact) or " + "when local image files have been updated. NOTE: Makes an HTTP request per image — " + "avoid using this in bulk runs unless necessary." + ), + ) + parser.add_argument( + "--export-diff", + action="store_true", + default=False, + help=( + "Export device/module/rack types from NetBox that are absent from or differ vs. " + "the local repo/ directory. Writes DTL-compatible YAML files and images to the " + "export directory. Does not run the import pipeline." + ), + ) + parser.add_argument( + "--export-diff-dir", + default="extra/", + metavar="PATH", + help="Directory to write exported files to (default: extra/).", + ) + parser.add_argument( + "--force-export-overwrite", + action="store_true", + default=False, + help=( + "Overwrite files in the export directory that differ from what would be " + "generated from NetBox. Without this flag, changed files are skipped with a warning." + ), + ) + return parser + + +def _parse_vendor_types(dtl_repo, netbox, args, vendor, devices_path, modules_path, racks_path, slug_resolved): + """Parse device-type, module-type, and rack-type YAML files for a single vendor. + + Returns a 3-tuple: (parsed_device_types, parsed_module_types, parsed_rack_types). + """ + if slug_resolved is not None: + device_files = slug_resolved["device_files"].get(vendor["slug"], []) + parsed_device_types = dtl_repo.parse_files(device_files) if device_files else [] + else: + device_files, _ = dtl_repo.get_devices(devices_path, [vendor["name"].casefold()]) + parsed_device_types = dtl_repo.parse_files(device_files, slugs=args.slugs or []) + + if netbox.modules: + module_hint = slug_resolved["module_vendors"] if slug_resolved is not None else None + if module_hint is not None and vendor["slug"] not in module_hint: + parsed_module_types = [] + else: + module_files, _ = dtl_repo.get_devices(modules_path, [vendor["name"].casefold()]) + parsed_module_types = dtl_repo.parse_files(module_files, slugs=args.slugs or []) + else: + parsed_module_types = [] + + if netbox.rack_types: + rack_hint = slug_resolved["rack_vendors"] if slug_resolved is not None else None + if rack_hint is not None and vendor["slug"] not in rack_hint: + parsed_rack_types = [] + else: + parsed_rack_types = _parse_vendor_racks(dtl_repo, racks_path, vendor["name"], args.slugs or []) + else: + parsed_rack_types = [] + + return parsed_device_types, parsed_module_types, parsed_rack_types + + +def _run_vendor_loop( + dtl_repo, + netbox, + args, + handle, + vendors_to_process, + devices_path, + modules_path, + racks_path, + slug_resolved, + progress, + task_registry, + vendor_task_id, +) -> None: + """Process all selected vendors and finalize preload/progress teardown.""" + cache_preload_job = None + try: + for vendor in vendors_to_process: + parsed_device_types, parsed_module_types, parsed_rack_types = _parse_vendor_types( + dtl_repo, netbox, args, vendor, devices_path, modules_path, racks_path, slug_resolved + ) + + if not parsed_device_types and not parsed_module_types and not parsed_rack_types: + if vendor_task_id is not None: + progress.advance(vendor_task_id) + continue + + netbox.load_vendor(vendor["slug"]) + cache_preload_job = None + + if (parsed_device_types or parsed_module_types) and not args.only_new: + cache_preload_job = netbox.device_types.start_component_preload( + manufacturer_slug=vendor["slug"], + progress=progress, + task_registry=task_registry, + ) + + def _pump(): + if cache_preload_job and progress is not None: + netbox.device_types.pump_preload_progress(cache_preload_job, progress) + + handle.verbose_log(f"{len(parsed_device_types)} Device-Types Found") + _pump() + netbox.create_manufacturers([vendor]) + _pump() + + cache_preload_job = _process_device_types( + args, + netbox, + handle, + progress, + parsed_device_types, + cache_preload_job, + vendor_slug=vendor["slug"], + task_registry=task_registry, + ) + _pump() + + if netbox.modules: + _process_module_types( + args, + netbox, + handle, + progress, + parsed_module_types, + task_registry=task_registry, + ) + _pump() + + _process_rack_types( + args, + netbox, + handle, + progress, + parsed_rack_types, + task_registry=task_registry, + ) + _pump() + + if cache_preload_job: + netbox.device_types.stop_component_preload(cache_preload_job, progress=progress) + cache_preload_job = None + + if vendor_task_id is not None: + progress.advance(vendor_task_id) + finally: + if cache_preload_job: + netbox.device_types.stop_component_preload(cache_preload_job, progress=progress) + _finalize_task_registry(progress, task_registry) + handle.set_console(None) + + +def main(): + """Orchestrate importing device- and module-types from a Git repository into NetBox. + + Parses CLI arguments, validates environment variables, clones/pulls the DTL repo, + parses YAML files, and creates manufacturers, device types, and module types in NetBox. + Reports progress and summary counters. + """ + startTime = datetime.now() + + parser = _build_argument_parser() args = parser.parse_args() - if args.remove_components and not args.update: - parser.error("--remove-components requires --update") - if args.force_resolve_conflicts and not args.update: - parser.error("--force-resolve-conflicts requires --update") + _validate_argument_combinations(parser, args) # Normalize arguments - args.vendors = [v.casefold() for vendor in args.vendors for v in vendor.split(",") if v.strip()] - args.slugs = [s for slug in args.slugs for s in slug.split(",") if s.strip()] + args.vendors = [ + re.sub(r"\W+", "-", v.strip().casefold()) for vendor in args.vendors for v in vendor.split(",") if v.strip() + ] + args.slugs = [s.strip() for slug in args.slugs for s in slug.split(",") if s.strip()] handle = LogHandler(args) _check_env_vars(handle) + if args.export_diff: + _run_export_diff(settings, handle, args) + return + dtl_repo = DTLRepo(args, settings.REPO_PATH, handle) # Instantiate NetBox with all required dependencies @@ -823,89 +1178,63 @@ def main(): # For now, we will update NetBox to verify compatibility with this new setup netbox = NetBox(settings, handle) # handle passed explicitly netbox.force_resolve_conflicts = args.force_resolve_conflicts + netbox.remove_unmanaged_types = args.remove_unmanaged_types + netbox.verify_images = args.verify_images # Confirm effective run behavior right after compatibility checks. log_run_mode(handle, args) _log_import_filters(handle, args) - files, discovered_vendors = dtl_repo.get_devices(dtl_repo.get_devices_path(), args.vendors) - cache_preload_job = None - _module_parse_executor = None - _module_parse_future = None - - with get_progress_panel(args.show_remaining_time) as progress: - if progress is not None: - handle.set_console(progress.console) - try: - parse_fn = None - - def on_parse_step(): - """Invoke *parse_fn* (if set) after each parsed file, used to pump preload progress.""" - if parse_fn is not None: - parse_fn() - - parse_progress = get_progress_wrapper(progress, files, desc="Parsing Device Types", on_step=on_parse_step) - - if not args.only_new: - cache_preload_job = netbox.device_types.start_component_preload( - progress=progress, - ) - if progress is not None: - - def pump_preload(): - """Drain pending preload-progress updates from the background preload job.""" - netbox.device_types.pump_preload_progress(cache_preload_job, progress) + devices_path = dtl_repo.get_devices_path() + modules_path = dtl_repo.get_modules_path() + racks_path = dtl_repo.get_racks_path() - parse_fn = pump_preload + # Discover all vendors present in the repo across all three type directories. + all_vendors = dtl_repo.discover_vendors(devices_path, modules_path, racks_path) - device_types = dtl_repo.parse_files( - files, - slugs=args.slugs, - progress=parse_progress, + # Filter to the requested vendors when --vendor args are provided. + if args.vendors: + vendor_slug_filter = {v.lower() for v in args.vendors} + vendors_to_process = [v for v in all_vendors if v["slug"] in vendor_slug_filter] + if not vendors_to_process: + handle.log( + f"No vendors matched --vendors: {', '.join(args.vendors)}. " + f"Available: {', '.join(v['slug'] for v in all_vendors[:10])}" + f"{'...' if len(all_vendors) > 10 else ''}" ) - on_parse_step() - vendors, selected_vendor_slugs = filter_vendors_for_parsed_types(discovered_vendors, device_types) - - handle.verbose_log(f"{len(vendors)} Vendors Found") - handle.verbose_log(f"{len(device_types)} Device-Types Found") - - # Start module type file discovery and YAML parsing in a background thread - # so it overlaps with device type processing (which can take minutes). - if netbox.modules: - _module_vendor_filter = args.vendors - if args.slugs and not args.vendors: - _module_vendor_filter = sorted(selected_vendor_slugs) - _module_parse_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - _module_parse_future = _module_parse_executor.submit( - _bg_parse_module_types, dtl_repo, _module_vendor_filter, args.slugs - ) + raise SystemExit(1) + else: + vendors_to_process = all_vendors - netbox.create_manufacturers(vendors) + vendors_to_process, slug_resolved = _apply_slug_fast_path(dtl_repo, args, vendors_to_process, handle) - cache_preload_job = _process_device_types(args, netbox, handle, progress, device_types, cache_preload_job) + if args.vendors and not vendors_to_process: + handle.log(f"No vendors matched the combination of --vendors and --slugs: {', '.join(args.vendors)}") + raise SystemExit(1) - if netbox.modules: - _process_module_types( - args, - netbox, - dtl_repo, - handle, - progress, - selected_vendor_slugs, - module_parse_future=_module_parse_future, - module_parse_executor=_module_parse_executor, - ) - _module_parse_future = None - _module_parse_executor = None - _process_rack_types(args, netbox, dtl_repo, handle, progress, selected_vendor_slugs) - finally: - if cache_preload_job: - netbox.device_types.stop_component_preload(cache_preload_job) - if _module_parse_future is not None and not _module_parse_future.done(): - _module_parse_future.cancel() - if _module_parse_executor is not None: - _module_parse_executor.shutdown(wait=False, cancel_futures=True) - handle.set_console(None) + with get_progress_panel(args.show_remaining_time) as progress: + if progress is not None: + handle.set_console(progress.console) + # Shared task registry for cumulative progress bars across all vendors. + task_registry = {} if progress is not None else None + vendor_task_id = None + if progress is not None and vendors_to_process: + _vdesc = "Vendors".ljust(_PROGRESS_DESC_WIDTH) + vendor_task_id = progress.add_task(_vdesc, total=len(vendors_to_process)) + _run_vendor_loop( + dtl_repo=dtl_repo, + netbox=netbox, + args=args, + handle=handle, + vendors_to_process=vendors_to_process, + devices_path=devices_path, + modules_path=modules_path, + racks_path=racks_path, + slug_resolved=slug_resolved, + progress=progress, + task_registry=task_registry, + vendor_task_id=vendor_task_id, + ) _log_run_summary(handle, netbox, startTime, dtl_repo=dtl_repo) @@ -932,3 +1261,9 @@ def pump_preload(): file=sys.stderr, ) raise SystemExit(1) + except requests.exceptions.ConnectionError as exc: + print( + f"[{datetime.now().strftime('%H:%M:%S')}] Error: {_fmt_connection_error(settings.NETBOX_URL, exc)}", + file=sys.stderr, + ) + raise SystemExit(1) diff --git a/tests/conftest.py b/tests/conftest.py index 5ceb82a1b..587e25ccc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,16 @@ from unittest.mock import MagicMock, patch +@pytest.fixture(autouse=True) +def reset_graphql_clamping_warned(mock_env_vars): + """Reset the module-level page-size clamping warning dedup set before each test.""" + import core.graphql_client as _gql + + _gql._CLAMPING_WARNED.clear() + yield + _gql._CLAMPING_WARNED.clear() + + @pytest.fixture(autouse=True) def mock_env_vars(): """Set mandatory environment variables to prevent settings.py from exiting.""" diff --git a/tests/test_change_detector.py b/tests/test_change_detector.py index 1e8553b9d..3535119c6 100644 --- a/tests/test_change_detector.py +++ b/tests/test_change_detector.py @@ -133,6 +133,17 @@ def test_no_removal_when_key_absent(self): changes = detector._compare_components(yaml_data, 1) assert not any(c.component_name == "eth99" for c in changes) + def test_removal_when_key_absent_with_remove_unmanaged_types(self): + """remove_unmanaged_types=True flags removals even when the YAML omits the section entirely.""" + existing_comp = MagicMock() + existing_comp.name = "eth99" + dt_instance = MagicMock() + dt_instance.cached_components = {"interface_templates": {("device", 1): {"eth99": existing_comp}}} + detector = ChangeDetector(dt_instance, MagicMock(), remove_unmanaged_types=True) + yaml_data = {} # 'interfaces' key absent + changes = detector._compare_components(yaml_data, 1) + assert any(c.component_name == "eth99" and c.change_type == ChangeType.COMPONENT_REMOVED for c in changes) + def test_component_changed_when_property_differs(self): existing_comp = MagicMock() existing_comp.type = "virtual" @@ -253,11 +264,23 @@ def _make_detector(self, verbose=False): handle.args = SimpleNamespace(verbose=verbose) return ChangeDetector(dt_instance, handle) - def test_empty_report_logs_zeros(self): + def test_empty_report_logs_nothing(self): + """An all-zero report (no new, no modified, no unchanged) should be silent.""" detector = self._make_detector() report = ChangeReport(new_device_types=[], modified_device_types=[], unchanged_count=0) detector.log_change_report(report) - detector.handle.log.assert_any_call("New device types: 0") + detector.handle.log.assert_not_called() + detector.handle.verbose_log.assert_not_called() + + def test_unchanged_only_report_logs_verbose(self): + """When only unchanged types exist, a brief verbose_log summary is emitted.""" + detector = self._make_detector() + report = ChangeReport(new_device_types=[], modified_device_types=[], unchanged_count=5) + detector.log_change_report(report) + detector.handle.log.assert_not_called() + detector.handle.verbose_log.assert_called_once() + msg = detector.handle.verbose_log.call_args[0][0] + assert "5 unchanged" in msg def test_modified_with_removals_always_logged(self): detector = self._make_detector(verbose=False) @@ -489,6 +512,29 @@ def test_multi_mapping_added_detected(self): ) assert any(c.property_name == "_mappings" for c in changes) + def test_missing_canonical_mappings_are_treated_as_unmanaged(self): + yaml_comp = { + "name": "FP1", + "_mappings": [{"rear_port": "RP1", "front_port_position": 1, "rear_port_position": 1}], + } + netbox_comp = self._make_netbox_comp(None) + + changes = self._cd()._compare_component_properties( + yaml_comp, netbox_comp, ["_mappings"], comp_type="front-ports" + ) + + assert changes == [] + + def test_missing_component_attribute_is_skipped(self): + yaml_comp = {"name": "eth0", "type": "1000base-t"} + netbox_comp = SimpleNamespace() + + changes = self._cd()._compare_component_properties( + yaml_comp, netbox_comp, ["name", "type"], comp_type="interfaces" + ) + + assert changes == [] + def test_no_mappings_key_in_yaml_skips_comparison(self): """When _mappings is absent from YAML, no comparison is done (absent != removal).""" yaml_comp = {"name": "FP1", "type": "8p8c"} # no _mappings key @@ -589,6 +635,21 @@ def test_unexpected_exception_propagates(self): # --------------------------------------------------------------------------- +class TestLogChangeReportNoModified: + """Tests for log_change_report when there are no modified types.""" + + def test_logs_zero_modified_when_only_new_types_exist(self): + handle = MagicMock() + handle.args.verbose = False + detector = ChangeDetector(MagicMock(), handle) + report = ChangeReport(new_device_types=[DeviceTypeChange("cisco", "X", "x", is_new=True)]) + + detector.log_change_report(report) + + logged = [call.args[0] for call in handle.log.call_args_list] + assert "Modified device types: 0" in logged + + class TestCompareDeviceTypePropertiesMissingAttribute: """Tests for the _MISSING sentinel guard inside _compare_device_type_properties.""" diff --git a/tests/test_export_manifest.py b/tests/test_export_manifest.py new file mode 100644 index 000000000..8ffeb7965 --- /dev/null +++ b/tests/test_export_manifest.py @@ -0,0 +1,127 @@ +"""Tests for core/export_manifest.py.""" + +import json +from core.export_manifest import load_manifest, save_manifest, is_entry_fresh, update_entry + + +class TestLoadManifest: + """Tests for load_manifest function.""" + + def test_returns_empty_manifest_when_file_missing(self, tmp_path): + m = load_manifest(tmp_path / ".export-manifest.json") + assert m == {"device-types": {}, "module-types": {}, "rack-types": {}} + + def test_returns_empty_manifest_when_corrupt(self, tmp_path): + p = tmp_path / ".export-manifest.json" + p.write_text("not-json{{{") + m = load_manifest(p) + assert m == {"device-types": {}, "module-types": {}, "rack-types": {}} + + def test_loads_existing_manifest(self, tmp_path): + p = tmp_path / ".export-manifest.json" + data = { + "device-types": {"Nokia/acme-x": {"last_updated": "2024-01-01T00:00:00Z"}}, + "module-types": {}, + "rack-types": {}, + } + p.write_text(json.dumps(data)) + m = load_manifest(p) + assert m["device-types"]["Nokia/acme-x"]["last_updated"] == "2024-01-01T00:00:00Z" + + def test_returns_empty_when_json_is_list(self, tmp_path): + """JSON array (not dict) must return the empty manifest, not raise AttributeError.""" + p = tmp_path / ".export-manifest.json" + p.write_text(json.dumps([{"device-types": {}}])) + m = load_manifest(p) + assert m == {"device-types": {}, "module-types": {}, "rack-types": {}} + + def test_returns_empty_when_json_is_string(self, tmp_path): + p = tmp_path / ".export-manifest.json" + p.write_text('"just a string"') + m = load_manifest(p) + assert m == {"device-types": {}, "module-types": {}, "rack-types": {}} + + def test_returns_empty_when_section_is_not_dict(self, tmp_path): + """A section that is not a dict (e.g. a list) must be reset to {}.""" + p = tmp_path / ".export-manifest.json" + p.write_text(json.dumps({"device-types": ["bad"], "module-types": {}, "rack-types": {}})) + m = load_manifest(p) + assert m["device-types"] == {} + assert m["module-types"] == {} + + def test_fills_missing_sections(self, tmp_path): + """Manifest missing one section must get that section initialised to {}.""" + p = tmp_path / ".export-manifest.json" + p.write_text(json.dumps({"device-types": {"Nokia/x": {"last_updated": "ts"}}})) + m = load_manifest(p) + assert m["module-types"] == {} + assert m["rack-types"] == {} + assert m["device-types"]["Nokia/x"]["last_updated"] == "ts" + + def test_non_utf8_file_returns_empty_manifest(self, tmp_path): + """UnicodeDecodeError (non-UTF-8 file) must be caught and return empty manifest.""" + p = tmp_path / ".export-manifest.json" + p.write_bytes(b"\xff\xfe{}") # BOM + non-UTF-8 bytes + m = load_manifest(p) + assert m == {"device-types": {}, "module-types": {}, "rack-types": {}} + + +class TestSaveManifest: + """Tests for save_manifest function.""" + + def test_saves_atomically(self, tmp_path): + p = tmp_path / ".export-manifest.json" + data = { + "device-types": {"Nokia/acme-x": {"last_updated": "2024-01-01T00:00:00Z"}}, + "module-types": {}, + "rack-types": {}, + } + save_manifest(p, data) + assert p.exists() + loaded = json.loads(p.read_text()) + assert loaded == data + + def test_overwrites_existing(self, tmp_path): + p = tmp_path / ".export-manifest.json" + p.write_text(json.dumps({"device-types": {"old": {}}, "module-types": {}, "rack-types": {}})) + new_data = {"device-types": {"new": {}}, "module-types": {}, "rack-types": {}} + save_manifest(p, new_data) + assert json.loads(p.read_text()) == new_data + + +class TestIsEntryFresh: + """Tests for is_entry_fresh function.""" + + def test_fresh_when_last_updated_matches(self): + manifest = { + "device-types": {"Nokia/acme-x": {"last_updated": "2024-01-01T00:00:00Z"}}, + "module-types": {}, + "rack-types": {}, + } + assert is_entry_fresh(manifest, "device-types", "Nokia/acme-x", "2024-01-01T00:00:00Z") is True + + def test_stale_when_last_updated_differs(self): + manifest = { + "device-types": {"Nokia/acme-x": {"last_updated": "2024-01-01T00:00:00Z"}}, + "module-types": {}, + "rack-types": {}, + } + assert is_entry_fresh(manifest, "device-types", "Nokia/acme-x", "2024-02-01T00:00:00Z") is False + + def test_stale_when_entry_missing(self): + manifest = {"device-types": {}, "module-types": {}, "rack-types": {}} + assert is_entry_fresh(manifest, "device-types", "Nokia/acme-x", "2024-01-01T00:00:00Z") is False + + +class TestUpdateEntry: + """Tests for update_entry function.""" + + def test_adds_new_entry(self): + manifest = {"device-types": {}, "module-types": {}, "rack-types": {}} + update_entry(manifest, "device-types", "Nokia/acme-x", "2024-01-01T00:00:00Z") + assert manifest["device-types"]["Nokia/acme-x"]["last_updated"] == "2024-01-01T00:00:00Z" + + def test_updates_existing_entry(self): + manifest = {"device-types": {"Nokia/acme-x": {"last_updated": "old"}}, "module-types": {}, "rack-types": {}} + update_entry(manifest, "device-types", "Nokia/acme-x", "2024-02-01T00:00:00Z") + assert manifest["device-types"]["Nokia/acme-x"]["last_updated"] == "2024-02-01T00:00:00Z" diff --git a/tests/test_exporter.py b/tests/test_exporter.py new file mode 100644 index 000000000..4d92450cc --- /dev/null +++ b/tests/test_exporter.py @@ -0,0 +1,1395 @@ +"""Tests for core/export.py — Exporter class.""" + +import pytest +from unittest.mock import MagicMock, patch +import yaml + +from core.export import ( + ExportItem, + Exporter, + _canon_mfr_slug, + _is_subset, + _make_filename, + _normalize_for_compare, + _repo_supersedes, + _sanitize_attachment_filename, + _yaml_equal, + _SKIP, +) + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + + +def _make_settings(tmp_path): + s = MagicMock() + s.NETBOX_URL = "http://localhost:8000/" + s.NETBOX_TOKEN = "test-token" + s.IGNORE_SSL_ERRORS = False + s.REPO_PATH = str(tmp_path / "repo") + return s + + +def _make_handle(): + h = MagicMock() + h.log = MagicMock() + h.verbose = False + return h + + +def _make_mfr(name="Nokia", slug="nokia"): + m = MagicMock() + m.name = name + m.slug = slug + return m + + +def _make_dt( + id=1, + model="7750-SR-7s", + slug="nokia-7750-sr-7s", + last_updated="2024-01-01T00:00:00Z", + u_height=7, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, +): + r = MagicMock() + r.id = id + r.model = model + r.slug = slug + r.last_updated = last_updated + r.u_height = u_height + r.is_full_depth = is_full_depth + r.part_number = part_number + r.airflow = airflow + r.weight = weight + r.weight_unit = weight_unit + r.description = description + r.comments = comments + r.subdevice_role = subdevice_role + r.front_image = front_image + r.rear_image = rear_image + r.manufacturer = _make_mfr() + return r + + +def _make_mt( + id=10, + model="SFP-10G", + last_updated="2024-01-01T00:00:00Z", + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", +): + r = MagicMock() + r.id = id + r.model = model + r.last_updated = last_updated + r.part_number = part_number + r.airflow = airflow + r.weight = weight + r.weight_unit = weight_unit + r.description = description + r.comments = comments + r.manufacturer = _make_mfr() + return r + + +def _make_rt( + id=20, + model="Rack-42U", + slug="rack-42u", + last_updated="2024-01-01T00:00:00Z", + form_factor="4-post-cabinet", + description="", + width=19, + u_height=42, + starting_unit=1, + outer_width=None, + outer_height=None, + outer_depth=None, + outer_unit=None, + mounting_depth=None, + weight=None, + max_weight=None, + weight_unit=None, + desc_units=False, + comments="", +): + r = MagicMock() + r.id = id + r.model = model + r.slug = slug + r.last_updated = last_updated + r.form_factor = form_factor + r.description = description + r.width = width + r.u_height = u_height + r.starting_unit = starting_unit + r.outer_width = outer_width + r.outer_height = outer_height + r.outer_depth = outer_depth + r.outer_unit = outer_unit + r.mounting_depth = mounting_depth + r.weight = weight + r.max_weight = max_weight + r.weight_unit = weight_unit + r.desc_units = desc_units + r.comments = comments + r.manufacturer = _make_mfr() + return r + + +# ── Tests ───────────────────────────────────────────────────────────────────── + + +class TestMakeFilename: + """Tests for _make_filename sanitizer.""" + + def test_spaces_replaced_with_dashes(self): + assert _make_filename("10622 G2") == "10622-G2" + + def test_forward_slash_replaced(self): + assert _make_filename("PS-7220-IXR-D2/D3-AC-B2F") == "PS-7220-IXR-D2-D3-AC-B2F" + + def test_back_slash_replaced(self): + assert _make_filename("model\\sub") == "model-sub" + + def test_consecutive_dashes_collapsed(self): + assert _make_filename("A / B") == "A-B" + + def test_leading_trailing_dashes_stripped(self): + assert _make_filename("/leading-slash") == "leading-slash" + + def test_no_change_for_clean_model(self): + assert _make_filename("7750-SR-7s") == "7750-SR-7s" + + def test_case_preserved(self): + assert _make_filename("7220 IXR-D2L 25/100GE") == "7220-IXR-D2L-25-100GE" + + +class TestNormalizeForCompare: + """Tests for _normalize_for_compare function.""" + + def test_float_int_coercion(self): + result = _normalize_for_compare({"u_height": 1.0, "weight": 2.5}) + assert result["u_height"] == 1 + assert isinstance(result["u_height"], int) + assert result["weight"] == 2.5 + + def test_empty_string_becomes_none(self): + result = _normalize_for_compare({"description": ""}) + assert result["description"] is None + + def test_nested_list_normalized(self): + result = _normalize_for_compare({"interfaces": [{"positions": 1.0}]}) + assert result["interfaces"][0]["positions"] == 1 + assert isinstance(result["interfaces"][0]["positions"], int) + + def test_named_component_lists_sorted_for_comparison(self): + """Components with 'name' field compare equal regardless of order.""" + a = {"interfaces": [{"name": "c1"}, {"name": "c2"}, {"name": "c10"}]} + b = {"interfaces": [{"name": "c10"}, {"name": "c1"}, {"name": "c2"}]} + assert _normalize_for_compare(a) == _normalize_for_compare(b) + + def test_lists_without_name_field_not_reordered(self): + """Plain lists (e.g. of strings) keep their order.""" + a = _normalize_for_compare({"x": [3, 1, 2]}) + assert a["x"] == [3, 1, 2] + + +class TestYamlEqual: + """Tests for _yaml_equal helper.""" + + def test_yaml_equal_normalizes_component_order_and_numbers(self): + left = {"interfaces": [{"name": "b", "positions": 1.0}, {"name": "a", "positions": 2.0}]} + right = {"interfaces": [{"name": "a", "positions": 2}, {"name": "b", "positions": 1}]} + assert _yaml_equal(left, right) is True + + +class TestRepoSupersedes: + """Tests for _repo_supersedes / _is_subset (asymmetric containment).""" + + def test_equal_dicts(self): + repo = {"manufacturer": "Nokia", "model": "X", "u_height": 1} + nb = {"manufacturer": "Nokia", "model": "X", "u_height": 1} + assert _repo_supersedes(repo, nb) is True + + def test_repo_has_extra_field(self): + """Repo carries 'profile' that NB lacks → repo is superset → suppress.""" + repo = {"manufacturer": "Nokia", "model": "X", "profile": "psu"} + nb = {"manufacturer": "Nokia", "model": "X"} + assert _repo_supersedes(repo, nb) is True + + def test_nb_has_extra_field(self): + """NB carries 'comments' that repo lacks → NOT a superset → export.""" + repo = {"manufacturer": "Nokia", "model": "X"} + nb = {"manufacturer": "Nokia", "model": "X", "comments": "datasheet"} + assert _repo_supersedes(repo, nb) is False + + def test_value_differs(self): + repo = {"manufacturer": "Nokia", "model": "7750-SR-7s"} + nb = {"manufacturer": "Nokia", "model": "7750 SR-7s"} + assert _repo_supersedes(repo, nb) is False + + def test_named_component_match_by_name(self): + """Each NB component must be in repo (by name) with equal fields.""" + repo = { + "interfaces": [ + {"name": "eth0", "type": "1000base-t"}, + {"name": "eth1", "type": "1000base-t"}, + ] + } + nb = {"interfaces": [{"name": "eth0", "type": "1000base-t"}]} + assert _repo_supersedes(repo, nb) is True + + def test_repo_component_has_extra_field(self): + """Per-component extras in repo (e.g. 'label') don't trigger export.""" + repo = {"interfaces": [{"name": "eth0", "type": "1000base-t", "label": "WAN"}]} + nb = {"interfaces": [{"name": "eth0", "type": "1000base-t"}]} + assert _repo_supersedes(repo, nb) is True + + def test_nb_component_value_differs(self): + repo = {"interfaces": [{"name": "eth0", "type": "1000base-t"}]} + nb = {"interfaces": [{"name": "eth0", "type": "10gbase-t"}]} + assert _repo_supersedes(repo, nb) is False + + def test_nb_has_extra_component(self): + repo = {"interfaces": [{"name": "eth0", "type": "1000base-t"}]} + nb = { + "interfaces": [ + {"name": "eth0", "type": "1000base-t"}, + {"name": "eth1", "type": "1000base-t"}, + ] + } + assert _repo_supersedes(repo, nb) is False + + def test_is_subset_with_floats(self): + """Numeric normalization applies (33 vs 33.0).""" + assert _is_subset(_normalize_for_compare({"weight": 33}), _normalize_for_compare({"weight": 33.0})) is True + + def test_is_subset_returns_false_when_sub_list_sup_is_not_list(self): + assert _is_subset([1, 2], {"not": "a-list"}) is False + + def test_is_subset_returns_false_when_sub_is_dict_sup_is_not_dict(self): + """Branch: sub is dict but sup is a scalar or list.""" + assert _is_subset({"key": "val"}, "not-a-dict") is False + assert _is_subset({"key": "val"}, [1, 2]) is False + + def test_is_subset_positional_lists_require_exact_equality(self): + assert _is_subset([1, 2], [1, 2]) is True + + # ── manufacturer normalization ────────────────────────────────────────── + + def test_repo_dict_mfr_supersedes_nb_slug_string(self): + """Repo with dict-form manufacturer must compare equal to NB plain slug. + + Before the fix, {name: Nokia, slug: nokia} vs "nokia" was never a + subset because _is_subset(str, dict) falls through to str == dict. + """ + repo = {"manufacturer": {"name": "Nokia", "slug": "nokia"}, "model": "X"} + nb = {"manufacturer": "nokia", "model": "X"} + assert _repo_supersedes(repo, nb) is True + + def test_repo_capitalised_slug_dict_normalised(self): + """Capitalized slug value in repo YAML must still match NB string.""" + repo = {"manufacturer": {"slug": "Nokia"}, "model": "X"} + nb = {"manufacturer": "nokia", "model": "X"} + assert _repo_supersedes(repo, nb) is True + + def test_repo_name_only_dict_mfr_normalised(self): + """Dict with only 'name' key is slugified and matches NB slug string.""" + repo = {"manufacturer": {"name": "Cisco Systems"}, "model": "X"} + nb = {"manufacturer": "cisco-systems", "model": "X"} + assert _repo_supersedes(repo, nb) is True + + +class TestCanonMfrSlug: + """Unit tests for the _canon_mfr_slug helper.""" + + def test_plain_string_lowercased(self): + assert _canon_mfr_slug("Nokia") == "nokia" + + def test_slug_key_normalised(self): + assert _canon_mfr_slug({"slug": "Nokia"}) == "nokia" + + def test_name_key_slugified(self): + assert _canon_mfr_slug({"name": "Cisco Systems"}) == "cisco-systems" + + def test_slug_preferred_over_name(self): + assert _canon_mfr_slug({"slug": "cisco", "name": "Cisco Systems"}) == "cisco" + + def test_unknown_type_returns_empty(self): + assert _canon_mfr_slug(42) == "" + assert _canon_mfr_slug(None) == "" + + def test_empty_dict_returns_empty(self): + assert _canon_mfr_slug({}) == "" + + +class TestExporterDirWritable: + """Tests for export directory writability checks.""" + + def test_raises_when_dir_not_writable(self, tmp_path, mocker): + settings = _make_settings(tmp_path) + export_dir = tmp_path / "export" + export_dir.mkdir() + mocker.patch("os.access", return_value=False) + exporter = Exporter(settings, _make_handle(), str(export_dir), False, None) + with pytest.raises(PermissionError, match="not writable"): + exporter._verify_export_dir_writable() + + def test_creates_dir_if_missing(self, tmp_path): + settings = _make_settings(tmp_path) + export_dir = tmp_path / "export" / "new" + exporter = Exporter(settings, _make_handle(), str(export_dir), False, None) + exporter._verify_export_dir_writable() + assert export_dir.exists() + + +class TestDetermineExportSet: + """Test the three export triggers.""" + + def _setup_exporter(self, tmp_path): + settings = _make_settings(tmp_path) + (tmp_path / "repo").mkdir(parents=True) + return Exporter(settings, _make_handle(), str(tmp_path / "extra"), False, None) + + def test_absent_from_repo_triggers_export(self, tmp_path): + exporter = self._setup_exporter(tmp_path) + dt = _make_dt() + items = exporter._determine_export_set_for_device_types( + nb_records=[dt], + repo_dt_by_slug={}, + components_by_dt_id={}, + ) + assert len(items) == 1 + assert items[0].reason == "absent" + + def test_matching_yaml_not_exported(self, tmp_path): + exporter = self._setup_exporter(tmp_path) + dt = _make_dt() + from core.nb_serializer import serialize_device_type + + repo_yaml = serialize_device_type(dt, {}) + items = exporter._determine_export_set_for_device_types( + nb_records=[dt], + repo_dt_by_slug={("nokia", dt.slug): repo_yaml}, + components_by_dt_id={}, + ) + assert len(items) == 0 + + def test_differs_from_repo_triggers_export(self, tmp_path): + exporter = self._setup_exporter(tmp_path) + dt = _make_dt(u_height=7) + repo_yaml = { + "manufacturer": {"slug": "nokia"}, + "model": dt.model, + "slug": dt.slug, + "u_height": 9, + "is_full_depth": True, + } + items = exporter._determine_export_set_for_device_types( + nb_records=[dt], + repo_dt_by_slug={("nokia", dt.slug): repo_yaml}, + components_by_dt_id={}, + ) + assert len(items) == 1 + assert items[0].reason == "differs" + + def test_images_missing_locally_triggers_export(self, tmp_path): + exporter = self._setup_exporter(tmp_path) + dt = _make_dt(front_image="/media/devicetype-images/nokia-7750-sr-7s.front.png") + from core.nb_serializer import serialize_device_type + + repo_yaml = serialize_device_type(dt, {}) + # Do NOT create the image file — it's missing + items = exporter._determine_export_set_for_device_types( + nb_records=[dt], + repo_dt_by_slug={("nokia", dt.slug): repo_yaml}, + components_by_dt_id={}, + ) + assert len(items) == 1 + assert items[0].reason == "images-missing" + + def test_images_present_not_exported(self, tmp_path): + exporter = self._setup_exporter(tmp_path) + dt = _make_dt(front_image="/media/devicetype-images/nokia-7750-sr-7s.front.png") + from core.nb_serializer import serialize_device_type + + repo_yaml = serialize_device_type(dt, {}) + # Create the image file + img_dir = tmp_path / "repo" / "elevation-images" / "Nokia" + img_dir.mkdir(parents=True) + (img_dir / "nokia-7750-sr-7s.front.png").write_bytes(b"PNG") + items = exporter._determine_export_set_for_device_types( + nb_records=[dt], + repo_dt_by_slug={("nokia", dt.slug): repo_yaml}, + components_by_dt_id={}, + ) + assert len(items) == 0 + + def test_slug_collision_across_manufacturers_resolved_correctly(self, tmp_path): + """Two DTs with the same slug but different manufacturers must each match their own repo YAML.""" + exporter = self._setup_exporter(tmp_path) + + shared_slug = "shared-model-x1" + + dt_nokia = _make_dt(id=1, model="Model-X1", slug=shared_slug) + dt_nokia.manufacturer = _make_mfr(name="Nokia", slug="nokia") + + dt_acme = _make_dt(id=2, model="Model-X1", slug=shared_slug) + dt_acme.manufacturer = _make_mfr(name="Acme", slug="acme") + + from core.nb_serializer import serialize_device_type + + nokia_repo_yaml = serialize_device_type(dt_nokia, {}) + # Acme YAML differs (u_height 99) so it should trigger "differs" + acme_repo_yaml = { + "manufacturer": {"slug": "acme"}, + "model": dt_acme.model, + "slug": dt_acme.slug, + "u_height": 99, + "is_full_depth": True, + } + + repo_dt_by_slug = { + ("nokia", shared_slug): nokia_repo_yaml, + ("acme", shared_slug): acme_repo_yaml, + } + + items = exporter._determine_export_set_for_device_types( + nb_records=[dt_nokia, dt_acme], + repo_dt_by_slug=repo_dt_by_slug, + components_by_dt_id={}, + ) + + # Nokia matches exactly → no export; Acme differs → export with reason "differs" + assert len(items) == 1 + assert items[0].reason == "differs" + assert items[0].nb_record.id == dt_acme.id + + +class TestWriteYaml: + """Tests for YAML file writing with overwrite guards.""" + + def test_writes_new_file(self, tmp_path): + settings = _make_settings(tmp_path) + exporter = Exporter(settings, _make_handle(), str(tmp_path / "extra"), False, None) + dest = tmp_path / "extra" / "device-types" / "Nokia" / "test.yaml" + exporter._write_yaml(dest, {"model": "Test", "u_height": 1}) + assert dest.exists() + loaded = yaml.safe_load(dest.read_text()) + assert loaded["model"] == "Test" + + def test_overwrite_guard_blocks_changed_file(self, tmp_path): + settings = _make_settings(tmp_path) + exporter = Exporter(settings, _make_handle(), str(tmp_path / "extra"), force_overwrite=False, vendor_slugs=None) + dest = tmp_path / "extra" / "device-types" / "Nokia" / "test.yaml" + dest.parent.mkdir(parents=True) + dest.write_text(yaml.dump({"model": "Old"})) + result = exporter._write_yaml(dest, {"model": "New"}) + assert result is False # blocked + assert yaml.safe_load(dest.read_text())["model"] == "Old" + + def test_force_overwrite_allows_changed_file(self, tmp_path): + settings = _make_settings(tmp_path) + exporter = Exporter(settings, _make_handle(), str(tmp_path / "extra"), force_overwrite=True, vendor_slugs=None) + dest = tmp_path / "extra" / "device-types" / "Nokia" / "test.yaml" + dest.parent.mkdir(parents=True) + dest.write_text(yaml.dump({"model": "Old"})) + result = exporter._write_yaml(dest, {"model": "New"}) + assert result is True + assert yaml.safe_load(dest.read_text())["model"] == "New" + + +class TestManifestConsistency: + """Tests for manifest update consistency during image downloads.""" + + def test_manifest_not_updated_when_image_download_fails(self, tmp_path): + """When images fail to download, manifest entry should NOT be updated.""" + from core.export_manifest import load_manifest + + settings = _make_settings(tmp_path) + (tmp_path / "repo").mkdir(parents=True) + export_dir = tmp_path / "extra" + exporter = Exporter(settings, _make_handle(), str(export_dir), False, None) + # Patch _download_image to simulate failure (returns None) + exporter._download_image = MagicMock(return_value=None) + # Patch graphql to return a single device type + dt = _make_dt(front_image="/media/img/nokia-7750-sr-7s.front.png") + exporter.graphql.get_device_types = MagicMock(return_value=({("nokia", dt.model): dt}, {dt.slug: dt})) + exporter.graphql.get_module_types = MagicMock(return_value={}) + exporter.graphql.get_rack_types = MagicMock(return_value={}) + exporter.graphql.get_component_templates = MagicMock(return_value=[]) + exporter.run() + # Manifest should NOT have an entry for this item (images failed) + manifest = load_manifest(export_dir / ".export-manifest.json") + assert "Nokia/nokia-7750-sr-7s" not in manifest.get("device-types", {}) + + def test_manifest_updated_when_first_image_skipped_second_succeeds(self, tmp_path): + """First image returns _SKIP (already exists), second returns hash → ok=True → manifest updated.""" + from core.export_manifest import load_manifest + + settings = _make_settings(tmp_path) + (tmp_path / "repo").mkdir(parents=True) + export_dir = tmp_path / "extra" + exporter = Exporter(settings, _make_handle(), str(export_dir), False, None) + # First call: _SKIP (already exists); second call: hash string (success) + exporter._download_image = MagicMock(side_effect=[_SKIP, "abc123hash"]) + dt = _make_dt( + front_image="/media/img/nokia-7750-sr-7s.front.png", + rear_image="/media/img/nokia-7750-sr-7s.rear.png", + ) + exporter.graphql.get_device_types = MagicMock(return_value=({("nokia", dt.model): dt}, {dt.slug: dt})) + exporter.graphql.get_module_types = MagicMock(return_value={}) + exporter.graphql.get_rack_types = MagicMock(return_value={}) + exporter.graphql.get_component_templates = MagicMock(return_value=[]) + exporter.run() + manifest = load_manifest(export_dir / ".export-manifest.json") + assert "Nokia/nokia-7750-sr-7s" in manifest.get("device-types", {}) + + def test_manifest_not_updated_when_first_succeeds_second_fails(self, tmp_path): + """First image returns hash (success), second returns None (failure) → ok=False → manifest NOT updated.""" + from core.export_manifest import load_manifest + + settings = _make_settings(tmp_path) + (tmp_path / "repo").mkdir(parents=True) + export_dir = tmp_path / "extra" + exporter = Exporter(settings, _make_handle(), str(export_dir), False, None) + # First call: hash (success); second call: None (failure) + exporter._download_image = MagicMock(side_effect=["abc123hash", None]) + dt = _make_dt( + front_image="/media/img/nokia-7750-sr-7s.front.png", + rear_image="/media/img/nokia-7750-sr-7s.rear.png", + ) + exporter.graphql.get_device_types = MagicMock(return_value=({("nokia", dt.model): dt}, {dt.slug: dt})) + exporter.graphql.get_module_types = MagicMock(return_value={}) + exporter.graphql.get_rack_types = MagicMock(return_value={}) + exporter.graphql.get_component_templates = MagicMock(return_value=[]) + exporter.run() + manifest = load_manifest(export_dir / ".export-manifest.json") + assert "Nokia/nokia-7750-sr-7s" not in manifest.get("device-types", {}) + + +class TestWriteYamlEdgeCases: + """Edge cases for _write_yaml with corrupted or identical files.""" + + def test_corrupted_existing_file_force_overwrites(self, tmp_path): + settings = _make_settings(tmp_path) + exporter = Exporter(settings, _make_handle(), str(tmp_path / "extra"), force_overwrite=True, vendor_slugs=None) + dest = tmp_path / "extra" / "Nokia" / "test.yaml" + dest.parent.mkdir(parents=True) + dest.write_bytes(b"\xff\xfe invalid utf-8") # corrupted file + result = exporter._write_yaml(dest, {"model": "Test"}) + assert result is True + assert yaml.safe_load(dest.read_text())["model"] == "Test" + + def test_corrupted_existing_file_no_force_blocks(self, tmp_path): + settings = _make_settings(tmp_path) + exporter = Exporter(settings, _make_handle(), str(tmp_path / "extra"), force_overwrite=False, vendor_slugs=None) + dest = tmp_path / "extra" / "Nokia" / "test.yaml" + dest.parent.mkdir(parents=True) + dest.write_bytes(b"\xff\xfe invalid utf-8") # corrupted file + result = exporter._write_yaml(dest, {"model": "Test"}) + assert result is False # blocked — treat corrupted file as different content + + +class TestModuleTypeExport: + """Tests for module type export set determination.""" + + def test_absent_module_type_triggers_export(self, tmp_path): + settings = _make_settings(tmp_path) + (tmp_path / "repo").mkdir(parents=True) + exporter = Exporter(settings, _make_handle(), str(tmp_path / "extra"), False, None) + mt = MagicMock() + mt.id = 10 + mt.model = "SFP-10G" + mt.last_updated = "2024-01-01T00:00:00Z" + mt.part_number = None + mt.airflow = None + mt.weight = None + mt.weight_unit = None + mt.description = "" + mt.comments = "" + mt.manufacturer = _make_mfr() + items = exporter._determine_export_set_for_module_types( + nb_records=[mt], + repo_mt_by_key={}, + components_by_mt_id={}, + ) + assert len(items) == 1 + assert items[0].reason == "absent" + assert items[0].kind == "module-type" + + +class TestVendorDirSlugNormalization: + """Tests for Exporter._vendor_dirs slug-based directory matching.""" + + def _make_exporter(self, tmp_path, vendor_slugs): + settings = _make_settings(tmp_path) + return Exporter(settings, _make_handle(), str(tmp_path / "extra"), False, vendor_slugs) + + def test_single_word_dir_matches_slug(self, tmp_path): + root = tmp_path / "device-types" + (root / "Nokia").mkdir(parents=True) + exporter = self._make_exporter(tmp_path, ["nokia"]) + dirs = list(exporter._vendor_dirs(root)) + assert len(dirs) == 1 + assert dirs[0].name == "Nokia" + + def test_multi_word_dir_matches_hyphenated_slug(self, tmp_path): + """'Extreme Networks' dir must match slug 'extreme-networks'.""" + root = tmp_path / "device-types" + (root / "Extreme Networks").mkdir(parents=True) + exporter = self._make_exporter(tmp_path, ["extreme-networks"]) + dirs = list(exporter._vendor_dirs(root)) + assert len(dirs) == 1 + assert dirs[0].name == "Extreme Networks" + + def test_non_matching_vendor_excluded(self, tmp_path): + root = tmp_path / "device-types" + (root / "Nokia").mkdir(parents=True) + (root / "Juniper").mkdir(parents=True) + exporter = self._make_exporter(tmp_path, ["nokia"]) + dirs = list(exporter._vendor_dirs(root)) + assert len(dirs) == 1 + assert dirs[0].name == "Nokia" + + def test_no_filter_yields_all_dirs(self, tmp_path): + root = tmp_path / "device-types" + (root / "Nokia").mkdir(parents=True) + (root / "Extreme Networks").mkdir(parents=True) + exporter = self._make_exporter(tmp_path, None) + names = {d.name for d in exporter._vendor_dirs(root)} + assert names == {"Nokia", "Extreme Networks"} + + def test_nonexistent_root_yields_nothing(self, tmp_path): + root = tmp_path / "does-not-exist" + exporter = self._make_exporter(tmp_path, None) + assert list(exporter._vendor_dirs(root)) == [] + + +class TestExporterAdditionalCoverage: + """Additional coverage tests for Exporter internals.""" + + def _make_exporter(self, tmp_path, force_overwrite=False): + settings = _make_settings(tmp_path) + (tmp_path / "repo").mkdir(parents=True, exist_ok=True) + return Exporter(settings, _make_handle(), str(tmp_path / "extra"), force_overwrite, None) + + def test_get_module_image_details_is_cached(self, tmp_path): + exporter = self._make_exporter(tmp_path) + exporter.graphql.get_module_type_image_details = MagicMock(return_value={1: {}}) + + assert exporter._get_module_image_details() == {1: {}} + assert exporter._get_module_image_details() == {1: {}} + exporter.graphql.get_module_type_image_details.assert_called_once() + + def test_load_repo_device_types_skips_bad_yaml(self, tmp_path): + exporter = self._make_exporter(tmp_path) + vdir = tmp_path / "repo" / "device-types" / "Nokia" + vdir.mkdir(parents=True) + (vdir / "good.yaml").write_text("slug: good-slug\nmodel: Good\n", encoding="utf-8") + (vdir / "bad.yaml").write_text("foo: [\n", encoding="utf-8") + + assert exporter._load_repo_device_types() == {("nokia", "good-slug"): {"slug": "good-slug", "model": "Good"}} + + def test_load_repo_module_types_accepts_dict_and_string_manufacturers(self, tmp_path): + exporter = self._make_exporter(tmp_path) + vdir = tmp_path / "repo" / "module-types" / "Nokia" + vdir.mkdir(parents=True) + (vdir / "named.yaml").write_text("manufacturer:\n name: Nokia\nmodel: M1\n", encoding="utf-8") + (vdir / "slugged.yaml").write_text("manufacturer:\n slug: nokia\nmodel: M2\n", encoding="utf-8") + (vdir / "string.yaml").write_text("manufacturer: Nokia\nmodel: M3\n", encoding="utf-8") + (vdir / "bad.yaml").write_text("manufacturer: [\n", encoding="utf-8") + + result = exporter._load_repo_module_types() + + assert ("nokia", "M1") in result + assert ("nokia", "M2") in result + assert ("nokia", "M3") in result + + def test_load_repo_rack_types_accepts_dict_and_string_manufacturers(self, tmp_path): + exporter = self._make_exporter(tmp_path) + vdir = tmp_path / "repo" / "rack-types" / "Nokia" + vdir.mkdir(parents=True) + (vdir / "named.yaml").write_text("manufacturer:\n name: Nokia\nmodel: R1\n", encoding="utf-8") + (vdir / "slugged.yaml").write_text("manufacturer:\n slug: nokia\nmodel: R2\n", encoding="utf-8") + (vdir / "string.yaml").write_text("manufacturer: Nokia\nmodel: R3\n", encoding="utf-8") + (vdir / "bad.yaml").write_text("manufacturer: [\n", encoding="utf-8") + + result = exporter._load_repo_rack_types() + + assert ("nokia", "R1") in result + assert ("nokia", "R2") in result + assert ("nokia", "R3") in result + + def test_fetch_vendor_components_groups_device_and_module_records(self, tmp_path): + exporter = self._make_exporter(tmp_path) + dt_rec = MagicMock() + dt_rec.device_type = MagicMock(id=11) + dt_rec.module_type = None + mt_rec = MagicMock() + mt_rec.device_type = None + mt_rec.module_type = MagicMock(id=22) + + def _side_effect(endpoint_name, manufacturer_slug=None): + return [dt_rec, mt_rec] if endpoint_name == "interface_templates" else [] + + mock_client = MagicMock() + mock_client.get_component_templates.side_effect = _side_effect + with patch("core.export.NetBoxGraphQLClient", return_value=mock_client): + dt_result, mt_result = exporter._fetch_vendor_components("nokia") + + assert dt_result[11]["interface_templates"] == [dt_rec] + assert mt_result[22]["interface_templates"] == [mt_rec] + + def test_determine_export_set_for_module_types_differs_and_superseded(self, tmp_path): + from core.nb_serializer import serialize_module_type + + exporter = self._make_exporter(tmp_path) + mt = _make_mt(part_number="NEW") + + differs = exporter._determine_export_set_for_module_types( + nb_records=[mt], + repo_mt_by_key={("nokia", mt.model): {"manufacturer": "Nokia", "model": mt.model, "part_number": "OLD"}}, + components_by_mt_id={}, + ) + superseded = exporter._determine_export_set_for_module_types( + nb_records=[mt], + repo_mt_by_key={("nokia", mt.model): {**serialize_module_type(mt, {}), "profile": "extra"}}, + components_by_mt_id={}, + ) + + assert differs[0].reason == "differs" + assert superseded == [] + + def test_module_type_slug_only_manufacturer_treated_as_present(self, tmp_path): + """Regression: repo YAML with manufacturer: {slug: nokia} must not cause duplicate export. + + The loader stores keys as (mfr_slug, model); the lookup uses rec.manufacturer.slug. + Prior to the fix the loader used mfr_name ("Nokia") for {name:} entries but the + raw slug string ("nokia") for {slug:}-only entries, causing a mismatch. + """ + from core.nb_serializer import serialize_module_type + + exporter = self._make_exporter(tmp_path) + mt = _make_mt() + repo_yaml = serialize_module_type(mt, {}) + + items = exporter._determine_export_set_for_module_types( + nb_records=[mt], + repo_mt_by_key={("nokia", mt.model): repo_yaml}, + components_by_mt_id={}, + ) + + assert items == [], "module type present in repo with slug-only key must not be re-exported" + + def test_determine_export_set_for_rack_types_all_paths(self, tmp_path): + exporter = self._make_exporter(tmp_path) + absent_rt = _make_rt(model="R-ABSENT", slug="r-absent") + differs_rt = _make_rt(model="R-DIFF", slug="r-diff", u_height=42) + same_rt = _make_rt(model="R-SAME", slug="r-same") + + result = exporter._determine_export_set_for_rack_types( + nb_records=[absent_rt, differs_rt, same_rt], + repo_rt_by_key={ + ("nokia", "R-DIFF"): {"manufacturer": "Nokia", "model": "R-DIFF", "u_height": 40}, + ("nokia", "R-SAME"): { + "manufacturer": "Nokia", + "model": "R-SAME", + "slug": "r-same", + "form_factor": "4-post-cabinet", + "width": 19, + "u_height": 42, + "starting_unit": 1, + "desc_units": False, + "comments": "", + "description": "", + "profile": "extra", + }, + }, + ) + + assert [item.reason for item in result] == ["absent", "differs"] + assert [item.nb_record.model for item in result] == ["R-ABSENT", "R-DIFF"] + + def test_check_missing_images_handles_rear_only_and_none(self, tmp_path): + exporter = self._make_exporter(tmp_path) + + assert exporter._check_missing_images(None, "/rear.png", "Nokia", "rack") == "images-missing" + assert exporter._check_missing_images(None, None, "Nokia", "rack") is None + + def test_write_yaml_returns_true_for_same_content(self, tmp_path): + exporter = self._make_exporter(tmp_path, force_overwrite=False) + dest = tmp_path / "extra" / "device-types" / "Nokia" / "same.yaml" + dest.parent.mkdir(parents=True) + content = yaml.dump({"model": "Same"}, default_flow_style=False, allow_unicode=True, sort_keys=False) + dest.write_text(content, encoding="utf-8") + + assert exporter._write_yaml(dest, {"model": "Same"}) is True + assert dest.read_text(encoding="utf-8") == content + + def test_download_type_images_dispatches_module_and_rack(self, tmp_path): + exporter = self._make_exporter(tmp_path) + exporter._download_module_type_images = MagicMock(return_value=False) + exporter._download_device_type_images = MagicMock(return_value=True) + dt_item = ExportItem("device-type", _make_dt(), None, {}, "absent", "Nokia", "dt.yaml", "Nokia/dt") + mt_item = ExportItem("module-type", _make_mt(), None, {}, "absent", "Nokia", "mt.yaml", "Nokia/mt") + rt_item = ExportItem("rack-type", _make_rt(), None, {}, "absent", "Nokia", "rt.yaml", "Nokia/rt") + + assert exporter._download_type_images(dt_item) is True + assert exporter._download_type_images(mt_item) is False + assert exporter._download_type_images(rt_item) is True + + def test_download_module_type_images_handles_fetch_failure(self, tmp_path): + exporter = self._make_exporter(tmp_path) + item = ExportItem("module-type", _make_mt(id=77), None, {}, "absent", "Nokia", "mt.yaml", "Nokia/mt") + exporter._get_module_image_details = MagicMock(side_effect=RuntimeError("boom")) + + assert exporter._download_module_type_images(item) is False + assert any("Could not fetch module image details" in str(call) for call in exporter.handle.log.call_args_list) + + def test_download_module_type_images_downloads_available_attachments(self, tmp_path): + exporter = self._make_exporter(tmp_path) + item = ExportItem("module-type", _make_mt(id=55), None, {}, "absent", "Nokia", "mt.yaml", "Nokia/mt") + exporter._get_module_image_details = MagicMock( + return_value={ + 55: { + "front.png": {"url": "/img/front.png"}, + "rear.png": MagicMock(url="/img/rear.png"), + "skip.png": {"url": None}, + } + } + ) + exporter._download_image = MagicMock(side_effect=["hash", None]) + + assert exporter._download_module_type_images(item) is False + assert exporter._download_image.call_count == 2 + + def test_download_module_type_images_rejects_path_traversal(self, tmp_path): + """Attachment names with directory separators are sanitized (stripped to basename).""" + exporter = self._make_exporter(tmp_path) + item = ExportItem("module-type", _make_mt(id=99), None, {}, "absent", "Nokia", "mt.yaml", "Nokia/mt") + exporter._get_module_image_details = MagicMock( + return_value={ + 99: { + "../evil.png": {"url": "/img/evil.png"}, + } + } + ) + calls = [] + + def capture_download(url_path, dest, content_type_out=None): + calls.append(dest) + return "hash" + + exporter._download_image = capture_download + + result = exporter._download_module_type_images(item) + assert result is True + assert len(calls) == 1 + img_dir = exporter.export_dir / "module-images" / "Nokia" + # The file must be written under img_dir, not the parent + assert calls[0].parent == img_dir + assert calls[0].name == "evil.png" + + def test_download_module_type_images_derives_extension_from_url(self, tmp_path): + """When att_name has no extension, URL suffix is used to add one.""" + exporter = self._make_exporter(tmp_path) + item = ExportItem("module-type", _make_mt(id=42), None, {}, "absent", "Nokia", "mt.yaml", "Nokia/mt") + exporter._get_module_image_details = MagicMock( + return_value={ + 42: { + "front": {"url": "/media/front.jpg"}, # no extension in att_name + } + } + ) + calls = [] + + def capture_download(url_path, dest, content_type_out=None): + calls.append((url_path, dest)) + return "abc123" + + exporter._download_image = capture_download + + result = exporter._download_module_type_images(item) + assert result is True + assert len(calls) == 1 + _, dest = calls[0] + assert dest.suffix == ".jpg" + + def test_download_module_type_images_renames_from_content_type(self, tmp_path): + """When neither att_name nor URL has a known extension, Content-Type renames the file.""" + exporter = self._make_exporter(tmp_path) + item = ExportItem("module-type", _make_mt(id=11), None, {}, "absent", "Nokia", "mt.yaml", "Nokia/mt") + exporter._get_module_image_details = MagicMock( + return_value={ + 11: { + "front": {"url": "/media/front"}, # no extension anywhere + } + } + ) + img_dir = exporter.export_dir / "module-images" / "Nokia" + img_dir.mkdir(parents=True, exist_ok=True) + + def capture_download(url_path, dest, content_type_out=None): + # Write a dummy file and populate content_type_out + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(b"imgdata") + if content_type_out is not None: + content_type_out.append("image/webp") + return "deadbeef" + + exporter._download_image = capture_download + result = exporter._download_module_type_images(item) + assert result is True + # The file should have been renamed to include .webp extension + assert (img_dir / "front.webp").exists() + assert not (img_dir / "front.bin").exists() + + def test_download_module_type_images_rename_skips_overwrite_when_dest_exists(self, tmp_path): + """Rename step must NOT overwrite an existing file when force_overwrite=False.""" + exporter = self._make_exporter(tmp_path, force_overwrite=False) + item = ExportItem("module-type", _make_mt(id=22), None, {}, "absent", "Nokia", "mt.yaml", "Nokia/mt") + exporter._get_module_image_details = MagicMock( + return_value={ + 22: { + "front": {"url": "/media/front"}, # no extension → provisional dest = front.bin + } + } + ) + img_dir = exporter.export_dir / "module-images" / "Nokia" + img_dir.mkdir(parents=True, exist_ok=True) + # Pre-existing file at the final (renamed) destination + existing = img_dir / "front.webp" + existing.write_bytes(b"original") + + def capture_download(url_path, dest, content_type_out=None): + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(b"new") + if content_type_out is not None: + content_type_out.append("image/webp") + return "newhash" + + exporter._download_image = capture_download + result = exporter._download_module_type_images(item) + assert result is True + # Existing file must not have been overwritten + assert existing.read_bytes() == b"original" + # The provisional .bin file should have been cleaned up + assert not (img_dir / "front.bin").exists() + + def test_download_image_skip_existing_file(self, tmp_path): + exporter = self._make_exporter(tmp_path, force_overwrite=False) + dest = tmp_path / "extra" / "elevation-images" / "Nokia" / "existing.png" + dest.parent.mkdir(parents=True) + dest.write_bytes(b"old") + + assert exporter._download_image("/media/existing.png", dest) is _SKIP + + def test_download_image_handles_request_error(self, tmp_path): + import requests + + exporter = self._make_exporter(tmp_path) + dest = tmp_path / "extra" / "elevation-images" / "Nokia" / "error.png" + with patch("core.export.requests.get", side_effect=requests.RequestException("nope")): + assert exporter._download_image("/media/error.png", dest) is None + + def test_download_image_rejects_non_image_response(self, tmp_path): + exporter = self._make_exporter(tmp_path) + dest = tmp_path / "extra" / "elevation-images" / "Nokia" / "bad.png" + resp = MagicMock(ok=False, status_code=404, headers={"Content-Type": "application/json"}) + + with patch("core.export.requests.get", return_value=resp): + assert exporter._download_image("/media/bad.png", dest) is None + + def test_download_image_writes_file_and_returns_hash(self, tmp_path): + import hashlib + + exporter = self._make_exporter(tmp_path) + dest = tmp_path / "extra" / "elevation-images" / "Nokia" / "ok.png" + resp = MagicMock(ok=True, status_code=200, headers={"Content-Type": "image/png"}, content=b"image-bytes") + + with patch("core.export.requests.get", return_value=resp): + result = exporter._download_image("/media/ok.png", dest) + + assert result == hashlib.sha256(b"image-bytes").hexdigest() + assert dest.read_bytes() == b"image-bytes" + + def test_download_device_type_images_uses_url_extension(self, tmp_path): + """Extension must be derived from the URL, not hardcoded to .png.""" + exporter = self._make_exporter(tmp_path) + dt = _make_dt(front_image="/media/nokia-7750.front.jpg", slug="nokia-7750") + item = ExportItem("device-type", dt, None, {}, "absent", "Nokia", "nokia-7750.yaml", "Nokia/nokia-7750") + + resp = MagicMock(ok=True, status_code=200, headers={"Content-Type": "image/jpeg"}, content=b"JPEG") + with patch("core.export.requests.get", return_value=resp): + result = exporter._download_device_type_images(item) + + assert result is True + dest = tmp_path / "extra" / "elevation-images" / "Nokia" / "nokia-7750.front.jpg" + assert dest.exists(), "JPEG image must be saved with .jpg extension from URL" + + def test_download_device_type_images_falls_back_to_png_for_unknown_ext(self, tmp_path): + """Unknown URL extension falls back to .png.""" + exporter = self._make_exporter(tmp_path) + dt = _make_dt(front_image="/media/nokia-7750.front", slug="nokia-7750") + item = ExportItem("device-type", dt, None, {}, "absent", "Nokia", "nokia-7750.yaml", "Nokia/nokia-7750") + + resp = MagicMock(ok=True, status_code=200, headers={"Content-Type": "image/png"}, content=b"PNG") + with patch("core.export.requests.get", return_value=resp): + result = exporter._download_device_type_images(item) + + assert result is True + # URL has no recognised ext → falls back to .png; content-type is also png so same name + dest = tmp_path / "extra" / "elevation-images" / "Nokia" / "nokia-7750.front.png" + assert dest.exists() + + def test_run_skips_fresh_records_and_exits_when_nothing_to_export(self, tmp_path): + exporter = self._make_exporter(tmp_path) + dt = _make_dt() + mt = _make_mt() + rt = _make_rt() + exporter.graphql.get_device_types = MagicMock(return_value=({("nokia", dt.model): dt}, {dt.slug: dt})) + exporter.graphql.get_module_types = MagicMock(return_value={"nokia": {mt.model: mt}}) + exporter.graphql.get_rack_types = MagicMock(return_value={"nokia": {rt.model: rt}}) + progress = MagicMock() + progress.add_task.side_effect = [1, 2] + + with patch("core.export.is_entry_fresh", return_value=True): + exporter.run(progress=progress) + + assert progress.advance.call_args_list == [((1,),), ((2,),)] + skipped_calls = exporter.handle.verbose_log.call_args_list + assert any("Skipped 3 record(s) unchanged" in call.args[0] for call in skipped_calls) + nothing_calls = exporter.handle.log.call_args_list + assert any("Nothing to export" in call.args[0] for call in nothing_calls) + + def test_run_advances_write_task_for_skips_and_writes(self, tmp_path): + exporter = self._make_exporter(tmp_path) + dt = _make_dt(last_updated="2024-01-02T00:00:00Z") + mt = _make_mt(last_updated="2024-01-03T00:00:00Z") + exporter.graphql.get_device_types = MagicMock(return_value=({("nokia", dt.model): dt}, {dt.slug: dt})) + exporter.graphql.get_module_types = MagicMock(return_value={"nokia": {mt.model: mt}}) + exporter.graphql.get_rack_types = MagicMock(return_value={}) + exporter._fetch_vendor_components = MagicMock(return_value=({}, {})) + exporter._determine_export_set_for_device_types = MagicMock( + return_value=[ + ExportItem("device-type", dt, None, {"model": dt.model}, "absent", "Nokia", "dt.yaml", "Nokia/dt") + ] + ) + exporter._determine_export_set_for_module_types = MagicMock( + return_value=[ + ExportItem("module-type", mt, None, {"model": mt.model}, "differs", "Nokia", "mt.yaml", "Nokia/mt") + ] + ) + exporter._write_yaml = MagicMock(side_effect=[False, True]) + exporter._download_type_images = MagicMock(return_value=True) + progress = MagicMock() + progress.add_task.side_effect = [1, 2] + + with patch("core.export.is_entry_fresh", return_value=False): + exporter.run(progress=progress) + + assert progress.advance.call_args_list == [((1,),), ((2,),), ((2,),)] + assert any("Skipped (overwrite guard)" in call.args[0] for call in exporter.handle.log.call_args_list) + assert any("wrote 1 file(s), skipped 1" in call.args[0] for call in exporter.handle.log.call_args_list) + + def test_compare_vendors_to_items_skips_fresh_and_fetches_stale_vendor_once(self, tmp_path): + exporter = self._make_exporter(tmp_path) + nokia_dt = _make_dt(model="Nokia-DT") + nokia_mt = _make_mt(model="Nokia-MT") + acme_mfr = _make_mfr(name="Acme", slug="acme") + acme_fresh_dt = _make_dt(model="Acme-Fresh-DT", slug="acme-fresh-dt") + acme_fresh_dt.manufacturer = acme_mfr + acme_stale_dt = _make_dt(model="Acme-Stale-DT", slug="acme-stale-dt") + acme_stale_dt.manufacturer = acme_mfr + acme_stale_mt = _make_mt(model="Acme-Stale-MT") + acme_stale_mt.manufacturer = acme_mfr + progress = MagicMock() + progress.add_task.return_value = 41 + exporter._fetch_vendor_components = MagicMock(return_value=({"dt-components": 1}, {"mt-components": 1})) + exporter._determine_export_set_for_device_types = MagicMock(return_value=["dt-export"]) + exporter._determine_export_set_for_module_types = MagicMock(return_value=["mt-export"]) + + def _is_fresh(_manifest, kind, key, _last_updated): + return key in {"Nokia/nokia-7750-sr-7s", "Nokia/Nokia-MT", "Acme/acme-fresh-dt"} + + with patch("core.export.is_entry_fresh", side_effect=_is_fresh): + items, skipped_fresh = exporter._compare_vendors_to_items( + all_vendor_slugs=["acme", "nokia"], + dt_by_vendor={"acme": [acme_fresh_dt, acme_stale_dt], "nokia": [nokia_dt]}, + mt_by_vendor={"acme": [acme_stale_mt], "nokia": [nokia_mt]}, + manifest={}, + repo_dt_by_slug={"acme-stale-dt": {"slug": "acme-stale-dt"}}, + repo_mt_by_key={("Acme", "Acme-Stale-MT"): {"model": "Acme-Stale-MT"}}, + progress=progress, + ) + + assert items == ["dt-export", "mt-export"] + assert skipped_fresh == 3 + progress.add_task.assert_called_once_with("Comparing vendors", total=2) + assert progress.advance.call_args_list == [((41,),), ((41,),)] + exporter._fetch_vendor_components.assert_called_once_with("acme") + exporter._determine_export_set_for_device_types.assert_called_once_with( + nb_records=[acme_stale_dt], + repo_dt_by_slug={"acme-stale-dt": {"slug": "acme-stale-dt"}}, + components_by_dt_id={"dt-components": 1}, + ) + exporter._determine_export_set_for_module_types.assert_called_once_with( + nb_records=[acme_stale_mt], + repo_mt_by_key={("Acme", "Acme-Stale-MT"): {"model": "Acme-Stale-MT"}}, + components_by_mt_id={"mt-components": 1}, + ) + + def test_compare_racks_to_items_skips_fresh_and_exports_stale(self, tmp_path): + exporter = self._make_exporter(tmp_path) + fresh_rt = _make_rt(model="Rack-Fresh") + stale_rt = _make_rt(model="Rack-Stale") + progress = MagicMock() + progress.add_task.return_value = 42 + exporter._determine_export_set_for_rack_types = MagicMock(return_value=["rack-export"]) + + with patch( + "core.export.is_entry_fresh", + side_effect=lambda _manifest, _kind, key, _last_updated: key == "Nokia/Rack-Fresh", + ): + items, skipped_fresh = exporter._compare_racks_to_items( + all_rt={"nokia": {fresh_rt.model: fresh_rt, stale_rt.model: stale_rt}}, + manifest={}, + repo_rt_by_key={("Nokia", "Rack-Stale"): {"model": "Rack-Stale"}}, + progress=progress, + ) + + assert items == ["rack-export"] + assert skipped_fresh == 1 + progress.add_task.assert_called_once_with("Comparing rack types", total=2) + assert progress.advance.call_args_list == [((42,),), ((42,),)] + exporter._determine_export_set_for_rack_types.assert_called_once_with( + nb_records=[stale_rt], + repo_rt_by_key={("Nokia", "Rack-Stale"): {"model": "Rack-Stale"}}, + ) + + def test_write_export_items_logs_summary_updates_manifest_and_tracks_progress(self, tmp_path): + exporter = self._make_exporter(tmp_path) + manifest = {} + manifest_path = tmp_path / "extra" / ".export-manifest.json" + progress = MagicMock() + progress.add_task.return_value = 43 + dt_item = ExportItem("device-type", _make_dt(), None, {"model": "dt"}, "absent", "Nokia", "dt.yaml", "Nokia/dt") + mt_item = ExportItem( + "module-type", _make_mt(), None, {"model": "mt"}, "differs", "Nokia", "mt.yaml", "Nokia/mt" + ) + rt_item = ExportItem( + "rack-type", _make_rt(), None, {"model": "rt"}, "images-missing", "Nokia", "rt.yaml", "Nokia/rt" + ) + exporter._write_yaml = MagicMock(side_effect=[False, True, True]) + exporter._download_type_images = MagicMock(side_effect=[True, False]) + + with ( + patch("core.export.update_entry") as mock_update_entry, + patch("core.export.save_manifest") as mock_save_manifest, + ): + exporter._write_export_items([dt_item, mt_item, rt_item], manifest, manifest_path, progress) + + progress.add_task.assert_called_once_with("Writing exports", total=3) + assert progress.advance.call_args_list == [((43,),), ((43,),), ((43,),)] + assert ( + exporter._write_yaml.call_args_list[0].args[0] == tmp_path / "extra" / "device-types" / "Nokia" / "dt.yaml" + ) + assert ( + exporter._write_yaml.call_args_list[1].args[0] == tmp_path / "extra" / "module-types" / "Nokia" / "mt.yaml" + ) + assert exporter._write_yaml.call_args_list[2].args[0] == tmp_path / "extra" / "rack-types" / "Nokia" / "rt.yaml" + mock_update_entry.assert_called_once_with(manifest, "module-types", "Nokia/mt", mt_item.nb_record.last_updated) + mock_save_manifest.assert_called_once_with(manifest_path, manifest) + assert any( + "Will export 3 item(s) to" in call.args[0] and "1 absent, 1 differs, 1 images-missing" in call.args[0] + for call in exporter.handle.log.call_args_list + ) + assert any("Skipped (overwrite guard)" in call.args[0] for call in exporter.handle.log.call_args_list) + assert any( + "Export-diff complete: wrote 2 file(s), skipped 1" in call.args[0] + for call in exporter.handle.log.call_args_list + ) + + def test_download_image_off_host_url_sends_no_auth(self, tmp_path): + """Token must NOT be sent when the image URL resolves to a different host.""" + exporter = self._make_exporter(tmp_path) + # base_url is http://localhost:8000 (from _make_settings) + dest = tmp_path / "extra" / "images" / "off-host.png" + captured_headers = {} + + def _fake_get(url, headers=None, verify=True, timeout=30): + captured_headers.update(headers or {}) + resp = MagicMock(ok=True, status_code=200) + resp.headers = {"Content-Type": "image/png"} + resp.content = b"img" + return resp + + with patch("core.export.requests.get", side_effect=_fake_get): + exporter._download_image("https://s3.amazonaws.com/bucket/image.png", dest) + + assert "Authorization" not in captured_headers + + def test_download_image_same_host_url_sends_auth(self, tmp_path): + """Token IS sent when the image URL is on the same host as base_url.""" + exporter = self._make_exporter(tmp_path) + dest = tmp_path / "extra" / "images" / "same-host.png" + captured_headers = {} + + def _fake_get(url, headers=None, verify=True, timeout=30): + captured_headers.update(headers or {}) + resp = MagicMock(ok=True, status_code=200) + resp.headers = {"Content-Type": "image/png"} + resp.content = b"img" + return resp + + with patch("core.export.requests.get", side_effect=_fake_get): + exporter._download_image("/media/devicetype-images/image.png", dest) + + assert "Authorization" in captured_headers + assert "Token" in captured_headers["Authorization"] + + def test_write_export_items_preserves_repo_only_fields_on_differs(self, tmp_path): + """When reason='differs', repo-only top-level fields must survive the write.""" + exporter = self._make_exporter(tmp_path) + nb_serialized = {"manufacturer": "Nokia", "model": "SR-1", "u_height": 1} + repo_yaml = { + "manufacturer": "Nokia", + "model": "SR-1", + "u_height": 2, # differs → not merged (NB authoritative) + "profile": "my-custom-profile", # repo-only → must be preserved + "comments": "Internal notes", # repo-only → must be preserved + } + dt = MagicMock() + dt.last_updated = "2024-01-01T00:00:00Z" + item = ExportItem( + kind="device-type", + nb_record=dt, + repo_yaml=repo_yaml, + serialized=nb_serialized, + reason="differs", + mfr_name="Nokia", + filename="SR-1.yaml", + manifest_key="Nokia/SR-1", + ) + + written_data = {} + + def _capture_write(dest, data): + written_data.update(data) + return True + + exporter._write_yaml = _capture_write + exporter._download_type_images = MagicMock(return_value=True) + + from unittest.mock import patch as _patch + + with _patch("core.export.update_entry"), _patch("core.export.save_manifest"): + exporter._write_export_items( + [item], + {}, + tmp_path / "manifest.json", + None, + ) + + assert written_data["profile"] == "my-custom-profile" + assert written_data["comments"] == "Internal notes" + # NB authoritative field is NOT overwritten by repo value + assert written_data["u_height"] == 1 + + def test_load_repo_device_types_logs_bad_yaml_at_verbose(self, tmp_path): + """Malformed YAML must be logged at verbose level, not silently dropped.""" + exporter = self._make_exporter(tmp_path) + vdir = tmp_path / "repo" / "device-types" / "Nokia" + vdir.mkdir(parents=True) + (vdir / "bad.yaml").write_text("foo: [\n", encoding="utf-8") + + exporter._load_repo_device_types() + + verbose_calls = " ".join(str(c) for c in exporter.handle.verbose_log.call_args_list) + assert "Skipping malformed YAML" in verbose_calls or "malformed" in verbose_calls.lower() + + def test_load_repo_module_types_logs_bad_yaml_at_verbose(self, tmp_path): + """Malformed module YAML must be logged at verbose level.""" + exporter = self._make_exporter(tmp_path) + vdir = tmp_path / "repo" / "module-types" / "Nokia" + vdir.mkdir(parents=True) + (vdir / "bad.yaml").write_text("manufacturer: [\n", encoding="utf-8") + + exporter._load_repo_module_types() + + verbose_calls = " ".join(str(c) for c in exporter.handle.verbose_log.call_args_list) + assert "Skipping malformed YAML" in verbose_calls or "malformed" in verbose_calls.lower() + + def test_load_repo_rack_types_logs_bad_yaml_at_verbose(self, tmp_path): + """Malformed rack YAML must be logged at verbose level.""" + exporter = self._make_exporter(tmp_path) + vdir = tmp_path / "repo" / "rack-types" / "Nokia" + vdir.mkdir(parents=True) + (vdir / "bad.yaml").write_text("manufacturer: [\n", encoding="utf-8") + + exporter._load_repo_rack_types() + + verbose_calls = " ".join(str(c) for c in exporter.handle.verbose_log.call_args_list) + assert "Skipping malformed YAML" in verbose_calls or "malformed" in verbose_calls.lower() + + +class TestSanitizeAttachmentFilename: + """Tests for _sanitize_attachment_filename.""" + + def test_plain_name_with_known_extension_is_returned_unchanged(self): + assert _sanitize_attachment_filename("front.png", "/media/front.png", "") == "front.png" + + def test_strips_directory_components(self): + result = _sanitize_attachment_filename("../evil.png", "/media/evil.png", "") + assert "/" not in result + assert ".." not in result + assert result == "evil.png" + + def test_strips_subdirectory_prefix(self): + result = _sanitize_attachment_filename("subdir/img.jpg", "/media/img.jpg", "") + assert result == "img.jpg" + + def test_derives_extension_from_url_when_name_has_none(self): + result = _sanitize_attachment_filename("front", "/media/front.jpg", "") + assert result == "front.jpg" + + def test_derives_extension_from_content_type_when_url_has_none(self): + result = _sanitize_attachment_filename("front", "/media/front", "image/png") + assert result == "front.png" + + def test_content_type_takes_priority_over_url(self): + # content_type is non-empty → preferred over URL suffix + result = _sanitize_attachment_filename("front", "/media/front.jpg", "image/png") + assert result == "front.png" + + def test_falls_back_to_bin_when_nothing_known(self): + result = _sanitize_attachment_filename("front", "/media/front", "") + assert result == "front.bin" + + def test_empty_name_uses_attachment_prefix(self): + result = _sanitize_attachment_filename("", "/media/front.png", "") + assert result.endswith(".png") + assert len(result) > 0 + + def test_all_known_content_types_are_recognised(self): + from core.export import _CONTENT_TYPE_EXT + + for ct, ext in _CONTENT_TYPE_EXT.items(): + result = _sanitize_attachment_filename("img", "/media/img", ct) + assert result.endswith(ext), f"Expected {ext} for {ct}, got {result}" diff --git a/tests/test_graphql_client.py b/tests/test_graphql_client.py index 74c9fcf32..d2a016d7b 100644 --- a/tests/test_graphql_client.py +++ b/tests/test_graphql_client.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import requests -from core.graphql_client import DotDict +from core.graphql_client import DotDict, NetBoxGraphQLClient def _make_paged_responses(data, list_key): @@ -391,7 +391,7 @@ def log(self, msg): assert "10" in warned_msgs[0] # requested page size in warning def test_query_all_clamping_warning_emitted_only_once(self, mock_post): - """Clamping warning is emitted at most once per client instance.""" + """Clamping warning is emitted at most once per URL+page-size combination.""" # Two separate query_all calls, each seeing clamping. def make_pages(n_pages, page_size=2): @@ -863,6 +863,122 @@ def test_falls_back_to_python_filter_on_schema_error(self, mock_post): assert result == {10: {"front"}} +# ── get_module_type_image_details tests ─────────────────────────────────── + + +class TestGetModuleTypeImageDetails: + """Tests for the get_module_type_image_details() convenience method.""" + + def _make_client(self): + from core.graphql_client import NetBoxGraphQLClient + + return NetBoxGraphQLClient("http://netbox.local", "tok") + + def test_returns_mapping_by_module_type_id(self, mocker): + """Returns dict with structure {module_type_id: {attachment_name: {"att_id": id, "url": url}}}.""" + items = [ + {"id": 1, "name": "front", "object_id": 42, "image": {"url": "http://example.com/front.jpg"}}, + {"id": 2, "name": "rear", "object_id": 42, "image": {"url": "http://example.com/rear.jpg"}}, + {"id": 3, "name": "detail", "object_id": 43, "image": {"url": "http://example.com/detail.jpg"}}, + ] + client = self._make_client() + mocker.patch.object(client, "query_all", return_value=items) + + result = client.get_module_type_image_details() + + assert result[42]["front"] == {"att_id": 1, "url": "http://example.com/front.jpg"} + assert result[42]["rear"] == {"att_id": 2, "url": "http://example.com/rear.jpg"} + assert result[43]["detail"] == {"att_id": 3, "url": "http://example.com/detail.jpg"} + + def test_fallback_on_graphql_error(self, mocker): + """When filtered query fails, falls back to fetch-all + Python filter by object_type.""" + from core.graphql_client import GraphQLError + + client = self._make_client() + error = GraphQLError("Field 'filters' not found") + + fallback_items = [ + { + "id": 1, + "name": "front", + "object_id": 42, + "image": {"url": "http://example.com/front.jpg"}, + "object_type": {"app_label": "dcim", "model": "moduletype"}, + }, + { + "id": 2, + "name": "other", + "object_id": 50, + "image": {"url": "http://example.com/other.jpg"}, + "object_type": {"app_label": "circuits", "model": "provider"}, + }, + ] + mocker.patch.object(client, "query_all", side_effect=[error, fallback_items]) + + result = client.get_module_type_image_details() + + assert result[42]["front"] == {"att_id": 1, "url": "http://example.com/front.jpg"} + assert 50 not in result + + def test_skips_items_without_name(self, mocker): + """Items with None or empty name should be skipped.""" + items = [ + {"id": 1, "name": None, "object_id": 42, "image": {"url": "http://example.com/1.jpg"}}, + {"id": 2, "name": "", "object_id": 42, "image": {"url": "http://example.com/2.jpg"}}, + {"id": 3, "name": "valid", "object_id": 42, "image": {"url": "http://example.com/3.jpg"}}, + ] + client = self._make_client() + mocker.patch.object(client, "query_all", return_value=items) + + result = client.get_module_type_image_details() + + assert 42 in result + assert "valid" in result[42] + assert None not in result[42] + assert "" not in result[42] + + def test_string_object_id_coerced_to_int(self, mocker): + """String object_id values should be coerced to integers.""" + items = [ + {"id": 1, "name": "img", "object_id": "42", "image": {"url": "http://example.com/1.jpg"}}, + ] + client = self._make_client() + mocker.patch.object(client, "query_all", return_value=items) + + result = client.get_module_type_image_details() + + assert 42 in result + assert result[42]["img"]["att_id"] == 1 + + def test_string_att_id_coerced_to_int(self, mocker): + """String attachment IDs should be coerced to integers.""" + items = [{"id": "99", "name": "img", "object_id": 10, "image": {"url": "http://x.com/1.jpg"}}] + client = self._make_client() + mocker.patch.object(client, "query_all", return_value=items) + result = client.get_module_type_image_details() + assert result[10]["img"]["att_id"] == 99 + + def test_invalid_string_att_id_stored_as_none(self, mocker): + """A non-numeric attachment ID string should store att_id=None, not crash.""" + items = [{"id": "not-valid", "name": "img", "object_id": 10, "image": {"url": "http://x.com/1.jpg"}}] + client = self._make_client() + mocker.patch.object(client, "query_all", return_value=items) + result = client.get_module_type_image_details() + assert result[10]["img"]["att_id"] is None + + def test_non_numeric_string_object_id_is_skipped(self, mocker): + """Items with a non-numeric string object_id should be skipped silently.""" + items = [ + {"id": 1, "name": "front", "object_id": "not-a-number", "image": {"url": "http://x.com/1.jpg"}}, + {"id": 2, "name": "rear", "object_id": "42", "image": {"url": "http://x.com/2.jpg"}}, + ] + client = self._make_client() + mocker.patch.object(client, "query_all", return_value=items) + result = client.get_module_type_image_details() + assert 42 in result + assert all("front" not in v for v in result.values()) + + # ── get_component_templates tests ────────────────────────────────────────── @@ -1538,3 +1654,547 @@ def test_request_exception_retries_before_exhausting(self, mock_post): with patch("core.graphql_client.time.sleep"): with pytest.raises(GraphQLError, match="timed out"): client.query("{ test }", _retries=1) + + +# ── Vendor-scoped filtering tests ───────────────────────────────────────── + + +class TestVendorScopedDeviceTypes: + """Tests for vendor-scoped filtering in get_device_types().""" + + def _make_client(self): + from core.graphql_client import NetBoxGraphQLClient + + return NetBoxGraphQLClient("http://netbox.local", "tok") + + def test_single_vendor_filter(self, mock_post): + """Test filtering by a single manufacturer slug.""" + data = { + "device_type_list": [ + { + "id": "1", + "model": "Catalyst 3850", + "slug": "catalyst-3850", + "u_height": 1.0, + "part_number": "WS-C3850-24P", + "is_full_depth": True, + "subdevice_role": None, + "airflow": "front-to-rear", + "weight": 5.4, + "weight_unit": "kg", + "description": "Test device", + "comments": "", + "front_image": None, + "rear_image": None, + "manufacturer": {"id": "10", "name": "Cisco", "slug": "cisco"}, + } + ] + } + mock_post.side_effect = _make_paged_responses(data, "device_type_list") + + client = self._make_client() + by_model, by_slug = client.get_device_types(manufacturer_slugs=["cisco"]) + + assert ("cisco", "Catalyst 3850") in by_model + assert by_model[("cisco", "Catalyst 3850")].model == "Catalyst 3850" + # Verify the filter was applied via GraphQL variable (not string interpolation) + call_payload = mock_post.call_args_list[0][1]["json"] + query = call_payload["query"] + variables = call_payload["variables"] + assert "filters: {manufacturer: {slug: {exact: $manufacturer_slug}}}" in query + assert "$manufacturer_slug: String!" in query + assert variables["manufacturer_slug"] == "cisco" + + def test_multiple_vendor_filter(self, mock_post): + """Test filtering by multiple manufacturer slugs.""" + data = { + "device_type_list": [ + { + "id": "1", + "model": "Catalyst 3850", + "slug": "catalyst-3850", + "u_height": 1.0, + "part_number": "", + "is_full_depth": True, + "subdevice_role": None, + "airflow": None, + "weight": None, + "weight_unit": None, + "description": "", + "comments": "", + "front_image": None, + "rear_image": None, + "manufacturer": {"id": "10", "name": "Cisco", "slug": "cisco"}, + }, + { + "id": "2", + "model": "EX4300", + "slug": "ex4300", + "u_height": 1.0, + "part_number": "", + "is_full_depth": False, + "subdevice_role": None, + "airflow": None, + "weight": None, + "weight_unit": None, + "description": "", + "comments": "", + "front_image": None, + "rear_image": None, + "manufacturer": {"id": "20", "name": "Juniper", "slug": "juniper"}, + }, + ] + } + mock_post.side_effect = _make_paged_responses(data, "device_type_list") + + client = self._make_client() + by_model, by_slug = client.get_device_types(manufacturer_slugs=["cisco", "juniper"]) + + assert ("cisco", "Catalyst 3850") in by_model + assert ("juniper", "EX4300") in by_model + # Verify the filter was applied via GraphQL variable (not string interpolation) + call_payload = mock_post.call_args_list[0][1]["json"] + query = call_payload["query"] + variables = call_payload["variables"] + assert "filters: {manufacturer: {slug: {in_list: $manufacturer_slugs}}}" in query + assert "$manufacturer_slugs: [String!]!" in query + assert variables["manufacturer_slugs"] == ["cisco", "juniper"] + + def test_none_manufacturer_slugs_unfiltered(self, mock_post): + """Test that manufacturer_slugs=None produces unfiltered behavior.""" + data = { + "device_type_list": [ + { + "id": "1", + "model": "Test", + "slug": "test", + "u_height": 1.0, + "part_number": "", + "is_full_depth": True, + "subdevice_role": None, + "airflow": None, + "weight": None, + "weight_unit": None, + "description": "", + "comments": "", + "front_image": None, + "rear_image": None, + "manufacturer": {"id": "1", "name": "Test", "slug": "test"}, + } + ] + } + mock_post.side_effect = _make_paged_responses(data, "device_type_list") + + client = self._make_client() + by_model, by_slug = client.get_device_types(manufacturer_slugs=None) + + # Verify no filter in the query and no manufacturer variable sent + call_payload = mock_post.call_args_list[0][1]["json"] + query = call_payload["query"] + variables = call_payload["variables"] + assert "filters:" not in query + assert "manufacturer_slug" not in variables + + def test_empty_list_raises_value_error(self): + """Passing [] for manufacturer_slugs should raise ValueError immediately.""" + client = self._make_client() + with pytest.raises(ValueError, match="manufacturer_slugs must be None or a non-empty list"): + client.get_device_types(manufacturer_slugs=[]) + + +class TestVendorScopedModuleTypes: + """Tests for vendor-scoped filtering in get_module_types().""" + + def _make_client(self): + from core.graphql_client import NetBoxGraphQLClient + + return NetBoxGraphQLClient("http://netbox.local", "tok") + + def test_single_vendor_filter(self, mock_post): + """Test filtering by a single manufacturer slug.""" + data = { + "module_type_list": [ + { + "id": "1", + "model": "C9300-NM-8X", + "part_number": "C9300-NM-8X", + "airflow": None, + "description": "", + "comments": "", + "weight": None, + "weight_unit": None, + "manufacturer": {"id": "10", "name": "Cisco", "slug": "cisco"}, + } + ] + } + mock_post.side_effect = _make_paged_responses(data, "module_type_list") + + client = self._make_client() + result = client.get_module_types(manufacturer_slugs=["cisco"]) + + assert "cisco" in result + assert "C9300-NM-8X" in result["cisco"] + # Verify the filter was applied via GraphQL variable (not string interpolation) + call_payload = mock_post.call_args_list[0][1]["json"] + query = call_payload["query"] + variables = call_payload["variables"] + assert "filters: {manufacturer: {slug: {exact: $manufacturer_slug}}}" in query + assert "$manufacturer_slug: String!" in query + assert variables["manufacturer_slug"] == "cisco" + + def test_multiple_vendor_filter(self, mock_post): + """Test filtering by multiple manufacturer slugs.""" + data = { + "module_type_list": [ + { + "id": "1", + "model": "C9300-NM-8X", + "part_number": "", + "airflow": None, + "description": "", + "comments": "", + "weight": None, + "weight_unit": None, + "manufacturer": {"id": "10", "name": "Cisco", "slug": "cisco"}, + }, + { + "id": "2", + "model": "MIC-3D-20GE-SFP", + "part_number": "", + "airflow": None, + "description": "", + "comments": "", + "weight": None, + "weight_unit": None, + "manufacturer": {"id": "20", "name": "Juniper", "slug": "juniper"}, + }, + ] + } + mock_post.side_effect = _make_paged_responses(data, "module_type_list") + + client = self._make_client() + result = client.get_module_types(manufacturer_slugs=["cisco", "juniper"]) + + assert "cisco" in result + assert "juniper" in result + # Verify the filter was applied via GraphQL variable (not string interpolation) + call_payload = mock_post.call_args_list[0][1]["json"] + query = call_payload["query"] + variables = call_payload["variables"] + assert "filters: {manufacturer: {slug: {in_list: $manufacturer_slugs}}}" in query + assert "$manufacturer_slugs: [String!]!" in query + assert variables["manufacturer_slugs"] == ["cisco", "juniper"] + + def test_empty_list_raises_value_error(self): + """Passing [] for manufacturer_slugs should raise ValueError immediately.""" + client = self._make_client() + with pytest.raises(ValueError, match="manufacturer_slugs must be None or a non-empty list"): + client.get_module_types(manufacturer_slugs=[]) + + +class TestVendorScopedRackTypes: + """Tests for vendor-scoped filtering in get_rack_types().""" + + def _make_client(self): + from core.graphql_client import NetBoxGraphQLClient + + return NetBoxGraphQLClient("http://netbox.local", "tok") + + def test_single_vendor_filter(self, mock_post): + """Test filtering by a single manufacturer slug.""" + data = { + "rack_type_list": [ + { + "id": "10", + "model": "AR1300", + "slug": "apc-ar1300", + "form_factor": "4-post-cabinet", + "width": 19, + "u_height": 42, + "starting_unit": 1, + "outer_width": None, + "outer_height": None, + "outer_depth": None, + "outer_unit": None, + "mounting_depth": None, + "weight": None, + "max_weight": None, + "weight_unit": None, + "desc_units": False, + "comments": "", + "description": "", + "manufacturer": {"id": "5", "name": "APC", "slug": "apc"}, + } + ] + } + mock_post.side_effect = _make_paged_responses(data, "rack_type_list") + + client = self._make_client() + result = client.get_rack_types(manufacturer_slugs=["apc"]) + + assert "apc" in result + assert "AR1300" in result["apc"] + # Verify the filter was applied via GraphQL variable (not string interpolation) + call_payload = mock_post.call_args_list[0][1]["json"] + query = call_payload["query"] + variables = call_payload["variables"] + assert "filters: {manufacturer: {slug: {exact: $manufacturer_slug}}}" in query + assert "$manufacturer_slug: String!" in query + assert variables["manufacturer_slug"] == "apc" + + def test_empty_list_raises_value_error(self): + """Passing [] for manufacturer_slugs should raise ValueError immediately.""" + client = self._make_client() + with pytest.raises(ValueError, match="manufacturer_slugs must be None or a non-empty list"): + client.get_rack_types(manufacturer_slugs=[]) + + +class TestVendorScopedComponentTemplates: + """Tests for vendor-scoped filtering in get_component_templates().""" + + def _make_client(self): + from core.graphql_client import NetBoxGraphQLClient + + return NetBoxGraphQLClient("http://netbox.local", "tok") + + def test_unfiltered_query(self, mock_post): + """Test that manufacturer_slug=None produces unfiltered behavior.""" + data = { + "interface_template_list": [ + { + "id": "1", + "name": "Ethernet1/0/1", + "type": "1000base-t", + "device_type": {"id": "10"}, + "module_type": None, + } + ] + } + mock_post.side_effect = _make_paged_responses(data, "interface_template_list") + + client = self._make_client() + result = client.get_component_templates("interface_templates", manufacturer_slug=None) + + assert len(result) == 1 + assert result[0].name == "Ethernet1/0/1" + # Verify no filter in the query + call_args = mock_post.call_args_list[0] + query = call_args[1]["json"]["query"] + assert "filters:" not in query + + def test_vendor_scoped_two_queries(self, mock_post): + """Test that vendor-scoped query makes two separate queries (device + module).""" + device_data = { + "interface_template_list": [ + { + "id": "1", + "name": "Ethernet1/0/1", + "type": "1000base-t", + "device_type": {"id": "10"}, + "module_type": None, + } + ] + } + module_data = { + "interface_template_list": [ + { + "id": "2", + "name": "GigabitEthernet0/0", + "type": "1000base-t", + "device_type": None, + "module_type": {"id": "20"}, + } + ] + } + # Each query needs a data page followed by an empty page + device_responses = _make_paged_responses(device_data, "interface_template_list") + module_responses = _make_paged_responses(module_data, "interface_template_list") + mock_post.side_effect = device_responses + module_responses + + client = self._make_client() + result = client.get_component_templates("interface_templates", manufacturer_slug="cisco") + + assert len(result) == 2 + assert result[0].name == "Ethernet1/0/1" + assert result[1].name == "GigabitEthernet0/0" + + # Verify both filters were applied via GraphQL variables (not string interpolation) + calls = mock_post.call_args_list + # calls[0]: device_type filter query – data page + device_payload = calls[0][1]["json"] + device_query = device_payload["query"] + device_vars = device_payload["variables"] + # calls[1]: device_type filter query – empty terminator (pagination ends) + # calls[2]: module_type filter query – data page + module_payload = calls[2][1]["json"] + module_query = module_payload["query"] + module_vars = module_payload["variables"] + + assert "filters: {device_type: {manufacturer: {slug: {exact: $manufacturer_slug}}}}" in device_query + assert device_vars["manufacturer_slug"] == "cisco" + assert "filters: {module_type: {manufacturer: {slug: {exact: $manufacturer_slug}}}}" in module_query + assert module_vars["manufacturer_slug"] == "cisco" + + def test_vendor_scoped_device_bay_single_query(self, mock_post): + """Test that device_bay_templates makes only one query (no module_type field).""" + device_data = { + "device_bay_template_list": [ + { + "id": "1", + "name": "Device Bay 1", + "label": "Bay 1", + "description": "", + "device_type": {"id": "10"}, + } + ] + } + mock_post.side_effect = _make_paged_responses(device_data, "device_bay_template_list") + + client = self._make_client() + result = client.get_component_templates("device_bay_templates", manufacturer_slug="cisco") + + assert len(result) == 1 + assert result[0].name == "Device Bay 1" + + # Verify only device_type filter was applied (no module_type query) + assert mock_post.call_count == 2 # Just data + empty page + device_payload = mock_post.call_args_list[0][1]["json"] + device_query = device_payload["query"] + device_vars = device_payload["variables"] + assert "filters: {device_type: {manufacturer: {slug: {exact: $manufacturer_slug}}}}" in device_query + assert device_vars["manufacturer_slug"] == "cisco" + + def test_empty_string_raises_value_error(self): + """Passing empty string for manufacturer_slug should raise ValueError immediately.""" + client = self._make_client() + with pytest.raises(ValueError, match="manufacturer_slug must be None or a non-empty string"): + client.get_component_templates("interface_templates", manufacturer_slug="") + + def test_vendor_scoped_with_on_page_callback(self, mock_post): + """Test that on_page callback is called for both queries.""" + device_data = { + "interface_template_list": [ + {"id": "1", "name": "eth0", "type": "1000base-t", "device_type": {"id": "10"}, "module_type": None}, + ] + } + module_data = { + "interface_template_list": [ + {"id": "2", "name": "eth1", "type": "1000base-t", "device_type": None, "module_type": {"id": "20"}}, + ] + } + + mock_post.side_effect = _make_paged_responses(device_data, "interface_template_list") + _make_paged_responses( + module_data, "interface_template_list" + ) + + on_page_calls = [] + + def on_page_callback(count): + on_page_calls.append(count) + + client = self._make_client() + result = client.get_component_templates( + "interface_templates", manufacturer_slug="cisco", on_page=on_page_callback + ) + + assert len(result) == 2 + # Callback should be called for both queries + assert len(on_page_calls) >= 2 + + +class TestLastUpdatedInQueries: + """Verify last_updated is returned by the three type-fetching methods.""" + + def _mock_graphql(self, mocker, items): + client = NetBoxGraphQLClient("http://nb/", "token") + mocker.patch.object(client, "query_all", return_value=items) + return client + + def test_get_device_types_includes_last_updated(self, mocker): + items = [ + { + "id": "1", + "model": "M", + "slug": "m", + "u_height": 1, + "part_number": None, + "is_full_depth": True, + "subdevice_role": None, + "airflow": None, + "weight": None, + "weight_unit": None, + "description": "", + "comments": "", + "front_image": None, + "rear_image": None, + "last_updated": "2024-01-15T10:00:00Z", + "manufacturer": {"id": "1", "name": "Acme", "slug": "acme"}, + } + ] + client = self._mock_graphql(mocker, items) + by_model, by_slug = client.get_device_types() + record = by_model[("acme", "M")] + assert record.last_updated == "2024-01-15T10:00:00Z" + + def test_get_module_types_includes_last_updated(self, mocker): + items = [ + { + "id": "2", + "model": "MM", + "part_number": None, + "airflow": None, + "description": "", + "comments": "", + "weight": None, + "weight_unit": None, + "last_updated": "2024-02-01T00:00:00Z", + "manufacturer": {"id": "1", "name": "Acme", "slug": "acme"}, + } + ] + client = self._mock_graphql(mocker, items) + result = client.get_module_types() + assert result["acme"]["MM"].last_updated == "2024-02-01T00:00:00Z" + + def test_get_rack_types_includes_last_updated(self, mocker): + items = [ + { + "id": "3", + "model": "RR", + "slug": "acme-rr", + "form_factor": "4-post-cabinet", + "width": 19, + "u_height": 42, + "starting_unit": 1, + "outer_width": None, + "outer_height": None, + "outer_depth": None, + "outer_unit": None, + "mounting_depth": None, + "weight": None, + "max_weight": None, + "weight_unit": None, + "desc_units": False, + "comments": "", + "description": "", + "last_updated": "2024-03-10T12:00:00Z", + "manufacturer": {"id": "1", "name": "Acme", "slug": "acme"}, + } + ] + client = self._mock_graphql(mocker, items) + result = client.get_rack_types() + assert result["acme"]["RR"].last_updated == "2024-03-10T12:00:00Z" + + +class TestErrorHierarchy: + """GraphQL exception classes must form the correct inheritance tree.""" + + def test_count_mismatch_error_is_graphql_error(self): + from core.graphql_client import GraphQLCountMismatchError, GraphQLError + + assert issubclass(GraphQLCountMismatchError, GraphQLError) + + def test_count_mismatch_error_is_caught_by_graphql_error_handler(self): + from core.graphql_client import GraphQLCountMismatchError, GraphQLError + + with pytest.raises(GraphQLError): + raise GraphQLCountMismatchError("page cap exceeded") diff --git a/tests/test_nb_dt_import.py b/tests/test_nb_dt_import.py index 153f1f8f8..ceea049d4 100644 --- a/tests/test_nb_dt_import.py +++ b/tests/test_nb_dt_import.py @@ -28,22 +28,6 @@ def nb_dt_import(): sys.modules.pop("nb_dt_import", None) -def test_filter_vendors_for_parsed_types_uses_parsed_subset(nb_dt_import): - - discovered_vendors = [ - {"name": "Cisco", "slug": "cisco"}, - {"name": "Juniper", "slug": "juniper"}, - ] - parsed_types = [ - {"manufacturer": {"slug": "juniper"}, "model": "EX4300"}, - ] - - vendors, selected_slugs = nb_dt_import.filter_vendors_for_parsed_types(discovered_vendors, parsed_types) - - assert vendors == [{"name": "Juniper", "slug": "juniper"}] - assert selected_slugs == {"juniper"} - - def test_log_run_mode_reports_default_non_update_behavior(nb_dt_import): handle = MagicMock() args = SimpleNamespace(only_new=False, update=False, remove_components=False) @@ -288,6 +272,7 @@ def test_items_per_second_column_uses_elapsed_fallback_when_finished_speed_missi # --------------------------------------------------------------------------- _NB_DT_IMPORT_PATH = str(Path(__file__).resolve().parents[1] / "nb-dt-import.py") +_PROJECT_ROOT = Path(__file__).resolve().parents[1] def _make_mock_repo(device_types=None): @@ -297,7 +282,9 @@ def _make_mock_repo(device_types=None): mock_repo.get_devices_path.return_value = "/tmp/devices" mock_repo.get_modules_path.return_value = "/tmp/modules" mock_repo.get_racks_path.return_value = "/tmp/rack-types" + mock_repo.discover_vendors.return_value = [] mock_repo.parse_files.return_value = device_types if device_types is not None else [] + mock_repo.resolve_slug_files.return_value = None # no pickle available by default return mock_repo @@ -562,7 +549,8 @@ def test_only_new_creates_new_device_types(self, nb_dt_import): patch("nb_dt_import.NetBox") as MockNetBox, ): mock_repo = _make_mock_repo(device_types=dt) - mock_repo.get_devices.return_value = (["file.yaml"], [{"slug": "cisco"}]) + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.return_value = (["file.yaml"], []) MockRepo.return_value = mock_repo MockNetBox.return_value = _make_mock_netbox() @@ -571,7 +559,7 @@ def test_only_new_creates_new_device_types(self, nb_dt_import): MockNetBox.return_value.create_device_types.assert_called_once() def test_default_mode_no_device_types(self, nb_dt_import): - """Default mode with empty file list logs 'No device types matched'.""" + """Default mode with no discovered vendors completes without error.""" with ( patch.object(sys, "argv", ["nb-dt-import.py"]), patch("nb_dt_import.DTLRepo") as MockRepo, @@ -584,9 +572,6 @@ def test_default_mode_no_device_types(self, nb_dt_import): nb_dt_import.main() - # cache_preload_job is truthy here; stop_component_preload should be called - MockNetBox.return_value.device_types.stop_component_preload.assert_called() - def test_default_mode_with_new_device_types(self, nb_dt_import): """Default mode creates new device types when change report lists them.""" dt = [{"manufacturer": {"slug": "cisco"}, "model": "A", "slug": "a"}] @@ -600,7 +585,8 @@ def test_default_mode_with_new_device_types(self, nb_dt_import): patch("nb_dt_import.ChangeDetector") as MockDetector, ): mock_repo = _make_mock_repo(device_types=dt) - mock_repo.get_devices.return_value = (["file.yaml"], [{"slug": "cisco"}]) + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.return_value = (["file.yaml"], []) MockRepo.return_value = mock_repo MockNetBox.return_value = _make_mock_netbox() MockDetector.return_value.detect_changes.return_value = report @@ -636,7 +622,8 @@ def test_update_mode_with_changes(self, nb_dt_import): patch("nb_dt_import.ChangeDetector") as MockDetector, ): mock_repo = _make_mock_repo(device_types=dt) - mock_repo.get_devices.return_value = (["file.yaml"], [{"slug": "cisco"}]) + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.return_value = (["file.yaml"], []) MockRepo.return_value = mock_repo MockNetBox.return_value = _make_mock_netbox() MockDetector.return_value.detect_changes.return_value = report @@ -659,7 +646,8 @@ def test_update_with_remove_components(self, nb_dt_import): patch("nb_dt_import.ChangeDetector") as MockDetector, ): mock_repo = _make_mock_repo(device_types=dt) - mock_repo.get_devices.return_value = (["file.yaml"], [{"slug": "cisco"}]) + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.return_value = (["file.yaml"], []) MockRepo.return_value = mock_repo MockNetBox.return_value = _make_mock_netbox() MockDetector.return_value.detect_changes.return_value = report @@ -676,6 +664,44 @@ def test_remove_components_without_update_exits_with_error(self, nb_dt_import): nb_dt_import.main() assert exc_info.value.code == 2 + def test_remove_unmanaged_types_without_remove_components_exits_with_error(self, nb_dt_import): + """--remove-unmanaged-types without --remove-components triggers parser.error (SystemExit 2).""" + with patch.object(sys, "argv", ["nb-dt-import.py", "--update", "--remove-unmanaged-types"]): + with pytest.raises(SystemExit) as exc_info: + nb_dt_import.main() + assert exc_info.value.code == 2 + + def test_update_with_remove_unmanaged_types_sets_attribute_and_detector_kwarg(self, nb_dt_import): + """--update --remove-components --remove-unmanaged-types propagates to NetBox and ChangeDetector.""" + dt = [{"manufacturer": {"slug": "cisco"}, "model": "A", "slug": "a"}] + change_entry = SimpleNamespace(manufacturer_slug="cisco", model="A", slug="a") + report = SimpleNamespace(new_device_types=[change_entry], modified_device_types=[]) + + with ( + patch.object( + sys, + "argv", + ["nb-dt-import.py", "--update", "--remove-components", "--remove-unmanaged-types"], + ), + patch("nb_dt_import.DTLRepo") as MockRepo, + patch("nb_dt_import.NetBox") as MockNetBox, + patch("nb_dt_import.ChangeDetector") as MockDetector, + ): + mock_repo = _make_mock_repo(device_types=dt) + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.return_value = (["file.yaml"], []) + MockRepo.return_value = mock_repo + mock_nb = _make_mock_netbox() + MockNetBox.return_value = mock_nb + MockDetector.return_value.detect_changes.return_value = report + + nb_dt_import.main() + + assert mock_nb.remove_unmanaged_types is True + # ChangeDetector instantiated with remove_unmanaged_types=True + _, detector_kwargs = MockDetector.call_args + assert detector_kwargs.get("remove_unmanaged_types") is True + def test_force_resolve_conflicts_without_update_exits_with_error(self, nb_dt_import): """--force-resolve-conflicts without --update triggers parser.error (SystemExit 2).""" with patch.object(sys, "argv", ["nb-dt-import.py", "--force-resolve-conflicts"]): @@ -696,7 +722,8 @@ def test_update_with_force_resolve_conflicts(self, nb_dt_import): patch("nb_dt_import.ChangeDetector") as MockDetector, ): mock_repo = _make_mock_repo(device_types=dt) - mock_repo.get_devices.return_value = (["file.yaml"], [{"slug": "cisco"}]) + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.return_value = (["file.yaml"], []) MockRepo.return_value = mock_repo mock_nb = _make_mock_netbox() MockNetBox.return_value = mock_nb @@ -706,6 +733,30 @@ def test_update_with_force_resolve_conflicts(self, nb_dt_import): assert mock_nb.force_resolve_conflicts is True + def test_verify_images_sets_attribute(self, nb_dt_import): + """--verify-images propagates to netbox.verify_images = True.""" + dt = [{"manufacturer": {"slug": "cisco"}, "model": "A", "slug": "a"}] + change_entry = SimpleNamespace(manufacturer_slug="cisco", model="A", slug="a") + report = SimpleNamespace(new_device_types=[change_entry], modified_device_types=[]) + + with ( + patch.object(sys, "argv", ["nb-dt-import.py", "--update", "--verify-images"]), + patch("nb_dt_import.DTLRepo") as MockRepo, + patch("nb_dt_import.NetBox") as MockNetBox, + patch("nb_dt_import.ChangeDetector") as MockDetector, + ): + mock_repo = _make_mock_repo(device_types=dt) + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.return_value = (["file.yaml"], []) + MockRepo.return_value = mock_repo + mock_nb = _make_mock_netbox() + MockNetBox.return_value = mock_nb + MockDetector.return_value.detect_changes.return_value = report + + nb_dt_import.main() + + assert mock_nb.verify_images is True + def test_missing_env_var_triggers_system_exit(self, nb_dt_import): """A missing mandatory env var calls handle.exception which exits.""" with ( @@ -729,27 +780,33 @@ def test_vendors_and_slugs_flags_log_lines(self, nb_dt_import): patch("nb_dt_import.NetBox") as MockNetBox, patch("nb_dt_import.ChangeDetector") as MockDetector, ): - MockRepo.return_value = _make_mock_repo() + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + MockRepo.return_value = mock_repo MockNetBox.return_value = _make_mock_netbox() MockDetector.return_value.detect_changes.return_value = _empty_change_report() nb_dt_import.main() # should not raise - def test_modules_future_no_module_files_with_slugs(self, nb_dt_import): - """modules=True + --slugs triggers vendor filter path; empty bg result sets module_types=[].""" + def test_unknown_vendors_exits_nonzero(self, nb_dt_import): + """--vendors with no matching slug exits with code 1 instead of silently doing nothing.""" with ( - patch.object(sys, "argv", ["nb-dt-import.py", "--only-new", "--slugs", "my-slug"]), + patch.object( + sys, + "argv", + ["nb-dt-import.py", "--vendors", "nonexistent-vendor"], + ), patch("nb_dt_import.DTLRepo") as MockRepo, - patch("nb_dt_import.NetBox") as MockNetBox, + patch("nb_dt_import.NetBox"), ): - mock_nb = _make_mock_netbox(modules=True) - MockNetBox.return_value = mock_nb - MockNetBox.filter_new_module_types.return_value = [] - MockRepo.return_value = _make_mock_repo() - - nb_dt_import.main() + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [{"name": "Nokia", "slug": "nokia"}] + MockRepo.return_value = mock_repo + with pytest.raises(SystemExit) as exc_info: + nb_dt_import.main() + assert exc_info.value.code == 1 - def test_modules_future_with_module_types_to_process(self, nb_dt_import): + def test_modules_with_types_to_process(self, nb_dt_import): """modules=True + non-empty filter_actionable_module_types calls create_module_types.""" module_type = {"manufacturer": {"slug": "cisco"}, "model": "CM1", "slug": "cm1"} @@ -765,13 +822,10 @@ def test_modules_future_with_module_types_to_process(self, nb_dt_import): MockNetBox.filter_new_module_types.return_value = [] mock_repo = _make_mock_repo() - - def _get_devices_se(path, vendors): - if path == "/tmp/modules": - return (["/module.yaml"], []) - return ([], []) - - mock_repo.get_devices.side_effect = _get_devices_se + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.return_value = ([], []) + # parse_files returns the module type for any call + mock_repo.parse_files.return_value = [module_type] MockRepo.return_value = mock_repo MockDetector.return_value.detect_changes.return_value = _empty_change_report() @@ -796,59 +850,6 @@ def test_modules_update_mode_logs_change_detection_section(self, nb_dt_import): nb_dt_import.main() # should not raise - def test_modules_no_future_else_branch_no_module_files(self, nb_dt_import): - """Else branch (no future): modules path called inline; empty result → module_types=[].""" - # Make submit() return None so _module_parse_future stays None, - # forcing the else branch in the second `if netbox.modules:` block. - mock_executor = MagicMock() - mock_executor.submit.return_value = None - - with ( - patch.object(sys, "argv", ["nb-dt-import.py", "--only-new"]), - patch("nb_dt_import.DTLRepo") as MockRepo, - patch("nb_dt_import.NetBox") as MockNetBox, - patch( - "nb_dt_import.concurrent.futures.ThreadPoolExecutor", - return_value=mock_executor, - ), - ): - mock_nb = _make_mock_netbox(modules=True) - MockNetBox.return_value = mock_nb - MockNetBox.filter_new_module_types.return_value = [] - MockRepo.return_value = _make_mock_repo() - - nb_dt_import.main() - - def test_modules_no_future_else_branch_with_module_files_and_slugs(self, nb_dt_import): - """Else branch: module files present → parse_files called; --slugs triggers vendor filter.""" - mock_executor = MagicMock() - mock_executor.submit.return_value = None - - with ( - patch.object(sys, "argv", ["nb-dt-import.py", "--only-new", "--slugs", "my-slug"]), - patch("nb_dt_import.DTLRepo") as MockRepo, - patch("nb_dt_import.NetBox") as MockNetBox, - patch( - "nb_dt_import.concurrent.futures.ThreadPoolExecutor", - return_value=mock_executor, - ), - ): - mock_nb = _make_mock_netbox(modules=True) - MockNetBox.return_value = mock_nb - MockNetBox.filter_new_module_types.return_value = [] - - mock_repo = _make_mock_repo() - - def _get_devices_se(path, vendors): - if path == "/tmp/modules": - return (["/module.yaml"], []) - return ([], []) - - mock_repo.get_devices.side_effect = _get_devices_se - MockRepo.return_value = mock_repo - - nb_dt_import.main() - def test_settings_netbox_features_modules_logs_module_count(self, nb_dt_import): """When netbox.modules is True, module_added/updated counters are logged.""" with ( @@ -895,34 +896,6 @@ def test_progress_panel_tty_sets_console_and_pumps_preload(self, nb_dt_import): # handle.set_console(progress.console) must have been called mock_handle.set_console.assert_any_call(mock_prog.console) - def test_future_cancel_and_executor_shutdown_in_finally(self, nb_dt_import): - """An exception during future.result() triggers cancel() and shutdown() in finally.""" - mock_future = MagicMock() - mock_future.done.return_value = False - mock_future.result.side_effect = RuntimeError("bg thread crash") - - mock_executor = MagicMock() - mock_executor.submit.return_value = mock_future - - with ( - patch.object(sys, "argv", ["nb-dt-import.py", "--only-new"]), - patch("nb_dt_import.DTLRepo") as MockRepo, - patch("nb_dt_import.NetBox") as MockNetBox, - patch( - "nb_dt_import.concurrent.futures.ThreadPoolExecutor", - return_value=mock_executor, - ), - ): - mock_nb = _make_mock_netbox(modules=True) - MockNetBox.return_value = mock_nb - MockRepo.return_value = _make_mock_repo() - - with pytest.raises(RuntimeError): - nb_dt_import.main() - - mock_future.cancel.assert_called_once() - mock_executor.shutdown.assert_called() - # --------------------------------------------------------------------------- # _process_rack_types @@ -932,110 +905,69 @@ def test_future_cancel_and_executor_shutdown_in_finally(self, nb_dt_import): class TestProcessRackTypes: """Tests for the _process_rack_types() helper function.""" - def _make_args(self, vendors=None, slugs=None, only_new=False): - return SimpleNamespace(vendors=vendors, slugs=slugs, only_new=only_new) + def _make_args(self, only_new=False): + return SimpleNamespace(only_new=only_new) def test_rack_types_disabled_logs_warning_and_returns(self, nb_dt_import): - """netbox.rack_types=False: warning logged, no further processing.""" + """netbox.rack_types=False with actual rack types: warning logged, no further processing.""" handle = MagicMock() netbox = MagicMock() netbox.rack_types = False - dtl_repo = MagicMock() - nb_dt_import._process_rack_types(self._make_args(), netbox, dtl_repo, handle, None, set()) + rack_type = {"manufacturer": {"slug": "apc"}, "model": "AR1300", "slug": "apc-ar1300"} + nb_dt_import._process_rack_types(self._make_args(), netbox, handle, None, [rack_type]) handle.log.assert_called_once() assert "4.1" in handle.log.call_args[0][0] - dtl_repo.get_racks_path.assert_not_called() + netbox.get_existing_rack_types.assert_not_called() - def test_rack_types_dir_not_exist_verbose_log_and_returns(self, nb_dt_import, tmp_path): - """rack_types=True but racks_path is not a directory: verbose_log + return.""" + def test_empty_rack_types_returns_early(self, nb_dt_import): + """rack_types=[]: returns immediately without any logging or API calls.""" handle = MagicMock() netbox = MagicMock() netbox.rack_types = True - dtl_repo = MagicMock() - dtl_repo.get_racks_path.return_value = str(tmp_path / "nonexistent") - nb_dt_import._process_rack_types(self._make_args(), netbox, dtl_repo, handle, None, set()) + nb_dt_import._process_rack_types(self._make_args(), netbox, handle, None, []) - handle.verbose_log.assert_called() - assert "No rack-types directory" in handle.verbose_log.call_args[0][0] - dtl_repo.get_devices.assert_not_called() + handle.log.assert_not_called() + handle.verbose_log.assert_not_called() + netbox.get_existing_rack_types.assert_not_called() + netbox.create_rack_types.assert_not_called() - def test_no_rack_files_verbose_log_and_returns(self, nb_dt_import, tmp_path): - """racks_path exists but no files discovered: verbose_log + return.""" - handle = MagicMock() - netbox = MagicMock() - netbox.rack_types = True - dtl_repo = MagicMock() - racks_dir = tmp_path / "rack-types" - racks_dir.mkdir() - dtl_repo.get_racks_path.return_value = str(racks_dir) - dtl_repo.get_devices.return_value = ([], []) - - nb_dt_import._process_rack_types(self._make_args(), netbox, dtl_repo, handle, None, set()) - - handle.verbose_log.assert_called() - assert "No rack-type files" in handle.verbose_log.call_args[0][0] - dtl_repo.parse_files.assert_not_called() - - def test_full_flow_calls_create_rack_types(self, nb_dt_import, tmp_path): - """Full flow: files found, parse_files called, create_rack_types called.""" + def test_full_flow_calls_create_rack_types(self, nb_dt_import): + """Full flow: pre-parsed rack_types provided, create_rack_types called.""" handle = MagicMock() netbox = MagicMock() netbox.rack_types = True netbox.get_existing_rack_types.return_value = {} - racks_dir = tmp_path / "rack-types" - racks_dir.mkdir() - dtl_repo = MagicMock() - dtl_repo.get_racks_path.return_value = str(racks_dir) - dtl_repo.get_devices.return_value = ( - [str(racks_dir / "apc-ar1300.yaml")], - [{"name": "APC", "slug": "apc"}], - ) rack_type = { "manufacturer": {"slug": "apc"}, "model": "AR1300", "slug": "apc-ar1300", } - dtl_repo.parse_files.return_value = [rack_type] - nb_dt_import._process_rack_types(self._make_args(), netbox, dtl_repo, handle, None, set()) + nb_dt_import._process_rack_types(self._make_args(), netbox, handle, None, [rack_type]) - dtl_repo.parse_files.assert_called_once() netbox.create_rack_types.assert_called_once() - def test_vendor_filter_from_selected_vendor_slugs_when_slugs_set(self, nb_dt_import, tmp_path): - """When args.slugs set and no args.vendors, rack_vendor_filter uses selected_vendor_slugs.""" + def test_existing_rack_type_shows_as_existing(self, nb_dt_import): + """A rack type already in NetBox is counted as existing, not new.""" handle = MagicMock() netbox = MagicMock() netbox.rack_types = True - netbox.get_existing_rack_types.return_value = {} - - racks_dir = tmp_path / "rack-types" - racks_dir.mkdir() - dtl_repo = MagicMock() - dtl_repo.get_racks_path.return_value = str(racks_dir) - - captured_vendor_filter = {} + netbox.get_existing_rack_types.return_value = {"apc": {"AR1300": object()}} - def _get_devices_se(path, vendors): - captured_vendor_filter["vendors"] = vendors - return ([], []) - - dtl_repo.get_devices.side_effect = _get_devices_se + rack_type = { + "manufacturer": {"slug": "apc"}, + "model": "AR1300", + "slug": "apc-ar1300", + } - nb_dt_import._process_rack_types( - self._make_args(vendors=None, slugs=["apc-ar1300"]), - netbox, - dtl_repo, - handle, - None, - {"apc"}, - ) + nb_dt_import._process_rack_types(self._make_args(), netbox, handle, None, [rack_type]) - assert captured_vendor_filter["vendors"] == ["apc"] + log_calls = [call.args[0] for call in handle.verbose_log.call_args_list] + assert any("No new rack types (1 unchanged)" in msg for msg in log_calls) # --------------------------------------------------------------------------- @@ -1069,6 +1001,192 @@ def test_entry_point_keyboard_interrupt_exits_130(self): assert exc_info.value.code == 130 + def test_entry_point_connection_error_exits_1(self): + """requests.ConnectionError mid-run becomes SystemExit(1) with informative message.""" + import requests as _requests + + with patch("core.repo.DTLRepo") as MockDTLRepo, patch("core.netbox_api.NetBox"): + MockDTLRepo.side_effect = _requests.exceptions.ConnectionError("Remote end closed connection") + + with patch.object(sys, "argv", ["nb-dt-import.py", "--only-new"]): + with pytest.raises(SystemExit) as exc_info: + runpy.run_path(_NB_DT_IMPORT_PATH, run_name="__main__") + + assert exc_info.value.code == 1 + + def test_entry_point_connection_error_message_references_netbox(self, capsys): + """ConnectionError prints a human-friendly message (not a raw traceback).""" + import requests as _requests + + with patch("core.repo.DTLRepo") as MockDTLRepo, patch("core.netbox_api.NetBox"): + MockDTLRepo.side_effect = _requests.exceptions.ConnectionError("Remote end closed") + + with patch.object(sys, "argv", ["nb-dt-import.py", "--only-new"]): + with pytest.raises(SystemExit): + runpy.run_path(_NB_DT_IMPORT_PATH, run_name="__main__") + + captured = capsys.readouterr() + assert "connection" in captured.err.lower() or "netbox" in captured.err.lower() + assert "Traceback" not in captured.err + + +# --------------------------------------------------------------------------- +# Per-vendor loop behaviour +# --------------------------------------------------------------------------- + + +class TestPerVendorLoop: + """Tests for the per-vendor iteration logic in main().""" + + def _run_main(self, argv, mock_repo, mock_nb, nb_dt_import_module, mock_detector=None): + with ( + patch.object(sys, "argv", argv), + patch("nb_dt_import.DTLRepo") as MockRepo, + patch("nb_dt_import.NetBox") as MockNetBox, + patch("nb_dt_import.ChangeDetector") as MockDetector, + ): + MockRepo.return_value = mock_repo + MockNetBox.return_value = mock_nb + if mock_detector is not None: + MockDetector.return_value.detect_changes.return_value = mock_detector + else: + MockDetector.return_value.detect_changes.return_value = _empty_change_report() + nb_dt_import_module.main() + + def test_vendor_flag_only_processes_matching_vendor(self, nb_dt_import): + """--vendors cisco: load_vendor called only for cisco, not for juniper.""" + cisco_dt = {"manufacturer": {"slug": "cisco"}, "model": "C1", "slug": "cisco-c1"} + juniper_dt = {"manufacturer": {"slug": "juniper"}, "model": "J1", "slug": "juniper-j1"} + + mock_nb = _make_mock_netbox() + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [ + {"name": "Cisco", "slug": "cisco"}, + {"name": "Juniper", "slug": "juniper"}, + ] + + def _get_devices_se(path, vendors): + if vendors == ["cisco"]: + return (["cisco.yaml"], []) + if vendors == ["juniper"]: + return (["juniper.yaml"], []) + return ([], []) + + def _parse_files_se(files, slugs=None, progress=None): + if "cisco.yaml" in files: + return [cisco_dt] + if "juniper.yaml" in files: + return [juniper_dt] + return [] + + mock_repo.get_devices.side_effect = _get_devices_se + mock_repo.parse_files.side_effect = _parse_files_se + + self._run_main(["nb-dt-import.py", "--vendors", "cisco"], mock_repo, mock_nb, nb_dt_import) + + slugs_loaded = [call.args[0] for call in mock_nb.load_vendor.call_args_list] + assert "cisco" in slugs_loaded + assert "juniper" not in slugs_loaded + + def test_no_vendor_flag_processes_all_vendors(self, nb_dt_import): + """Without --vendors: load_vendor called for every discovered vendor.""" + cisco_dt = {"manufacturer": {"slug": "cisco"}, "model": "C1", "slug": "cisco-c1"} + arista_dt = {"manufacturer": {"slug": "arista"}, "model": "A1", "slug": "arista-a1"} + + mock_nb = _make_mock_netbox() + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [ + {"name": "Cisco", "slug": "cisco"}, + {"name": "Arista", "slug": "arista"}, + ] + + def _get_devices_se(path, vendors): + if vendors == ["cisco"]: + return (["cisco.yaml"], []) + if vendors == ["arista"]: + return (["arista.yaml"], []) + return ([], []) + + def _parse_files_se(files, slugs=None, progress=None): + if "cisco.yaml" in files: + return [cisco_dt] + if "arista.yaml" in files: + return [arista_dt] + return [] + + mock_repo.get_devices.side_effect = _get_devices_se + mock_repo.parse_files.side_effect = _parse_files_se + + self._run_main(["nb-dt-import.py"], mock_repo, mock_nb, nb_dt_import) + + slugs_loaded = [call.args[0] for call in mock_nb.load_vendor.call_args_list] + assert "cisco" in slugs_loaded + assert "arista" in slugs_loaded + + def test_slug_filter_skips_vendor_with_no_matching_types(self, nb_dt_import): + """--slug other-slug: vendor whose parsed files don't match the slug is skipped.""" + mock_nb = _make_mock_netbox() + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [{"name": "APC", "slug": "apc"}] + mock_repo.get_devices.return_value = (["file.yaml"], []) + # parse_files returns empty regardless (slug filter stripped all matches) + mock_repo.parse_files.return_value = [] + + self._run_main(["nb-dt-import.py", "--slugs", "other-slug"], mock_repo, mock_nb, nb_dt_import) + + mock_nb.load_vendor.assert_not_called() + + def test_vendor_with_matching_slug_is_processed(self, nb_dt_import): + """Vendor whose slug matches parsed files does call load_vendor.""" + mock_nb = _make_mock_netbox() + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + dt = {"manufacturer": {"slug": "cisco"}, "model": "A", "slug": "cisco-a"} + mock_repo.get_devices.return_value = (["file.yaml"], []) + mock_repo.parse_files.return_value = [dt] + + self._run_main(["nb-dt-import.py"], mock_repo, mock_nb, nb_dt_import) + + mock_nb.load_vendor.assert_called_with("cisco") + + def test_module_type_only_vendor_uses_scoped_preload(self, nb_dt_import): + """Vendor with only module types (no device types) must use scoped preload. + + Regression: the preload guard used to check ``parsed_device_types`` only, + causing module-type-only vendors to fall back to the unscoped global preload. + """ + mt = {"manufacturer": {"slug": "acbel"}, "model": "M1", "slug": "acbel-m1"} + + mock_nb = _make_mock_netbox(modules=True) + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [{"name": "Acbel", "slug": "acbel"}] + + def _get_devices_se(path, vendors): + # Return module file only when querying the modules path + if "module" in path: + return (["module.yaml"], []) + return ([], []) + + def _parse_files_se(files, slugs=None, progress=None): + if "module.yaml" in files: + return [mt] + return [] + + mock_repo.get_devices.side_effect = _get_devices_se + mock_repo.parse_files.side_effect = _parse_files_se + + self._run_main(["nb-dt-import.py"], mock_repo, mock_nb, nb_dt_import) + + mock_nb.device_types.start_component_preload.assert_called_once() + call_kwargs = mock_nb.device_types.start_component_preload.call_args + assert call_kwargs.kwargs.get("manufacturer_slug") == "acbel" + + # The unscoped preload_all_components (no manufacturer_slug) must NOT be called + for call in mock_nb.device_types.preload_all_components.call_args_list: + assert call.kwargs.get("manufacturer_slug") is not None, ( + "preload_all_components called without manufacturer_slug (global fetch triggered)" + ) + # --------------------------------------------------------------------------- # _process_module_types hints and counters (lines 554-559, 572, 574-575) @@ -1115,10 +1233,9 @@ def test_pending_removal_counters_and_hints(self, nb_dt_import): nb_dt_import._process_module_types( self._make_args(only_new=False, update=False, remove_components=False), mock_nb, - mock_repo, handle, None, - set(), + [module_to_process], ) logged = [call.args[0] for call in handle.log.call_args_list] @@ -1257,3 +1374,491 @@ def test_netbox_request_error_prints_message_and_exits_1(self, capsys): assert exc_info.value.code == 1 assert "NetBox REST API request failed" in capsys.readouterr().err + + +class TestExportDiffFlags: + """Test --export-diff CLI flag parsing and mutual exclusion.""" + + def test_export_diff_flag_in_help(self): + import subprocess + + result = subprocess.run( + ["uv", "run", "--native-tls", "nb-dt-import.py", "--help"], + capture_output=True, + text=True, + cwd=_PROJECT_ROOT, + ) + assert "--export-diff" in result.stdout + assert "--export-diff-dir" in result.stdout + assert "--force-export-overwrite" in result.stdout + + def test_export_diff_mutually_exclusive_with_update(self): + import subprocess + + result = subprocess.run( + ["uv", "run", "--native-tls", "nb-dt-import.py", "--export-diff", "--update"], + capture_output=True, + text=True, + cwd=_PROJECT_ROOT, + ) + assert result.returncode == 2 + assert "--export-diff" in result.stderr + + def test_export_diff_mutually_exclusive_with_only_new(self): + import subprocess + + result = subprocess.run( + ["uv", "run", "--native-tls", "nb-dt-import.py", "--export-diff", "--only-new"], + capture_output=True, + text=True, + cwd=_PROJECT_ROOT, + ) + assert result.returncode == 2 + assert "--export-diff" in result.stderr + + def test_export_diff_mutually_exclusive_with_remove_components(self): + import subprocess + + result = subprocess.run( + ["uv", "run", "--native-tls", "nb-dt-import.py", "--export-diff", "--remove-components"], + capture_output=True, + text=True, + cwd=_PROJECT_ROOT, + ) + assert result.returncode == 2 + assert "--export-diff" in result.stderr + + +class TestDirectHelpers: + """Tests for direct helper functions and custom progress columns.""" + + def test_no_pulse_bar_column_uses_static_empty_bar_for_unknown_total(self, nb_dt_import): + column = nb_dt_import.NoPulseBarColumn() + bar = column.render(SimpleNamespace(total=None, completed=5, get_time=lambda: 0)) + + assert bar.total == 1.0 + assert bar.completed == 0.0 + + def test_parse_vendor_racks_calls_repo_when_directory_exists(self, nb_dt_import): + repo = MagicMock() + repo.get_devices.return_value = (["rack.yaml"], []) + repo.parse_files.return_value = [{"model": "Rack"}] + + with patch("nb_dt_import.os.path.isdir", return_value=True): + result = nb_dt_import._parse_vendor_racks(repo, "/racks", "nokia", ["rack"]) + + assert result == [{"model": "Rack"}] + repo.get_devices.assert_called_once_with("/racks", ["nokia"]) + repo.parse_files.assert_called_once_with(["rack.yaml"], slugs=["rack"]) + + def test_finalize_task_registry_updates_unknown_totals_and_skips_missing_tasks(self, nb_dt_import): + progress = MagicMock() + progress.tasks = [SimpleNamespace(id=1, total=None, completed=3)] + + nb_dt_import._finalize_task_registry(progress, {"seen": 1, "missing": 2}) + + progress.update.assert_called_once_with(1, total=3) + progress.stop_task.assert_called_once_with(1) + + def test_validate_argument_combinations_blocks_force_resolve_without_update(self, nb_dt_import): + parser = MagicMock() + args = SimpleNamespace( + export_diff=False, + update=False, + only_new=False, + remove_components=False, + remove_unmanaged_types=False, + force_resolve_conflicts=True, + ) + + nb_dt_import._validate_argument_combinations(parser, args) + + parser.error.assert_called_once_with("--force-resolve-conflicts requires --update") + + def test_validate_argument_combinations_blocks_remove_unmanaged_without_remove_components(self, nb_dt_import): + parser = MagicMock() + args = SimpleNamespace( + export_diff=False, + update=True, + only_new=False, + remove_components=False, + remove_unmanaged_types=True, + force_resolve_conflicts=False, + ) + + nb_dt_import._validate_argument_combinations(parser, args) + + parser.error.assert_called_once_with("--remove-unmanaged-types requires --remove-components") + + def test_validate_argument_combinations_blocks_remove_unmanaged_with_export_diff(self, nb_dt_import): + parser = MagicMock() + parser.error.side_effect = SystemExit(2) + args = SimpleNamespace( + export_diff=True, + update=False, + only_new=False, + remove_components=False, + remove_unmanaged_types=True, + force_resolve_conflicts=False, + ) + + with pytest.raises(SystemExit): + nb_dt_import._validate_argument_combinations(parser, args) + + parser.error.assert_called_once_with( + "--remove-unmanaged-types is an import-only flag and cannot be used with --export-diff" + ) + + def test_validate_argument_combinations_blocks_slugs_with_export_diff(self, nb_dt_import): + parser = MagicMock() + parser.error.side_effect = SystemExit(2) + args = SimpleNamespace( + export_diff=True, + update=False, + only_new=False, + remove_components=False, + remove_unmanaged_types=False, + force_resolve_conflicts=False, + slugs=["nokia-7750"], + verify_images=False, + ) + with pytest.raises(SystemExit): + nb_dt_import._validate_argument_combinations(parser, args) + parser.error.assert_called_once_with("--slugs is an import-only flag and cannot be used with --export-diff") + + def test_validate_argument_combinations_blocks_verify_images_with_export_diff(self, nb_dt_import): + parser = MagicMock() + parser.error.side_effect = SystemExit(2) + args = SimpleNamespace( + export_diff=True, + update=False, + only_new=False, + remove_components=False, + remove_unmanaged_types=False, + force_resolve_conflicts=False, + slugs=[], + verify_images=True, + ) + with pytest.raises(SystemExit): + nb_dt_import._validate_argument_combinations(parser, args) + parser.error.assert_called_once_with( + "--verify-images is an import-only flag and cannot be used with --export-diff" + ) + + def test_validate_argument_combinations_blocks_force_resolve_with_export_diff(self, nb_dt_import): + parser = MagicMock() + parser.error.side_effect = SystemExit(2) + args = SimpleNamespace( + export_diff=True, + update=False, + only_new=False, + remove_components=False, + remove_unmanaged_types=False, + force_resolve_conflicts=True, + slugs=[], + verify_images=False, + ) + with pytest.raises(SystemExit): + nb_dt_import._validate_argument_combinations(parser, args) + parser.error.assert_called_once_with( + "--force-resolve-conflicts is an import-only flag and cannot be used with --export-diff" + ) + + """_run_export_diff wires up Exporter with progress panel and console.""" + handle = MagicMock() + progress = MagicMock() + progress.console = object() + args = SimpleNamespace( + export_diff_dir="extra", + force_export_overwrite=True, + vendors=["nokia"], + show_remaining_time=True, + ) + + class _Ctx: + def __enter__(self): + return progress + + def __exit__(self, exc_type, exc, tb): + return False + + with ( + patch("nb_dt_import.get_progress_panel", return_value=_Ctx()), + patch("core.export.Exporter") as MockExporter, + ): + nb_dt_import._run_export_diff(nb_dt_import.settings, handle, args) + + MockExporter.assert_called_once() + handle.set_console.assert_called_once_with(progress.console) + MockExporter.return_value.run.assert_called_once_with(progress=progress) + + +class TestMainAdditionalCoverage: + """Additional main() coverage for export-diff and preload-job teardown.""" + + def test_main_returns_early_for_export_diff(self, nb_dt_import): + with ( + patch.object(sys, "argv", ["nb-dt-import.py", "--export-diff"]), + patch("nb_dt_import._run_export_diff") as mock_run_export, + patch("nb_dt_import.DTLRepo") as MockRepo, + patch("nb_dt_import.NetBox") as MockNetBox, + ): + nb_dt_import.main() + + mock_run_export.assert_called_once() + MockRepo.assert_not_called() + MockNetBox.assert_not_called() + + def test_main_uses_slug_fast_path_device_files(self, nb_dt_import): + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.resolve_slug_files.return_value = { + "device_files": {"cisco": ["resolved.yaml"]}, + "module_vendors": set(), + "rack_vendors": set(), + } + mock_repo.parse_files.side_effect = lambda files, slugs=None, progress=None: ( + [{"manufacturer": {"slug": "cisco"}, "model": "X", "slug": "x"}] if files == ["resolved.yaml"] else [] + ) + mock_nb = _make_mock_netbox() + + with ( + patch.object(sys, "argv", ["nb-dt-import.py", "--slugs", "x"]), + patch("nb_dt_import.DTLRepo", return_value=mock_repo), + patch("nb_dt_import.NetBox", return_value=mock_nb), + patch("nb_dt_import.ChangeDetector") as MockDetector, + ): + MockDetector.return_value.detect_changes.return_value = _empty_change_report() + nb_dt_import.main() + + assert any(call.args[0] == ["resolved.yaml"] for call in mock_repo.parse_files.call_args_list) + + def test_main_pumps_and_stops_preload_job(self, nb_dt_import): + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + + def _parse_files(files, slugs=None, progress=None): + if files == ["device.yaml"]: + return [{"manufacturer": {"slug": "cisco"}, "model": "X", "slug": "x"}] + return [] + + mock_repo.get_devices.side_effect = lambda path, vendors=None: ( + (["device.yaml"], []) if path == "/tmp/devices" else ([], []) + ) + mock_repo.parse_files.side_effect = _parse_files + mock_nb = _make_mock_netbox() + mock_nb.device_types.start_component_preload.return_value = "job-1" + progress = MagicMock() + progress.console = object() + + class _Ctx: + def __enter__(self): + return progress + + def __exit__(self, exc_type, exc, tb): + return False + + with ( + patch.object(sys, "argv", ["nb-dt-import.py"]), + patch("nb_dt_import.DTLRepo", return_value=mock_repo), + patch("nb_dt_import.NetBox", return_value=mock_nb), + patch("nb_dt_import.ChangeDetector") as MockDetector, + patch("nb_dt_import._process_device_types", return_value="job-1"), + patch("sys.stdout") as mock_stdout, + patch("nb_dt_import.get_progress_panel", return_value=_Ctx()), + ): + mock_stdout.isatty.return_value = True + MockDetector.return_value.detect_changes.return_value = _empty_change_report() + nb_dt_import.main() + + mock_nb.device_types.pump_preload_progress.assert_called() + mock_nb.device_types.stop_component_preload.assert_called_with("job-1", progress=progress) + + def test_main_stops_preload_job_in_finally_on_error(self, nb_dt_import): + mock_repo = _make_mock_repo() + mock_repo.discover_vendors.return_value = [{"name": "Cisco", "slug": "cisco"}] + mock_repo.get_devices.side_effect = lambda path, vendors=None: ( + (["device.yaml"], []) if path == "/tmp/devices" else ([], []) + ) + mock_repo.parse_files.side_effect = lambda files, slugs=None, progress=None: ( + [{"manufacturer": {"slug": "cisco"}, "model": "X", "slug": "x"}] if files == ["device.yaml"] else [] + ) + mock_nb = _make_mock_netbox() + mock_nb.device_types.start_component_preload.return_value = "job-2" + progress = MagicMock() + progress.console = object() + + class _Ctx: + def __enter__(self): + return progress + + def __exit__(self, exc_type, exc, tb): + return False + + with ( + patch.object(sys, "argv", ["nb-dt-import.py"]), + patch("nb_dt_import.DTLRepo", return_value=mock_repo), + patch("nb_dt_import.NetBox", return_value=mock_nb), + patch("nb_dt_import._process_device_types", side_effect=RuntimeError("boom")), + patch("nb_dt_import.get_progress_panel", return_value=_Ctx()), + ): + with pytest.raises(RuntimeError, match="boom"): + nb_dt_import.main() + + mock_nb.device_types.stop_component_preload.assert_called_with("job-2", progress=progress) + + def test_build_argument_parser_sets_expected_defaults_and_flags(self, nb_dt_import): + parser = nb_dt_import._build_argument_parser() + + defaults = parser.parse_args([]) + parsed = parser.parse_args( + [ + "--vendors", + "cisco", + "juniper", + "--url", + "https://example.com/repo.git", + "--slugs", + "x", + "y", + "--branch", + "feature/test", + "--verbose", + "--show-remaining-time", + "--update", + "--remove-components", + "--remove-unmanaged-types", + "--force-resolve-conflicts", + "--verify-images", + "--export-diff", + "--export-diff-dir", + "exports/", + "--force-export-overwrite", + ] + ) + + assert parser.description == "Import Netbox Device Types" + assert parser.allow_abbrev is False + assert defaults.export_diff_dir == "extra/" + assert defaults.force_export_overwrite is False + assert parsed.vendors == ["cisco", "juniper"] + assert parsed.url == "https://example.com/repo.git" + assert parsed.slugs == ["x", "y"] + assert parsed.branch == "feature/test" + assert parsed.verbose is True + assert parsed.show_remaining_time is True + assert parsed.update is True + assert parsed.remove_components is True + assert parsed.remove_unmanaged_types is True + assert parsed.force_resolve_conflicts is True + assert parsed.verify_images is True + assert parsed.export_diff is True + assert parsed.export_diff_dir == "exports/" + assert parsed.force_export_overwrite is True + + def test_run_vendor_loop_processes_slug_fast_path_and_skips_empty_vendor(self, nb_dt_import): + args = SimpleNamespace(only_new=False, slugs=["x"]) + handle = MagicMock() + progress = MagicMock() + dtl_repo = _make_mock_repo() + dtl_repo.get_devices.side_effect = lambda path, vendors=None: ( + ([f"{vendors[0]}-{path.split('/')[-1]}.yaml"], []) + if vendors and vendors[0] == "cisco" and path in {"/tmp/modules", "/tmp/rack-types"} + else ([], []) + ) + dtl_repo.parse_files.side_effect = lambda files, slugs=None, progress=None: ( + [{"manufacturer": {"slug": "cisco"}, "model": "X", "slug": "x"}] + if files == ["resolved.yaml"] + else [{"manufacturer": {"slug": "cisco"}, "model": "M", "slug": "m"}] + if files == ["cisco-modules.yaml"] + else [] + ) + netbox = _make_mock_netbox(modules=True) + netbox.device_types.start_component_preload.return_value = "job-1" + slug_resolved = { + "device_files": {"empty": [], "cisco": ["resolved.yaml"]}, + "module_vendors": {"cisco"}, + "rack_vendors": set(), + } + + with ( + patch( + "nb_dt_import._parse_vendor_racks", + side_effect=[[], [{"manufacturer": {"slug": "cisco"}, "model": "R", "slug": "r"}]], + ), + patch("nb_dt_import._process_device_types", return_value="job-1") as mock_process_device_types, + patch("nb_dt_import._process_module_types") as mock_process_module_types, + patch("nb_dt_import._process_rack_types") as mock_process_rack_types, + patch("nb_dt_import._finalize_task_registry") as mock_finalize, + ): + nb_dt_import._run_vendor_loop( + dtl_repo=dtl_repo, + netbox=netbox, + args=args, + handle=handle, + vendors_to_process=[{"slug": "empty", "name": "Empty"}, {"slug": "cisco", "name": "Cisco"}], + devices_path="/tmp/devices", + modules_path="/tmp/modules", + racks_path="/tmp/rack-types", + slug_resolved=slug_resolved, + progress=progress, + task_registry={}, + vendor_task_id=7, + ) + + netbox.load_vendor.assert_called_once_with("cisco") + netbox.device_types.start_component_preload.assert_called_once_with( + manufacturer_slug="cisco", + progress=progress, + task_registry={}, + ) + # pump is now called after preload start, after create_manufacturers, + # and after each of the three process_*_types steps → 5 calls total + assert netbox.device_types.pump_preload_progress.call_count == 5 + netbox.device_types.pump_preload_progress.assert_called_with("job-1", progress) + netbox.device_types.stop_component_preload.assert_called_once_with("job-1", progress=progress) + netbox.create_manufacturers.assert_called_once_with([{"slug": "cisco", "name": "Cisco"}]) + mock_process_device_types.assert_called_once() + mock_process_module_types.assert_called_once() + mock_process_rack_types.assert_called_once() + assert progress.advance.call_args_list == [((7,),), ((7,),)] + mock_finalize.assert_called_once_with(progress, {}) + assert any(call.args[0] == ["resolved.yaml"] for call in dtl_repo.parse_files.call_args_list) + + def test_run_vendor_loop_stops_preload_in_finally_on_error(self, nb_dt_import): + args = SimpleNamespace(only_new=False, slugs=[]) + handle = MagicMock() + progress = MagicMock() + dtl_repo = _make_mock_repo() + dtl_repo.get_devices.side_effect = lambda path, vendors=None: ( + (["device.yaml"], []) if path == "/tmp/devices" else ([], []) + ) + dtl_repo.parse_files.side_effect = lambda files, slugs=None, progress=None: ( + [{"manufacturer": {"slug": "cisco"}, "model": "X", "slug": "x"}] if files == ["device.yaml"] else [] + ) + netbox = _make_mock_netbox() + netbox.device_types.start_component_preload.return_value = "job-2" + + with ( + patch("nb_dt_import._parse_vendor_racks", return_value=[]), + patch("nb_dt_import._process_device_types", side_effect=RuntimeError("boom")), + patch("nb_dt_import._finalize_task_registry") as mock_finalize, + ): + with pytest.raises(RuntimeError, match="boom"): + nb_dt_import._run_vendor_loop( + dtl_repo=dtl_repo, + netbox=netbox, + args=args, + handle=handle, + vendors_to_process=[{"slug": "cisco", "name": "Cisco"}], + devices_path="/tmp/devices", + modules_path="/tmp/modules", + racks_path="/tmp/rack-types", + slug_resolved=None, + progress=progress, + task_registry={}, + vendor_task_id=8, + ) + + netbox.device_types.stop_component_preload.assert_called_once_with("job-2", progress=progress) + mock_finalize.assert_called_once_with(progress, {}) diff --git a/tests/test_nb_serializer.py b/tests/test_nb_serializer.py new file mode 100644 index 000000000..9250234da --- /dev/null +++ b/tests/test_nb_serializer.py @@ -0,0 +1,679 @@ +"""Tests for core/nb_serializer.py — NetBox → DTL YAML serializer.""" + +from core.nb_serializer import _coerce_numeric, serialize_device_type, serialize_module_type, serialize_rack_type + + +def _dotdict(**kw): + """Build a lightweight stub that only has attributes for the provided kwargs. + + Unlike MagicMock, accessing an attribute not in ``kw`` returns the default + supplied to ``getattr(obj, attr, default)`` rather than a truthy Mock object. + This exercises absent-field logic correctly. + """ + + class _Stub: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def get(self, key, default=None): + return kw.get(key, default) + + return _Stub(**kw) + + +def _make_mfr(name="Acme", slug="acme"): + return _dotdict(name=name, slug=slug) + + +class TestSerializeDeviceType: + """Tests for serialize_device_type function.""" + + def test_minimal_required_fields(self): + record = _dotdict( + id=1, + model="My Switch", + slug="acme-my-switch", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + result = serialize_device_type(record, components_by_dt_id={}) + assert result["manufacturer"] == "Acme" + assert result["model"] == "My Switch" + assert result["slug"] == "acme-my-switch" + assert result["u_height"] == 1 + assert result["is_full_depth"] is True + # None/empty fields must be absent + assert "part_number" not in result + assert "airflow" not in result + assert "description" not in result + + def test_optional_scalar_fields_included_when_set(self): + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=2, + is_full_depth=False, + part_number="PN-123", + airflow="front-to-rear", + weight=10.5, + weight_unit="kg", + description="A switch", + comments="note", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + result = serialize_device_type(record, components_by_dt_id={}) + assert result["part_number"] == "PN-123" + assert result["airflow"] == "front-to-rear" + assert result["weight"] == 10.5 + assert result["weight_unit"] == "kg" + assert result["description"] == "A switch" + assert result["comments"] == "note" + assert result["is_full_depth"] is False + + def test_image_flags_set_when_urls_present(self): + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image="/media/devicetype-images/acme-x.front.png", + rear_image="/media/devicetype-images/acme-x.rear.png", + ) + result = serialize_device_type(record, components_by_dt_id={}) + assert result["front_image"] is True + assert result["rear_image"] is True + + def test_interfaces_serialized(self): + iface = _dotdict( + id=10, + name="eth0", + type="1000base-t", + label="", + description="", + mgmt_only=False, + enabled=True, + poe_mode=None, + poe_type=None, + rf_role=None, + device_type=_dotdict(id=1), + ) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + components = {1: {"interface_templates": [iface]}} + result = serialize_device_type(record, components_by_dt_id=components) + assert "interfaces" in result + assert result["interfaces"][0]["name"] == "eth0" + assert result["interfaces"][0]["type"] == "1000base-t" + # defaults omitted + assert "label" not in result["interfaces"][0] + assert "mgmt_only" not in result["interfaces"][0] + assert "enabled" not in result["interfaces"][0] + + def test_interface_with_mgmt_only_true_included(self): + iface = _dotdict( + id=11, + name="mgmt0", + type="1000base-t", + label="", + description="", + mgmt_only=True, + enabled=True, + poe_mode=None, + poe_type=None, + rf_role=None, + device_type=_dotdict(id=1), + ) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + result = serialize_device_type(record, {1: {"interface_templates": [iface]}}) + assert result["interfaces"][0]["mgmt_only"] is True + + def test_float_u_height_coerced_to_int(self): + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1.0, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + result = serialize_device_type(record, {}) + assert result["u_height"] == 1 + assert isinstance(result["u_height"], int) + + def test_weight_as_numeric_string_coerced_to_float(self): + """NetBox returns weight as a quoted decimal string e.g. '13.60' — must become float.""" + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight="13.60", + weight_unit="kg", + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + result = serialize_device_type(record, {}) + assert result["weight"] == 13.6 + assert isinstance(result["weight"], float) + + def test_weight_as_integer_string_coerced_to_int(self): + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight="14.00", + weight_unit="kg", + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + result = serialize_device_type(record, {}) + assert result["weight"] == 14 + assert isinstance(result["weight"], int) + + def test_key_order(self): + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number="PN", + airflow="front-to-rear", + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + result = serialize_device_type(record, {}) + keys = list(result.keys()) + assert keys.index("manufacturer") < keys.index("model") + assert keys.index("model") < keys.index("slug") + + +class TestSerializeModuleType: + """Tests for serialize_module_type function.""" + + def test_minimal_fields(self): + record = _dotdict( + id=5, + model="MyModule", + manufacturer=_make_mfr(), + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + ) + result = serialize_module_type(record, components_by_mt_id={}) + assert result["manufacturer"] == "Acme" + assert result["model"] == "MyModule" + assert "slug" not in result # module types have no slug field + assert "part_number" not in result + + def test_optional_fields(self): + record = _dotdict( + id=5, + model="MyModule", + manufacturer=_make_mfr(), + part_number="MP-1", + airflow="front-to-rear", + weight=2.5, + weight_unit="kg", + description="desc", + comments="comment", + ) + result = serialize_module_type(record, {}) + assert result["part_number"] == "MP-1" + assert result["airflow"] == "front-to-rear" + + +class TestSerializeRackType: + """Tests for serialize_rack_type function.""" + + def test_minimal_fields(self): + record = _dotdict( + id=7, + model="MyRack", + slug="acme-myrack", + manufacturer=_make_mfr(), + form_factor="4-post-cabinet", + width=19, + u_height=42, + starting_unit=1, + outer_width=None, + outer_height=None, + outer_depth=None, + outer_unit=None, + mounting_depth=None, + weight=None, + max_weight=None, + weight_unit=None, + desc_units=False, + comments="", + description="", + ) + result = serialize_rack_type(record) + assert result["manufacturer"] == "Acme" + assert result["model"] == "MyRack" + assert result["slug"] == "acme-myrack" + assert result["form_factor"] == "4-post-cabinet" + assert result["u_height"] == 42 + assert "outer_width" not in result + assert result["desc_units"] is False + + def test_desc_units_true_included(self): + record = _dotdict( + id=7, + model="R", + slug="acme-r", + manufacturer=_make_mfr(), + form_factor="4-post-cabinet", + width=19, + u_height=10, + starting_unit=1, + outer_width=None, + outer_height=None, + outer_depth=None, + outer_unit=None, + mounting_depth=None, + weight=None, + max_weight=None, + weight_unit=None, + desc_units=True, + comments="", + description="", + ) + result = serialize_rack_type(record) + assert result["desc_units"] is True + + +def test_coerce_numeric_leaves_invalid_decimal_string_unchanged(): + assert _coerce_numeric("12.3.4") == "12.3.4" + + +class TestManufacturerSerialization: + """Tests for manufacturer serialized as plain name string.""" + + def test_device_type_manufacturer_as_name_string(self): + record = _dotdict( + id=1, + model="My Switch", + slug="acme-my-switch", + manufacturer=_make_mfr("Nokia", "nokia"), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + result = serialize_device_type(record, components_by_dt_id={}) + assert result["manufacturer"] == "Nokia" + + def test_module_type_manufacturer_as_name_string(self): + record = _dotdict( + id=5, + model="MyModule", + manufacturer=_make_mfr("Arista", "arista"), + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + ) + result = serialize_module_type(record, components_by_mt_id={}) + assert result["manufacturer"] == "Arista" + + def test_rack_type_manufacturer_as_name_string(self): + record = _dotdict( + id=7, + model="MyRack", + slug="acme-myrack", + manufacturer=_make_mfr("Cisco", "cisco"), + form_factor="4-post-cabinet", + width=19, + u_height=42, + starting_unit=1, + outer_width=None, + outer_height=None, + outer_depth=None, + outer_unit=None, + mounting_depth=None, + weight=None, + max_weight=None, + weight_unit=None, + desc_units=False, + comments="", + description="", + ) + result = serialize_rack_type(record) + assert result["manufacturer"] == "Cisco" + + +class TestFrontPortSerialization: + """Tests for front port rear_port extraction.""" + + def test_front_port_rear_port_extracted_from_mapping(self): + from types import SimpleNamespace + + mapping = SimpleNamespace(rear_port=SimpleNamespace(name="RP1"), rear_port_position=1) + fp = SimpleNamespace(name="FP1", type="8p8c", label="", description="", color="", mappings=[mapping]) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + components = {1: {"front_port_templates": [fp]}} + result = serialize_device_type(record, components) + assert result["front-ports"][0]["rear_port"] == "RP1" + assert "rear_port_position" not in result["front-ports"][0] + + def test_front_port_rear_port_position_included_when_gt_1(self): + from types import SimpleNamespace + + mapping = SimpleNamespace(rear_port=SimpleNamespace(name="RP1"), rear_port_position=3) + fp = SimpleNamespace(name="FP1", type="8p8c", label="", description="", color="", mappings=[mapping]) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + components = {1: {"front_port_templates": [fp]}} + result = serialize_device_type(record, components) + assert result["front-ports"][0]["rear_port_position"] == 3 + + def test_front_port_rear_port_position_zero_omitted(self): + """Position 0 is not a valid DTL value and should be omitted.""" + from types import SimpleNamespace + + mapping = SimpleNamespace(rear_port=SimpleNamespace(name="RP1"), rear_port_position=0) + fp = SimpleNamespace(name="FP1", type="8p8c", label="", description="", color="", mappings=[mapping]) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + components = {1: {"front_port_templates": [fp]}} + result = serialize_device_type(record, components) + assert "rear_port_position" not in result["front-ports"][0] + + def test_front_port_multiple_mappings_warns_and_uses_first(self): + """When a front port has >1 mappings a UserWarning is raised and only the first is used.""" + import warnings + from types import SimpleNamespace + + m1 = SimpleNamespace(rear_port=SimpleNamespace(name="RP1"), rear_port_position=1) + m2 = SimpleNamespace(rear_port=SimpleNamespace(name="RP2"), rear_port_position=1) + fp = SimpleNamespace(name="FP1", type="8p8c", label="", description="", color="", mappings=[m1, m2]) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + components = {1: {"front_port_templates": [fp]}} + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + result = serialize_device_type(record, components) + assert result["front-ports"][0]["rear_port"] == "RP1" + assert len(caught) == 1 + assert issubclass(caught[0].category, UserWarning) + assert "FP1" in str(caught[0].message) + assert "2 mappings" in str(caught[0].message) + assert "issue #78" in str(caught[0].message) + + def test_front_port_legacy_rear_port_scalars(self): + """pre-4.5 NetBox: record has rear_port/rear_port_position as direct attrs (no mappings).""" + from types import SimpleNamespace + + fp = SimpleNamespace( + name="FP1", + type="8p8c", + label="", + description="", + color="", + mappings=None, + rear_port=SimpleNamespace(name="RP1"), + rear_port_position=3, + ) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + components = {1: {"front_port_templates": [fp]}} + result = serialize_device_type(record, components) + assert result["front-ports"][0]["rear_port"] == "RP1" + assert result["front-ports"][0]["rear_port_position"] == 3 + + def test_front_port_legacy_rear_port_position_1_omitted(self): + """pre-4.5: rear_port_position == 1 should be omitted (same as mappings path).""" + from types import SimpleNamespace + + fp = SimpleNamespace( + name="FP1", + type="8p8c", + label="", + description="", + color="", + mappings=None, + rear_port=SimpleNamespace(name="RP1"), + rear_port_position=1, + ) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + components = {1: {"front_port_templates": [fp]}} + result = serialize_device_type(record, components) + assert result["front-ports"][0]["rear_port"] == "RP1" + assert "rear_port_position" not in result["front-ports"][0] + + def test_components_sorted_by_name(self): + from types import SimpleNamespace + + iface_z = SimpleNamespace( + name="eth9", + type="1000base-t", + label="", + description="", + mgmt_only=False, + enabled=True, + poe_mode=None, + poe_type=None, + rf_role=None, + ) + iface_a = SimpleNamespace( + name="eth0", + type="1000base-t", + label="", + description="", + mgmt_only=False, + enabled=True, + poe_mode=None, + poe_type=None, + rf_role=None, + ) + record = _dotdict( + id=1, + model="X", + slug="acme-x", + manufacturer=_make_mfr(), + u_height=1, + is_full_depth=True, + part_number=None, + airflow=None, + weight=None, + weight_unit=None, + description="", + comments="", + subdevice_role=None, + front_image=None, + rear_image=None, + ) + components = {1: {"interface_templates": [iface_z, iface_a]}} + result = serialize_device_type(record, components) + names = [i["name"] for i in result["interfaces"]] + assert names == sorted(names) diff --git a/tests/test_netbox_api.py b/tests/test_netbox_api.py index 6912457d4..95ddb9c40 100644 --- a/tests/test_netbox_api.py +++ b/tests/test_netbox_api.py @@ -325,7 +325,6 @@ def test_fetch_global_endpoint_records_uses_graphql( records = dt._fetch_global_endpoint_records( "interface_templates", progress_callback=lambda endpoint, advance: updates.append((endpoint, advance)), - expected_total=3, ) assert len(records) == 3 @@ -359,7 +358,7 @@ def test_fetch_global_endpoint_records_progress_emits_live_per_page( advances_during_fetch = [] - def fake_get(endpoint_name, on_page=None): + def fake_get(endpoint_name, manufacturer_slug=None, on_page=None): # Stream pages and verify that the consumer-facing callback was invoked # before the next page is yielded. all_records = [] @@ -377,7 +376,6 @@ def progress_cb(endpoint, advance): records = dt._fetch_global_endpoint_records( "interface_templates", progress_callback=progress_cb, - expected_total=3, ) assert len(records) == 3 @@ -389,53 +387,25 @@ def progress_cb(endpoint, advance): ] -def test_fetch_global_endpoint_records_emits_rewind_on_retry( +def test_fetch_global_endpoint_records_passes_manufacturer_slug( mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types ): - """When a count-mismatch triggers a retry, a negative-advance "rewind" must be emitted. - - Without the rewind, the next attempt's live advances would double-count on - top of the failed attempt's leftover ones. - """ + """manufacturer_slug is forwarded to get_component_templates.""" from unittest.mock import patch as _patch - from core.graphql_client import DotDict mock_nb_api = mock_pynetbox.api.return_value dt = make_device_types(nb_api=mock_nb_api) - iface1 = DotDict({"id": "1", "name": "a", "device_type": {"id": "5"}, "module_type": None}) - iface2 = DotDict({"id": "2", "name": "b", "device_type": {"id": "5"}, "module_type": None}) - - call_count = {"n": 0} - - def fake_get(endpoint_name, on_page=None): - call_count["n"] += 1 - if call_count["n"] == 1: - if on_page is not None: - on_page(1) - return [iface1] - if on_page is not None: - on_page(1) - on_page(1) - return [iface1, iface2] + received_slugs = [] - advances = [] + def fake_get(endpoint_name, manufacturer_slug=None, on_page=None): + received_slugs.append(manufacturer_slug) + return [] - def progress_cb(endpoint, advance): - advances.append(advance) - - with _patch("core.netbox_api.time.sleep"): - with _patch.object(dt.graphql, "get_component_templates", side_effect=fake_get): - records = dt._fetch_global_endpoint_records( - "interface_templates", - progress_callback=progress_cb, - expected_total=2, - ) + with _patch.object(dt.graphql, "get_component_templates", side_effect=fake_get): + dt._fetch_global_endpoint_records("interface_templates", manufacturer_slug="cisco") - assert len(records) == 2 - # Live: +1 (failed attempt page), -1 (rewind), +1 + +1 (successful retry pages). - assert advances == [1, -1, 1, 1] - assert sum(advances) == 2 + assert received_slugs == ["cisco"] def test_fetch_global_endpoint_records_progress_skipped_when_empty( @@ -456,160 +426,16 @@ def test_fetch_global_endpoint_records_progress_skipped_when_empty( fetched = dt._fetch_global_endpoint_records( "interface_templates", progress_callback=lambda endpoint, advance: updates.append((endpoint, advance)), - expected_total=0, ) assert fetched == [] assert updates == [] -def test_fetch_global_endpoint_records_retries_and_aborts_on_count_mismatch( - mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types -): - """After _MAX_RETRIES mismatches a GraphQLCountMismatchError is raised; each attempt logs a warning.""" - from unittest.mock import patch as _patch - from core.graphql_client import GraphQLCountMismatchError - - mock_nb_api = mock_pynetbox.api.return_value - mock_graphql_requests.side_effect = _make_graphql_dispatch( - { - "device_type_list": {"data": {"device_type_list": []}}, - "interface_template_list": { - "data": { - "interface_template_list": [ - { - "id": "1", - "name": "xe-0/0/0", - "type": "10gbase-x-sfpp", - "label": "", - "mgmt_only": False, - "enabled": True, - "poe_mode": None, - "poe_type": None, - "device_type": {"id": "5"}, - "module_type": None, - } - ] - } - }, - } - ) - - dt = make_device_types(nb_api=mock_nb_api) - logged = [] - dt.handle.log = lambda msg: logged.append(msg) - - with _patch("core.netbox_api.time.sleep") as mock_sleep: - with pytest.raises(GraphQLCountMismatchError, match="interface_templates"): - dt._fetch_global_endpoint_records( - "interface_templates", - progress_callback=None, - expected_total=113259, - ) - - from core.netbox_api import _MAX_RETRIES, _RETRY_BACKOFF - - # One sleep per retry attempt - assert mock_sleep.call_count == _MAX_RETRIES - # Backoff durations must match the configured sequence — guards against a - # regression that silently changes the wait pattern. - assert [call.args[0] for call in mock_sleep.call_args_list] == list(_RETRY_BACKOFF[:_MAX_RETRIES]) - # A WARNING is logged for each retry - warnings = [m for m in logged if "WARNING" in m and "interface_templates" in m] - assert len(warnings) == _MAX_RETRIES - - -def test_fetch_global_endpoint_records_detects_mismatch_when_rest_returns_zero( - mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types -): - """A REST count of 0 must NOT silently skip validation when GraphQL returns records. - - Regression: previously ``if expected_total and ...`` treated ``0`` as "skip - validation", meaning a real REST count of 0 paired with a GraphQL response that - leaked records would go unnoticed. The check now uses ``is not None`` so 0 is a - legitimate expected value and any mismatch (including 0 vs N>0) is flagged. - """ - from unittest.mock import patch as _patch - from core.graphql_client import GraphQLCountMismatchError - - mock_nb_api = mock_pynetbox.api.return_value - mock_graphql_requests.side_effect = _make_graphql_dispatch( - { - "device_type_list": {"data": {"device_type_list": []}}, - "interface_template_list": { - "data": { - "interface_template_list": [ - { - "id": "1", - "name": "xe-0/0/0", - "type": "10gbase-x-sfpp", - "label": "", - "mgmt_only": False, - "enabled": True, - "poe_mode": None, - "poe_type": None, - "device_type": {"id": "5"}, - "module_type": None, - } - ] - } - }, - } - ) - - dt = make_device_types(nb_api=mock_nb_api) - - with _patch("core.netbox_api.time.sleep"): - with pytest.raises(GraphQLCountMismatchError, match="interface_templates"): - dt._fetch_global_endpoint_records( - "interface_templates", - progress_callback=None, - expected_total=0, - ) - - -def test_fetch_global_endpoint_records_succeeds_on_retry( - mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types -): - """When the first fetch is truncated but a retry returns the full count, the records are returned.""" - from unittest.mock import patch as _patch - from core.graphql_client import DotDict - - mock_nb_api = mock_pynetbox.api.return_value - dt = make_device_types(nb_api=mock_nb_api) - - iface1 = DotDict({"id": "1", "name": "xe-0/0/0", "device_type": {"id": "5"}, "module_type": None}) - iface2 = DotDict({"id": "2", "name": "xe-0/0/1", "device_type": {"id": "5"}, "module_type": None}) - - call_count = {"n": 0} - - def fake_get(endpoint_name, on_page=None): - call_count["n"] += 1 - if call_count["n"] == 1: - return [iface1] # truncated - return [iface1, iface2] # full on retry - - with _patch("core.netbox_api.time.sleep"): - with _patch.object(dt.graphql, "get_component_templates", side_effect=fake_get): - records = dt._fetch_global_endpoint_records( - "interface_templates", - progress_callback=None, - expected_total=2, - ) - - assert len(records) == 2 - assert call_count["n"] == 2 # initial attempt + 1 retry - - -def test_fetch_global_endpoint_records_progress_not_double_counted_on_retry( +def test_fetch_global_endpoint_records_returns_all_records_without_retry( mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types ): - """A mismatched-then-retried fetch must not double-advance the progress bar. - - If page advances were published during the failing attempt and again during - the successful retry, the progress callback would receive more advances than - the final expected total. Only the successful attempt should publish. - """ + """_fetch_global_endpoint_records returns all GraphQL records; no retry logic.""" from unittest.mock import patch as _patch from core.graphql_client import DotDict @@ -621,43 +447,21 @@ def test_fetch_global_endpoint_records_progress_not_double_counted_on_retry( call_count = {"n": 0} - def fake_get(endpoint_name, on_page=None): + def fake_get(endpoint_name, manufacturer_slug=None, on_page=None): call_count["n"] += 1 - if call_count["n"] == 1: - # First attempt: emit one page advance, but return truncated list -> mismatch - if on_page is not None: - on_page(1) - return [iface1] - # Retry attempt: emit two page advances, return full list -> success - if on_page is not None: - on_page(1) - on_page(1) return [iface1, iface2] - advances = [] - - def progress_cb(endpoint, advance): - advances.append((endpoint, advance)) - - with _patch("core.netbox_api.time.sleep"): - with _patch.object(dt.graphql, "get_component_templates", side_effect=fake_get): - records = dt._fetch_global_endpoint_records( - "interface_templates", - progress_callback=progress_cb, - expected_total=2, - ) + with _patch.object(dt.graphql, "get_component_templates", side_effect=fake_get): + records = dt._fetch_global_endpoint_records("interface_templates") assert len(records) == 2 - assert call_count["n"] == 2 - # Total advances must equal expected_total (2), not 1 + 2 == 3. - total_advance = sum(n for _, n in advances) - assert total_advance == 2, f"progress callback double-counted retry advances: got {total_advance}, expected 2" + assert call_count["n"] == 1 # exactly one call, no retry -def test_fetch_global_endpoint_records_no_warning_on_count_match( +def test_fetch_global_endpoint_records_no_warning_logged( mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types ): - """No warning is logged when GraphQL returns the expected number of records.""" + """No warnings are logged by _fetch_global_endpoint_records on a normal fetch.""" mock_nb_api = mock_pynetbox.api.return_value mock_graphql_requests.side_effect = _make_graphql_dispatch( { @@ -687,70 +491,12 @@ def test_fetch_global_endpoint_records_no_warning_on_count_match( logged = [] dt.handle.log = lambda msg: logged.append(msg) - records = dt._fetch_global_endpoint_records( - "interface_templates", - progress_callback=None, - expected_total=1, - ) + records = dt._fetch_global_endpoint_records("interface_templates") assert len(records) == 1 assert not any("WARNING" in m for m in logged) -def test_get_rest_component_count_returns_count( - mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types -): - """_get_rest_component_count returns the integer count from pynetbox.""" - mock_nb_api = mock_pynetbox.api.return_value - mock_nb_api.dcim.interface_templates.count.return_value = 42 - - dt = make_device_types(nb_api=mock_nb_api) - assert dt._get_rest_component_count("interface_templates") == 42 - - -def test_get_rest_component_count_returns_none_on_error( - mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types -): - """_get_rest_component_count returns None if the REST call fails.""" - mock_nb_api = mock_pynetbox.api.return_value - mock_nb_api.dcim.interface_templates.count.side_effect = Exception("connection failed") - - dt = make_device_types(nb_api=mock_nb_api) - assert dt._get_rest_component_count("interface_templates") is None - - -def test_get_endpoint_totals_fetches_rest_counts( - mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types -): - """_get_endpoint_totals fetches actual REST counts for graphql endpoints.""" - mock_nb_api = mock_pynetbox.api.return_value - mock_nb_api.dcim.interface_templates.count.return_value = 100 - mock_nb_api.dcim.power_port_templates.count.return_value = 50 - - dt = make_device_types(nb_api=mock_nb_api) - components = [("interface_templates", "Interfaces"), ("power_port_templates", "Power Ports")] - totals = dt._get_endpoint_totals(components) - - assert totals["interface_templates"] == 100 - assert totals["power_port_templates"] == 50 - - -def test_get_endpoint_totals_tolerates_count_failure( - mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types -): - """_get_endpoint_totals preserves None sentinel when a REST count fails.""" - mock_nb_api = mock_pynetbox.api.return_value - mock_nb_api.dcim.interface_templates.count.side_effect = Exception("timeout") - mock_nb_api.dcim.power_port_templates.count.return_value = 20 - - dt = make_device_types(nb_api=mock_nb_api) - components = [("interface_templates", "Interfaces"), ("power_port_templates", "Power Ports")] - totals = dt._get_endpoint_totals(components) - - assert totals["interface_templates"] is None - assert totals["power_port_templates"] == 20 - - def test_preload_always_global_caches_all_vendors( mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types ): @@ -846,13 +592,11 @@ def test_start_component_preload_global_job_can_be_consumed( def test_preload_tolerates_none_endpoint_totals(mock_settings, mock_pynetbox, graphql_client, make_device_types): """``_preload_track_progress`` must not raise TypeError when a total is ``None``. - Regression: ``_get_endpoint_totals`` now returns ``None`` for endpoints whose - REST count failed (preserving the "count unavailable" sentinel). Internal - ``max(endpoint_totals.get(name, 0), ...)`` calls would raise ``TypeError`` on - ``None`` because ``dict.get`` returns the stored ``None`` instead of the default. + ``endpoint_totals`` may contain ``None`` sentinels when a count is unavailable. + Internal ``max(endpoint_totals.get(name, 0), ...)`` calls would raise ``TypeError`` + on ``None`` because ``dict.get`` returns the stored ``None`` instead of the default. The fix uses ``endpoint_totals.get(name) or 0`` to coerce ``None`` to ``0`` for - the ``max()`` comparison while still letting ``None`` flow through as - ``expected_total`` to ``_fetch_global_endpoint_records``. + the ``max()`` comparison. """ from concurrent.futures import Future @@ -4167,11 +3911,11 @@ def test_with_progress_creates_task_ids(self, mock_settings, mock_pynetbox, grap dt.stop_component_preload(preload_job) def test_exception_shuts_down_executor(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): - """If _get_endpoint_totals raises, executor is shut down and exception re-raised.""" + """If _component_preload_targets raises, executor is shut down and exception re-raised.""" mock_nb_api = mock_pynetbox.api.return_value dt = make_device_types(nb_api=mock_nb_api) - with patch.object(dt, "_get_endpoint_totals", side_effect=RuntimeError("oops")): + with patch.object(dt, "_component_preload_targets", side_effect=RuntimeError("oops")): with pytest.raises(RuntimeError, match="oops"): dt.start_component_preload() @@ -4258,8 +4002,7 @@ def test_own_executor_with_progress( progress.add_task.return_value = 1 # Run only one component to keep the test fast components = [("interface_templates", "Interface Templates")] - with patch.object(dt, "_get_endpoint_totals", return_value={"interface_templates": 0}): - dt._preload_global(components, progress_wrapper=None, progress=progress) + dt._preload_global(components, progress_wrapper=None, progress=progress) progress.add_task.assert_called() def test_preload_global_no_progress_future_failure( @@ -4275,7 +4018,7 @@ def test_preload_global_no_progress_future_failure( mock_settings.handle.log.reset_mock() mock_settings.handle.verbose_log.reset_mock() - with patch.object(dt, "_get_endpoint_totals", return_value={}): + with patch.object(dt, "_fetch_global_endpoint_records", return_value=[]): preload_job = { "executor": None, "futures": {"interface_templates": broken_future}, @@ -5627,7 +5370,7 @@ def test_update_progress_called_when_records_available( mock_nb_api = mock_pynetbox.api.return_value # Mock _fetch_global_endpoint_records to call its callback - def fake_fetch(endpoint_name, progress_callback=None, expected_total=None): + def fake_fetch(endpoint_name, progress_callback=None, manufacturer_slug=None): records = [MagicMock(name="fake")] if progress_callback is not None and records: progress_callback(endpoint_name, len(records)) @@ -5638,13 +5381,10 @@ def fake_fetch(endpoint_name, progress_callback=None, expected_total=None): progress = MagicMock() progress.add_task.return_value = 1 - with patch.object( - dt, "_get_endpoint_totals", return_value={ep: 0 for ep, _ in dt._component_preload_targets()} - ): - with patch.object(dt, "_fetch_global_endpoint_records", side_effect=fake_fetch): - preload_job = dt.start_component_preload(progress=progress) - # Let the futures complete - dt.preload_all_components(preload_job=preload_job, progress=progress) + with patch.object(dt, "_fetch_global_endpoint_records", side_effect=fake_fetch): + preload_job = dt.start_component_preload(progress=progress) + # Let the futures complete + dt.preload_all_components(preload_job=preload_job, progress=progress) # update_progress was called, which put items in progress_updates queue # pump_preload_progress or preload_all_components drained them @@ -5665,7 +5405,7 @@ def test_update_progress_callback_triggered( """The update_progress closure in _preload_global is called when records exist.""" mock_nb_api = mock_pynetbox.api.return_value - def fake_fetch(endpoint_name, progress_callback=None, expected_total=None): + def fake_fetch(endpoint_name, progress_callback=None, manufacturer_slug=None): records = [MagicMock()] if progress_callback is not None and records: progress_callback(endpoint_name, len(records)) @@ -5677,9 +5417,8 @@ def fake_fetch(endpoint_name, progress_callback=None, expected_total=None): progress.add_task.return_value = 1 components = [("interface_templates", "Interface Templates")] - with patch.object(dt, "_get_endpoint_totals", return_value={"interface_templates": 0}): - with patch.object(dt, "_fetch_global_endpoint_records", side_effect=fake_fetch): - dt._preload_global(components, progress_wrapper=None, progress=progress) + with patch.object(dt, "_fetch_global_endpoint_records", side_effect=fake_fetch): + dt._preload_global(components, progress_wrapper=None, progress=progress) progress.add_task.assert_called() progress.stop_task.assert_called() @@ -6549,3 +6288,1903 @@ def test_component_reconciliation_continues_when_scalar_patch_fails( failures = nb.outcomes.failures() assert len(failures) == 1 assert "CM-Fail-Patch" in failures[0].identity + + +# --------------------------------------------------------------------------- +# Task 6: DeviceTypes.load_for_vendor and NetBox.load_vendor +# --------------------------------------------------------------------------- + + +class TestLoadForVendor: + """Tests for DeviceTypes.load_for_vendor (Task 6).""" + + def test_load_for_vendor_populates_existing_device_types( + self, mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types + ): + """load_for_vendor populates existing_device_types and existing_device_types_by_slug.""" + from unittest.mock import patch as _patch + from core.graphql_client import DotDict + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + # Initially empty (deferred load) + assert dt.existing_device_types == {} + assert dt.existing_device_types_by_slug == {} + + cisco_dt = DotDict( + { + "id": "1", + "model": "Catalyst 9000", + "slug": "catalyst-9000", + "manufacturer": DotDict({"id": "10", "name": "Cisco", "slug": "cisco"}), + } + ) + by_model = {("cisco", "Catalyst 9000"): cisco_dt} + by_slug = {("cisco", "catalyst-9000"): cisco_dt} + + with _patch.object(dt.graphql, "get_device_types", return_value=(by_model, by_slug)) as mock_gdt: + dt.load_for_vendor("cisco") + mock_gdt.assert_called_once_with(manufacturer_slugs=["cisco"]) + + assert dt.existing_device_types == by_model + assert dt.existing_device_types_by_slug == by_slug + + def test_load_for_vendor_resets_state_before_fetch_on_failure( + self, mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types + ): + """State is reset before the fetch so a raised exception leaves a clean slate.""" + from unittest.mock import patch as _patch + import pytest + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + dt._global_preload_done = True + dt.cached_components = {"interface_templates": {("device", 1): {}}} + + with _patch.object(dt.graphql, "get_device_types", side_effect=RuntimeError("timeout")): + with pytest.raises(RuntimeError, match="timeout"): + dt.load_for_vendor("vendor-b") + + assert dt._global_preload_done is False + assert dt.cached_components == {} + + def test_load_for_vendor_replaces_prior_data( + self, mock_settings, mock_pynetbox, mock_graphql_requests, graphql_client, make_device_types + ): + """A second call to load_for_vendor replaces data from the first call.""" + from unittest.mock import patch as _patch + from core.graphql_client import DotDict + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + cisco_dt = DotDict({"id": "1", "model": "M1", "slug": "m1", "manufacturer": DotDict({"slug": "cisco"})}) + juniper_dt = DotDict({"id": "2", "model": "M2", "slug": "m2", "manufacturer": DotDict({"slug": "juniper"})}) + + with _patch.object( + dt.graphql, "get_device_types", return_value=({("cisco", "M1"): cisco_dt}, {("cisco", "m1"): cisco_dt}) + ): + dt.load_for_vendor("cisco") + + assert ("cisco", "M1") in dt.existing_device_types + assert ("juniper", "M2") not in dt.existing_device_types + + with _patch.object( + dt.graphql, + "get_device_types", + return_value=({("juniper", "M2"): juniper_dt}, {("juniper", "m2"): juniper_dt}), + ): + dt.load_for_vendor("juniper") + + assert ("cisco", "M1") not in dt.existing_device_types + assert ("juniper", "M2") in dt.existing_device_types + + +class TestNetBoxLoadVendor: + """Tests for NetBox.load_vendor (Task 6).""" + + def test_load_vendor_calls_load_for_vendor_and_resets_change_detector(self, mock_settings, mock_pynetbox): + """load_vendor delegates to device_types.load_for_vendor and resets _change_detector.""" + nb = NetBox(mock_settings, mock_settings.handle) + nb.device_types = MagicMock() + + # Force a cached change detector + nb._change_detector = MagicMock() + + nb.load_vendor("cisco") + + nb.device_types.load_for_vendor.assert_called_once_with("cisco") + assert nb._change_detector is None + + def test_load_vendor_change_detector_lazily_recreated(self, mock_settings, mock_pynetbox): + """After load_vendor, accessing change_detector creates a fresh instance.""" + nb = NetBox(mock_settings, mock_settings.handle) + nb.device_types = MagicMock() + + # Force existing change detector + original_cd = MagicMock() + nb._change_detector = original_cd + + nb.load_vendor("juniper") + assert nb._change_detector is None + + # Access the property — should create a new one + _ = nb.change_detector + assert nb._change_detector is not None + assert nb._change_detector is not original_cd + + def test_load_vendor_resets_preload_state_for_iteration( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """load_vendor resets _global_preload_done and cached_components between vendors.""" + from unittest.mock import patch as _patch + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + # First call: simulate load for vendor-a + with _patch.object(dt.graphql, "get_device_types", return_value=({}, {})): + dt.load_for_vendor("vendor-a") + + # Simulate state left behind after processing vendor-a + dt._global_preload_done = True + dt.cached_components = {"interface_templates": {("device", 1): {}}} + + # Second call: load for vendor-b + with _patch.object(dt.graphql, "get_device_types", return_value=({}, {})): + dt.load_for_vendor("vendor-b") + + assert dt._global_preload_done is False + assert dt.cached_components == {} + + def test_load_vendor_resets_module_image_details(self, mock_settings, mock_pynetbox): + """load_vendor must clear _module_image_details so stale module entries from previous vendor don't persist.""" + nb = NetBox(mock_settings, mock_settings.handle) + nb.device_types = MagicMock() + nb._module_image_details = {99: {"front": {"url": "/media/stale.png", "att_id": 5}}} + + nb.load_vendor("new-vendor") + + assert nb._module_image_details == {} + + +# --------------------------------------------------------------------------- +# Task 7: start_component_preload with manufacturer_slug +# --------------------------------------------------------------------------- + + +class TestStartComponentPreloadManufacturerSlug: + """Tests for start_component_preload(manufacturer_slug=...) (Task 7).""" + + def test_manufacturer_slug_passed_to_fetch(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """manufacturer_slug from start_component_preload is forwarded to _fetch_global_endpoint_records.""" + from unittest.mock import patch as _patch + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + received_slugs = [] + + def fake_fetch(endpoint_name, progress_callback=None, manufacturer_slug=None): + received_slugs.append(manufacturer_slug) + return [] + + with _patch.object(dt, "_fetch_global_endpoint_records", side_effect=fake_fetch): + preload_job = dt.start_component_preload(manufacturer_slug="cisco") + dt.preload_all_components(preload_job=preload_job) + + # Every endpoint should have received the vendor slug + assert all(slug == "cisco" for slug in received_slugs) + assert len(received_slugs) == len(dt._component_preload_targets()) + + def test_no_manufacturer_slug_passes_none(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """Without manufacturer_slug, None is forwarded (global fetch).""" + from unittest.mock import patch as _patch + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + received_slugs = [] + + def fake_fetch(endpoint_name, progress_callback=None, manufacturer_slug=None): + received_slugs.append(manufacturer_slug) + return [] + + with _patch.object(dt, "_fetch_global_endpoint_records", side_effect=fake_fetch): + preload_job = dt.start_component_preload() + dt.preload_all_components(preload_job=preload_job) + + assert all(slug is None for slug in received_slugs) + + +# --------------------------------------------------------------------------- +# Task 8: _verify_component_cache_integrity +# --------------------------------------------------------------------------- + + +class TestVerifyComponentCacheIntegrity: + """Tests for DeviceTypes._verify_component_cache_integrity (Task 8).""" + + def test_returns_true_when_all_records_match_dt_ids( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """Returns True when cache records have device_type_ids in vendor_dt_ids.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + # Populate cache with records whose device_type id is 1 (in vendor set) + rec = MagicMock() + dt.cached_components = { + "interface_templates": {("device", 1): {"eth0": rec}}, + } + + result = dt._verify_component_cache_integrity(vendor_dt_ids={1}, vendor_mt_ids=set()) + assert result is True + # Cache should be untouched + assert ("device", 1) in dt.cached_components["interface_templates"] + + def test_returns_false_and_clears_cache_when_no_ids_match( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """Returns False and clears offending endpoint when no record belongs to vendor.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + rec = MagicMock() + # Cache has records for device_type id 99, but vendor only owns id 1 + dt.cached_components = { + "interface_templates": {("device", 99): {"eth0": rec}}, + } + + logged = [] + dt.handle.log = lambda msg: logged.append(msg) + + result = dt._verify_component_cache_integrity(vendor_dt_ids={1}, vendor_mt_ids=set()) + assert result is False + # Cache entry must be cleared + assert dt.cached_components["interface_templates"] == {} + assert any("ERROR" in m for m in logged) + + def test_empty_endpoint_cache_not_flagged(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """Empty endpoint cache entries are skipped — no false positives.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + dt.cached_components = { + "interface_templates": {}, + } + + result = dt._verify_component_cache_integrity(vendor_dt_ids={1}, vendor_mt_ids=set()) + assert result is True + + def test_module_type_records_pass_when_mt_ids_match( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """Module-type records pass when module_type_id is in vendor_mt_ids.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + rec = MagicMock() + dt.cached_components = { + "interface_templates": {("module", 5): {"mod-iface": rec}}, + } + + # vendor_mt_ids contains 5 + result = dt._verify_component_cache_integrity(vendor_dt_ids=set(), vendor_mt_ids={5}) + assert result is True + + def test_mixed_valid_invalid_records_pass(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """If at least one record matches, the endpoint is considered valid.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + rec = MagicMock() + dt.cached_components = { + "interface_templates": { + ("device", 1): {"eth0": rec}, # valid + ("device", 99): {"eth1": rec}, # foreign but that's OK if 1 is valid + }, + } + + result = dt._verify_component_cache_integrity(vendor_dt_ids={1}, vendor_mt_ids=set()) + assert result is True + + +# --------------------------------------------------------------------------- +# Issue 2: preload_all_components integrity check handles get_module_types errors +# --------------------------------------------------------------------------- + + +class TestPreloadAllComponentsIntegrityCheckError: + """preload_all_components skips integrity check gracefully when get_module_types raises.""" + + def test_get_module_types_raises_does_not_propagate( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """When get_module_types raises, no exception propagates and _global_preload_done is True.""" + from unittest.mock import patch as _patch + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + with _patch.object(dt, "_preload_global"): + with _patch.object(dt.graphql, "get_module_types", side_effect=RuntimeError("network error")): + # Should not raise + dt.preload_all_components(manufacturer_slug="cisco") + + assert dt._global_preload_done is True + # Warning was logged + dt.handle.log.assert_called() + warning_call = dt.handle.log.call_args_list[-1][0][0] + assert "WARNING" in warning_call + + +# --------------------------------------------------------------------------- +# _check_component_counts_against_rest and _rest_count_chunked +# --------------------------------------------------------------------------- + + +class TestCheckComponentCountsAgainstRest: + """Tests for DeviceTypes._check_component_counts_against_rest.""" + + def _make_dt_with_cache(self, make_device_types, mock_nb_api, cached_components): + dt = make_device_types(nb_api=mock_nb_api) + dt.cached_components = cached_components + return dt + + def test_matching_counts_does_not_raise(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """No exception when cached count equals REST count for all endpoints.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = self._make_dt_with_cache( + make_device_types, + mock_nb_api, + { + "interface_templates": {("device", 1): {"eth0": MagicMock(), "eth1": MagicMock()}}, + "power_port_templates": {("device", 1): {"pwr": MagicMock()}}, + }, + ) + + # REST always returns matching count + mock_nb_api.dcim.interface_templates.count.return_value = 2 + mock_nb_api.dcim.power_port_templates.count.return_value = 1 + + # Should not raise — all other endpoints return 0 from both REST and cache + def zero_count(**kwargs): + return 0 + + for endpoint_name, _ in dt._component_preload_targets(): + ep = getattr(mock_nb_api.dcim, endpoint_name) + ep.count.side_effect = zero_count + + mock_nb_api.dcim.interface_templates.count.side_effect = None + mock_nb_api.dcim.interface_templates.count.return_value = 2 + mock_nb_api.dcim.power_port_templates.count.side_effect = None + mock_nb_api.dcim.power_port_templates.count.return_value = 1 + + dt._check_component_counts_against_rest(vendor_dt_ids={1}, vendor_mt_ids=set()) + + def test_mismatch_raises_graphql_count_mismatch_error( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """GraphQLCountMismatchError raised when REST count differs from cached count.""" + from core.graphql_client import GraphQLCountMismatchError + + mock_nb_api = mock_pynetbox.api.return_value + dt = self._make_dt_with_cache( + make_device_types, + mock_nb_api, + { + "interface_templates": {("device", 1): {"eth0": MagicMock()}}, # 1 cached + }, + ) + + # Make all endpoints return 0 by default, then override interface_templates + for endpoint_name, _ in dt._component_preload_targets(): + getattr(mock_nb_api.dcim, endpoint_name).count.return_value = 0 + + # REST says 5 for interface_templates, cache has 1 → mismatch + mock_nb_api.dcim.interface_templates.count.return_value = 5 + + with pytest.raises(GraphQLCountMismatchError, match="interface_templates"): + dt._check_component_counts_against_rest(vendor_dt_ids={1}, vendor_mt_ids=set()) + + def test_empty_dt_ids_skips_dt_count_call(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """When vendor_dt_ids is empty, REST count is not called for device-type path.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = self._make_dt_with_cache(make_device_types, mock_nb_api, {}) + + for endpoint_name, _ in dt._component_preload_targets(): + getattr(mock_nb_api.dcim, endpoint_name).count.return_value = 0 + + dt._check_component_counts_against_rest(vendor_dt_ids=set(), vendor_mt_ids=set()) + + # count() should never be called when both ID sets are empty + for endpoint_name, _ in dt._component_preload_targets(): + getattr(mock_nb_api.dcim, endpoint_name).count.assert_not_called() + + def test_device_bay_templates_skips_module_type_path( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """device_bay_templates never queries the module-type count path.""" + from core.compat import module_type_filter_key + + mock_nb_api = mock_pynetbox.api.return_value + dt = self._make_dt_with_cache(make_device_types, mock_nb_api, {}) + + for endpoint_name, _ in dt._component_preload_targets(): + getattr(mock_nb_api.dcim, endpoint_name).count.return_value = 0 + + dt._check_component_counts_against_rest(vendor_dt_ids={1}, vendor_mt_ids={5}) + + # device_bay_templates must only be called once (device path), not twice + calls = mock_nb_api.dcim.device_bay_templates.count.call_args_list + mt_filter_key = module_type_filter_key(dt.new_filters) + mt_calls = [c for c in calls if mt_filter_key in (c.kwargs or {})] + assert mt_calls == [], "device_bay_templates should not be queried with module_type filter" + + def test_rest_only_endpoints_skipped(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """Endpoints in REST_ONLY_ENDPOINTS are skipped (no tautological REST-vs-REST check).""" + mock_nb_api = mock_pynetbox.api.return_value + dt = self._make_dt_with_cache(make_device_types, mock_nb_api, {}) + + # Override REST_ONLY_ENDPOINTS to include interface_templates for this test + dt.REST_ONLY_ENDPOINTS = frozenset({"interface_templates"}) + + for endpoint_name, _ in dt._component_preload_targets(): + getattr(mock_nb_api.dcim, endpoint_name).count.return_value = 0 + + dt._check_component_counts_against_rest(vendor_dt_ids={1}, vendor_mt_ids=set()) + + mock_nb_api.dcim.interface_templates.count.assert_not_called() + + +class TestRestCountChunked: + """Tests for DeviceTypes._rest_count_chunked.""" + + def test_single_chunk(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """IDs that fit in one chunk make exactly one REST call.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + ep = MagicMock() + ep.count.return_value = 5 + + result = dt._rest_count_chunked(ep, "device_type_id", [1, 2, 3]) + + assert result == 5 + ep.count.assert_called_once_with(device_type_id=[1, 2, 3]) + + def test_multiple_chunks_summed(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """IDs exceeding chunk_size are split into multiple calls whose counts are summed.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + ep = MagicMock() + ep.count.side_effect = [10, 7] # two chunks return 10 and 7 + + ids = list(range(150)) # 150 IDs, default chunk_size=100 → 2 chunks + result = dt._rest_count_chunked(ep, "device_type_id", ids, chunk_size=100) + + assert result == 17 + assert ep.count.call_count == 2 + ep.count.assert_any_call(device_type_id=ids[:100]) + ep.count.assert_any_call(device_type_id=ids[100:]) + + def test_empty_ids_returns_zero(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """Empty ID list returns 0 without making any REST calls.""" + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + ep = MagicMock() + + result = dt._rest_count_chunked(ep, "device_type_id", []) + assert result == 0 + ep.count.assert_not_called() + + +class TestPreloadAllComponentsCountCheck: + """preload_all_components calls count check and propagates GraphQLCountMismatchError.""" + + def test_count_check_called_when_manufacturer_slug_set( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """_check_component_counts_against_rest is called when manufacturer_slug is given.""" + from unittest.mock import patch as _patch + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + with _patch.object(dt, "_preload_global"): + with _patch.object(dt.graphql, "get_module_types", return_value={}): + with _patch.object(dt, "_check_component_counts_against_rest") as mock_check: + dt.preload_all_components(manufacturer_slug="cisco") + + mock_check.assert_called_once() + + def test_count_mismatch_propagates(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + """GraphQLCountMismatchError from count check is not swallowed.""" + from unittest.mock import patch as _patch + from core.graphql_client import GraphQLCountMismatchError + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + with _patch.object(dt, "_preload_global"): + with _patch.object(dt.graphql, "get_module_types", return_value={}): + with _patch.object( + dt, + "_check_component_counts_against_rest", + side_effect=GraphQLCountMismatchError("mismatch"), + ): + with pytest.raises(GraphQLCountMismatchError): + dt.preload_all_components(manufacturer_slug="cisco") + + def test_count_check_not_called_without_manufacturer_slug( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types + ): + """When no manufacturer_slug, count check is not called (global preload).""" + from unittest.mock import patch as _patch + + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + + with _patch.object(dt, "_preload_global"): + with _patch.object(dt, "_check_component_counts_against_rest") as mock_check: + dt.preload_all_components() + + mock_check.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests for _check_image_url and _is_image_hash_changed helpers +# --------------------------------------------------------------------------- + + +class TestCheckImageUrl: + """Unit tests for the module-level _check_image_url helper.""" + + def _image_resp(self, ok=True, content_type="image/png"): + mock_resp = MagicMock() + mock_resp.ok = ok + mock_resp.headers = {"Content-Type": content_type} + return mock_resp + + def test_returns_missing_when_response_not_ok(self): + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=False)): + assert _check_image_url("http://nb", "/media/front.png", False) == "missing" + + def test_returns_ok_when_response_ok_and_image_content_type(self): + """2xx with image/* Content-Type is 'ok'.""" + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=True, content_type="image/png")): + assert _check_image_url("http://nb", "/media/front.png", False) == "ok" + + def test_returns_missing_when_ok_but_html_content_type(self): + """2xx with text/html means a login-redirect / missing-file error page → 'missing'.""" + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=True, content_type="text/html; charset=utf-8")): + assert _check_image_url("http://nb", "/media/front.png", False) == "missing" + + def test_returns_missing_when_ok_but_json_content_type(self): + """2xx with application/json (API error body) → 'missing'.""" + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=True, content_type="application/json")): + assert _check_image_url("http://nb", "/media/front.png", False) == "missing" + + def test_returns_ok_on_network_error(self): + from core.netbox_api import _check_image_url + import requests as _req + + with patch("requests.get", side_effect=_req.RequestException("timeout")): + assert _check_image_url("http://nb", "/media/front.png", False) == "ok" + + def test_uses_full_url_when_image_url_is_absolute(self): + """If image_url_path starts with 'http', base_url is not prepended.""" + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=False)) as mock_get: + _check_image_url("http://nb", "http://other/media/front.png", False) + mock_get.assert_called_once_with("http://other/media/front.png", headers={}, verify=True, timeout=30) + + def test_prepends_base_url_for_relative_path(self): + """Relative image_url_path has base_url prepended before the GET.""" + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=True)) as mock_get: + _check_image_url("http://nb", "/media/front.png", False) + mock_get.assert_called_once_with("http://nb/media/front.png", headers={}, verify=True, timeout=30) + + def test_sends_token_auth_header_when_token_provided(self): + """Standard token → Authorization: Token .""" + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=True)) as mock_get: + _check_image_url("http://nb", "/media/front.png", False, token="mytoken") + mock_get.assert_called_once_with( + "http://nb/media/front.png", + headers={"Authorization": "Token mytoken"}, + verify=True, + timeout=30, + ) + + def test_sends_bearer_auth_for_nbt_token(self): + """nbt_… token → Authorization: Bearer (not Token).""" + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=True)) as mock_get: + _check_image_url("http://nb", "/media/front.png", False, token="nbt_abc123") + auth = mock_get.call_args.kwargs["headers"]["Authorization"] + assert auth == "Bearer nbt_abc123" + + def test_no_auth_header_for_off_host_url(self): + """When image URL is on a different host, auth must NOT be sent.""" + from core.netbox_api import _check_image_url + + with patch("requests.get", return_value=self._image_resp(ok=False)) as mock_get: + _check_image_url("http://nb", "http://cdn.other.example/img.png", False, token="secret") + assert "Authorization" not in mock_get.call_args.kwargs.get("headers", {}) + + +class TestIsImageHashChanged: + """Unit tests for the _is_image_hash_changed helper.""" + + def test_returns_false_when_no_cache_entry(self, tmp_path): + from core.netbox_api import _is_image_hash_changed + + img = tmp_path / "front.png" + img.write_bytes(b"content") + assert _is_image_hash_changed(str(img), {}) is False + + def test_returns_false_when_hash_matches(self, tmp_path): + import hashlib + from core.netbox_api import _is_image_hash_changed + + data = b"unchanged_content" + img = tmp_path / "front.png" + img.write_bytes(data) + cache = {str(img): hashlib.sha256(data).hexdigest()} + assert _is_image_hash_changed(str(img), cache) is False + + def test_returns_true_when_hash_differs(self, tmp_path): + import hashlib + from core.netbox_api import _is_image_hash_changed + + img = tmp_path / "front.png" + img.write_bytes(b"new_content") + cache = {str(img): hashlib.sha256(b"old_content").hexdigest()} + assert _is_image_hash_changed(str(img), cache) is True + + def test_returns_false_on_file_read_error(self, tmp_path): + import hashlib + from core.netbox_api import _is_image_hash_changed + + path = str(tmp_path / "missing.png") + cache = {path: hashlib.sha256(b"x").hexdigest()} + assert _is_image_hash_changed(path, cache) is False + + +# --------------------------------------------------------------------------- +# Tests for verify_images flag in _process_existing_device_type +# --------------------------------------------------------------------------- + + +class TestVerifyImagesDeviceType: + """Integration tests for --verify-images behaviour in _process_existing_device_type.""" + + def _make_nb(self, mock_settings, mock_pynetbox, graphql_client, make_device_types): + mock_nb_api = mock_pynetbox.api.return_value + dt = make_device_types(nb_api=mock_nb_api) + nb = NetBox(mock_settings, mock_settings.handle) + nb.device_types = dt + return nb + + def test_verify_images_reuploads_missing_image( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types, tmp_path + ): + """When verify_images=True and image is missing on server, upload_images is called.""" + nb = self._make_nb(mock_settings, mock_pynetbox, graphql_client, make_device_types) + nb.verify_images = True + nb.device_types.upload_images = MagicMock() + + existing_dt = MagicMock() + existing_dt.id = 1 + existing_dt.model = "Router" + existing_dt.manufacturer.name = "Cisco" + existing_dt.front_image = "/media/router.front.png" + nb.device_types.existing_device_types = {("cisco", "Router"): existing_dt} + nb.device_types.existing_device_types_by_slug = {} + + dev_types_dir = tmp_path / "device-types" / "cisco" + dev_types_dir.mkdir(parents=True) + elevation_dir = tmp_path / "elevation-images" / "cisco" + elevation_dir.mkdir(parents=True) + img = elevation_dir / "router.front.png" + img.write_bytes(b"imgdata") + + device_type = { + "manufacturer": {"slug": "cisco"}, + "model": "Router", + "slug": "router", + "front_image": True, + "src": str(dev_types_dir / "router.yaml"), + } + + mock_resp = MagicMock() + mock_resp.ok = False # image missing on server → "missing" + with patch("glob.glob", return_value=[str(img)]): + with patch("requests.get", return_value=mock_resp): + nb.create_device_types([device_type]) + + nb.device_types.upload_images.assert_called_once() + + def test_verify_images_skips_ok_image( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types, tmp_path + ): + """When verify_images=True and image is accessible and hash unchanged, upload is skipped.""" + import hashlib + + nb = self._make_nb(mock_settings, mock_pynetbox, graphql_client, make_device_types) + nb.verify_images = True + nb.device_types.upload_images = MagicMock() + + img_data = b"imgdata" + existing_dt = MagicMock() + existing_dt.id = 1 + existing_dt.model = "Switch" + existing_dt.manufacturer.name = "Cisco" + existing_dt.front_image = "/media/switch.front.png" + nb.device_types.existing_device_types = {("cisco", "Switch"): existing_dt} + nb.device_types.existing_device_types_by_slug = {} + + dev_types_dir = tmp_path / "device-types" / "cisco" + dev_types_dir.mkdir(parents=True) + elevation_dir = tmp_path / "elevation-images" / "cisco" + elevation_dir.mkdir(parents=True) + img = elevation_dir / "switch.front.png" + img.write_bytes(img_data) + # Pre-populate cache so hash check reports "unchanged" + nb._image_hash_cache[str(img)] = hashlib.sha256(img_data).hexdigest() + + device_type = { + "manufacturer": {"slug": "cisco"}, + "model": "Switch", + "slug": "switch", + "front_image": True, + "src": str(dev_types_dir / "switch.yaml"), + } + + mock_resp = MagicMock() + mock_resp.ok = True # image accessible on server + mock_resp.headers = {"Content-Type": "image/png"} + with patch("glob.glob", return_value=[str(img)]): + with patch("requests.get", return_value=mock_resp): + nb.create_device_types([device_type]) + + nb.device_types.upload_images.assert_not_called() + + def test_default_mode_still_skips_without_http( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types, tmp_path + ): + """When verify_images=False (default), existing images are skipped without HTTP check.""" + nb = self._make_nb(mock_settings, mock_pynetbox, graphql_client, make_device_types) + # verify_images defaults to False + nb.device_types.upload_images = MagicMock() + + existing_dt = MagicMock() + existing_dt.id = 1 + existing_dt.model = "AP" + existing_dt.manufacturer.name = "Aruba" + existing_dt.front_image = "/media/ap.front.png" + nb.device_types.existing_device_types = {("aruba", "AP"): existing_dt} + nb.device_types.existing_device_types_by_slug = {} + + dev_types_dir = tmp_path / "device-types" / "aruba" + dev_types_dir.mkdir(parents=True) + elevation_dir = tmp_path / "elevation-images" / "aruba" + elevation_dir.mkdir(parents=True) + img = elevation_dir / "ap.front.png" + img.write_bytes(b"imgdata") + + device_type = { + "manufacturer": {"slug": "aruba"}, + "model": "AP", + "slug": "ap", + "front_image": True, + "src": str(dev_types_dir / "ap.yaml"), + } + + with patch("glob.glob", return_value=[str(img)]): + with patch("requests.get") as mock_get: + nb.create_device_types([device_type]) + + mock_get.assert_not_called() + nb.device_types.upload_images.assert_not_called() + + def test_verify_ok_seeds_hash_cache( + self, mock_settings, mock_pynetbox, graphql_client, make_device_types, tmp_path + ): + """When verify-images reports OK, the local image hash must be written to cache.""" + import hashlib + + nb = self._make_nb(mock_settings, mock_pynetbox, graphql_client, make_device_types) + nb.verify_images = True + nb.device_types.upload_images = MagicMock() + + img_data = b"unchanged" + existing_dt = MagicMock() + existing_dt.id = 2 + existing_dt.model = "Firewall" + existing_dt.manufacturer.name = "Palo Alto Networks" + existing_dt.front_image = "/media/fw.front.png" + nb.device_types.existing_device_types = {("palo-alto-networks", "Firewall"): existing_dt} + nb.device_types.existing_device_types_by_slug = {} + + dev_types_dir = tmp_path / "device-types" / "palo-alto-networks" + dev_types_dir.mkdir(parents=True) + elevation_dir = tmp_path / "elevation-images" / "palo-alto-networks" + elevation_dir.mkdir(parents=True) + img = elevation_dir / "palo-alto-networks-firewall.front.png" + img.write_bytes(img_data) + # Cache is empty — this is the first verify after a devcontainer rebuild + nb._image_hash_cache = {} + + device_type = { + "manufacturer": {"slug": "palo-alto-networks"}, + "model": "Firewall", + "slug": "palo-alto-networks-firewall", + "front_image": True, + "src": str(dev_types_dir / "firewall.yaml"), + } + + mock_resp = MagicMock() + mock_resp.ok = True + mock_resp.headers = {"Content-Type": "image/png"} + with patch("glob.glob", return_value=[str(img)]): + with patch("requests.get", return_value=mock_resp): + with patch("core.netbox_api._save_image_hash_cache") as mock_save: + nb.create_device_types([device_type]) + + # Hash cache must have been seeded with the local file's hash + assert str(img) in nb._image_hash_cache + assert nb._image_hash_cache[str(img)] == hashlib.sha256(img_data).hexdigest() + mock_save.assert_called() + nb.device_types.upload_images.assert_not_called() + + +class TestNetBoxImageHelperFunctions: + """Tests for standalone image helper functions in netbox_api.""" + + def test_save_image_hash_cache_ignores_open_errors(self, tmp_path): + from core.netbox_api import _save_image_hash_cache + + with patch("core.netbox_api.tempfile.mkstemp", side_effect=OSError("disk full")): + result = _save_image_hash_cache(str(tmp_path / "cache.json"), {"a": "b"}) + + assert result is False + + def test_store_image_hashes_skips_unreadable_files(self, tmp_path): + from core.netbox_api import _store_image_hashes + + good = tmp_path / "good.png" + good.write_bytes(b"img") + missing = tmp_path / "missing.png" + cache = {} + + _store_image_hashes(cache, {"good": str(good), "missing": str(missing)}) + + assert str(good) in cache + assert str(missing) not in cache + + def test_delete_image_attachment_success(self, mock_settings): + from core.netbox_api import _delete_image_attachment + + response = MagicMock() + response.raise_for_status = MagicMock() + with patch("core.netbox_api.requests.delete", return_value=response) as mock_delete: + result = _delete_image_attachment("http://nb", mock_settings.NETBOX_TOKEN, 12, False, mock_settings.handle) + assert result is True + + mock_delete.assert_called_once() + + def test_delete_image_attachment_logs_request_errors(self, mock_settings): + import requests + from core.netbox_api import _delete_image_attachment + + with patch("core.netbox_api.requests.delete", side_effect=requests.RequestException("boom")): + result = _delete_image_attachment("http://nb", mock_settings.NETBOX_TOKEN, 12, False, mock_settings.handle) + assert result is False + + assert any( + "Error deleting image attachment 12" in str(call) for call in mock_settings.handle.log.call_args_list + ) + + def test_load_module_type_properties_falls_back_on_import_error(self): + from core.netbox_api import _MODULE_TYPE_PROPERTIES_FALLBACK, _load_module_type_properties + + _load_module_type_properties.cache_clear() + with patch("core.netbox_api.load_properties_for_type", side_effect=ImportError("no settings")): + assert _load_module_type_properties() == list(_MODULE_TYPE_PROPERTIES_FALLBACK) + _load_module_type_properties.cache_clear() + + def test_fmt_connection_error_contains_url_and_hint(self): + """_fmt_connection_error returns a message with the URL and a reachability hint.""" + import requests as _requests + from core.netbox_api import _fmt_connection_error + + url = "http://netbox.example.com" + exc = _requests.exceptions.ConnectionError("Remote end closed connection without response") + msg = _fmt_connection_error(url, exc) + + assert url in msg + assert "Remote end closed" in msg + + def test_fmt_connection_error_verify_compatibility_uses_it(self, mock_settings, mock_pynetbox): + """verify_compatibility uses _fmt_connection_error for ConnectionError.""" + import requests as _requests + from unittest.mock import PropertyMock + + type(mock_pynetbox.api.return_value).version = PropertyMock( + side_effect=_requests.exceptions.ConnectionError("drop") + ) + + with pytest.raises(SystemExit) as exc_info: + NetBox(mock_settings, mock_settings.handle) + + exc_msg = str(exc_info.value.args[0]) if exc_info.value.args else "" + assert mock_settings.NETBOX_URL in exc_msg and "Connection" in exc_msg + + def test_check_image_url_logs_request_exception_when_log_fn_provided(self): + """RequestException must be logged at verbose level when log_fn is given.""" + import requests as _requests + from core.netbox_api import _check_image_url + + logged = [] + with patch("core.netbox_api.requests.get", side_effect=_requests.RequestException("timed out")): + result = _check_image_url("http://nb", "/media/img.png", False, log_fn=logged.append) + + assert result == "ok" # conservative: treat as present + assert any("timed out" in m or "Network error" in m for m in logged) + + def test_check_image_url_no_log_fn_stays_silent_on_request_exception(self): + """When no log_fn provided, RequestException is swallowed silently.""" + import requests as _requests + from core.netbox_api import _check_image_url + + with patch("core.netbox_api.requests.get", side_effect=_requests.RequestException("x")): + # Must not raise + result = _check_image_url("http://nb", "/media/img.png", False) + assert result == "ok" + + """Tests for _upload_module_type_images with verify_images=True.""" + + def _make_nb(self, mock_settings, mock_pynetbox): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.device_types.upload_image_attachment = MagicMock(return_value=True) + nb.verify_images = True + return nb + + def _make_module_files(self, tmp_path): + module_dir = tmp_path / "module-types" / "vendor" + module_dir.mkdir(parents=True) + src = module_dir / "mymodule.yaml" + src.write_text("model: X") + img_dir = tmp_path / "module-images" / "vendor" + img_dir.mkdir(parents=True) + img = img_dir / "mymodule.front.jpg" + img.write_bytes(b"img") + return src, img + + def test_missing_server_image_skips_upload_when_delete_fails(self, mock_settings, mock_pynetbox, tmp_path): + nb = self._make_nb(mock_settings, mock_pynetbox) + src, _img = self._make_module_files(tmp_path) + mt_res = MagicMock(id=10, model="X") + existing_images = {10: {"mymodule.front"}} + nb._module_image_details = {10: {"mymodule.front": {"url": "/media/front.jpg", "att_id": 7}}} + + with ( + patch("core.netbox_api._check_image_url", return_value="missing"), + patch("core.netbox_api._delete_image_attachment", return_value=False), + ): + nb._upload_module_type_images(mt_res, str(src), existing_images) + + nb.device_types.upload_image_attachment.assert_not_called() + assert "mymodule.front" in existing_images[10] + + def test_missing_server_image_skips_upload_when_att_id_invalid(self, mock_settings, mock_pynetbox, tmp_path): + """Guard: att_id is None/non-int — skip upload to avoid duplicate attachments.""" + nb = self._make_nb(mock_settings, mock_pynetbox) + src, _img = self._make_module_files(tmp_path) + mt_res = MagicMock(id=10, model="X") + existing_images = {10: {"mymodule.front"}} + nb._module_image_details = {10: {"mymodule.front": {"url": "/media/front.jpg", "att_id": None}}} + + with patch("core.netbox_api._check_image_url", return_value="missing"): + nb._upload_module_type_images(mt_res, str(src), existing_images) + + nb.device_types.upload_image_attachment.assert_not_called() + + def test_changed_hash_skips_upload_when_att_id_invalid(self, mock_settings, mock_pynetbox, tmp_path): + """Guard: att_id is None when hash changed — skip upload to avoid duplicate attachments.""" + nb = self._make_nb(mock_settings, mock_pynetbox) + src, _img = self._make_module_files(tmp_path) + mt_res = MagicMock(id=10, model="X") + existing_images = {10: {"mymodule.front"}} + nb._module_image_details = {10: {"mymodule.front": {"url": "/media/front.jpg", "att_id": None}}} + + with ( + patch("core.netbox_api._check_image_url", return_value="ok"), + patch("core.netbox_api._is_image_hash_changed", return_value=True), + ): + nb._upload_module_type_images(mt_res, str(src), existing_images) + + nb.device_types.upload_image_attachment.assert_not_called() + + def test_changed_module_image_is_deleted_and_reuploaded(self, mock_settings, mock_pynetbox, tmp_path): + nb = self._make_nb(mock_settings, mock_pynetbox) + src, _img = self._make_module_files(tmp_path) + mt_res = MagicMock(id=10, model="X") + existing_images = {10: {"mymodule.front"}} + nb._module_image_details = {10: {"mymodule.front": {"url": "/media/front.jpg", "att_id": 7}}} + + with ( + patch("core.netbox_api._check_image_url", return_value="ok"), + patch("core.netbox_api._is_image_hash_changed", return_value=True), + patch("core.netbox_api._delete_image_attachment", return_value=True), + patch("core.netbox_api._save_image_hash_cache"), + ): + nb._upload_module_type_images(mt_res, str(src), existing_images) + + nb.device_types.upload_image_attachment.assert_called_once() + assert "mymodule.front" in existing_images[10] + + def test_verified_module_image_seeds_cache_and_skips_upload(self, mock_settings, mock_pynetbox, tmp_path): + import hashlib + + nb = self._make_nb(mock_settings, mock_pynetbox) + src, img = self._make_module_files(tmp_path) + mt_res = MagicMock(id=10, model="X") + existing_images = {10: {"mymodule.front"}} + nb._module_image_details = {10: {"mymodule.front": {"url": "/media/front.jpg", "att_id": 7}}} + nb._image_hash_cache = {} + + with ( + patch("core.netbox_api._check_image_url", return_value="ok"), + patch("core.netbox_api._is_image_hash_changed", return_value=False), + patch("core.netbox_api._save_image_hash_cache") as mock_save, + ): + nb._upload_module_type_images(mt_res, str(src), existing_images) + + nb.device_types.upload_image_attachment.assert_not_called() + assert nb._image_hash_cache[str(img)] == hashlib.sha256(b"img").hexdigest() + mock_save.assert_called_once() + + def test_missing_attachment_detail_skips_upload_to_avoid_duplicates(self, mock_settings, mock_pynetbox, tmp_path): + """When detail is unavailable (verify_images=True), skip upload to avoid creating duplicates.""" + nb = self._make_nb(mock_settings, mock_pynetbox) + src, _img = self._make_module_files(tmp_path) + mt_res = MagicMock(id=10, model="X") + existing_images = {10: {"mymodule.front"}} + nb._module_image_details = {10: {}} + + with patch("core.netbox_api._save_image_hash_cache"): + nb._upload_module_type_images(mt_res, str(src), existing_images) + + nb.device_types.upload_image_attachment.assert_not_called() + assert any("detail is unavailable" in str(c) for c in mock_settings.handle.verbose_log.call_args_list) + + +class TestAdditionalNetBoxCoverage: + """Focused tests for uncovered NetBox branches.""" + + def test_load_image_hash_cache_returns_empty_dict_on_bad_json(self, tmp_path): + from core.netbox_api import _load_image_hash_cache + + cache_file = tmp_path / "image-hashes.json" + cache_file.write_text("{not-json", encoding="utf-8") + + assert _load_image_hash_cache(str(cache_file)) == {} + + def test_init_exits_on_graphql_error_from_get_manufacturers(self, mock_settings, mock_pynetbox): + from core.graphql_client import GraphQLError + + mock_pynetbox.api.return_value.version = "3.5" + + with patch.object(NetBox, "get_manufacturers", side_effect=GraphQLError("bad query")): + with pytest.raises(SystemExit, match="GraphQL error: bad query"): + NetBox(mock_settings, mock_settings.handle) + + def test_init_exits_when_device_types_initialization_fails(self, mock_settings, mock_pynetbox): + mock_pynetbox.api.return_value.version = "3.5" + + with ( + patch.object(NetBox, "get_manufacturers", return_value=[]), + patch("core.netbox_api.DeviceTypes", side_effect=RuntimeError("boom")), + ): + with pytest.raises(SystemExit, match="Error initializing device types: boom"): + NetBox(mock_settings, mock_settings.handle) + + def test_verify_compatibility_exits_on_proxy_error(self, mock_settings): + import requests + + class BadAPI: + @property + def version(self): + raise requests.exceptions.ProxyError("proxy down") + + nb = NetBox.__new__(NetBox) + nb.url = mock_settings.NETBOX_URL + nb.handle = mock_settings.handle + nb.netbox = BadAPI() + + with pytest.raises(SystemExit, match="Proxy error while connecting to NetBox"): + nb.verify_compatibility() + + def test_verify_compatibility_formats_request_error_details(self, mock_settings): + import pynetbox as real_pynb + + request = MagicMock(status_code=502, reason="Bad Gateway") + error = real_pynb.RequestError(MagicMock(status_code=502, content=b'{"detail":"proxy fail"}')) + error.base = "http://upstream" + error.req = request + error.error = "proxy fail" + + class BadAPI: + @property + def version(self): + raise error + + nb = NetBox.__new__(NetBox) + nb.url = mock_settings.NETBOX_URL + nb.handle = mock_settings.handle + nb.netbox = BadAPI() + + with pytest.raises(SystemExit) as exc_info: + nb.verify_compatibility() + + message = str(exc_info.value) + assert "http://upstream" in message + assert "HTTP 502 Bad Gateway" in message + assert "proxy fail" in message + assert "not blocked by a proxy" in message + + def test_create_manufacturers_logs_retryable_exception(self, mock_settings, mock_pynetbox): + import pynetbox as real_pynb + import requests + + mock_pynetbox.RequestError = real_pynb.RequestError + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.netbox.dcim.manufacturers.create.side_effect = requests.exceptions.ConnectionError("offline") + + with patch("core.netbox_api.time.sleep"): + nb.create_manufacturers([{"name": "Cisco", "slug": "cisco"}]) + + assert any("Connection error creating manufacturers" in str(c) for c in mock_settings.handle.log.call_args_list) + + def test_try_resolve_update_logs_classifier_exception(self, mock_settings, mock_pynetbox): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + dt = MagicMock(id=1, model="Model-1") + + with patch("core.netbox_api.classify_device_type_update_failure", side_effect=RuntimeError("boom")): + ok, resolution = nb._try_resolve_and_retry_device_type_update(dt, {}, {"slug": "m1"}, MagicMock()) + + assert ok is False + assert resolution is None + assert any( + "Failure classifier raised RuntimeError: boom" in str(c) + for c in mock_settings.handle.verbose_log.call_args_list + ) + + def test_try_resolve_update_truncates_blocker_list(self, mock_settings, mock_pynetbox): + from types import SimpleNamespace + from core.update_failure_resolver import FailureKind + + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + dt = MagicMock(id=1, model="Model-1") + resolution = SimpleNamespace( + kind=FailureKind.MANUAL_REQUIRED, + is_actionable=False, + blocking_objects=[f"obj-{i}" for i in range(11)], + description="blocked", + hint="resolve manually", + remediation_steps=[], + ) + + with patch("core.netbox_api.classify_device_type_update_failure", return_value=resolution): + ok, returned = nb._try_resolve_and_retry_device_type_update(dt, {}, {"slug": "m1"}, MagicMock()) + + assert ok is False + assert returned is resolution + assert any("… (+1 more)" in str(c) for c in mock_settings.handle.log.call_args_list) + + def test_try_resolve_update_logs_auto_resolve_failure(self, mock_settings, mock_pynetbox): + from types import SimpleNamespace + from core.update_failure_resolver import FailureKind + + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.force_resolve_conflicts = True + dt = MagicMock(id=1, model="Model-1") + + def boom(): + raise RuntimeError("step failed") + + resolution = SimpleNamespace( + kind=FailureKind.SUBDEVICE_ROLE_FLIP, + is_actionable=True, + blocking_objects=[], + description="blocked", + hint="retry", + remediation_steps=[boom], + ) + + with patch("core.netbox_api.classify_device_type_update_failure", return_value=resolution): + ok, returned = nb._try_resolve_and_retry_device_type_update(dt, {}, {"slug": "m1"}, MagicMock()) + + assert ok is False + assert returned is resolution + assert any( + "Auto-resolve failed for Model-1: step failed" in str(c) for c in mock_settings.handle.log.call_args_list + ) + + def test_try_resolve_update_logs_retryable_exception_after_auto_resolve(self, mock_settings, mock_pynetbox): + import pynetbox as real_pynb + import requests + from types import SimpleNamespace + from core.update_failure_resolver import FailureKind + + mock_pynetbox.RequestError = real_pynb.RequestError + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.force_resolve_conflicts = True + dt = MagicMock(id=1, model="Model-1") + nb.netbox.dcim.device_types.update.side_effect = requests.exceptions.ConnectionError("offline") + resolution = SimpleNamespace( + kind=FailureKind.SUBDEVICE_ROLE_FLIP, + is_actionable=True, + blocking_objects=[], + description="blocked", + hint="retry", + remediation_steps=[MagicMock()], + ) + + with ( + patch("core.netbox_api.classify_device_type_update_failure", return_value=resolution), + patch("core.netbox_api.time.sleep"), + ): + ok, returned = nb._try_resolve_and_retry_device_type_update(dt, {}, {"slug": "m1"}, MagicMock()) + + assert ok is False + assert returned is resolution + assert any( + "Connection error during retry after auto-resolve" in str(c) + for c in mock_settings.handle.log.call_args_list + ) + + def test_log_device_type_change_outcome_partial_success_mentions_property_failure( + self, mock_settings, mock_pynetbox + ): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + dt = MagicMock(id=1, model="Model-1") + dt.manufacturer.name = "Cisco" + + nb._log_device_type_change_outcome( + dt, + MagicMock(), + property_attempted=True, + property_succeeded=False, + component_delta=2, + actionable_count=2, + ) + + assert any("Property PATCH failed" in str(c) for c in mock_settings.handle.verbose_log.call_args_list) + + def test_log_device_type_change_outcome_logs_cached_when_nothing_happened(self, mock_settings, mock_pynetbox): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + dt = MagicMock(id=1, model="Model-1") + dt.manufacturer.name = "Cisco" + + nb._log_device_type_change_outcome( + dt, + MagicMock(), + property_attempted=False, + property_succeeded=False, + component_delta=0, + actionable_count=0, + ) + + assert any("Device Type Cached" in str(c) for c in mock_settings.handle.verbose_log.call_args_list) + + def test_filter_images_for_upload_keeps_changed_image(self, mock_settings, mock_pynetbox, tmp_path): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.verify_images = True + + image_path = tmp_path / "router.front.png" + image_path.write_bytes(b"img") + dt = MagicMock(model="Router", front_image="/media/router.front.png") + saved_images = {"front_image": str(image_path)} + + with ( + patch("core.netbox_api._check_image_url", return_value="ok"), + patch("core.netbox_api._is_image_hash_changed", return_value=True), + ): + nb._filter_images_for_upload(dt, saved_images) + + assert "front_image" in saved_images + assert any("content has changed" in str(c) for c in mock_settings.handle.verbose_log.call_args_list) + + def test_handle_existing_device_type_logs_retryable_property_update_error(self, mock_settings, mock_pynetbox): + import pynetbox as real_pynb + import requests + from core.change_detector import PropertyChange + + mock_pynetbox.RequestError = real_pynb.RequestError + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.netbox.dcim.device_types.update.side_effect = requests.exceptions.ConnectionError("offline") + nb._log_device_type_change_outcome = MagicMock() + + dt = MagicMock(id=1, model="Router") + dt.manufacturer.name = "Cisco" + dt_change = MagicMock( + property_changes=[PropertyChange("part_number", "OLD", "NEW")], + component_changes=[], + ) + + with patch("core.netbox_api.time.sleep"): + nb._handle_existing_device_type(dt, {}, "cisco", {}, False, dt_change, False) + + assert any( + "Connection error updating device type Router" in str(c) for c in mock_settings.handle.log.call_args_list + ) + nb._log_device_type_change_outcome.assert_called_once() + + def test_create_new_device_type_logs_retryable_error(self, mock_settings, mock_pynetbox): + import pynetbox as real_pynb + import requests + + mock_pynetbox.RequestError = real_pynb.RequestError + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.netbox.dcim.device_types.create.side_effect = requests.exceptions.ConnectionError("offline") + + with patch("core.netbox_api.time.sleep"): + dt, should_continue = nb._create_new_device_type( + {"manufacturer": {"slug": "cisco"}, "model": "Router"}, + "router.yaml", + ) + + assert dt is None + assert should_continue is True + assert any( + "Connection error creating device type cisco Router" in str(c) + for c in mock_settings.handle.log.call_args_list + ) + + def test_log_module_property_diffs_emits_added_changed_and_removed_components(self, mock_settings, mock_pynetbox): + from core.change_detector import ChangeType, ComponentChange, PropertyChange + + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + changes = [ + ComponentChange("interfaces", "xe-0", ChangeType.COMPONENT_ADDED), + ComponentChange( + "interfaces", + "xe-1", + ChangeType.COMPONENT_CHANGED, + [PropertyChange("label", "old", "new")], + ), + ComponentChange("interfaces", "xe-2", ChangeType.COMPONENT_REMOVED), + ] + + nb._log_module_property_diffs("cisco", "LC", [], changes) + + logs = [str(c) for c in mock_settings.handle.verbose_log.call_args_list] + assert any("+ 1 new component(s)" in c for c in logs) + assert any("~ 1 changed component(s)" in c for c in logs) + assert any("- 1 component(s) present in NetBox but absent from YAML" in c for c in logs) + + def test_filter_actionable_module_types_marks_verify_images_module_actionable(self, mock_settings, mock_pynetbox): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.verify_images = True + nb.device_types._global_preload_done = True + nb._fetch_module_type_existing_images = MagicMock(return_value={42: {"linecard.front"}}) + nb._discover_module_image_files = MagicMock(return_value=["/repo/module-images/linecard.front.png"]) + nb.change_detector._compare_components = MagicMock(return_value=[]) + + existing_module = MagicMock(id=42, model="LC") + existing_module.manufacturer.slug = "cisco" + all_module_types = {"cisco": {"LC": existing_module}} + module_type = {"manufacturer": {"slug": "cisco"}, "model": "LC", "src": "linecard.yaml"} + + actionable, _, _ = nb.filter_actionable_module_types([module_type], all_module_types, only_new=False) + + assert actionable == [module_type] + + def test_fetch_module_type_existing_images_uses_detailed_query_in_verify_mode(self, mock_settings, mock_pynetbox): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.verify_images = True + details = {7: {"linecard.front": {"att_id": 5, "url": "/media/linecard.front.jpg"}}} + nb.graphql.get_module_type_image_details = MagicMock(return_value=details) + + result = nb._fetch_module_type_existing_images() + + assert result == {7: {"linecard.front"}} + assert nb._module_image_details == details + + def test_try_update_module_type_skips_missing_netbox_fields(self, mock_settings, mock_pynetbox): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + module_type_res = MagicMock(spec=["id", "manufacturer", "model"]) + module_type_res.id = 1 + module_type_res.manufacturer = MagicMock(name="Cisco") + module_type_res.manufacturer.name = "Cisco" + module_type_res.model = "LC" + + ok, updated = nb._try_update_module_type({"part_number": "NEW"}, module_type_res, "test.yaml") + + assert ok is True + assert updated is False + nb.netbox.dcim.module_types.update.assert_not_called() + + +class TestAdditionalModuleTypeCoverage: + """Focused tests for uncovered module-type branches.""" + + def test_apply_module_type_component_updates_records_failed_no_actionable_changes( + self, mock_settings, mock_pynetbox + ): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + module_type_res = MagicMock(id=7, model="LC") + module_type_res.manufacturer.name = "Cisco" + nb.change_detector._compare_components = MagicMock(return_value=[]) + + nb._apply_module_type_component_updates({}, module_type_res, False, False, patch_ok=False) + + assert nb.counter["module_update_failed"] == 1 + assert nb.outcomes.failures()[0].reason == "Scalar PATCH failed; no component changes detected." + + def test_apply_module_type_component_updates_records_failed_when_no_changes_apply( + self, mock_settings, mock_pynetbox + ): + from core.change_detector import ChangeType, ComponentChange + + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + module_type_res = MagicMock(id=7, model="LC") + module_type_res.manufacturer.name = "Cisco" + nb.device_types._global_preload_done = True + nb.change_detector._compare_components = MagicMock( + return_value=[ComponentChange("interfaces", "xe-0", ChangeType.COMPONENT_REMOVED)] + ) + + nb._apply_module_type_component_updates({}, module_type_res, False, False, patch_ok=False) + + assert nb.counter["module_update_failed"] == 1 + assert nb.outcomes.failures()[0].reason == "Scalar PATCH failed; no component changes were actionable." + + def test_apply_module_type_component_updates_marks_partial_when_properties_only_succeed( + self, mock_settings, mock_pynetbox + ): + from core.change_detector import ChangeType, ComponentChange + + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + module_type_res = MagicMock(id=7, model="LC") + module_type_res.manufacturer.name = "Cisco" + nb.device_types._global_preload_done = True + nb.device_types.update_components = MagicMock() + nb.change_detector._compare_components = MagicMock( + return_value=[ComponentChange("interfaces", "xe-0", ChangeType.COMPONENT_CHANGED)] + ) + + nb._apply_module_type_component_updates({}, module_type_res, True, False, patch_ok=True) + + assert nb.counter["module_partial_update"] == 1 + + def test_apply_module_type_component_updates_marks_partial_on_partial_component_success( + self, mock_settings, mock_pynetbox + ): + from core.change_detector import ChangeType, ComponentChange + + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + module_type_res = MagicMock(id=7, model="LC") + module_type_res.manufacturer.name = "Cisco" + nb.device_types._global_preload_done = True + + def update_some(*_args, **_kwargs): + nb.counter["components_updated"] += 1 + + nb.device_types.update_components = MagicMock(side_effect=update_some) + nb.change_detector._compare_components = MagicMock( + return_value=[ + ComponentChange("interfaces", "xe-0", ChangeType.COMPONENT_CHANGED), + ComponentChange("interfaces", "xe-1", ChangeType.COMPONENT_CHANGED), + ] + ) + + nb._apply_module_type_component_updates({}, module_type_res, False, False, patch_ok=True) + + assert nb.counter["module_partial_update"] == 1 + + +class TestAdditionalDeviceTypesCoverage: + """Focused tests for uncovered DeviceTypes branches.""" + + def test_get_device_types_delegates_to_graphql(self, make_device_types): + dt = make_device_types() + dt.graphql.get_device_types = MagicMock(return_value=({("cisco", "A"): 1}, {("cisco", "a"): 1})) + + assert dt.get_device_types() == ({("cisco", "A"): 1}, {("cisco", "a"): 1}) + + def test_start_component_preload_populates_task_registry(self, mock_pynetbox, make_device_types): + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + progress = MagicMock() + progress.add_task.side_effect = range(1, 20) + task_registry = {} + + preload_job = dt.start_component_preload(progress=progress, task_registry=task_registry) + + assert task_registry + assert preload_job["task_ids"] + dt.stop_component_preload(preload_job) + + def test_start_component_preload_shuts_down_executor_when_submit_fails(self, mock_pynetbox, make_device_types): + executor = MagicMock() + executor.submit.side_effect = RuntimeError("submit failed") + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + + with patch("core.netbox_api.concurrent.futures.ThreadPoolExecutor", return_value=executor): + with pytest.raises(RuntimeError, match="submit failed"): + dt.start_component_preload() + + executor.shutdown.assert_called_once_with(wait=False, cancel_futures=True) + + def test_stop_component_preload_swallows_progress_cleanup_errors(self): + progress = MagicMock() + progress.stop_task.side_effect = RuntimeError("gone") + preload_job = {"task_ids": {"interfaces": 1}, "owns_tasks": True} + + DeviceTypes.stop_component_preload(preload_job, progress=progress) + + def test_apply_progress_updates_ignores_zero_advance(self): + progress = MagicMock() + q = queue.Queue() + q.put(("interface_templates", 0)) + + result = DeviceTypes._apply_progress_updates(q, progress, {"interface_templates": 1}) + + assert result is False + progress.update.assert_not_called() + + def test_apply_progress_updates_rewinds_negative_advance(self): + progress = MagicMock() + progress.tasks = [MagicMock(id=1, completed=5)] + q = queue.Queue() + q.put(("interface_templates", -3)) + + result = DeviceTypes._apply_progress_updates(q, progress, {"interface_templates": 1}) + + assert result is True + progress.update.assert_called_once_with(1, completed=2) + + def test_drain_pending_rewinds_negative_progress_update(self, mock_pynetbox, make_device_types): + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + future = MagicMock() + done_states = [False, True] + future.done.side_effect = lambda: done_states.pop(0) + future.result.return_value = [] + progress = MagicMock() + progress.tasks = [MagicMock(id=1, completed=4)] + updates = queue.Queue() + updates.put(("interface_templates", -2)) + pending = {"interface_templates"} + records = {} + + dt._drain_pending( + pending, + {"interface_templates": future}, + progress, + {"interface_templates": 1}, + updates, + {}, + records, + ) + + assert progress.update.call_args_list[0].kwargs == {"completed": 2} + + def test_preload_global_creates_tasks_from_registry_when_missing(self, mock_pynetbox, make_device_types): + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + progress = MagicMock() + progress.add_task.return_value = 7 + task_registry = {} + preload_job = { + "executor": None, + "futures": {}, + "progress_updates": queue.Queue(), + "endpoint_totals": {}, + "task_ids": None, + "finished_endpoints": set(), + } + + with patch.object(dt, "_preload_track_progress", return_value={"interface_templates": []}): + dt._preload_global( + [("interface_templates", "Interfaces")], + preload_job=preload_job, + progress=progress, + task_registry=task_registry, + ) + + assert task_registry["Caching Interfaces"] == 7 + + def test_preload_module_type_components_wraps_front_ports_and_deduplicates_targets( + self, mock_pynetbox, make_device_types + ): + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + item = MagicMock() + item.module_type = MagicMock(id=5) + item.name = "FP1" + item.mappings = [] + dt.netbox.dcim.front_port_templates.filter.return_value = [item] + + dt.preload_module_type_components({5}, ["front-ports", "front-ports"]) + + cached = dt.cached_components["front_port_templates"][("module", 5)]["FP1"] + assert isinstance(cached, _FrontPortRecordWithMappings) + dt.netbox.dcim.front_port_templates.filter.assert_called_once() + + def test_create_generic_logs_retryable_exception(self, mock_pynetbox, make_device_types, mock_settings): + import pynetbox as real_pynb + import requests + + mock_pynetbox.RequestError = real_pynb.RequestError + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + dt.cached_components = {"interface_templates": {("device", 1): {}}} + endpoint = MagicMock() + endpoint.create.side_effect = requests.exceptions.ConnectionError("offline") + + with patch("core.netbox_api.time.sleep"): + dt._create_generic( + [{"name": "eth0"}], + 1, + endpoint, + "Interface", + cache_name="interface_templates", + ) + + assert any("Connection error creating Interface" in str(c) for c in mock_settings.handle.log.call_args_list) + + def test_apply_updates_for_type_logs_retryable_exception(self, mock_pynetbox, make_device_types, mock_settings): + import pynetbox as real_pynb + import requests + from core.change_detector import ChangeType, ComponentChange, PropertyChange + + mock_pynetbox.RequestError = real_pynb.RequestError + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + existing = MagicMock(id=9, name="eth0") + dt.cached_components = {"interface_templates": {("device", 1): {"eth0": existing}}} + dt.netbox.dcim.interface_templates.update.side_effect = requests.exceptions.ConnectionError("offline") + changes = [ + ComponentChange( + "interfaces", + "eth0", + ChangeType.COMPONENT_CHANGED, + [PropertyChange("label", "old", "new")], + ) + ] + + with patch("core.netbox_api.time.sleep"): + dt._apply_updates_for_type("interfaces", changes, {}, 1, "device") + + assert any( + "Connection error updating interfaces (ID: 9)" in str(c) for c in mock_settings.handle.log.call_args_list + ) + + def test_remove_components_logs_retryable_exception(self, mock_pynetbox, make_device_types, mock_settings): + import pynetbox as real_pynb + import requests + from core.change_detector import ChangeType, ComponentChange + + mock_pynetbox.RequestError = real_pynb.RequestError + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + existing = MagicMock(id=9, name="eth0") + dt.cached_components = {"interface_templates": {("device", 1): {"eth0": existing}}} + dt.netbox.dcim.interface_templates.delete.side_effect = requests.exceptions.ConnectionError("offline") + changes = [ComponentChange("interfaces", "eth0", ChangeType.COMPONENT_REMOVED)] + + with patch("core.netbox_api.time.sleep"): + dt.remove_components(1, changes) + + assert any( + "Connection error removing interfaces (ID: 9)" in str(c) for c in mock_settings.handle.log.call_args_list + ) + + def test_create_interfaces_logs_retryable_bridge_update_exception( + self, mock_pynetbox, make_device_types, mock_settings + ): + import pynetbox as real_pynb + import requests + + mock_pynetbox.RequestError = real_pynb.RequestError + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + eth0 = MagicMock(id=10) + eth1 = MagicMock(id=20) + dt.cached_components = {"interface_templates": {("device", 1): {"eth0": eth0, "eth1": eth1}}} + dt.netbox.dcim.interface_templates.create.return_value = [] + dt.netbox.dcim.interface_templates.update.side_effect = requests.exceptions.ConnectionError("offline") + interfaces = [{"name": "eth0", "type": "virtual", "bridge": "eth1"}, {"name": "eth1", "type": "virtual"}] + + with patch("core.netbox_api.time.sleep"): + dt.create_interfaces(interfaces, 1, context="ctx.yaml") + + assert any("Connection error bridging interfaces" in str(c) for c in mock_settings.handle.log.call_args_list) + + def test_build_link_rear_ports_skips_empty_mappings(self, mock_pynetbox, make_device_types): + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + post_process = dt._build_link_rear_ports("device", "Front Port") + items = [{"name": "FP1", "type": "8p8c", "_mappings": []}] + + post_process(items, 1) + + assert items == [{"name": "FP1", "type": "8p8c"}] + + +class TestRemainingCoverageBranches: + """Tests for the last uncovered netbox_api branches.""" + + def test_create_rack_types_logs_retryable_update_error(self, mock_settings, mock_pynetbox): + import pynetbox as real_pynb + import requests + from core.graphql_client import DotDict + + mock_pynetbox.RequestError = real_pynb.RequestError + mock_pynetbox.api.return_value.version = "4.1" + mock_pynetbox.api.return_value.dcim.rack_types.update.side_effect = requests.exceptions.ConnectionError( + "offline" + ) + nb = NetBox(mock_settings, mock_settings.handle) + existing = DotDict({"id": 5, "model": "AR1300", "u_height": 40}) + + with patch("core.netbox_api.time.sleep"): + nb.create_rack_types( + [{"manufacturer": {"slug": "apc"}, "model": "AR1300", "u_height": 42, "src": "rack.yaml"}], + all_rack_types={"apc": {"AR1300": existing}}, + ) + + assert any( + "Connection error updating Rack Type AR1300" in str(c) for c in mock_settings.handle.log.call_args_list + ) + + def test_create_rack_types_logs_retryable_create_error(self, mock_settings, mock_pynetbox): + import pynetbox as real_pynb + import requests + + mock_pynetbox.RequestError = real_pynb.RequestError + mock_pynetbox.api.return_value.version = "4.1" + mock_pynetbox.api.return_value.dcim.rack_types.create.side_effect = requests.exceptions.ConnectionError( + "offline" + ) + nb = NetBox(mock_settings, mock_settings.handle) + + with patch("core.netbox_api.time.sleep"): + nb.create_rack_types( + [{"manufacturer": {"slug": "apc"}, "model": "AR1300", "slug": "apc-ar1300", "src": "rack.yaml"}], + all_rack_types={}, + ) + + assert any( + "Connection error creating Rack Type AR1300" in str(c) for c in mock_settings.handle.log.call_args_list + ) + + def test_upload_module_type_images_discards_missing_attachment_before_failed_upload( + self, mock_settings, mock_pynetbox, tmp_path + ): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.verify_images = True + nb.device_types.upload_image_attachment = MagicMock(return_value=False) + + module_dir = tmp_path / "module-types" / "vendor" + module_dir.mkdir(parents=True) + src = module_dir / "mymodule.yaml" + src.write_text("model: X") + image_dir = tmp_path / "module-images" / "vendor" + image_dir.mkdir(parents=True) + (image_dir / "mymodule.front.jpg").write_bytes(b"img") + + mt_res = MagicMock(id=10, model="X") + existing_images = {10: {"mymodule.front"}} + nb._module_image_details = {10: {"mymodule.front": {"url": "/media/front.jpg", "att_id": 7}}} + + with ( + patch("core.netbox_api._check_image_url", return_value="missing"), + patch("core.netbox_api._delete_image_attachment", return_value=True), + ): + nb._upload_module_type_images(mt_res, str(src), existing_images) + + assert "mymodule.front" not in existing_images[10] + + def test_upload_module_type_images_skips_changed_image_when_delete_fails( + self, mock_settings, mock_pynetbox, tmp_path + ): + mock_pynetbox.api.return_value.version = "3.5" + nb = NetBox(mock_settings, mock_settings.handle) + nb.verify_images = True + nb.device_types.upload_image_attachment = MagicMock(return_value=True) + + module_dir = tmp_path / "module-types" / "vendor" + module_dir.mkdir(parents=True) + src = module_dir / "mymodule.yaml" + src.write_text("model: X") + image_dir = tmp_path / "module-images" / "vendor" + image_dir.mkdir(parents=True) + (image_dir / "mymodule.front.jpg").write_bytes(b"img") + + mt_res = MagicMock(id=10, model="X") + existing_images = {10: {"mymodule.front"}} + nb._module_image_details = {10: {"mymodule.front": {"url": "/media/front.jpg", "att_id": 7}}} + + with ( + patch("core.netbox_api._check_image_url", return_value="ok"), + patch("core.netbox_api._is_image_hash_changed", return_value=True), + patch("core.netbox_api._delete_image_attachment", return_value=False), + ): + nb._upload_module_type_images(mt_res, str(src), existing_images) + + nb.device_types.upload_image_attachment.assert_not_called() + assert "mymodule.front" in existing_images[10] + assert any( + "skipping upload to avoid duplicates" in str(c) for c in mock_settings.handle.verbose_log.call_args_list + ) + + def test_stop_component_preload_removes_tasks_when_owned(self): + progress = MagicMock() + preload_job = {"task_ids": {"interfaces": 1}, "owns_tasks": True} + + DeviceTypes.stop_component_preload(preload_job, progress=progress) + + progress.stop_task.assert_called_once_with(1) + progress.remove_task.assert_called_once_with(1) + + def test_drain_pending_rewinds_negative_update_from_fallback_queue(self, mock_pynetbox, make_device_types): + dt = make_device_types(nb_api=mock_pynetbox.api.return_value) + future = MagicMock() + done_states = [False, True] + future.done.side_effect = lambda: done_states.pop(0) + future.result.return_value = [] + progress = MagicMock() + progress.tasks = [MagicMock(id=1, completed=4)] + updates = queue.Queue() + updates.put(("interface_templates", -2)) + records = {} + + with patch.object(dt, "_apply_progress_updates", return_value=False): + dt._drain_pending( + {"interface_templates"}, + {"interface_templates": future}, + progress, + {"interface_templates": 1}, + updates, + {}, + records, + ) + + assert progress.update.call_args_list[0].args == (1,) + assert progress.update.call_args_list[0].kwargs == {"completed": 2} diff --git a/tests/test_outcomes.py b/tests/test_outcomes.py index 7f9527d08..5551372de 100644 --- a/tests/test_outcomes.py +++ b/tests/test_outcomes.py @@ -104,3 +104,19 @@ def test_render_failure_report_includes_partials_section(): assert "Nokia/IOM-s-3.0T" in text assert "image upload failed but properties applied" in text assert "v/no-reason-partial" in text + + +def test_render_failure_report_includes_partial_blockers_and_hint(): + reg = OutcomeRegistry() + reg.record( + EntityKind.MODULE_TYPE, + "Cisco/LC1", + Outcome.PARTIAL, + blocking_objects=["front", "rear"], + hint="retry module image upload", + ) + + text = "\n".join(reg.render_failure_report()) + + assert "blocked by: front, rear" in text + assert "hint: retry module image upload" in text diff --git a/tests/test_repo.py b/tests/test_repo.py index 4a4cd993d..5733c101e 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import MagicMock, call, mock_open, patch from git import exc as git_exc -from core.repo import DTLRepo, validate_git_url, normalize_port_mappings +from core.repo import DTLRepo, _safe_pickle_load, validate_git_url, normalize_port_mappings class TestValidateGitUrl: @@ -348,6 +348,237 @@ def test_get_devices_skips_testing_folder(self): assert not any(v["name"] == "testing" for v in vendors) +class TestDiscoverVendors: + """Tests for DTLRepo.discover_vendors(): vendor discovery across multiple paths with deduplication.""" + + def _make_repo(self): + mock_args = MagicMock() + mock_args.url = "https://github.com/org/repo.git" + mock_args.branch = "master" + mock_handle = MagicMock() + with ( + patch("os.path.isdir", return_value=True), + patch("core.repo.Repo") as MockRepo, + ): + mock_git_repo = MagicMock() + mock_git_repo.remotes.origin.url = "https://github.com/org/repo.git" + ref = MagicMock() + ref.name = "origin/master" + mock_git_repo.remotes.origin.refs = [ref] + MockRepo.return_value = mock_git_repo + repo = DTLRepo(mock_args, "/tmp/repo", mock_handle) + return repo + + def test_discovers_vendors_from_single_path(self): + """Test discovery from a single existing path.""" + repo = self._make_repo() + + def mock_exists(path): + return "devices" in path + + with ( + patch("os.path.exists", side_effect=mock_exists), + patch("os.listdir", return_value=["Cisco", "Juniper"]), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + assert len(vendors) == 2 + assert vendors[0]["name"] == "Cisco" + assert vendors[0]["slug"] == "cisco" + assert vendors[1]["name"] == "Juniper" + assert vendors[1]["slug"] == "juniper" + + def test_discovers_vendors_from_multiple_paths(self): + """Test discovery and merging from multiple paths.""" + repo = self._make_repo() + + def mock_listdir(path): + if "devices" in path: + return ["Cisco", "Juniper"] + elif "modules" in path: + return ["Arista", "Cisco"] + elif "racks" in path: + return ["Dell"] + return [] + + with ( + patch("os.path.exists", return_value=True), + patch("os.listdir", side_effect=mock_listdir), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + assert len(vendors) == 4 + vendor_names = [v["name"] for v in vendors] + assert "Cisco" in vendor_names + assert "Juniper" in vendor_names + assert "Arista" in vendor_names + assert "Dell" in vendor_names + + def test_deduplicates_vendors_across_paths(self): + """Test that vendors appearing in multiple paths are deduplicated.""" + repo = self._make_repo() + + def mock_listdir(path): + # Cisco appears in all three paths + if "devices" in path: + return ["Cisco", "Juniper"] + elif "modules" in path: + return ["Cisco", "Arista"] + elif "racks" in path: + return ["Cisco"] + return [] + + with ( + patch("os.path.exists", return_value=True), + patch("os.listdir", side_effect=mock_listdir), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + # Cisco should appear only once despite being in all three paths + cisco_vendors = [v for v in vendors if v["slug"] == "cisco"] + assert len(cisco_vendors) == 1 + assert len(vendors) == 3 # Cisco, Juniper, Arista + + def test_skips_testing_folder(self): + """Test that 'testing' folder (case-insensitive) is excluded.""" + repo = self._make_repo() + with ( + patch("os.path.exists", return_value=True), + patch("os.listdir", return_value=["Cisco", "testing", "Testing", "TESTING"]), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + assert len(vendors) == 1 + assert vendors[0]["name"] == "Cisco" + assert not any(v["name"].lower() == "testing" for v in vendors) + + def test_handles_nonexistent_paths(self): + """Test graceful handling of non-existent paths.""" + repo = self._make_repo() + + def mock_exists(path): + return "devices" in path # Only devices path exists + + def mock_listdir(path): + if "devices" in path: + return ["Cisco"] + return [] + + with ( + patch("os.path.exists", side_effect=mock_exists), + patch("os.listdir", side_effect=mock_listdir), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + assert len(vendors) == 1 + assert vendors[0]["name"] == "Cisco" + + def test_handles_all_nonexistent_paths(self): + """Test that all non-existent paths returns empty list.""" + repo = self._make_repo() + with patch("os.path.exists", return_value=False): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + assert vendors == [] + + def test_handles_os_errors_gracefully(self): + """Test graceful handling of OS errors when listing directories.""" + repo = self._make_repo() + + def mock_listdir(path): + if "devices" in path: + raise OSError("Permission denied") + elif "modules" in path: + return ["Cisco"] + return [] + + with ( + patch("os.path.exists", return_value=True), + patch("os.listdir", side_effect=mock_listdir), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + assert len(vendors) == 1 + assert vendors[0]["name"] == "Cisco" + + def test_skips_non_directory_entries(self): + """Test that files (non-directories) are skipped.""" + repo = self._make_repo() + + def mock_isdir(path): + # Only Cisco is a directory, README.md is a file + return "Cisco" in path + + with ( + patch("os.path.exists", return_value=True), + patch("os.listdir", return_value=["Cisco", "README.md"]), + patch("os.path.isdir", side_effect=mock_isdir), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + assert len(vendors) == 1 + assert vendors[0]["name"] == "Cisco" + + def test_returns_sorted_by_slug(self): + """Test that vendors are returned sorted by slug.""" + repo = self._make_repo() + with ( + patch("os.path.exists", return_value=True), + patch("os.listdir", return_value=["Zebra", "Arista", "Cisco", "Dell"]), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + slugs = [v["slug"] for v in vendors] + assert slugs == sorted(slugs) + assert slugs == ["arista", "cisco", "dell", "zebra"] + + def test_uses_slug_format_correctly(self): + """Test that slug_format is applied correctly to vendor names.""" + repo = self._make_repo() + with ( + patch("os.path.exists", return_value=True), + patch("os.listdir", return_value=["Extreme Networks", "HPE-Aruba"]), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + assert len(vendors) == 2 + # slug_format lowercases and replaces non-word chars with hyphens + assert any(v["slug"] == "extreme-networks" for v in vendors) + assert any(v["slug"] == "hpe-aruba" for v in vendors) + + def test_vendor_name_selection_is_deterministic_for_same_slug(self): + """Vendor name must be deterministic when multiple folders map to same slug. + + When two folder names produce the same slug, the alphabetically first + folder name must always win regardless of os.listdir() iteration order. + + Before the fix, os.listdir() was non-deterministic, so "Nokia" vs "nokia" + folders could produce different vendor names across runs. + """ + repo = self._make_repo() + + # Simulate os.listdir returning folders in reverse-alphabetical order + def mock_listdir_rev(_path): + return ["nokia", "Nokia"] # lowercase first → would win without sorting + + with ( + patch("os.path.exists", return_value=True), + patch("os.listdir", side_effect=mock_listdir_rev), + patch("os.path.isdir", return_value=True), + ): + vendors = repo.discover_vendors("/devices", "/modules", "/racks") + + # sorted("Nokia", "nokia") → "Nokia" < "nokia" (uppercase sorts first in ASCII) + assert len(vendors) == 1 + assert vendors[0]["name"] == "Nokia" # alphabetically first + + class TestParseFilesExtended: """Tests for DTLRepo.parse_files(): parallel parsing, slug filtering, error handling, and progress iteration.""" @@ -1055,3 +1286,229 @@ def test_duplicate_key_logs_warning_and_records_definition(self): assert len(repo.duplicate_definitions) == 1 assert repo.duplicate_definitions[0]["manufacturer"] == "cisco" assert repo.duplicate_definitions[0]["model"] == "X" + + +# --------------------------------------------------------------------------- +# resolve_slug_files +# --------------------------------------------------------------------------- + + +class TestRestrictedPickleLoading: + """Tests for _RestrictedUnpickler and _safe_pickle_load security guards.""" + + def test_safe_pickle_load_rejects_global_opcode(self, tmp_path): + import pickle + + class Evil: + def __reduce__(self): + return (len, ([1, 2, 3],)) + + path = tmp_path / "evil.pickle" + path.write_bytes(pickle.dumps(Evil())) + + with pytest.raises(pickle.UnpicklingError, match="not permitted"): + _safe_pickle_load(str(path)) + + +class TestResolveSlugFiles: + """Tests for the pickle-based slug fast path.""" + + def _make_repo(self): + mock_args = MagicMock() + mock_args.url = "https://github.com/org/repo.git" + mock_args.branch = "master" + mock_handle = MagicMock() + with ( + patch("os.path.isdir", return_value=True), + patch("core.repo.Repo") as MockRepo, + ): + mock_git_repo = MagicMock() + mock_git_repo.remotes.origin.url = "https://github.com/org/repo.git" + ref = MagicMock() + ref.name = "origin/master" + mock_git_repo.remotes.origin.refs = [ref] + MockRepo.return_value = mock_git_repo + repo = DTLRepo(mock_args, "/tmp/repo", mock_handle) + return repo + + def test_returns_none_when_pickle_missing(self, tmp_path): + """Returns None gracefully when the device pickle doesn't exist.""" + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + result = repo.resolve_slug_files(["nokia-7750-sr-7s"]) + assert result is None + + def test_returns_device_files_for_matching_slug(self, tmp_path): + """Pickle entry matches → file path appears in device_files under correct vendor.""" + import pickle + + # Build a minimal pickle + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + known = {("nokia-7750-sr-7s", "device-types/Nokia/7750-SR-7s.yaml")} + (tests_dir / "known-slugs.pickle").write_bytes(pickle.dumps(known)) + # Create an empty module/rack pickle so those code paths don't crash + (tests_dir / "known-modules.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-racks.pickle").write_bytes(pickle.dumps(set())) + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + result = repo.resolve_slug_files(["nokia-7750-sr-7s"]) + + assert result is not None + assert "nokia" in result["device_files"] + expected_path = str(tmp_path / "device-types" / "Nokia" / "7750-SR-7s.yaml") + assert expected_path in result["device_files"]["nokia"] + + def test_substring_match(self, tmp_path): + """Partial slug matches (substring) are found.""" + import pickle + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + known = { + ("nokia-7750-sr-7s", "device-types/Nokia/7750-SR-7s.yaml"), + ("nokia-7750-sr-12e", "device-types/Nokia/7750-SR-12e.yaml"), + ("cisco-catalyst-9200", "device-types/Cisco/Catalyst-9200.yaml"), + } + (tests_dir / "known-slugs.pickle").write_bytes(pickle.dumps(known)) + (tests_dir / "known-modules.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-racks.pickle").write_bytes(pickle.dumps(set())) + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + result = repo.resolve_slug_files(["7750-sr"]) + + assert result is not None + assert "nokia" in result["device_files"] + assert len(result["device_files"]["nokia"]) == 2 + assert "cisco" not in result["device_files"] + + def test_no_match_returns_empty_dict(self, tmp_path): + """No matching slug returns empty device_files (not None).""" + import pickle + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + known = {("cisco-catalyst-9200", "device-types/Cisco/Catalyst-9200.yaml")} + (tests_dir / "known-slugs.pickle").write_bytes(pickle.dumps(known)) + (tests_dir / "known-modules.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-racks.pickle").write_bytes(pickle.dumps(set())) + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + result = repo.resolve_slug_files(["nokia-7750-sr-7s"]) + + assert result is not None + assert result["device_files"] == {} + assert result["module_vendors"] == set() + + def test_corrupted_device_pickle_returns_none(self, tmp_path): + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "known-slugs.pickle").write_bytes(b"not-a-pickle") + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + assert repo.resolve_slug_files(["nokia"]) is None + + def test_short_device_relpath_is_skipped(self, tmp_path): + import pickle + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "known-slugs.pickle").write_bytes(pickle.dumps({("nokia-7750-sr-7s", "device-types/Nokia.yaml")})) + (tests_dir / "known-modules.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-racks.pickle").write_bytes(pickle.dumps(set())) + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + result = repo.resolve_slug_files(["7750-sr"]) + + assert result["device_files"] == {} + + def test_corrupted_module_pickle_is_ignored(self, tmp_path): + import pickle + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "known-slugs.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-modules.pickle").write_bytes(b"bad-module-pickle") + (tests_dir / "known-racks.pickle").write_bytes(pickle.dumps(set())) + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + result = repo.resolve_slug_files(["module"]) + + assert result["module_vendors"] is None + + def test_matching_module_pickle_adds_vendor_slug(self, tmp_path): + import pickle + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "known-slugs.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-modules.pickle").write_bytes(pickle.dumps({("7750 line card", "module-types/Nokia")})) + (tests_dir / "known-racks.pickle").write_bytes(pickle.dumps(set())) + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + result = repo.resolve_slug_files(["7750"]) + + assert result["module_vendors"] == {"nokia"} + + def test_corrupted_rack_pickle_is_ignored(self, tmp_path): + import pickle + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "known-slugs.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-modules.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-racks.pickle").write_bytes(b"bad-rack-pickle") + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + result = repo.resolve_slug_files(["rack"]) + + assert result["rack_vendors"] is None + + def test_rack_pickle_only_uses_rack_type_entries(self, tmp_path): + import pickle + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "known-slugs.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-modules.pickle").write_bytes(pickle.dumps(set())) + (tests_dir / "known-racks.pickle").write_bytes( + pickle.dumps( + { + ("42U Rack", "rack-types/APC"), + ("Not A Rack", "module-types/Nokia"), + } + ) + ) + + repo = self._make_repo() + repo.repo_path = str(tmp_path) + repo.cwd = "" + + result = repo.resolve_slug_files(["rack"]) + + assert result["rack_vendors"] == {"apc"} diff --git a/tests/test_schema_reader.py b/tests/test_schema_reader.py index e4255fd9e..7fa301967 100644 --- a/tests/test_schema_reader.py +++ b/tests/test_schema_reader.py @@ -50,6 +50,13 @@ def test_non_dict_property_entry_is_skipped(self, tmp_path): assert "broken_prop" not in result assert "also_broken" not in result + def test_non_object_properties_key_raises_value_error(self, tmp_path): + schema_file = tmp_path / "bad-props.json" + schema_file.write_text(json.dumps({"properties": []})) + + with pytest.raises(ValueError, match="non-object 'properties'"): + load_scalar_properties(str(schema_file)) + def test_excludes_named_properties(self, tmp_path): schema = { "properties": { diff --git a/tests/test_update_failure_resolver.py b/tests/test_update_failure_resolver.py index cb438ed18..b82088db1 100644 --- a/tests/test_update_failure_resolver.py +++ b/tests/test_update_failure_resolver.py @@ -8,6 +8,7 @@ from core.update_failure_resolver import ( FailureKind, + _extract_error_payload, classify_device_type_update_failure, ) @@ -194,6 +195,16 @@ def test_classifier_handles_bytes_payload(): assert res.kind == FailureKind.SUBDEVICE_ROLE_FLIP +class _BrokenBytes(bytes): + def decode(self, *args, **kwargs): + raise RuntimeError("decode failed") + + +def test_extract_error_payload_returns_original_bytes_on_decode_failure(): + payload = _BrokenBytes(b"bad") + assert _extract_error_payload(payload) is payload + + @pytest.mark.parametrize( ("payload", "expected_kind"), [