diff --git a/.claude/skills/temp-roadmap/SKILL.md b/.claude/skills/temp-roadmap/SKILL.md new file mode 100644 index 0000000..c9078e9 --- /dev/null +++ b/.claude/skills/temp-roadmap/SKILL.md @@ -0,0 +1,66 @@ +--- +name: temp-roadmap +description: Create a temporary, self-removing roadmap/backlog section in CLAUDE.md from a set of identified work items (issue triage, audit findings, review feedback). Use when the user asks to "track these items", "write this list into CLAUDE.md as a plan", "create a backlog/roadmap in CLAUDE.md", or wants a work list that cleans itself up as PRs land. +--- + +# Temporary Self-Removing Roadmap in CLAUDE.md + +Turn a set of identified work items into a CLAUDE.md section that shrinks with +every PR and disappears entirely when the work is done — so CLAUDE.md returns +to its clean state without manual housekeeping. + +## When to use + +- After issue triage, a code audit, or a review produced a concrete list of + work items that will be resolved across several future PRs. +- NOT for single-PR task tracking (use the todo list) and NOT for permanent + guidance (write a normal CLAUDE.md section or arc42 doc instead). + +## Structure to write into CLAUDE.md + +Insert a section near the top of CLAUDE.md (after the docs index, before +project structure): + +```markdown +## Backlog (TEMPORARY section — self-deleting) + +**Deletion hook:** When a PR resolves one of the items below, DELETE its row +from this table **in the same PR**. When the last row is gone, delete this +entire section so CLAUDE.md returns to its clean state. Never let a finished +item linger here. + +Items validated ; source: . + +| Item | Size | Task | +|------|------|------| +| #123 | quick win | One-line actionable description with file hints (`path/file.py`, function name, reporter-verified fix if any) | +| #124 | medium | ... | +| #125 | blocked | ... — name what it is blocked on so it can be re-checked | +``` + +## Rules + +1. **One line per item.** Concise but self-sufficient: a future session must be + able to start the item from the row alone — include file paths, function + names, and known pitfalls ("do NOT use setSortingEnabled — currentRowChanged + fires switch_image"). +2. **Classify every item**: `quick win` / `medium` / `large` / `blocked`. + For `blocked`, state the blocker. +3. **Reference the source** (issue number, ticket, review comment) so context + can be recovered. +4. **Date the validation** — rows describe code state at a point in time; + re-verify before implementing if the date is old. +5. **The deletion hook is mandatory** and must appear verbatim-in-spirit at the + top of the section. It is the whole point: the section is a consumable, not + documentation. +6. Mark items currently being worked on with *(in progress on this branch)* so + parallel sessions don't double-pick them. + +## Per-PR workflow (executing the hook) + +1. Pick item(s), implement on a feature branch. +2. In the same PR: delete the resolved row(s) from the table. +3. If the table is now empty: delete the entire section, including the heading + and the deletion-hook paragraph. +4. The PR diff thus always shows both the fix and the backlog shrinking — + reviewers can see progress without a separate tracker. diff --git a/.gitignore b/.gitignore index 1b85dae..70272cb 100644 --- a/.gitignore +++ b/.gitignore @@ -9,10 +9,12 @@ models/ *.iap # Local tooling — .claude is mostly local state, but the senior-reviewer -# agent definition under .claude/agents/ is tracked (referenced by -# CLAUDE.md as a mandatory quality gate), so un-ignore that subtree. +# agent definition under .claude/agents/ and project skills under +# .claude/skills/ are tracked (referenced by CLAUDE.md / used as shared +# workflow tooling), so un-ignore those subtrees. .claude/* !.claude/agents/ +!.claude/skills/ .venv # Python cache and build artifacts diff --git a/CLAUDE.md b/CLAUDE.md index a7897e0..e9f13f4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -36,6 +36,24 @@ For detailed architecture and design information, see **[docs/](docs/)**: See [docs/README.md](docs/README.md) for full documentation index. +## Upstream Issue Backlog (TEMPORARY section — self-deleting) + +**Deletion hook:** When a PR resolves one of the items below, DELETE its row +from this table **in the same PR**. When the last row is gone, delete this +entire section so CLAUDE.md returns to its clean state. Never let a finished +item linger here. + +Issue numbers refer to https://github.com/bnsreenu/digitalsreeni-image-annotator/issues +(validated 2026-06-12; already-fixed issues have close-request comments posted, not listed here). + +| Issue | Size | Task | +|-------|------|------| +| #32 + #36 | medium | Annotations can extend outside image bounds (manual edit + Image Augmenter) → clamp coords on commit; clip augmented polygons to image rect (shapely intersection). Silently poisons training data | +| #40 | medium | True bbox editing (move whole box, drag edges, keep rectangularity). Currently bbox annotations aren't editable at all — `start_polygon_edit` only matches `"segmentation"` | +| #63 | blocked | SAM 3 support — blocked on Ultralytics shipping SAM 3; re-check their releases before attempting | +| #35 | large | Keypoint annotation tool | +| #24 | large | Magic-wand-style point add/remove mask refinement (partially covered by SAM point prompts) | + ## Project Structure ``` @@ -47,14 +65,17 @@ src/digitalsreeni_image_annotator/ ├── __init__.py # Public API re-exports │ ├── core/ # constants, annotation_utils, image_utils -├── controllers/ # 7 controllers (project, image, sam, dino, -│ # yolo, annotation, class) + io_controller +├── controllers/ # 8 controllers (project, image, sam, +│ # sam_train, dino, yolo, annotation, +│ # class) + io_controller ├── widgets/ │ ├── image_label.py # ImageLabel canvas widget (dispatcher) │ ├── canvas_context.py # CanvasContext read accessor (ADR-018) │ └── tools/ # Per-tool handlers (ADR-019): rectangle, │ # polygon, paint, eraser ├── inference/ # sam_utils.py, dino_utils.py +├── training/ # SAM fine-tuning (ADR-021): sam_trainer.py +│ # (SAMFineTuner), sam_dataset.py ├── io/ # export_formats.py, import_formats.py ├── ui/ # menu_bar, sidebar, shortcuts, theme, stylesheets └── dialogs/ # Standalone tool dialogs (statistics, @@ -76,8 +97,10 @@ src/digitalsreeni_image_annotator/ | `SAMController` | controllers/sam_controller.py | SAM model picker, debounce, in-flight guard (ADR-013) | | `DINOController` | controllers/dino_controller.py | DINO single/batch detection, batch review, temp-class workflow | | `YOLOController` | controllers/yolo_controller.py | YOLO training menu + prediction wiring | -| `SAMUtils` | inference/sam_utils.py | Load SAM models, run inference | +| `SAMUtils` | inference/sam_utils.py | Load SAM models (built-in + fine-tuned), run inference | | `DINOUtils` | inference/dino_utils.py | Grounding-DINO model load + inference | +| `SAMFineTuner` | training/sam_trainer.py | Fine-tune SAM 2 decoder/encoder via custom loop over Ultralytics SAM2Model (ADR-021) | +| `SAMTrainController` | controllers/sam_train_controller.py | SAM fine-tune menu, GPU gate, training thread, selector registration | See [Building Block View](docs/05_building_block_view.md) for detailed class documentation. @@ -169,6 +192,8 @@ See [Runtime View](docs/06_runtime_view.md#multi-dimensional-image-loading) for | GPU model unload | `model.cpu()` → `gc.collect()` → `torch.cuda.empty_cache()` + `ipc_collect()` + `synchronize()` — full reclaim requires app restart due to per-process CUDA context | Setting refs to None alone leaves circular refs pinned and shows zero Task Manager drop. See [Releasing Model GPU Memory](docs/08_crosscutting_concepts.md#releasing-model-gpu-memory). | | Export image-path lookup | Exact-key match first, substring fallback only | `"bee.jpg" in "honeybee.jpg"` is True — substring-only matching writes the wrong file. See [Export Format Filename Matching](docs/08_crosscutting_concepts.md#export-format-filename-matching). | | F2 / global shortcuts | Use `QShortcut` with `Qt.ShortcutContext.ApplicationShortcut`, not `keyPressEvent` | `QTableWidget` consumes F2 for in-cell edit before it bubbles up. | +| Canvas ↔ list selection sync | Canvas selection (idle-mode click/Shift/rubber-band) drives the annotation list via `apply_canvas_selection`; mirror the list with `blockSignals(True/False)` and match annotations by **value-equality**, never identity | PyQt round-trips `UserRole` dicts as copies and `image_label.annotations` is a deepcopy, so identity is never stable; un-blocked `setSelected` recurses through `update_highlighted_annotations`. Multi-select uses **Shift** (Ctrl stays pan). See [ADR-022](docs/09_architecture_decisions.md#adr-022-canvas-mask-selection-unified-with-the-annotation-list). | +| Selection rendering | Don't recolour a selected mask. Keep its class colour; draw a class-colour-independent overlay (dashed `_SELECTION_COLOR` blue bounding box + bright handle squares at corners/edge-midpoints, OGP-style) in a final pass — `_draw_selection_overlay`. Handles are visual-only (resize = #40). Default class colours come from `core/constants.py::default_class_color` (red last, muted) | Red selection was invisible on a red-class mask; a thin dashed outline alone was too faint; the handles carry the visibility. See ADR-022 amendment. | ## Development Workflow @@ -245,6 +270,10 @@ See [Risks and Technical Debt](docs/11_risks_and_technical_debt.md) for full lis |--------|--------| | Ctrl+Wheel | Zoom | | Ctrl+Drag | Pan | +| Click / Shift+Click (no tool) | Select / toggle mask | +| Drag / Shift+Drag (no tool) | Rubber-band select / add | +| Double-click | Vertex-edit mode | +| Delete | Delete selected mask(s) | | Enter | Finish/Accept | | Esc | Cancel | | Up/Down | Navigate slices | diff --git a/README.md b/README.md index a541bb2..2a80607 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,14 @@ python -c "import torch; print(torch.cuda.is_available(), torch.cuda.get_device_ You should see `True` and your GPU name. For other platforms or driver combinations, use the official selector at . +#### Older NVIDIA GPUs (Pascal / Maxwell) + +PyTorch ≥ 2.8 wheels no longer include kernels for GPUs older than Volta (compute capability < 7.0), e.g. the GTX 10xx series (sm_61). On such cards the app detects the mismatch, warns once, and automatically runs inference on the CPU instead of crashing with `CUDA error: no kernel image is available`. To keep using the GPU, install an older PyTorch that still supports it: + +```bash +pip install torch==2.4.1 torchvision==0.19.1 --index-url https://download.pytorch.org/whl/cu121 +``` + ## Usage 1. Run the DigitalSreeni Image Annotator application: diff --git a/docs/05_building_block_view.md b/docs/05_building_block_view.md index a957b10..9ce3771 100644 --- a/docs/05_building_block_view.md +++ b/docs/05_building_block_view.md @@ -33,7 +33,8 @@ src/digitalsreeni_image_annotator/ ├── utils.py # Cross-cutting utilities ├── core/ # Constants, annotation utils, image utils │ ├── constants.py - │ └── annotation_utils.py + │ ├── annotation_utils.py + │ └── torch_utils.py # Shared torch device resolution + CPU fallback (#57) ├── widgets/ │ ├── image_label.py # ImageLabel - canvas widget; dispatcher │ ├── canvas_context.py # CanvasContext - narrow read view (ADR-018) @@ -155,6 +156,23 @@ The earlier subprocess approach is documented as [ADR-011](09_architecture_decisions.md#adr-011-run-torch-based-workers-in-isolated-subprocesses) (Superseded). +### SAM Fine-Tuning Subsystem (`training/`) + +Lets users fine-tune SAM 2 / 2.1 on their own annotations, since +Ultralytics ships no SAM trainer (ADR-021). Distinct from `inference/` +because it is *training*, not inference. + +| Module | Responsibility | +|--------|----------------| +| `training/sam_trainer.py` | `SAMFineTuner` — custom decoder (optionally encoder) fine-tuning loop reusing `SAM2Predictor.get_im_features` / `prompt_inference` under autograd, focal+dice loss, AdamW, checkpoint save+reload-verify. Also geometry helpers (`polygon_to_mask`, `mask_to_xyxy`, `mask_to_point`), `make_custom_filename`, `list_custom_models`, and the `SampleGroup` lazy-rasterising dataset item. | +| `training/sam_dataset.py` | `build_groups_from_project` (live `all_annotations`) and `build_groups_from_folder` (prepared dataset) → `list[SampleGroup]`, mirroring `export_yolo_v5plus` image resolution. | +| `io/export_formats.py::export_sam_dataset` | Writes `images/` + `manifest.json` (authoritative bbox/segmentation specs) for an inspectable, re-trainable on-disk dataset. | + +Fine-tuned checkpoints save as `{"model": state_dict}` and reload +through the unchanged `SAM(path)` inference path; `SAMUtils` gains a +`custom_models` registry so they appear in the SAM selector alongside +the eight built-ins. + ### DINO Subsystem (Grounding DINO + SAM pipeline) LLM-assisted detection: the user gives free-form text phrases per class, @@ -195,7 +213,7 @@ export time — see [Cross-cutting Concepts](08_crosscutting_concepts.md)). ## Level 3: Controllers -Seven `QObject` controllers plus an `io_controller` helper module +Eight `QObject` controllers plus an `io_controller` helper module carve `ImageAnnotator` into single-responsibility owners that the orchestrator delegates to. Each `QObject` controller holds `self.mw = main_window` and owns one slice of behaviour; the @@ -208,12 +226,13 @@ the controller graph. | Controller | Responsibility | |------------|----------------| | `ProjectController` | `.iap` save/load, auto-save, backup/restore, missing-image prompts, window-title sync. Owns the `is_loading_project` autosave guard (load/save round-trip safety, v0.8.12). | -| `ImageController` | Open / load / switch images and slices. TIFF + CZI loaders, the multi-dim `DimensionDialog`, the `[-ndim:]` axis-slice bug fix from the v0.9.0 era. | +| `ImageController` | Open / load / switch images and slices. TIFF + CZI loaders (with `imagecodecs` codec-error handling — #56), the multi-dim `DimensionDialog`, the `[-ndim:]` axis-slice bug fix from the v0.9.0 era. Image-list annotation-status filter (`image_has_annotations`, `apply_image_filter` — #27) and alphabetical sort (`sort_image_list` — #60). | | `AnnotationController` | Annotation CRUD, list sorting, highlight, edit-mode entry/exit, `finish_polygon`, `finish_rectangle`, `replace_annotations` (eraser path). Validates writes before mutating `all_annotations`. | | `ClassController` | Class add / delete / rename / colour / visibility. `update_slice_list_colors`, `is_class_visible`. | | `SAMController` | SAM box/points tool lifecycle, debounce timer, `_sam_inference_in_flight` re-entrancy guard (ADR-013), model picker. | | `DINOController` | Single + batch detection, batch review navigation, temp-annotation accept/reject, custom-model browse, `DINOReviewEventFilter` ownership (ADR-015). | | `YOLOController` | Training menu, `TrainingThread`, prediction dialog, result processing. | +| `SAMTrainController` | SAM fine-tuning menu, GPU gate, `SAMTrainingThread`, config dialog, registers fine-tuned checkpoints into the SAM selector (ADR-021). | | `io_controller` *(module-level functions, not a class)* | Thin UI wrappers around the pure `io/export_formats.py` and `io/import_formats.py` modules. | Communication: `ImageLabel` does not import controllers directly — diff --git a/docs/06_runtime_view.md b/docs/06_runtime_view.md index 42337ca..7c3c2a0 100644 --- a/docs/06_runtime_view.md +++ b/docs/06_runtime_view.md @@ -66,6 +66,37 @@ User presses Enter └─> update() to show final annotation ``` +## Mask Selection & Deletion on the Canvas (issue #75) + +Active only when no drawing/SAM tool is selected (`ImageLabel._is_select_mode()`). +Double-click still enters vertex-edit; Ctrl+drag still pans. + +``` +User clicks / drags on image (no tool active) + │ + ├─> ImageLabel mouse press/move/release + │ ├─> click → annotation_at(pos) (smallest mask, seg or bbox) + │ ├─> click empty → [] (clears selection) + │ ├─> drag → annotations_in_rect(rect) (rubber band, bounds-intersect) + │ └─> Shift → toggle (click) / add (drag) + │ + ├─> emit canvasSelectionChanged(annotations, mode) mode = replace|add|toggle + │ + └─> AnnotationController.apply_canvas_selection() + ├─> compute new set from highlighted_annotations + annotations per mode + ├─> image_label.highlighted_annotations = new (blue selection overlay) + ├─> mirror onto annotation_list (blockSignals while selecting) + └─> enable Merge (≥2) / Change Class (≥1) + +User presses Delete (canvas focused) + │ + ├─> ImageLabel.keyPressEvent → deleteSelectionRequested + └─> AnnotationController.delete_selected_annotations() (confirm → remove → re-sort → autosave) +``` + +The canvas and the list share one selection (matched by dict value-equality), so +Delete/Merge/Change-Class behave the same from either surface. See ADR-022. + ## SAM-Assisted Annotation (SAM-box / SAM-points) ``` @@ -370,3 +401,38 @@ User clicks "Export" > "YOLO v8/v11" │ └─> Return ``` + +## SAM Fine-Tuning (annotate → train → use) + +See [ADR-021](09_architecture_decisions.md#adr-021-sam-fine-tuning-via-a-custom-loop-over-the-ultralytics-sam2-module). + +``` +User: SAM Fine-Tune (beta) > Train on Current Project… + │ + ├─> build_groups_from_project(all_annotations, image_paths, slices, image_slices) + │ polygons/bboxes → SampleGroup(image_loader, specs) (masks rasterised lazily) + │ + ├─> _gpu_gate(): resolve_torch_device(); if "cpu" → warn + let user back out + │ + ├─> SAMTrainConfigDialog: base model, epochs, lr, batch, prompt (bbox/point), + │ "also fine-tune image encoder?" + │ + ├─> deactivate_sam_tools() + lock SAM inference UI (tools, selector, menu) + │ trainer loads its OWN SAM instance; locking avoids a 2nd model on the same CUDA context + │ + └─> SAMTrainingThread → SAMFineTuner.train(...) + │ build predictor (one warmup predict), pin device, apply freeze policy + │ for each epoch / image / instance: + │ set_image → get_im_features (no_grad when encoder frozen) + │ prompt_inference(bbox|point) under enable_grad → mask logits + │ focal+dice loss → backward → AdamW step (every batch_size instances) + │ progress_signal → TrainingInfoDialog (Stop supported) + │ save {"model": state_dict} as _.pt → reload-verify via SAM() + │ + └─> training_finished: register in SAMUtils.custom_models, + add "★ " to the SAM selector and select it + → SAM-box / SAM-points now use the fine-tuned model + +Offline variant: "Prepare SAM Dataset…" → export_sam_dataset (images/ + manifest.json), +then "Train from Dataset Folder…" → build_groups_from_folder → same training path. +``` diff --git a/docs/08_crosscutting_concepts.md b/docs/08_crosscutting_concepts.md index 7e4ea7e..79b65d6 100644 --- a/docs/08_crosscutting_concepts.md +++ b/docs/08_crosscutting_concepts.md @@ -178,6 +178,33 @@ nothing visible" in v0.9.0 manual testing. | SAM 2 base | ~150MB | Medium-High | Slow | ⚠️ Use with caution | | SAM 2 large | ~400MB | High | Very Slow | ❌ Not recommended (crashes on limited resources) | +### Device Selection & Compute-Capability Fallback + +`torch.cuda.is_available()` is **not** sufficient to decide on GPU +inference: it returns True for any visible CUDA device even when the +installed torch wheels contain no kernels for its compute capability +(torch ≥ 2.8 wheels ship sm_70+ only, so a Pascal GTX 1050 / sm_61 +passes the check but every kernel launch fails with +`CUDA error: no kernel image is available for execution on the device` +— upstream issue #57). + +All inference paths therefore resolve their device through +`core/torch_utils.resolve_torch_device()`: + +- compares `torch.cuda.get_device_capability(0)` against the minimum + `sm_*` in `torch.cuda.get_arch_list()`; on mismatch returns + `("cpu", warning)` instead of `"cuda"`, +- caches the decision process-wide (SAM, DINO and YOLO share it), +- prints the warning once; `maybe_warn_cpu_fallback(parent)` shows it + as a one-time `QMessageBox` from the SAM model picker and the DINO + detect entry points. + +SAM passes `device=` explicitly on every predict call; DINO's +`_resolve_device()` delegates to the helper (the `DINO_DEVICE` env +override still wins); the YOLO trainer passes `device=` to +`model.train()` and prediction. Never call bare +`torch.cuda.is_available()` to pick a device in new code. + ## Dark Mode Support ### Stylesheet Switching @@ -377,6 +404,11 @@ def generate_slice_name(filename, t, z, c, s): |----------|--------| | Ctrl+Wheel | Zoom In/Out | | Ctrl+Drag | Pan | +| Click (no tool) | Select mask under cursor | +| Shift+Click (no tool) | Toggle mask in selection | +| Drag (no tool) | Rubber-band box-select; Shift+Drag adds | +| Delete | Delete selected mask(s) | +| Double-click | Enter vertex-edit mode | | Esc | Cancel Current Annotation | | Enter | Finish/Accept Annotation | | Up/Down | Navigate Slices (multi-dimensional) | @@ -500,6 +532,68 @@ The substring fallback is kept for backward compatibility with old projects that may have stored normalised image names (e.g. without extension); new code should prefer the exact-key path. +## Image List Filter — Hide Rows, Never Remove Them + +The image list can be filtered by annotation status (combo above the +list; upstream issue #27, `ImageController.apply_image_filter`). The +filter uses `setRowHidden(i, True)` and must **never** remove items: + +- `DINOController._navigate_to_image_or_slice` and the COCO importer + iterate `image_list` rows by index; removing rows would shift + indices under them. +- Removing the current item fires `currentRowChanged`, which is wired + to `switch_image` — a filter change could silently switch the + displayed image. Hiding fires nothing. + +A non-matching row is hidden **even when it is the current selection**. +`setRowHidden` does not change `current_image` or fire +`currentRowChanged`, so the canvas keeps showing the worked-on image +while its row leaves the list — e.g. the current image gains its first +annotation under the "Without annotations" filter and disappears from +the list, but stays on screen until the user navigates away. Keyboard +nav skips hidden rows. (Guaranteed by +`test_hiding_current_row_keeps_canvas_and_fires_no_switch`.) + +Re-apply runs from `ClassController.update_slice_list_colors()`. The +contract: every annotation-mutation site either calls that method +directly **or** emits `annotationsBatchSaved` (whose handler +`_on_annotations_batch_saved` calls it). All `annotationCommitted` +emitters follow up with `annotationsBatchSaved` +(image_label.py / paint_tool.py), so both commit paths are covered. +New mutation paths must keep one of those two routes — don't add +bespoke `apply_image_filter()` call sites. + +## Image List Sorting — Rebuild, Don't `setSortingEnabled` + +The image list is kept alphabetical (upstream issue #60, +`ImageController.sort_image_list`). Two constraints shape the +implementation: + +- `currentRowChanged` is wired to `switch_image`, so `setSortingEnabled(True)` + is forbidden — a live re-sort would reorder rows and fire spurious + image switches. +- COCO import (and other positional lookups) assume `all_images[i]` + matches `image_list.item(i)`. So the model and the view are sorted + **together**: `all_images` is sorted, then the list is cleared and + repopulated from it with `blockSignals(True)` around the rebuild, and + the prior (or newly added) selection is restored explicitly. The #27 + filter is re-applied at the end of the rebuild. + +`update_image_list` routes through `sort_image_list`; `add_images_to_list` +calls it with the first added file selected. It is skipped per-image +during project load (the list is rebuilt once via `update_ui`) to avoid +an O(n²) re-sort. + +## TIFF Compression Codecs + +Reading an LZW- (or otherwise) compressed TIFF requires the optional +`imagecodecs` package; without it `tifffile` raises `ValueError` mid-read +(upstream issue #56). `imagecodecs` is now a hard dependency, but +`ImageController.add_images_to_list` also catches the codec `ValueError` +(`_is_missing_codec_error`) and shows an actionable "pip install +imagecodecs" dialog, skipping the file instead of crashing or leaving a +half-added entry. Non-codec `ValueError`s still propagate. + ## Canvas Decoupling — Signals + CanvasContext `ImageLabel` (the canvas widget) does **not** hold a reference to @@ -550,3 +644,55 @@ paint commits into O(N). See ADR-018. See ADR-018 in `09_architecture_decisions.md` for the rationale and the full pattern. + +## Canvas Selection ↔ List Selection + +When no drawing tool is active (`ImageLabel._is_select_mode()`), the canvas +behaves like a pointer: a single click selects the smallest mask under the +cursor, a drag draws a rubber band that box-selects, and **Shift** toggles / +adds. This is wired so there is **one** selection shared by the canvas overlay +(`highlighted_annotations`, blue selection outline + handles) and the bottom-left +annotation list — so +`Delete` / `Merge` / `Change Class` (which read `annotation_list.selectedItems()`) +work identically whether you selected on the image or in the list. See ADR-022. + +Flow: `ImageLabel` emits `canvasSelectionChanged(annotations, mode)` (mode = +`replace` | `add` | `toggle`) → `AnnotationController.apply_canvas_selection` +computes the new set, assigns `image_label.highlighted_annotations`, and mirrors +it onto the list. + +Two non-obvious rules make this correct: + +- **Match by value-equality, never identity.** `image_label.annotations` is a + `deepcopy` of `all_annotations`, and PyQt round-trips dicts stored in a list + item's `UserRole` as *copies* — so the "same" annotation has different object + identity on the canvas, in `all_annotations`, and in a list item. Every + selection comparison therefore uses dict `==` (`a == b`), the same convention + as `select_annotation_in_list`, `delete_selected_annotations`, and the + `annotation in highlighted_annotations` test in `draw_annotations`. A + consequence: two value-equal duplicate masks select together — accepted, and + pre-existing. +- **Block list signals while mirroring.** `apply_canvas_selection` wraps the + programmatic list selection in `annotation_list.blockSignals(True/False)`. + Without it, `setSelected` fires `itemSelectionChanged` → + `update_highlighted_annotations`, which would overwrite the freshly-computed + set with the list items' own objects (and clobber a `toggle`). + +**Ctrl is reserved for pan.** Multi-select uses **Shift**, not Ctrl, because +Ctrl+drag is the pan gesture (whose reference-frame handling is deliberately +delicate — see [Pan + Zoom Reference Frames](#pan--zoom-reference-frames)). +Leaving Ctrl untouched keeps that gesture intact. + +**Selection is drawn independent of class colour.** A selected mask is *not* +recoloured (the first version turned it red, which vanished on a red-class mask, +and red was the default first class colour). Instead `draw_annotations` keeps the +class colour and, in a final pass on top of every mask, draws a dashed +selection-blue **bounding-box marquee plus bright handle squares** at the 4 +corners + 4 edge midpoints (`_SELECTION_COLOR`, `_draw_selection_overlay`) — +modelled on the sibling open-garden-planner app's CAD selection. The handles +carry the visibility (a single thin dashed outline was too faint); they are +visual markers only — resize-by-handle is a separate feature (upstream #40). This +never collides with any class colour. Relatedly, the default class palette +(`core/constants.py`) was reordered so red is last and the fill opacity lowered +to keep the image legible — see the No Hardcoded Colors Rule for the broader +"don't fight the theme/colours" theme. diff --git a/docs/09_architecture_decisions.md b/docs/09_architecture_decisions.md index ec1306b..a58ba6b 100644 --- a/docs/09_architecture_decisions.md +++ b/docs/09_architecture_decisions.md @@ -739,6 +739,150 @@ persistence mechanism in the app. --- +## ADR-021: SAM Fine-Tuning via a Custom Loop over the Ultralytics SAM2 Module + +**Status**: Accepted + +**Context**: Users annotating domain-specific imagery (microscopy, +medical, materials) get generic SAM masks that need heavy correction. +We want to let them fine-tune SAM 2 / 2.1 on their own annotations and +reuse the result in the existing SAM-box / SAM-points workflow +(upstream issue bnsreenu#73). + +The obvious approach — mirror the YOLO trainer's `model.train(...)` — +**does not work**: Ultralytics registers only a *predictor* for SAM's +`segment` task (`SAM.task_map`), so `SAM(...).train()` raises +`NotImplementedError` (verified on ultralytics 8.4.51). + +**Decision**: Fine-tune with a custom PyTorch loop that **reuses +Ultralytics' own forward path**. `SAM(...).model` is a plain +`SAM2Model` `nn.Module`; its `SAM2Predictor` exposes the forward in +reusable pieces — `get_im_features` (image encoder) and +`prompt_inference` / `_inference_features` (prompt encoder + mask +decoder). These are *not* wrapped in `inference_mode` unless reached +via the public `__call__`, so calling them directly under +`torch.enable_grad()` yields differentiable mask logits. The engine +(`training/sam_trainer.py`) adds focal+dice loss (≈20:1) + AdamW + +backward. Default freeze policy: train only `sam_mask_decoder` +(image + prompt encoders frozen); an optional flag also unfreezes the +image encoder. + +Checkpoints are saved as `{"model": state_dict}` — the exact shape +Ultralytics' `_load_checkpoint` reads (it rebuilds the architecture +from the filename suffix and `load_state_dict`s the nested `model` +key). Consequently a fine-tuned file **must keep its base token in the +name** (e.g. `myrun_sam2_t.pt`), enforced by `make_custom_filename`; +`build_sam` selects the architecture by `ckpt.endswith(token)`. Every +save is round-trip-verified by reloading through `SAM(out_path)` and +running one forward — failing loudly rather than producing a file that +won't reload (cf. facebookresearch/sam2#337 key-mismatch failures). + +**Alternatives considered**: +- *facebookresearch/sam2 training code* — rejected: heavy extra + dependency overlapping Ultralytics' bundled SAM2, and its checkpoints + need state-dict conversion to reload into our `SAM()` inference path. +- *Export dataset + train externally* — rejected as the default (less + "integrated"), though `Prepare SAM Dataset` + folder training give a + similar offline path for users who want it. + +**Consequences**: +- ✅ No new runtime dependency; fine-tuned models drop straight into + the existing SAM selector and inference path. +- ✅ Exposure to Ultralytics internals is confined to a few + already-exercised predictor methods, guarded by + `test_sam_finetuning.py::TestUltralyticsAPI` (fails on an upgrade + that renames them). +- ⚠️ The trainer loads its **own** `SAM` instance on its `QThread` + (it does not touch `SAMUtils._model`), and must **not** use + `sam_utils._run_sync` (its re-entry guard is GUI-thread-local). The + real hazard is two SAM models (resident inference + training) on one + CUDA context, so `SAMTrainController` locks the SAM inference UI + (tools + model selector + the fine-tune menu) for the duration — + re-enabled in `training_finished` on both the success and error + paths. +- ⚠️ Decoder fine-tuning is realistically GPU-only; a CPU-only box is + hard-warned before a run (`resolve_torch_device`), and the device is + pinned so an incompatible GPU is honoured as CPU instead of crashing. +- ⚠️ Encoder features are recomputed per epoch (bounded memory) rather + than cached across epochs; revisit if large datasets need the speedup. +- ⚠️ **Loss must use the inference coordinate frame.** SAM2 letterboxes + the image (`LetterBox(1024, center=False)`, pad bottom/right) and + inference maps masks back with `ops.scale_masks(..., padding=False)`, + which crops that padding before upsampling. The training loss therefore + runs the decoder logits through the *same* `ops.scale_masks` before + comparing to the GT mask — a naive `F.interpolate` over the full + low-res mask bakes the padding into the target and the decoder learns + masks shifted by the pad (a downward shift on non-square images, caught + only during GUI testing because the e2e tests used square images). The + landscape regression test (`test_landscape_no_mask_shift`) and the + `ops.scale_masks` API-drift guard protect this. + +--- + +## ADR-022: Canvas Mask Selection Unified with the Annotation List + +**Status**: Accepted (issue bnsreenu#75) + +**Context**: Selecting an existing annotation was only possible through the +bottom-left annotation list (already `ExtendedSelection`) or by *double*-clicking +a mask on the canvas — which immediately enters vertex-edit mode. There was no +single-click select, no box/multi-select on the image, and canvas `Delete` worked +only while in vertex-edit mode. Issue #75 asked for single-click select (without +entering edit), rubber-band box select, modifier multi-select, and multi-delete — +all directly on the canvas. + +**Decision**: Add an **idle-mode selection layer** to `ImageLabel` and route it +through the *existing* annotation-list selection so delete/merge/change-class are +reused unchanged: + +- **Idle activation.** Selection is live only in `_is_select_mode()` — no drawing + tool, not editing, not SAM, no temp review. Picking any tool restores drawing. + No new tool button (matches the user's "a single click should select" ask). +- **Gestures.** Plain click selects the smallest mask under the cursor (covers + segmentation *and* bbox); click on empty space clears; drag draws a rubber band + and selects every annotation whose bounds intersect it; **Shift** makes a click + toggle and a drag additive. Double-click is unchanged (still vertex edit). +- **Ctrl stays pan.** Ctrl+drag pan (with its carefully tuned reference frame) is + left untouched; multi-select uses Shift instead of Ctrl. +- **One selection, two surfaces.** The canvas emits + `canvasSelectionChanged(annotations, mode)`; `AnnotationController.apply_canvas_selection` + computes the new set (replace/add/toggle), sets `image_label.highlighted_annotations`, + and **mirrors it onto the list** with signals blocked. `Delete` on the canvas + reuses `delete_selected_annotations` (which reads the list selection). + +**Consequences**: +- Delete / Merge / Change-Class need no new logic — they already operate on the + list selection, which the canvas now drives. +- ⚠️ Matching between the canvas and list relies on **dict value-equality**, like + the rest of the selection code (`image_label.annotations` is a deepcopy of + `all_annotations`, and PyQt round-trips `UserRole` dicts as copies, so identity + is never stable). Value-equal duplicate masks would select together — a + pre-existing, accepted limitation. See the crosscutting "Canvas selection ↔ + list selection" section. +- ⚠️ The list mirror must block `itemSelectionChanged` while selecting, or it + recurses back through `update_highlighted_annotations` and overwrites the set. + +**Selection is rendered class-colour-independent (amendment).** The first cut +drew the selected mask in solid **red** — invisible on a red-class mask, and the +default palette assigned red as the *first* class colour. Selection is now an +overlay drawn in a final pass on top of every mask, independent of class colour +and modelled on the sibling open-garden-planner app's CAD selection: a dashed +selection-blue **bounding-box marquee** (`_SELECTION_COLOR = QColor(0, 120, 215, +220)`) plus bright opaque-blue **handle squares** at the 4 corners + 4 edge +midpoints, white-cased and fixed on-screen size (`_draw_selection_overlay` in +`widgets/image_label.py`). The handles are what make selection unmistakable +regardless of mask colour (a single thin dashed outline was too faint; an earlier +marching-ants + marquee was too busy). The handles are visual markers only — +resizing via handles is a separate feature (upstream #40). The mask keeps its +class colour; the rubber-band rect uses the same blue dashed style. Separately, +the default class palette +(`core/constants.py::DEFAULT_CLASS_COLORS` / `default_class_color`) was reordered +so red is **last** (no fresh project starts on red) and muted, and the default +fill opacity dropped to `0.2` (`DEFAULT_FILL_OPACITY`) so masks don't bury the +image. Existing projects keep their persisted class colours. + +--- + ## Decisions Under Consideration ### Consider pytest-qt for Utility Testing diff --git a/docs/12_glossary.md b/docs/12_glossary.md index 03df404..21ebb5d 100644 --- a/docs/12_glossary.md +++ b/docs/12_glossary.md @@ -20,6 +20,15 @@ Carl Zeiss Image file format for multi-dimensional microscopy images. Contains m ### DINO / Grounding DINO "DINO" in this codebase refers specifically to **Grounding DINO** (IDEA-Research, 2023) — an open-set object detector that takes a natural-language phrase ("drone", "wing of an aircraft") and returns bounding boxes for matching regions of an image. Not to be confused with the self-supervised vision-only DINOv1/v2 backbones (similar name, different model). Models live under `models/grounding-dino-base/` and `models/grounding-dino-tiny/`. +### Fine-Tuning (SAM) +Continuing training of a pre-trained SAM 2 / 2.1 model on the user's own annotations so the assisted tools work better on their imagery. Because Ultralytics ships no SAM trainer, the app uses a custom loop over the Ultralytics `SAM2Model` (see [ADR-021](09_architecture_decisions.md#adr-021-sam-fine-tuning-via-a-custom-loop-over-the-ultralytics-sam2-module)). **Decoder-only** (default) trains just the mask decoder, freezing the image and prompt encoders — fast, low-VRAM, robust on modest data; optionally the image encoder is also unfrozen for heavily domain-shifted data. + +### Focal + Dice Loss +The mask-supervision loss used during SAM fine-tuning: a focal term (down-weights easy pixels, emphasises hard ones) plus a dice term (region overlap), combined ≈20:1. Standard across the SAM fine-tuning literature. + +### Mask Decoder +The lightweight SAM head that turns image embeddings + prompt embeddings into mask logits. The default fine-tuning target (`sam_mask_decoder`, ~4.2M params for the tiny model) since it is small and adapts quickly. + ### Multi-dimensional Image An image with more than 2 dimensions, typically from microscopy. Dimensions include T (time), Z (depth), C (channel), S (scene), H (height), W (width). @@ -47,6 +56,16 @@ Segment Anything Model - Meta's foundation model for image segmentation. Version ### SAM Point Mode Annotation mode where user clicks positive points (inside object) and negative points (outside object) to guide SAM segmentation. +### Select Mode (Canvas) +The idle canvas state (no drawing/SAM tool active, not editing, no temp review) in +which clicks and drags select existing masks instead of drawing. Single-click +selects, Shift toggles/adds, drag box-selects; double-click still enters vertex +edit. See ADR-022. + +### Rubber-Band Selection +A dashed selection rectangle dragged on the canvas in Select Mode; every annotation +whose bounds intersect it is selected. Shift+drag adds to the current selection. + ### Semantic Labels Single-channel image where each pixel value represents the class ID. Used for semantic segmentation training. diff --git a/setup.py b/setup.py index 04c20f4..a7c05e7 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "numpy>=2.0.0", # pip resolves 2.4+ on Py3.14, 2.2.x on Py3.10 (last 3.10-compatible) "Pillow>=10.0.0", "tifffile>=2023.0.0", + "imagecodecs>=2023.1.23", # tifffile needs it for LZW/compressed TIFF (#56) "czifile>=2019.7.2", "opencv-python>=4.8.0", "pyyaml>=6.0.0", diff --git a/src/digitalsreeni_image_annotator/annotator_window.py b/src/digitalsreeni_image_annotator/annotator_window.py index aa6ad0e..817438c 100644 --- a/src/digitalsreeni_image_annotator/annotator_window.py +++ b/src/digitalsreeni_image_annotator/annotator_window.py @@ -24,6 +24,7 @@ from .controllers.image_controller import ImageController from .controllers.project_controller import ProjectController from .controllers.sam_controller import SAMController +from .controllers.sam_train_controller import SAMTrainController from .controllers.yolo_controller import YOLOController from .core import image_utils from .ui import theme @@ -117,6 +118,7 @@ def __init__(self): self._sam_inference_in_flight = False self.sam_controller = SAMController(self) + self.sam_train_controller = SAMTrainController(self) self.dino_controller = DINOController(self) self.yolo_controller = YOLOController(self) self.annotation_controller = AnnotationController(self) @@ -163,6 +165,11 @@ def __init__(self): self.yolo_trainer = None self.setup_yolo_menu() + # SAM fine-tuning menu + register any previously fine-tuned models so + # they appear in the SAM model selector (built during setup_ui above). + self.sam_train_controller.setup_sam_train_menu() + self.sam_train_controller.refresh_model_selector() + install_shortcuts(self) install_event_filters(self) @@ -184,6 +191,7 @@ def _connect_image_label_signals(self): il.annotationsReplaced.connect(ac.replace_annotations) il.annotationListUpdateRequested.connect(ac.update_annotation_list) il.annotationSelected.connect(ac.select_annotation_in_list) + il.canvasSelectionChanged.connect(ac.apply_canvas_selection) il.deleteSelectionRequested.connect(ac.delete_selected_annotations) il.finishPolygonRequested.connect(ac.finish_polygon) il.finishRectangleRequested.connect(ac.finish_rectangle) @@ -273,6 +281,9 @@ def load_missing_images(self, missing_images): def update_image_list(self): return self.image_controller.update_image_list() + def apply_image_filter(self): + return self.image_controller.apply_image_filter() + def select_class(self, index): return self.class_controller.select_class(index) diff --git a/src/digitalsreeni_image_annotator/controllers/annotation_controller.py b/src/digitalsreeni_image_annotator/controllers/annotation_controller.py index 00e52e5..99e0c46 100644 --- a/src/digitalsreeni_image_annotator/controllers/annotation_controller.py +++ b/src/digitalsreeni_image_annotator/controllers/annotation_controller.py @@ -36,6 +36,7 @@ from shapely.geometry import MultiPolygon, Polygon from shapely.ops import unary_union +from ..core.constants import default_class_color from ..utils import calculate_area, calculate_bbox @@ -217,7 +218,7 @@ def load_annotations(self): if not file_name: return - with open(file_name, "r") as f: + with open(file_name, "r", encoding='utf-8') as f: self.mw.loaded_json = json.load(f) self.mw.class_list.clear() @@ -229,7 +230,7 @@ def load_annotations(self): if class_name not in self.mw.image_label.class_colors: color = QColor( - Qt.GlobalColor(len(self.mw.image_label.class_colors) % 16 + 7) + default_class_color(len(self.mw.image_label.class_colors)) ) self.mw.image_label.class_colors[class_name] = color @@ -259,9 +260,11 @@ def load_annotations(self): for img in json_images.values(): updated_all_images.append(img) - self.mw.image_list.addItem(img["file_name"]) self.mw.all_images = updated_all_images + # Rebuild the list in sorted order (issue #60). The reconciliation + # loop above already consumed the pre-sort row/index alignment. + self.mw.update_image_list() self.mw.all_annotations.clear() for annotation in self.mw.loaded_json["annotations"]: @@ -331,6 +334,9 @@ def load_annotations(self): def clear_highlighted_annotation(self): self.mw.image_label.highlighted_annotations.clear() + # Selection is gone — Merge/Change Class must follow, or they linger + # enabled against an empty list selection after an image/slice switch. + self._sync_selection_buttons(0) self.mw.image_label.update() def update_highlighted_annotations(self): @@ -339,9 +345,59 @@ def update_highlighted_annotations(self): item.data(Qt.ItemDataRole.UserRole) for item in selected_items ] self.mw.image_label.update() + self._sync_selection_buttons(len(selected_items)) + + def _sync_selection_buttons(self, count): + """Merge needs ≥2 annotations; Change Class needs ≥1. Shared by the + list-driven and canvas-driven (issue #75) selection paths.""" + self.mw.merge_button.setEnabled(count >= 2) + self.mw.change_class_button.setEnabled(count > 0) + + def apply_canvas_selection(self, annotations, mode): + """Apply a selection change that originated on the canvas (issue #75) + and mirror it onto the annotation list so Delete / Merge / Change + Class operate on the same set. Matching uses dict value-equality, + consistent with the rest of the selection code. + + ``mode`` is one of ``"replace"``, ``"add"``, ``"toggle"``. + """ + current = list(self.mw.image_label.highlighted_annotations) + + def contains(seq, ann): + return any(a == ann for a in seq) + + if mode == "replace": + new = list(annotations) + elif mode == "add": + new = current + [a for a in annotations if not contains(current, a)] + elif mode == "toggle": + new = list(current) + for a in annotations: + match = next((x for x in new if x == a), None) + if match is not None: + new.remove(match) + else: + new.append(a) + else: + return - self.mw.merge_button.setEnabled(len(selected_items) >= 2) - self.mw.change_class_button.setEnabled(len(selected_items) > 0) + self.mw.image_label.highlighted_annotations = new + + # Mirror onto the list widget. Block signals so the programmatic + # selection doesn't retrigger itemSelectionChanged → + # update_highlighted_annotations, which would overwrite `new` with + # the list items' own (all_annotations) object identities. + lst = self.mw.annotation_list + lst.blockSignals(True) + lst.clearSelection() + for i in range(lst.count()): + item = lst.item(i) + if contains(new, item.data(Qt.ItemDataRole.UserRole)): + item.setSelected(True) + lst.blockSignals(False) + + self._sync_selection_buttons(len(new)) + self.mw.image_label.update() def highlight_annotation_in_list(self, annotation): for i in range(self.mw.annotation_list.count()): @@ -420,6 +476,8 @@ def delete_selected_annotations(self): self.sort_annotations_by_class() self.mw.image_label.highlighted_annotations.clear() + # Selection is now empty — Merge/Change Class must follow. + self._sync_selection_buttons(0) self.mw.image_label.update() self.mw.update_slice_list_colors() @@ -538,7 +596,7 @@ def are_all_polygons_connected(polygons): msg_box.setText("Do you want to keep the original annotations?") msg_box.setIcon(QMessageBox.Icon.Question) - keep_button = msg_box.addButton("Keep", QMessageBox.ButtonRole.YesRole) + msg_box.addButton("Keep", QMessageBox.ButtonRole.YesRole) delete_button = msg_box.addButton("Delete", QMessageBox.ButtonRole.NoRole) cancel_button = msg_box.addButton("Cancel", QMessageBox.ButtonRole.RejectRole) diff --git a/src/digitalsreeni_image_annotator/controllers/class_controller.py b/src/digitalsreeni_image_annotator/controllers/class_controller.py index d45b2c4..dcd08c3 100644 --- a/src/digitalsreeni_image_annotator/controllers/class_controller.py +++ b/src/digitalsreeni_image_annotator/controllers/class_controller.py @@ -27,6 +27,8 @@ QMessageBox, ) +from ..core.constants import default_class_color + class ClassController(QObject): def __init__(self, main_window): @@ -96,6 +98,13 @@ def update_slice_list_colors(self): self.mw.slice_list.repaint() + # Re-apply hook for the image-list annotation filter. Contract: + # every annotation-mutation site either calls this method directly + # or emits annotationsBatchSaved, whose handler + # (_on_annotations_batch_saved) calls it. New mutation paths must + # keep one of those two routes. + self.mw.image_controller.apply_image_filter() + def add_class(self, class_name=None, color=None): if not self.mw.image_label.check_unsaved_changes(): return @@ -136,7 +145,7 @@ def add_class(self, class_name=None, color=None): if color is None: color = QColor( - Qt.GlobalColor(len(self.mw.image_label.class_colors) % 16 + 7) + default_class_color(len(self.mw.image_label.class_colors)) ) elif isinstance(color, str): color = QColor(color) diff --git a/src/digitalsreeni_image_annotator/controllers/dino_controller.py b/src/digitalsreeni_image_annotator/controllers/dino_controller.py index b21edd8..76ff0ff 100644 --- a/src/digitalsreeni_image_annotator/controllers/dino_controller.py +++ b/src/digitalsreeni_image_annotator/controllers/dino_controller.py @@ -38,6 +38,8 @@ QTextEdit, ) +from ..core.constants import default_class_color + class DINOReviewEventFilter(QObject): """Application-wide event filter that lets Enter / Escape accept or @@ -211,6 +213,9 @@ def run_dino_detection_single(self): "Please add at least one class with phrases.") return + from ..core.torch_utils import maybe_warn_cpu_fallback + maybe_warn_cpu_fallback(self.mw) + self.mw.btn_detect_single.setEnabled(False) self.mw.btn_detect_batch.setEnabled(False) @@ -360,6 +365,9 @@ def run_dino_detection_batch(self): "Please add at least one class with phrases.") return + from ..core.torch_utils import maybe_warn_cpu_fallback + maybe_warn_cpu_fallback(self.mw) + # Prevent stale temp annotations from a prior single-image review from # confusing the batch results handler or the DINOReviewEventFilter. self.mw.image_label.temp_annotations = [] @@ -681,7 +689,7 @@ def add_temp_classes(self, temp_annotations): for temp_class_name, annotations in temp_annotations.items(): if temp_class_name not in self.mw.image_label.class_colors: color = QColor( - Qt.GlobalColor(len(self.mw.image_label.class_colors) % 16 + 7) + default_class_color(len(self.mw.image_label.class_colors)) ) self.mw.image_label.class_colors[temp_class_name] = color self.mw.image_label.annotations[temp_class_name] = annotations diff --git a/src/digitalsreeni_image_annotator/controllers/image_controller.py b/src/digitalsreeni_image_annotator/controllers/image_controller.py index 1d30557..e1d426f 100644 --- a/src/digitalsreeni_image_annotator/controllers/image_controller.py +++ b/src/digitalsreeni_image_annotator/controllers/image_controller.py @@ -86,9 +86,121 @@ def __init__(self, main_window): self.mw = main_window def update_image_list(self): + # Rebuild (and sort) the list, preserving the current selection + # without switching images. + self.sort_image_list() + + def sort_image_list(self, select_name=None, do_switch=False): + """Populate image_list in alphabetical order (upstream issue #60). + + Sorts the model (`all_images`) and the view together so the + `all_images[i]` ↔ `image_list.item(i)` positional invariant holds + (relied on by COCO import reconciliation). `setSortingEnabled` is + deliberately NOT used: `currentRowChanged` is wired to + `switch_image`, so a live re-sort would fire spurious image + switches. We rebuild with signals blocked instead, then re-select + explicitly. + + select_name: file to select after the rebuild (defaults to the + previously-current item). do_switch: call switch_image once for + the selected item (used when adding new images). + """ + current = None + if self.mw.image_list.currentItem() is not None: + current = self.mw.image_list.currentItem().text() + + self.mw.all_images.sort( + key=lambda info: ( + info.get("file_name", "").casefold(), + info.get("file_name", ""), + ) + ) + + self.mw.image_list.blockSignals(True) self.mw.image_list.clear() - for image_info in self.mw.all_images: - self.mw.image_list.addItem(image_info["file_name"]) + for info in self.mw.all_images: + self.mw.image_list.addItem(info["file_name"]) + self.mw.image_list.blockSignals(False) + + self.apply_image_filter() + + target = select_name if select_name is not None else current + if target is not None: + items = self.mw.image_list.findItems( + target, Qt.MatchFlag.MatchExactly + ) + if items: + self.mw.image_list.blockSignals(True) + self.mw.image_list.setCurrentItem(items[0]) + self.mw.image_list.blockSignals(False) + if do_switch: + self.switch_image(items[0]) + + def image_has_annotations(self, image_info): + """True if the image (or, for multi-dim images, any of its slices) + has at least one annotation.""" + + def _non_empty(by_class): + return bool(by_class) and any(by_class.values()) + + file_name = image_info["file_name"] + if _non_empty(self.mw.all_annotations.get(file_name, {})): + return True + + if image_info.get("is_multi_slice", False): + base_name = os.path.splitext(file_name)[0] + slices = self.mw.image_slices.get(base_name) + if slices: + return any( + _non_empty(self.mw.all_annotations.get(slice_name, {})) + for slice_name, _ in slices + ) + # Slices not extracted yet (e.g. load cancelled) — slice keys + # are f"{base_name}_T1_Z5_..." so a "{base_name}_" prefix match + # is exact enough; a bare substring match would not be. Caveat: + # this also matches "{base_name}_8bit" artifact keys, which + # redefine_dimensions deliberately excludes — acceptable here + # since an _8bit key with annotations still means "this image + # has annotations". + prefix = base_name + "_" + return any( + key.startswith(prefix) and _non_empty(by_class) + for key, by_class in self.mw.all_annotations.items() + ) + + return False + + def apply_image_filter(self): + """Hide image-list rows that don't match the annotation-status + filter (upstream issue #27). + + Rows are hidden via setRowHidden, never removed: other code + (DINO batch navigation, COCO import) iterates the list by row + index, and hiding fires no currentRowChanged so it cannot + trigger a spurious switch_image. + + A non-matching row is hidden even when it is the current + selection — hiding does not change `current_image`, so the canvas + keeps showing the worked-on image while its row leaves the list + (e.g. the image just gained its first annotation under the + "Without annotations" filter). Keyboard nav skips hidden rows. + """ + combo = getattr(self.mw, "image_filter_combo", None) + if combo is None: + return + mode = combo.currentIndex() # 0 = all, 1 = without, 2 = with + if mode == 0: + # Default case runs on every update_slice_list_colors — + # keep it a plain unhide pass with no annotation scans. + for i in range(self.mw.image_list.count()): + self.mw.image_list.setRowHidden(i, False) + return + infos = {info["file_name"]: info for info in self.mw.all_images} + for i in range(self.mw.image_list.count()): + info = infos.get(self.mw.image_list.item(i).text()) + annotated = bool(info) and self.image_has_annotations(info) + hide = annotated if mode == 1 else not annotated + self.mw.image_list.setRowHidden(i, hide) def setup_slice_list(self): self.mw.slice_list = QListWidget() @@ -114,7 +226,7 @@ def open_images(self): self.add_images_to_list(file_names) def add_images_to_list(self, file_names): - first_added_item = None + first_added_name = None for file_name in file_names: base_name = os.path.basename(file_name) if base_name not in self.mw.image_paths: @@ -127,7 +239,25 @@ def add_images_to_list(self, file_names): } if file_name.lower().endswith((".tif", ".tiff", ".czi")): - self.load_multi_slice_image(file_name) + try: + self.load_multi_slice_image(file_name) + except ValueError as e: + # LZW/compressed TIFFs need the optional imagecodecs + # package; without it tifffile raises ValueError and + # the app used to crash (#56). Skip the file with an + # actionable message instead of a half-added entry. + if self._is_missing_codec_error(e): + QMessageBox.critical( + self.mw, + "Cannot open TIFF", + f"'{base_name}' uses a compression that requires " + "the 'imagecodecs' package, which is not " + "installed.\n\nInstall it with:\n" + " pip install imagecodecs\n\n" + "then reopen the image.", + ) + continue + raise base_name_without_ext = os.path.splitext(base_name)[0] if ( base_name_without_ext in self.mw.image_slices @@ -151,20 +281,34 @@ def add_images_to_list(self, file_names): image_info["width"] = image.width() self.mw.all_images.append(image_info) - item = QListWidgetItem(base_name) - self.mw.image_list.addItem(item) - if first_added_item is None: - first_added_item = item + if first_added_name is None: + first_added_name = base_name self.mw.image_paths[base_name] = file_name - if first_added_item: - self.mw.image_list.setCurrentItem(first_added_item) - self.switch_image(first_added_item) + # Rebuild the list in sorted order and select/switch to the first + # newly added image. Skipped during project load (the list is + # rebuilt once via update_ui afterwards, and load picks row 0) to + # avoid an O(n^2) re-sort per image. + if first_added_name is not None and not self.mw.is_loading_project: + self.sort_image_list(select_name=first_added_name, do_switch=True) + else: + self.apply_image_filter() if not self.mw.is_loading_project: self.mw.auto_save() + @staticmethod + def _is_missing_codec_error(exc): + """True if a tifffile read failed because the imagecodecs package + is unavailable for the TIFF's compression — e.g. LZW (#56). + + Matches only the reliable 'imagecodecs' token: tifffile names the + package in every such message. A broader 'compression' match would + silently swallow unrelated ValueErrors behind a misleading dialog. + """ + return "imagecodecs" in str(exc).lower() + def update_all_images(self, new_image_info): for info in new_image_info: if not any( diff --git a/src/digitalsreeni_image_annotator/controllers/io_controller.py b/src/digitalsreeni_image_annotator/controllers/io_controller.py index f3de559..1fcbdff 100644 --- a/src/digitalsreeni_image_annotator/controllers/io_controller.py +++ b/src/digitalsreeni_image_annotator/controllers/io_controller.py @@ -11,10 +11,11 @@ import os -from PyQt6.QtCore import Qt from PyQt6.QtGui import QColor from PyQt6.QtWidgets import QFileDialog, QMessageBox +from ..core.constants import default_class_color + from ..io.export_formats import ( export_coco_json, export_labeled_images, @@ -157,7 +158,7 @@ def import_annotations(mw): new_id = len(mw.class_mapping) + 1 mw.class_mapping[category_name] = new_id mw.image_label.class_colors[category_name] = QColor( - Qt.GlobalColor(new_id % 16 + 7) + default_class_color(new_id - 1) ) print("Updating UI") diff --git a/src/digitalsreeni_image_annotator/controllers/project_controller.py b/src/digitalsreeni_image_annotator/controllers/project_controller.py index 70e92c3..18521e0 100644 --- a/src/digitalsreeni_image_annotator/controllers/project_controller.py +++ b/src/digitalsreeni_image_annotator/controllers/project_controller.py @@ -116,7 +116,7 @@ def open_specific_project(self, project_file): try: self.mw.is_loading_project = True - with open(project_file, "r") as f: + with open(project_file, "r", encoding='utf-8') as f: project_data = json.load(f) self.mw.clear_all(show_messages=False) @@ -146,12 +146,10 @@ def open_specific_project(self, project_file): self.mw.initialize_yolo_trainer() self.update_window_title() + # No success dialog — the loaded canvas + updated window title + # already make a successful open obvious; a modal just adds a + # click. Errors below still surface as dialogs. print(f"Project opened successfully: {project_file}") - QMessageBox.information( - self.mw, - "Project Opened", - f"Project opened successfully: {os.path.basename(project_file)}", - ) except Exception as e: self.mw.is_loading_project = False @@ -490,7 +488,7 @@ def save_project(self, show_message=True): if dino_cfg["phrases"] or dino_cfg["thresholds"]: project_data["dino_config"] = dino_cfg - with open(self.mw.current_project_file, "w") as f: + with open(self.mw.current_project_file, "w", encoding='utf-8') as f: json.dump(image_utils.convert_to_serializable(project_data), f, indent=2) if show_message: diff --git a/src/digitalsreeni_image_annotator/controllers/sam_controller.py b/src/digitalsreeni_image_annotator/controllers/sam_controller.py index 096fa4e..b4c9e9b 100644 --- a/src/digitalsreeni_image_annotator/controllers/sam_controller.py +++ b/src/digitalsreeni_image_annotator/controllers/sam_controller.py @@ -214,6 +214,10 @@ def change_sam_model(self, model_name): if model_name != "Pick a SAM Model": print(f"Changed SAM model to: {model_name}") + # One-time dialog if CUDA exists but the torch wheels can't + # run it (e.g. Pascal sm_61 on torch>=2.8) — upstream #57. + from ..core.torch_utils import maybe_warn_cpu_fallback + maybe_warn_cpu_fallback(self.mw) else: self.deactivate_sam_tools() print("SAM model unset") diff --git a/src/digitalsreeni_image_annotator/controllers/sam_train_controller.py b/src/digitalsreeni_image_annotator/controllers/sam_train_controller.py new file mode 100644 index 0000000..7ac6d05 --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/sam_train_controller.py @@ -0,0 +1,273 @@ +"""SAM 2 fine-tuning coordination controller. + +Mirrors ``YOLOController``'s shape (menu → validate → config dialog → +``QThread`` worker → finished handler) but drives the Ultralytics-native +:class:`~..training.sam_trainer.SAMFineTuner` instead of ``model.train`` — +Ultralytics has no SAM trainer (see the SAM fine-tuning ADR). + +Key differences from YOLO training: +- **GPU gate**: decoder fine-tuning is realistically GPU-only, so a CPU-only + box is hard-warned before a run starts. +- **Re-entrancy**: the trainer loads its own SAM instance on a worker thread + (separate from the resident inference model), so this is not a one-model + race. But two SAM models on one CUDA context is, so the SAM inference UI + (tools + model selector + this menu) is locked for the duration and the + trainer never goes through ``sam_utils._run_sync``. +- On success the fine-tuned checkpoint is registered into the SAM model + selector so it's immediately usable for annotation. +""" + +import os + +from PyQt6.QtCore import QObject, QThread, pyqtSignal +from PyQt6.QtGui import QAction +from PyQt6.QtWidgets import QFileDialog, QMessageBox + +from ..dialogs.sam_trainer_dialog import SAMTrainConfigDialog +from ..dialogs.yolo_trainer import TrainingInfoDialog +from ..training.sam_trainer import ( + SAMFineTuner, + list_custom_models, + make_custom_filename, +) + + +class SAMTrainingThread(QThread): + """Runs a fine-tuning job off the GUI thread. Emits the result dict on + success or the exception's string on failure (same contract as YOLO's + ``TrainingThread``).""" + + finished = pyqtSignal(object) + + def __init__(self, trainer: SAMFineTuner, base_model, groups, config): + super().__init__() + self.trainer = trainer + self.base_model = base_model + self.groups = groups + self.config = config + + def run(self): + try: + result = self.trainer.train(self.base_model, self.groups, **self.config) + self.finished.emit(result) + except Exception as e: # surfaced to the GUI thread by training_finished + import traceback + traceback.print_exc() + self.finished.emit(str(e)) + + +class SAMTrainController(QObject): + def __init__(self, main_window): + super().__init__(main_window) + self.mw = main_window + + # -- menu ---------------------------------------------------------------- + + def setup_sam_train_menu(self): + menu = self.mw.menuBar().addMenu("SAM &Fine-Tune (beta)") + self._menu = menu + + train_project = QAction("Train on Current Project…", self.mw) + train_project.triggered.connect(self.train_on_project) + menu.addAction(train_project) + + prepare = QAction("Prepare SAM Dataset…", self.mw) + prepare.triggered.connect(self.prepare_dataset) + menu.addAction(prepare) + + train_folder = QAction("Train from Dataset Folder…", self.mw) + train_folder.triggered.connect(self.train_from_folder) + menu.addAction(train_folder) + + menu.addSeparator() + refresh = QAction("Refresh Fine-Tuned Model List", self.mw) + refresh.triggered.connect(self.refresh_model_selector) + menu.addAction(refresh) + + # -- entry points -------------------------------------------------------- + + def train_on_project(self): + from ..training.sam_dataset import build_groups_from_project + + groups = build_groups_from_project( + self.mw.all_annotations, + self.mw.image_paths, + self.mw.slices, + self.mw.image_slices, + ) + if not groups: + QMessageBox.warning( + self.mw, "No Training Data", + "No usable annotations found. Annotate some objects (polygons " + "or boxes) first.", + ) + return + self._launch(groups) + + def prepare_dataset(self): + from ..io.export_formats import export_sam_dataset + + out_dir = QFileDialog.getExistingDirectory(self.mw, "Choose dataset output folder") + if not out_dir: + return + try: + _, manifest = export_sam_dataset( + self.mw.all_annotations, + self.mw.class_mapping, + self.mw.image_paths, + self.mw.slices, + self.mw.image_slices, + out_dir, + ) + except Exception as e: + QMessageBox.critical(self.mw, "Export Failed", str(e)) + return + QMessageBox.information( + self.mw, "Dataset Prepared", + f"SAM dataset written to:\n{manifest}\n\nUse 'Train from Dataset " + f"Folder…' to fine-tune on it.", + ) + + def train_from_folder(self): + from ..training.sam_dataset import build_groups_from_folder + + folder = QFileDialog.getExistingDirectory(self.mw, "Choose prepared SAM dataset folder") + if not folder: + return + try: + groups = build_groups_from_folder(folder) + except Exception as e: + QMessageBox.critical(self.mw, "Load Failed", str(e)) + return + if not groups: + QMessageBox.warning(self.mw, "Empty Dataset", "No usable entries in that folder.") + return + self._launch(groups) + + # -- run ----------------------------------------------------------------- + + def _launch(self, groups): + if hasattr(self.mw, "sam_training_thread") and self.mw.sam_training_thread is not None \ + and self.mw.sam_training_thread.isRunning(): + QMessageBox.information(self.mw, "Training Busy", "A fine-tuning run is already in progress.") + return + if not self._gpu_gate(): + return + + dialog = SAMTrainConfigDialog(self.mw) + if dialog.exec() != SAMTrainConfigDialog.DialogCode.Accepted: + return + cfg = dialog.get_config() + base_model = cfg.pop("base_model") + out_name = cfg.pop("out_name") + cfg["out_path"] = make_custom_filename(base_model, out_name) + + # The trainer loads its OWN SAM instance on a worker thread. The + # resident inference model is separate, but it stays reachable from the + # GUI — running inference (its own CUDA work) alongside training on the + # same device/context invites OOM and contention. Deactivate the tools + # and lock the SAM inference UI for the duration so no concurrent + # inference or model-swap can be triggered. + self.mw.sam_controller.deactivate_sam_tools() + self._set_sam_ui_locked(True) + + # Everything from here to start() must restore the UI if it raises — + # otherwise training_finished (the only other unlock site) never fires + # and the SAM tools stay disabled until app restart. + try: + self.mw.sam_finetuner = SAMFineTuner() + if not hasattr(self.mw, "sam_training_dialog"): + self.mw.sam_training_dialog = TrainingInfoDialog(self.mw) + self.mw.sam_training_dialog.setWindowTitle("SAM Fine-Tuning Progress") + # The dialog is reused across runs (same TrainingInfoDialog class as + # YOLO, but a separate instance) — clear the previous run's log so a + # new run doesn't append under stale output. + self.mw.sam_training_dialog.info_text.clear() + self.mw.sam_training_dialog.stop_button.setEnabled(True) + self.mw.sam_training_dialog.stop_button.setText("Stop Training") + self.mw.sam_training_dialog.show() + + self.mw.sam_finetuner.progress_signal.connect(self.mw.sam_training_dialog.update_info) + self.mw.sam_training_dialog.stop_signal.connect(self.mw.sam_finetuner.stop_training_signal) + + self.mw.sam_training_thread = SAMTrainingThread( + self.mw.sam_finetuner, base_model, groups, cfg + ) + self.mw.sam_training_thread.finished.connect(self.training_finished) + self.mw.sam_training_thread.start() + except Exception as e: + self._set_sam_ui_locked(False) + QMessageBox.critical(self.mw, "Could Not Start Training", str(e)) + + def _gpu_gate(self) -> bool: + """Warn (and let the user back out) when no usable GPU is present.""" + from ..core.torch_utils import resolve_torch_device + + device, _ = resolve_torch_device() + if device == "cuda": + return True + choice = QMessageBox.warning( + self.mw, "No GPU — training will be very slow", + "No usable CUDA GPU was detected. Fine-tuning SAM on CPU is " + "impractically slow (minutes per image). Continue anyway with a " + "small run?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + return choice == QMessageBox.StandardButton.Yes + + def _set_sam_ui_locked(self, locked: bool): + """Disable/enable SAM inference controls + the fine-tune menu so no + concurrent inference, model swap, or second training run can start + while a run is in flight.""" + for attr in ("sam_box_button", "sam_points_button", "sam_model_selector"): + widget = getattr(self.mw, attr, None) + if widget is not None: + widget.setEnabled(not locked) + if getattr(self, "_menu", None) is not None: + self._menu.setEnabled(not locked) + + def training_finished(self, result): + self._set_sam_ui_locked(False) + dlg = self.mw.sam_training_dialog + # Training is over — DISABLE Stop so a post-completion click can't strand + # the button on "Stopping…" forever (it only un-sticks when a run + # finishes, and none is running). _launch re-enables it for the next run. + dlg.stop_button.setEnabled(False) + dlg.stop_button.setText("Stop Training") + try: + self.mw.sam_finetuner.progress_signal.disconnect(dlg.update_info) + dlg.stop_signal.disconnect(self.mw.sam_finetuner.stop_training_signal) + except TypeError: + pass # already disconnected + + if isinstance(result, str): + QMessageBox.critical(self.mw, "Fine-Tuning Error", f"An error occurred:\n{result}") + return + + self.refresh_model_selector() + out_path = result.get("out_path") + display = f"★ {os.path.splitext(os.path.basename(out_path))[0]}" if out_path else None + if display and hasattr(self.mw, "sam_model_selector"): + idx = self.mw.sam_model_selector.findText(display) + if idx >= 0: + self.mw.sam_model_selector.setCurrentIndex(idx) + QMessageBox.information( + self.mw, "Fine-Tuning Complete", + f"Saved and verified:\n{out_path}\n\nSelected it in the SAM model " + f"dropdown — use SAM-box / SAM-points to try it.", + ) + + # -- selector ------------------------------------------------------------ + + def refresh_model_selector(self): + """Register fine-tuned checkpoints and (re)add them to the SAM dropdown.""" + customs = list_custom_models() + self.mw.sam_utils.register_custom_models(customs) + selector = getattr(self.mw, "sam_model_selector", None) + if selector is None: + return + existing = {selector.itemText(i) for i in range(selector.count())} + for display in customs: + if display not in existing: + selector.addItem(display) diff --git a/src/digitalsreeni_image_annotator/core/constants.py b/src/digitalsreeni_image_annotator/core/constants.py index 32dcf68..6a01753 100644 --- a/src/digitalsreeni_image_annotator/core/constants.py +++ b/src/digitalsreeni_image_annotator/core/constants.py @@ -21,5 +21,32 @@ DEFAULT_ZOOM = 100 # Annotation settings -DEFAULT_FILL_OPACITY = 0.3 +# Mask fill alpha — kept low so the underlying image stays legible through +# overlapping masks (the border still carries the class colour). +DEFAULT_FILL_OPACITY = 0.2 + +# Default class colour palette (tab10-style, moderately muted so masks don't +# overpower the image). Red is intentionally LAST so a fresh project's first +# class isn't red — selection highlighting is class-colour-independent, but +# starting on red was needlessly harsh. Hex strings keep this module Qt-free. +DEFAULT_CLASS_COLORS = [ + "#1F77B4", # blue + "#FF7F0E", # orange + "#2CA02C", # green + "#9467BD", # purple + "#17BECF", # cyan + "#BCBD22", # olive + "#E377C2", # pink + "#8C564B", # brown + "#7F7F7F", # gray + "#D62728", # red (last) +] + + +def default_class_color(index: int) -> str: + """Hex colour for the index-th class, cycling through DEFAULT_CLASS_COLORS. + + Callers wrap the result in ``QColor(...)`` (kept out of this module so the + core stays Qt-free).""" + return DEFAULT_CLASS_COLORS[index % len(DEFAULT_CLASS_COLORS)] diff --git a/src/digitalsreeni_image_annotator/core/torch_utils.py b/src/digitalsreeni_image_annotator/core/torch_utils.py new file mode 100644 index 0000000..6acaa2f --- /dev/null +++ b/src/digitalsreeni_image_annotator/core/torch_utils.py @@ -0,0 +1,95 @@ +""" +Shared torch device selection (upstream issue #57). + +torch.cuda.is_available() returns True for any visible CUDA device, even +when the installed torch wheels contain no kernels for its compute +capability (e.g. torch >= 2.8 wheels ship sm_70+ only, so a Pascal +GTX 1050 / sm_61 passes the availability check but every kernel launch +dies with "CUDA error: no kernel image is available for execution on +the device"). This module detects that mismatch up front and falls back +to CPU with an actionable warning instead of a cryptic crash mid-inference. +""" + +_cached_result = None + + +def resolve_torch_device(): + """Return ``(device, warning)``. + + ``device`` is ``"cuda"`` or ``"cpu"``; ``warning`` is None or a + human-readable explanation of why CUDA was rejected (printed once on + first call). The result is cached for the process lifetime so SAM, + DINO and YOLO all share one decision. + """ + global _cached_result + if _cached_result is None: + _cached_result = _resolve() + if _cached_result[1]: + print(f"[torch] {_cached_result[1]}") + return _cached_result + + +def _resolve(): + try: + import torch + except Exception: + return ("cpu", None) + + try: + if not torch.cuda.is_available(): + return ("cpu", None) + + # Deliberately device-0-only: on a mixed multi-GPU box this may + # force CPU even though a later index is supported; the app never + # selects a non-default CUDA device anywhere, so index 0 is what + # inference would actually run on. + major, minor = torch.cuda.get_device_capability(0) + device_sm = major * 10 + minor + compiled_sms = _parse_arch_list(torch.cuda.get_arch_list()) + + if compiled_sms and device_sm < min(compiled_sms): + gpu = torch.cuda.get_device_name(0) + return ( + "cpu", + f"GPU '{gpu}' (compute capability sm_{device_sm}) is not " + f"supported by the installed PyTorch build (compiled for " + f"sm_{min(compiled_sms)}+). Falling back to CPU. For GPU " + f"inference install an older PyTorch with support for this " + f"card, e.g.:\n" + f" pip install torch==2.4.1 torchvision==0.19.1 " + f"--index-url https://download.pytorch.org/whl/cu121", + ) + return ("cuda", None) + except Exception as e: + # Any probing failure: prefer a working CPU path over a crash. + return ("cpu", f"Could not verify CUDA compatibility ({e}); " + f"falling back to CPU.") + + +_warning_shown = False + + +def maybe_warn_cpu_fallback(parent=None): + """Show the CUDA-incompatibility warning as a dialog, once per session. + + No-op when the device resolved cleanly (CUDA usable, or no GPU at + all — running on CPU without a discrete GPU is expected and needs + no dialog). + """ + global _warning_shown + _, warning = resolve_torch_device() + if warning is None or _warning_shown: + return + _warning_shown = True + from PyQt6.QtWidgets import QMessageBox + QMessageBox.warning(parent, "GPU not usable — running on CPU", warning) + + +def _parse_arch_list(arch_list): + """Extract numeric sm values from e.g. ["sm_70", "sm_80", "compute_90"].""" + sms = [] + for arch in arch_list: + prefix, _, num = arch.rpartition("_") + if prefix.endswith("sm") and num.isdigit(): + sms.append(int(num)) + return sms diff --git a/src/digitalsreeni_image_annotator/dialogs/coco_json_combiner.py b/src/digitalsreeni_image_annotator/dialogs/coco_json_combiner.py index c2eb5c6..c67ff70 100644 --- a/src/digitalsreeni_image_annotator/dialogs/coco_json_combiner.py +++ b/src/digitalsreeni_image_annotator/dialogs/coco_json_combiner.py @@ -63,7 +63,7 @@ def combine_json_files(self): try: for file_path in self.json_files: - with open(file_path, 'r') as f: + with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) # Combine categories @@ -98,7 +98,7 @@ def combine_json_files(self): output_file, _ = QFileDialog.getSaveFileName(self, "Save Combined JSON", "", "JSON Files (*.json)") if output_file: - with open(output_file, 'w') as f: + with open(output_file, 'w', encoding='utf-8') as f: json.dump(combined_data, f, indent=2) QMessageBox.information(self, "Success", f"Combined JSON saved to {output_file}") diff --git a/src/digitalsreeni_image_annotator/dialogs/dataset_splitter.py b/src/digitalsreeni_image_annotator/dialogs/dataset_splitter.py index df2169f..c9e09b7 100644 --- a/src/digitalsreeni_image_annotator/dialogs/dataset_splitter.py +++ b/src/digitalsreeni_image_annotator/dialogs/dataset_splitter.py @@ -160,7 +160,7 @@ def split_images_only(self): QMessageBox.information(self, "Success", "Dataset split successfully!") def split_images_and_annotations(self): - with open(self.json_file, 'r') as f: + with open(self.json_file, 'r', encoding='utf-8') as f: coco_data = json.load(f) image_files = [img['file_name'] for img in coco_data['images']] @@ -246,7 +246,7 @@ def save_coco_annotations(self, data, subset): subset_dir = os.path.join(self.output_directory, subset) os.makedirs(subset_dir, exist_ok=True) output_file = os.path.join(subset_dir, f"{subset}_annotations.json") - with open(output_file, 'w') as f: + with open(output_file, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2) def split_yolo_format(self, coco_data, train_images, val_images, test_images): @@ -289,7 +289,7 @@ def split_yolo_format(self, coco_data, train_images, val_images, test_images): # Create YOLO format labels label_file = os.path.join(labels_dir, os.path.splitext(image_file)[0] + ".txt") - with open(label_file, "w") as f: + with open(label_file, "w", encoding='utf-8') as f: for ann in annotations: # Convert COCO class id to YOLO class id yolo_class = categories[ann["category_id"]] @@ -310,7 +310,7 @@ def split_yolo_format(self, coco_data, train_images, val_images, test_images): } yaml_data.update(yaml_paths) # Add only paths for non-empty splits - with open(os.path.join(self.output_directory, 'data.yaml'), 'w') as f: + with open(os.path.join(self.output_directory, 'data.yaml'), 'w', encoding='utf-8') as f: yaml.dump(yaml_data, f, default_flow_style=False) QMessageBox.information(self, "Success", "Dataset and YOLO annotations split successfully!") diff --git a/src/digitalsreeni_image_annotator/dialogs/dicom_converter.py b/src/digitalsreeni_image_annotator/dialogs/dicom_converter.py index b53bc60..2aa374c 100644 --- a/src/digitalsreeni_image_annotator/dialogs/dicom_converter.py +++ b/src/digitalsreeni_image_annotator/dialogs/dicom_converter.py @@ -3,8 +3,7 @@ import numpy as np from datetime import datetime from PyQt6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, - QLabel, QProgressDialog, QRadioButton, QButtonGroup, - QMessageBox, QApplication, QGroupBox) + QLabel, QProgressDialog, QRadioButton, QMessageBox, QApplication, QGroupBox) from PyQt6.QtCore import Qt import pydicom from pydicom.pixel_data_handlers.util import apply_voi_lut @@ -230,7 +229,7 @@ def convert_dicom(self): metadata_file = os.path.join(self.output_directory, os.path.splitext(os.path.basename(self.input_file))[0] + "_metadata.json") - with open(metadata_file, 'w') as f: + with open(metadata_file, 'w', encoding='utf-8') as f: json.dump(series_metadata, f, indent=2) # Get physical sizes from metadata diff --git a/src/digitalsreeni_image_annotator/dialogs/dino_merge_dialog.py b/src/digitalsreeni_image_annotator/dialogs/dino_merge_dialog.py index 44bb6ac..0b4b8dc 100644 --- a/src/digitalsreeni_image_annotator/dialogs/dino_merge_dialog.py +++ b/src/digitalsreeni_image_annotator/dialogs/dino_merge_dialog.py @@ -6,7 +6,6 @@ import json import math -import os import random import traceback from collections import defaultdict @@ -180,7 +179,7 @@ def _run(self): # Load and validate records = [] for path in coco_files: - with open(path) as f: + with open(path, encoding='utf-8') as f: data = json.load(f) if not data.get("images") or not data.get("annotations"): self._log_msg(f" [skip] {path.name}: empty.") @@ -295,9 +294,9 @@ def _build_coco(imgs): train_data = _build_coco(train_imgs) val_data = _build_coco(val_imgs) - with open(out_path / "train.json", "w") as f: + with open(out_path / "train.json", "w", encoding='utf-8') as f: json.dump(train_data, f, indent=2) - with open(out_path / "val.json", "w") as f: + with open(out_path / "val.json", "w", encoding='utf-8') as f: json.dump(val_data, f, indent=2) self._log_msg(f"Train images: {len(train_imgs)}, annotations: {len(train_data['annotations'])}") diff --git a/src/digitalsreeni_image_annotator/dialogs/dino_phrase_editor.py b/src/digitalsreeni_image_annotator/dialogs/dino_phrase_editor.py index 107fd09..fe4fda6 100644 --- a/src/digitalsreeni_image_annotator/dialogs/dino_phrase_editor.py +++ b/src/digitalsreeni_image_annotator/dialogs/dino_phrase_editor.py @@ -65,15 +65,17 @@ def __init__(self, parent=None): self.verticalHeader().setSectionResizeMode( QHeaderView.ResizeMode.ResizeToContents) self.setMaximumHeight(160) - # No hardcoded background colors — pick them up from the active - # stylesheet so the table integrates with both light and dark - # mode. The earlier "background: #e0e0e0" produced a bright bar - # across the top of the panel in dark mode. - # No font-size either: the compact size is set (and scaled with - # ui_font_pt) by the appended overrides in ui/theme.py. + # Structural only — no colours here. Header background/text come from + # the active stylesheet's QHeaderView::section rule (light: + # default_stylesheet, dark: soft_dark_stylesheet), so the header + # matches both themes. The earlier inline "background: palette(mid); + # color: palette(text)" resolved against the *OS* palette (dark on + # some boxes) and painted the header black in the app's light mode — + # see the No Hardcoded Colors Rule. No font-size either: the compact + # size is set (and scaled with ui_font_pt) by the overrides in + # ui/theme.py. self.setStyleSheet( - "QHeaderView::section { font-weight: bold; " - " padding: 2px; background-color: palette(mid); color: palette(text); }" + "QHeaderView::section { font-weight: bold; padding: 2px; }" ) def _make_spin(self, value=0.25): diff --git a/src/digitalsreeni_image_annotator/dialogs/image_augmenter.py b/src/digitalsreeni_image_annotator/dialogs/image_augmenter.py index e78c3b6..e160d64 100644 --- a/src/digitalsreeni_image_annotator/dialogs/image_augmenter.py +++ b/src/digitalsreeni_image_annotator/dialogs/image_augmenter.py @@ -169,7 +169,7 @@ def select_coco_json(self): self.coco_file, _ = QFileDialog.getOpenFileName(self, "Select COCO JSON File", "", "JSON Files (*.json)") if self.coco_file: self.coco_label.setText(f"COCO JSON File: {os.path.basename(self.coco_file)}") - with open(self.coco_file, 'r') as f: + with open(self.coco_file, 'r', encoding='utf-8') as f: self.coco_data = json.load(f) self.coco_check.setChecked(True) # Automatically check the box when a file is loaded @@ -270,7 +270,7 @@ def start_augmentation(self): if self.coco_check.isChecked(): output_coco_path = os.path.join(self.output_dir, "augmented_annotations.json") - with open(output_coco_path, 'w') as f: + with open(output_coco_path, 'w', encoding='utf-8') as f: json.dump(augmented_coco_data, f, indent=2) QMessageBox.information(self, "Augmentation Complete", "Image and annotation augmentation has been completed successfully.") diff --git a/src/digitalsreeni_image_annotator/dialogs/project_search.py b/src/digitalsreeni_image_annotator/dialogs/project_search.py index c19c7e7..b6d7064 100644 --- a/src/digitalsreeni_image_annotator/dialogs/project_search.py +++ b/src/digitalsreeni_image_annotator/dialogs/project_search.py @@ -1,7 +1,7 @@ from PyQt6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QLineEdit, QPushButton, - QDateEdit, QLabel, QListWidget, QDialogButtonBox, QFormLayout, + QDateEdit, QListWidget, QDialogButtonBox, QFormLayout, QFileDialog, QMessageBox) -from PyQt6.QtCore import Qt, QDate +from PyQt6.QtCore import QDate import os import json from datetime import datetime @@ -83,7 +83,7 @@ def perform_search(self): if filename.endswith('.iap'): project_path = os.path.join(root, filename) try: - with open(project_path, 'r') as f: + with open(project_path, 'r', encoding='utf-8') as f: project_data = json.load(f) if self.project_matches(project_data, query, start_date, end_date): diff --git a/src/digitalsreeni_image_annotator/dialogs/sam_trainer_dialog.py b/src/digitalsreeni_image_annotator/dialogs/sam_trainer_dialog.py new file mode 100644 index 0000000..9041dae --- /dev/null +++ b/src/digitalsreeni_image_annotator/dialogs/sam_trainer_dialog.py @@ -0,0 +1,100 @@ +"""Config dialog for SAM 2 fine-tuning. + +Collects the hyperparameters for ``training.sam_trainer.SAMFineTuner.train``. +Progress is shown via the shared ``TrainingInfoDialog`` (reused from the YOLO +trainer) — this module only owns the *config* dialog. +""" + +from PyQt6.QtWidgets import ( + QCheckBox, + QComboBox, + QDialog, + QDialogButtonBox, + QDoubleSpinBox, + QFormLayout, + QLabel, + QLineEdit, + QSpinBox, + QVBoxLayout, +) + +from ..inference.sam_utils import MODEL_NAMES + + +class SAMTrainConfigDialog(QDialog): + """Modal config for a fine-tuning run. Read results via :meth:`get_config`.""" + + def __init__(self, parent=None, *, default_name="my_finetune"): + super().__init__(parent) + self.setWindowTitle("Fine-Tune SAM Model") + layout = QVBoxLayout(self) + form = QFormLayout() + + self.base_model = QComboBox() + self.base_model.addItems(MODEL_NAMES) + # tiny/small are the realistic choices for desktop fine-tuning. + form.addRow("Base model:", self.base_model) + + self.out_name = QLineEdit(default_name) + form.addRow("Save as:", self.out_name) + + self.epochs = QSpinBox() + self.epochs.setRange(1, 1000) + self.epochs.setValue(10) + form.addRow("Epochs:", self.epochs) + + self.lr = QDoubleSpinBox() + self.lr.setDecimals(6) + self.lr.setRange(1e-6, 1e-1) + self.lr.setSingleStep(1e-5) + self.lr.setValue(1e-4) + form.addRow("Learning rate:", self.lr) + + self.batch_size = QSpinBox() + self.batch_size.setRange(1, 64) + self.batch_size.setValue(2) + self.batch_size.setToolTip( + "Gradient-accumulation count — the optimizer steps every N images " + "(all of an image's objects are backpropagated together)." + ) + form.addRow("Batch size:", self.batch_size) + + self.prompt_type = QComboBox() + self.prompt_type.addItems(["bbox", "point"]) + self.prompt_type.setToolTip("Prompt derived from each ground-truth mask during training.") + form.addRow("Train prompt:", self.prompt_type) + + self.train_encoder = QCheckBox("Also fine-tune image encoder (slower, needs more VRAM/data)") + form.addRow("", self.train_encoder) + + layout.addLayout(form) + + note = QLabel( + "Decoder-only (default) is fast and robust on modest data. " + "Fine-tuning the image encoder can help on heavily domain-shifted " + "imagery but needs more GPU memory and labels." + ) + note.setWordWrap(True) + # No inline color — inside a QDialog `palette(text)` resolves against the + # stale OS palette and renders near-white in light mode. Let the global + # QLabel stylesheet rule provide a theme-correct colour (No Hardcoded + # Colors Rule, docs/08_crosscutting_concepts.md). + layout.addWidget(note) + + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addWidget(buttons) + + def get_config(self) -> dict: + return { + "base_model": self.base_model.currentText(), + "out_name": self.out_name.text().strip() or "my_finetune", + "epochs": self.epochs.value(), + "lr": self.lr.value(), + "batch_size": self.batch_size.value(), + "prompt_type": self.prompt_type.currentText(), + "freeze_image_encoder": not self.train_encoder.isChecked(), + } diff --git a/src/digitalsreeni_image_annotator/dialogs/yolo_trainer.py b/src/digitalsreeni_image_annotator/dialogs/yolo_trainer.py index 07536fe..e4b0974 100644 --- a/src/digitalsreeni_image_annotator/dialogs/yolo_trainer.py +++ b/src/digitalsreeni_image_annotator/dialogs/yolo_trainer.py @@ -1,7 +1,7 @@ import os from PyQt6.QtWidgets import QFileDialog, QMessageBox from PyQt6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QPushButton, - QLineEdit, QLabel, QFileDialog, QDialogButtonBox) + QLineEdit, QLabel, QDialogButtonBox) import yaml import numpy as np from pathlib import Path @@ -11,8 +11,8 @@ from collections import deque -from PyQt6.QtWidgets import QDialog, QVBoxLayout, QTextEdit, QPushButton -from PyQt6.QtCore import Qt, pyqtSignal, QObject +from PyQt6.QtWidgets import QTextEdit +from PyQt6.QtCore import pyqtSignal, QObject class TrainingInfoDialog(QDialog): stop_signal = pyqtSignal() @@ -140,7 +140,7 @@ def prepare_dataset(self): ) yaml_path = Path(yaml_path) - with yaml_path.open('r') as f: + with yaml_path.open('r', encoding='utf-8') as f: yaml_content = yaml.safe_load(f) # Update paths for new YOLO v5+ structure @@ -148,7 +148,7 @@ def prepare_dataset(self): yaml_content['val'] = 'images/val' # Changed from train/images yaml_content['test'] = '../test/images' - with yaml_path.open('w') as f: + with yaml_path.open('w', encoding='utf-8') as f: yaml.dump(yaml_content, f, default_flow_style=False) self.yaml_path = str(yaml_path) @@ -158,7 +158,7 @@ def load_yaml(self, yaml_path=None): if yaml_path is None: yaml_path, _ = QFileDialog.getOpenFileName(self.main_window, "Select YOLO Dataset YAML", "", "YAML Files (*.yaml *.yml)") if yaml_path and os.path.exists(yaml_path): - with open(yaml_path, 'r') as f: + with open(yaml_path, 'r', encoding='utf-8') as f: try: yaml_data = yaml.safe_load(f) print(f"Loaded YAML contents: {yaml_data}") @@ -175,7 +175,7 @@ def load_yaml(self, yaml_path=None): self.yaml_path = yaml_path # Write the updated YAML back to the file - with open(yaml_path, 'w') as f: + with open(yaml_path, 'w', encoding='utf-8') as f: yaml.dump(yaml_data, f, default_flow_style=False) return True @@ -218,7 +218,7 @@ def train_model(self, epochs=100, imgsz=640): print(f"Training with YAML: {yaml_path}") print(f"YAML directory: {yaml_dir}") - with yaml_path.open('r') as f: + with yaml_path.open('r', encoding='utf-8') as f: yaml_content = yaml.safe_load(f) print(f"YAML content: {yaml_content}") @@ -237,13 +237,15 @@ def train_model(self, epochs=100, imgsz=640): # Write updated YAML with adjusted paths temp_yaml_path = yaml_dir / 'temp_train.yaml' - with temp_yaml_path.open('w') as f: + with temp_yaml_path.open('w', encoding='utf-8') as f: yaml.dump(yaml_content, f, default_flow_style=False) print(f"Training with updated YAML: {temp_yaml_path}") print(f"Updated YAML content: {yaml_content}") - results = self.model.train(data=str(temp_yaml_path), epochs=epochs, imgsz=imgsz) + from ..core.torch_utils import resolve_torch_device + device, _ = resolve_torch_device() + results = self.model.train(data=str(temp_yaml_path), epochs=epochs, imgsz=imgsz, device=device) return results finally: # Clear the callback @@ -256,7 +258,7 @@ def verify_dataset_structure(self): yaml_path = Path(self.yaml_path) yaml_dir = yaml_path.parent - with yaml_path.open('r') as f: + with yaml_path.open('r', encoding='utf-8') as f: yaml_content = yaml.safe_load(f) # Use paths from YAML content @@ -277,9 +279,9 @@ def verify_dataset_structure(self): missing_dirs.append(f"Validation labels directory: {val_labels_dir}") if missing_dirs: - raise FileNotFoundError(f"The following directories were not found:\n" + "\n".join(missing_dirs)) + raise FileNotFoundError("The following directories were not found:\n" + "\n".join(missing_dirs)) - print(f"Dataset structure verified:") + print("Dataset structure verified:") print(f"Train images: {train_images_dir}") print(f"Train labels: {train_labels_dir}") print(f"Val images: {val_images_dir}") @@ -288,7 +290,7 @@ def verify_dataset_structure(self): def check_ultralytics_settings(self): settings_path = Path.home() / ".config" / "Ultralytics" / "settings.yaml" if settings_path.exists(): - with settings_path.open('r') as f: + with settings_path.open('r', encoding='utf-8') as f: settings = yaml.safe_load(f) print(f"Ultralytics settings: {settings}") else: @@ -330,7 +332,7 @@ def on_epoch_end(self, trainer): info = f"Epoch {epoch}/{total_epochs}, Loss: {loss:.4f}" self.epoch_info.append(info) - display_text = f"Current Progress:\n" + "\n".join(self.epoch_info) + display_text = "Current Progress:\n" + "\n".join(self.epoch_info) if self.progress_callback: self.progress_callback(display_text) @@ -340,7 +342,7 @@ def save_model(self): raise ValueError("No model to save. Please train a model first.") save_path, _ = QFileDialog.getSaveFileName(self.main_window, "Save YOLO Model", "", "YOLO Model (*.pt)") if save_path: - self.model.export(save_path) + self.model.save(save_path) return True return False @@ -349,7 +351,7 @@ def load_prediction_model(self, model_path, yaml_path): try: self.model = YOLO(model_path) - with open(yaml_path, 'r') as f: + with open(yaml_path, 'r', encoding='utf-8') as f: self.prediction_yaml = yaml.safe_load(f) if 'names' not in self.prediction_yaml: @@ -375,12 +377,10 @@ def load_prediction_model(self, model_path, yaml_path): def predict(self, input_data): if self.model is None: raise ValueError("No model loaded. Please load a model first.") - if isinstance(input_data, str): - # It's a file path - results = self.model(input_data, task='segment', conf=self.conf_threshold, save=False, show=False) - elif isinstance(input_data, np.ndarray): - # It's a numpy array - results = self.model(input_data, task='segment', conf=self.conf_threshold, save=False, show=False) + from ..core.torch_utils import resolve_torch_device + device, _ = resolve_torch_device() + if isinstance(input_data, (str, np.ndarray)): + results = self.model(input_data, task='segment', conf=self.conf_threshold, save=False, show=False, device=device) else: raise ValueError("Invalid input type. Expected file path or numpy array.") diff --git a/src/digitalsreeni_image_annotator/inference/dino_utils.py b/src/digitalsreeni_image_annotator/inference/dino_utils.py index f1cea6a..f8ad372 100644 --- a/src/digitalsreeni_image_annotator/inference/dino_utils.py +++ b/src/digitalsreeni_image_annotator/inference/dino_utils.py @@ -73,15 +73,15 @@ def __init__(self): # ── model lifecycle ─────────────────────────────────────────────── def _resolve_device(self) -> str: - """Pick CUDA if available; honour DINO_DEVICE env override.""" + """Pick CUDA if available and usable; honour DINO_DEVICE override.""" env = os.environ.get("DINO_DEVICE") if env: return env - try: - import torch - return "cuda" if torch.cuda.is_available() else "cpu" - except Exception: - return "cpu" + # Shared helper also rejects GPUs whose compute capability the + # installed torch wheels can't run (upstream issue #57). + from ..core.torch_utils import resolve_torch_device + device, _ = resolve_torch_device() + return device def _load_model_blocking(self, model_path: str) -> None: """Load (cache) the Grounding DINO model for ``model_path``.""" diff --git a/src/digitalsreeni_image_annotator/inference/sam_utils.py b/src/digitalsreeni_image_annotator/inference/sam_utils.py index badd6df..a584d99 100644 --- a/src/digitalsreeni_image_annotator/inference/sam_utils.py +++ b/src/digitalsreeni_image_annotator/inference/sam_utils.py @@ -302,6 +302,22 @@ def __init__(self): self.current_sam_model: str | None = None self._model = None # ultralytics.SAM instance once loaded self._loaded_model_file: str | None = None + self._device: str | None = None # resolved at model load + # Fine-tuned checkpoints: {display_name: checkpoint_path}. Populated + # from training.sam_trainer.list_custom_models() so they load through + # the same SAM(path) path as the built-ins (see SAM fine-tuning ADR). + self.custom_models: dict[str, str] = {} + + def register_custom_models(self, mapping: dict) -> None: + """Merge fine-tuned model entries so they become selectable.""" + self.custom_models.update(mapping or {}) + + def _resolve_model_file(self, model_name: str) -> str: + if model_name in MODEL_NAMES: + return os.path.join(SAM_MODELS_DIR, MODEL_FILES[model_name]) + if model_name in self.custom_models: + return self.custom_models[model_name] + raise ValueError(f"Unknown SAM model: {model_name}") # ── model lifecycle ──────────────────────────────────────────────── @@ -313,37 +329,39 @@ def change_sam_model(self, model_name: str) -> None: print("SAM model unset") return - if model_name not in MODEL_NAMES: - raise ValueError(f"Unknown SAM model: {model_name}") + # Resolve to a checkpoint path first (built-in name or fine-tuned + # model) so an unknown name fails before we touch a worker thread. + model_file = self._resolve_model_file(model_name) # Load on a worker thread to avoid stalling the UI on the # ~1-3 s torch model-load. Behaves synchronously to callers and # re-raises any load-time exception (network, corrupt weights, # CUDA OOM) — only flip `current_sam_model` AFTER success so # callers don't see a stale name on failure. - _run_sync(self._load_model_blocking, model_name) + _run_sync(self._load_model_blocking, model_file) self.current_sam_model = model_name self.model_changed.emit(model_name) print(f"SAM model loaded: {model_name}") - def _load_model_blocking(self, model_name: str) -> None: + def _load_model_blocking(self, model_file: str) -> None: # Lazy import keeps app startup fast for users who never use SAM. from ultralytics import SAM + from ..core.torch_utils import resolve_torch_device + + self._device, _ = resolve_torch_device() self._log_device() - model_file = os.path.join(SAM_MODELS_DIR, MODEL_FILES[model_name]) os.makedirs(os.path.dirname(model_file), exist_ok=True) self._model = SAM(model_file) self._loaded_model_file = model_file - @staticmethod - def _log_device() -> None: + def _log_device(self) -> None: try: import torch - if torch.cuda.is_available(): + if self._device == "cuda": dev = torch.cuda.get_device_name(0) print(f"[SAM] Using CUDA: {torch.version.cuda} — {dev}") else: - print("[SAM] No GPU available, running on CPU") + print("[SAM] Running on CPU") except Exception: pass @@ -406,7 +424,9 @@ def apply_sam_points(self, image: QImage, positive_points, negative_points): def _sam_points_blocking(self, image_np, positive_points, negative_points): all_points = [positive_points + negative_points] all_labels = [([1] * len(positive_points)) + ([0] * len(negative_points))] - results = self._model(image_np, points=all_points, labels=all_labels) + results = self._model( + image_np, points=all_points, labels=all_labels, device=self._device + ) masks = results[0].masks.data.cpu().numpy() confidences = results[0].boxes.conf.cpu().numpy() @@ -438,7 +458,7 @@ def apply_sam_prediction(self, image: QImage, bbox): ) def _sam_bbox_blocking(self, image_np, bbox): - results = self._model(image_np, bboxes=[bbox]) + results = self._model(image_np, bboxes=[bbox], device=self._device) res = results[0] if not (hasattr(res, "masks") and res.masks is not None): return None @@ -481,7 +501,7 @@ def apply_sam_predictions_batch(self, image: QImage, bboxes: list): ) def _sam_batch_blocking(self, image_np, bboxes): - results = self._model(image_np, bboxes=bboxes) + results = self._model(image_np, bboxes=bboxes, device=self._device) res = results[0] if not (hasattr(res, "masks") and res.masks is not None): # Build a fresh dict per bbox so callers can mutate one diff --git a/src/digitalsreeni_image_annotator/io/export_formats.py b/src/digitalsreeni_image_annotator/io/export_formats.py index ae181aa..6eff3d0 100644 --- a/src/digitalsreeni_image_annotator/io/export_formats.py +++ b/src/digitalsreeni_image_annotator/io/export_formats.py @@ -19,7 +19,7 @@ def convert_to_coco(all_annotations, class_mapping, image_paths, slices, image_s with tempfile.TemporaryDirectory() as temp_dir: json_file_path, images_dir = export_coco_json(all_annotations, class_mapping, image_paths, slices, image_slices, temp_dir) - with open(json_file_path, 'r') as f: + with open(json_file_path, 'r', encoding='utf-8') as f: coco_data = json.load(f) return coco_data, images_dir @@ -118,7 +118,7 @@ def export_coco_json(all_annotations, class_mapping, image_paths, slices, image_ # Save COCO JSON file json_file_path = os.path.join(output_dir, json_filename) - with open(json_file_path, 'w') as f: + with open(json_file_path, 'w', encoding='utf-8') as f: json.dump(coco_format, f, indent=2) return json_file_path, images_dir @@ -205,7 +205,7 @@ def export_yolo_v4(all_annotations, class_mapping, image_paths, slices, image_sl # Write YOLO format annotation label_file = os.path.splitext(file_name_img)[0] + '.txt' - with open(os.path.join(labels_dir, label_file), 'w') as f: + with open(os.path.join(labels_dir, label_file), 'w', encoding='utf-8') as f: for class_name, class_annotations in annotations.items(): if class_name not in class_to_index: print(f"[YOLO v4] warning: class {class_name!r} not in class_mapping, skipped") @@ -236,7 +236,7 @@ def export_yolo_v4(all_annotations, class_mapping, image_paths, slices, image_sl # Save YAML file in the output directory yaml_path = os.path.join(output_dir, 'data.yaml') - with open(yaml_path, 'w') as f: + with open(yaml_path, 'w', encoding='utf-8') as f: yaml.dump(yaml_data, f, default_flow_style=False) return train_dir, yaml_path @@ -280,7 +280,7 @@ def export_yolo_v5plus(all_annotations, class_mapping, image_paths, slices, imag print(f"[YOLO v5+] image={image_name!r} annotation-classes={list(annotations.keys()) if annotations else '(none)'}") # Skip if there are no annotations for this image/slice if not annotations: - print(f"[YOLO v5+] skipping: no annotations") + print("[YOLO v5+] skipping: no annotations") continue # For simplicity, we'll put all data in the train directory @@ -334,7 +334,7 @@ def export_yolo_v5plus(all_annotations, class_mapping, image_paths, slices, imag label_file = os.path.splitext(file_name_img)[0] + '.txt' label_path = os.path.join(labels_dir, label_file) ann_lines = 0 - with open(label_path, 'w') as f: + with open(label_path, 'w', encoding='utf-8') as f: for class_name, class_annotations in annotations.items(): if class_name not in class_to_index: print(f"[YOLO v5+] warning: class {class_name!r} not in class_mapping, skipped") @@ -372,13 +372,85 @@ def export_yolo_v5plus(all_annotations, class_mapping, image_paths, slices, imag # Save YAML file in the output directory yaml_path = os.path.join(output_dir, 'data.yaml') - with open(yaml_path, 'w') as f: + with open(yaml_path, 'w', encoding='utf-8') as f: yaml.dump(yaml_data, f, default_flow_style=False) return output_dir, yaml_path +def export_sam_dataset(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir): + """Export a SAM fine-tuning dataset: ``images/`` + ``manifest.json``. + + The manifest is the authoritative training source — per-instance ``bbox``/ + ``segmentation`` specs are rasterised to masks deterministically at train + time (see ``training.sam_dataset``), so no separate mask PNGs are written. + Image resolution mirrors ``export_yolo_v5plus`` (slices via ``slices`` / + ``image_slices``; regular images via ``image_paths``; TIFF/CZI skipped). + + Returns ``(output_dir, manifest_path)``. + """ + images_dir = os.path.join(output_dir, 'images') + os.makedirs(images_dir, exist_ok=True) + slice_map = {slice_name: qimage for slice_name, qimage in slices} + + manifest = {"classes": list(class_mapping.keys()), "images": []} + for image_name, annotations in all_annotations.items(): + if not annotations: + continue + + # Resolve + save the image (same branching as export_yolo_v5plus). + if image_name in slice_map or ('_' in image_name and '.' not in image_name): + qimage = slice_map.get(image_name) + if qimage is None: + for stack_slices in image_slices.values(): + qimage = next((s[1] for s in stack_slices if s[0] == image_name), None) + if qimage is not None: + break + if qimage is None: + continue + # basename guards against a separator in an image/slice key + # escaping images/ during write. + file_name_img = f"{os.path.basename(image_name)}.png" + save_path = os.path.join(images_dir, file_name_img) + if not os.path.exists(save_path): + qimage.save(save_path) + else: + image_path = image_paths.get(image_name) + if image_path is None: + image_path = next( + (path for name, path in image_paths.items() if image_name in name), + None, + ) + if not image_path: + continue + if image_path.lower().endswith(('.tif', '.tiff', '.czi')): + continue + file_name_img = os.path.basename(image_name) + dst_path = os.path.join(images_dir, file_name_img) + if not os.path.exists(dst_path): + shutil.copy2(image_path, dst_path) + + instances = [] + for class_name, class_annotations in annotations.items(): + for ann in class_annotations: + if ann.get('segmentation'): + instances.append({"class": class_name, "segmentation": ann['segmentation']}) + elif ann.get('bbox'): + instances.append({"class": class_name, "bbox": ann['bbox']}) + if instances: + manifest["images"].append({ + "image": os.path.join('images', file_name_img), + "instances": instances, + }) + + manifest_path = os.path.join(output_dir, 'manifest.json') + with open(manifest_path, 'w', encoding='utf-8') as f: + json.dump(manifest, f, indent=2) + print(f"[SAM dataset] wrote {len(manifest['images'])} image entries -> {manifest_path}") + return output_dir, manifest_path + + def export_labeled_images(all_annotations, class_mapping, image_paths, slices, image_slices, output_dir): # Create output directories images_dir = os.path.join(output_dir, 'images') @@ -477,7 +549,7 @@ def export_labeled_images(all_annotations, class_mapping, image_paths, slices, i # Create summary text file summary_path = os.path.join(labeled_images_dir, 'class_summary.txt') - with open(summary_path, 'w') as f: + with open(summary_path, 'w', encoding='utf-8') as f: f.write("Classes (folder names):\n") for class_name, files in class_summary.items(): if files: # Only include classes that have annotations @@ -575,7 +647,7 @@ def export_semantic_labels(all_annotations, class_mapping, image_paths, slices, # Create class mapping text file mapping_path = os.path.join(segmented_images_dir, 'class_pixel_mapping.txt') - with open(mapping_path, 'w') as f: + with open(mapping_path, 'w', encoding='utf-8') as f: f.write("Pixel Value : Class Name\n") for class_name, pixel_value in class_to_pixel.items(): f.write(f"{pixel_value} : {class_name}\n") @@ -680,7 +752,7 @@ def export_pascal_voc_bbox(all_annotations, class_mapping, image_paths, slices, # Save the XML file xml_str = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ") xml_filename = os.path.splitext(file_name_img)[0] + '.xml' - with open(os.path.join(annotations_dir, xml_filename), 'w') as f: + with open(os.path.join(annotations_dir, xml_filename), 'w', encoding='utf-8') as f: f.write(xml_str) return output_dir @@ -798,7 +870,7 @@ def export_pascal_voc_both(all_annotations, class_mapping, image_paths, slices, # Save the XML file xml_str = minidom.parseString(ET.tostring(root)).toprettyxml(indent=" ") xml_filename = os.path.splitext(file_name_img)[0] + '.xml' - with open(os.path.join(annotations_dir, xml_filename), 'w') as f: + with open(os.path.join(annotations_dir, xml_filename), 'w', encoding='utf-8') as f: f.write(xml_str) return output_dir \ No newline at end of file diff --git a/src/digitalsreeni_image_annotator/io/import_formats.py b/src/digitalsreeni_image_annotator/io/import_formats.py index 6de3669..650fa0a 100644 --- a/src/digitalsreeni_image_annotator/io/import_formats.py +++ b/src/digitalsreeni_image_annotator/io/import_formats.py @@ -4,17 +4,12 @@ import yaml from PIL import Image -from PyQt6.QtCore import QRectF -from PyQt6.QtGui import QColor -from PyQt6.QtWidgets import QMessageBox, QFileDialog - -import os -import json from PyQt6.QtWidgets import QMessageBox + def import_coco_json(file_path, class_mapping): try: - with open(file_path, 'r') as f: + with open(file_path, 'r', encoding='utf-8') as f: coco_data = json.load(f) # Validate required fields @@ -127,7 +122,7 @@ def import_yolo_v4(yaml_file_path, class_mapping): directory_path = os.path.dirname(yaml_file_path) - with open(yaml_file_path, 'r') as f: + with open(yaml_file_path, 'r', encoding='utf-8') as f: yaml_data = yaml.safe_load(f) class_names = yaml_data.get('names', []) @@ -184,7 +179,7 @@ def import_yolo_v4(yaml_file_path, class_mapping): imported_annotations[img_file] = {} label_path = os.path.join(labels_dir, label_file) - with open(label_path, 'r') as f: + with open(label_path, 'r', encoding='utf-8') as f: lines = f.readlines() for line in lines: @@ -267,7 +262,7 @@ def import_yolo_v5plus(yaml_file_path, class_mapping): root_dir = os.path.dirname(yaml_file_path) - with open(yaml_file_path, 'r') as f: + with open(yaml_file_path, 'r', encoding='utf-8') as f: yaml_data = yaml.safe_load(f) class_names = yaml_data.get('names', []) @@ -320,7 +315,7 @@ def import_yolo_v5plus(yaml_file_path, class_mapping): imported_annotations[img_file] = {} label_path = os.path.join(labels_dir, label_file) - with open(label_path, 'r') as f: + with open(label_path, 'r', encoding='utf-8') as f: lines = f.readlines() for line in lines: diff --git a/src/digitalsreeni_image_annotator/training/__init__.py b/src/digitalsreeni_image_annotator/training/__init__.py new file mode 100644 index 0000000..975f438 --- /dev/null +++ b/src/digitalsreeni_image_annotator/training/__init__.py @@ -0,0 +1 @@ +"""SAM 2 / 2.1 fine-tuning: training engine and dataset builders.""" diff --git a/src/digitalsreeni_image_annotator/training/sam_dataset.py b/src/digitalsreeni_image_annotator/training/sam_dataset.py new file mode 100644 index 0000000..a8fcf94 --- /dev/null +++ b/src/digitalsreeni_image_annotator/training/sam_dataset.py @@ -0,0 +1,107 @@ +"""Build SAM fine-tuning :class:`SampleGroup`s from either the live project +annotations or a prepared on-disk dataset folder. + +The project path mirrors the image-resolution logic in +``io.export_formats.export_yolo_v5plus`` (slice lookup via ``slices`` / +``image_slices``; regular images via ``image_paths`` with exact-then-substring +match; TIFF/CZI source files skipped in favour of their extracted slices), so a +dataset that exports cleanly to YOLO also trains cleanly here. +""" + +from __future__ import annotations + +import json +import os + +from PyQt6.QtGui import QImage + +from .sam_trainer import SampleGroup +from ..inference.sam_utils import _qimage_to_numpy + + +def _specs_for(annotations) -> list: + """Flatten ``{class: [ann, ...]}`` into raw instance specs the + :class:`SampleGroup` rasterises lazily.""" + specs = [] + for _class_name, class_annotations in (annotations or {}).items(): + for ann in class_annotations: + if ann.get("segmentation"): + specs.append({"segmentation": ann["segmentation"]}) + elif ann.get("bbox"): + specs.append({"bbox": ann["bbox"]}) + return specs + + +def build_groups_from_project(all_annotations, image_paths, slices, image_slices): + """Live project annotations → ``list[SampleGroup]``. + + Images load lazily (one at a time during training) to bound memory; in-RAM + slice QImages are reused directly. + """ + slice_map = {name: qimage for name, qimage in slices} + groups = [] + + for image_name, image_annotations in all_annotations.items(): + specs = _specs_for(image_annotations) + if not specs: + continue + + if image_name in slice_map or ("_" in image_name and "." not in image_name): + qimage = slice_map.get(image_name) + if qimage is None: + for stack_slices in image_slices.values(): + qimage = next((s[1] for s in stack_slices if s[0] == image_name), None) + if qimage is not None: + break + if qimage is None: + print(f"[SAM dataset] skip slice {image_name!r}: no image data") + continue + # Convert the in-memory slice QImage to numpy HERE, on the GUI + # thread. The array is later consumed by the training worker + # thread; reading constBits() of a live, GUI-shared QImage from + # another thread is exactly what _qimage_to_numpy warns against, + # so we hand the worker a thread-owned copy instead of a lambda + # that defers the buffer read onto the worker. + arr = _qimage_to_numpy(qimage) + groups.append(SampleGroup(lambda a=arr: a, specs)) + continue + + image_path = image_paths.get(image_name) + if image_path is None: + image_path = next( + (p for name, p in image_paths.items() if image_name in name), None + ) + if not image_path: + print(f"[SAM dataset] skip {image_name!r}: no image_paths entry") + continue + if image_path.lower().endswith((".tif", ".tiff", ".czi")): + print(f"[SAM dataset] skip TIFF/CZI source {image_name!r} (use slices)") + continue + groups.append(SampleGroup(lambda p=image_path: _qimage_to_numpy(QImage(p)), specs)) + + return groups + + +# ── prepared folder ────────────────────────────────────────────────────────── + +def build_groups_from_folder(folder: str): + """Read a folder produced by ``export_sam_dataset`` → ``list[SampleGroup]``. + + Expects ``/manifest.json`` with entries + ``{"image": "images/x.png", "instances": [{"bbox": [...]}|{"segmentation": [...]}]}``. + """ + manifest_path = os.path.join(folder, "manifest.json") + if not os.path.exists(manifest_path): + raise FileNotFoundError(f"No manifest.json in {folder}") + with open(manifest_path, "r", encoding="utf-8") as f: + manifest = json.load(f) + + groups = [] + for entry in manifest.get("images", []): + img_rel = entry["image"] + img_path = os.path.join(folder, img_rel) + specs = entry.get("instances", []) + if not specs or not os.path.exists(img_path): + continue + groups.append(SampleGroup(lambda p=img_path: _qimage_to_numpy(QImage(p)), specs)) + return groups diff --git a/src/digitalsreeni_image_annotator/training/sam_trainer.py b/src/digitalsreeni_image_annotator/training/sam_trainer.py new file mode 100644 index 0000000..6c63d60 --- /dev/null +++ b/src/digitalsreeni_image_annotator/training/sam_trainer.py @@ -0,0 +1,443 @@ +"""SAM 2 / 2.1 fine-tuning engine — Ultralytics-native custom training loop. + +Why this exists +--------------- +Ultralytics has **no SAM trainer**: ``SAM.task_map`` registers only a +*predictor* for the ``segment`` task, so ``SAM(...).train()`` raises +``NotImplementedError`` (verified on ultralytics 8.4.51). We therefore +cannot mirror the YOLO ``model.train(data=yaml, ...)`` path. + +What we do instead +------------------ +``SAM(...).model`` is a plain ``nn.Module`` (``SAM2Model``) exposing +``image_encoder`` / ``sam_prompt_encoder`` / ``sam_mask_decoder``, and the +Ultralytics ``SAM2Predictor`` already implements the forward path in reusable +pieces — ``get_im_features`` (image encoder) and ``prompt_inference`` / +``_inference_features`` (prompt encoder + mask decoder). We call those methods +**under autograd** (they are not wrapped in ``inference_mode`` unless reached +via the public ``__call__``), add a focal+dice loss and an AdamW step, and +save a checkpoint that reloads through the existing ``SAM(path)`` inference +path. No extra dependency, and the result drops straight into the app's SAM +model selector. + +This keeps our exposure confined to a thin adapter over already-exercised +predictor methods. The spike that validated the mechanic lives in the PR for +issue bnsreenu#73. + +Threading +--------- +``train()`` is **blocking and CPU/GPU-bound**; the controller runs it on a +dedicated ``QThread`` (never the GUI thread). Unlike SAM inference it does +*not* go through ``sam_utils._run_sync`` — that helper's re-entry guard is +GUI-thread-local. + +The trainer loads its **own** ``SAM`` instance (see ``_build_predictor``); it +never touches ``SAMUtils._model``, so this is not "two threads driving one +model". The hazard it *does* create is two SAM models (the resident inference +one and this training one) competing for the same GPU/CUDA context. The +controller therefore locks the SAM inference UI (tools + model selector + the +fine-tune menu) for the duration so no concurrent inference or model swap can +be triggered while a run is in flight. +""" + +from __future__ import annotations + +import os +import random + +import cv2 +import numpy as np +from PyQt6.QtCore import QObject, pyqtSignal + +from ..inference.sam_utils import MODEL_FILES, MODEL_NAMES, SAM_MODELS_DIR + +# Fine-tuned checkpoints live alongside the base weights, namespaced so they +# never collide with Ultralytics' auto-downloaded base files. +SAM_CUSTOM_DIR = os.path.join(SAM_MODELS_DIR, "custom") + + +def make_custom_filename(base_model: str, name: str) -> str: + """Build a fine-tuned checkpoint path under ``SAM_CUSTOM_DIR``. + + Ultralytics' ``build_sam`` selects the architecture by ``ckpt.endswith(token)`` + where ``token`` is the base file name (e.g. ``sam2_t.pt``). A fine-tuned file + therefore **must** keep that suffix or ``SAM(path)`` raises "not a supported + SAM model". We sanitise the user label and append ``_``. + """ + token = MODEL_FILES.get(base_model, os.path.basename(base_model)) + safe = "".join(c if c.isalnum() or c in "-_" else "_" for c in name).strip("_") + safe = safe or "finetuned" + return os.path.join(SAM_CUSTOM_DIR, f"{safe}_{token}") + + +def list_custom_models() -> dict: + """``{display_name: path}`` for fine-tuned checkpoints, for the SAM selector.""" + out = {} + if os.path.isdir(SAM_CUSTOM_DIR): + for fn in sorted(os.listdir(SAM_CUSTOM_DIR)): + if fn.endswith(".pt"): + out[f"★ {os.path.splitext(fn)[0]}"] = os.path.join(SAM_CUSTOM_DIR, fn) + return out + + +# ── geometry: annotation → (mask, prompt) ─────────────────────────────────── + +def polygon_to_mask(segmentation, height: int, width: int) -> np.ndarray: + """Flat ``[x1,y1,x2,y2,...]`` polygon → bool mask (inverse of + ``sam_utils._mask_to_polygon``).""" + pts = np.array(segmentation, dtype=np.float32).reshape(-1, 2) + mask = np.zeros((height, width), dtype=np.uint8) + cv2.fillPoly(mask, [np.round(pts).astype(np.int32)], 1) + return mask.astype(bool) + + +def bbox_to_mask(bbox, height: int, width: int) -> np.ndarray: + """``[x, y, w, h]`` → bool mask.""" + x, y, w, h = bbox + mask = np.zeros((height, width), dtype=np.uint8) + cv2.rectangle(mask, (int(x), int(y)), (int(x + w), int(y + h)), 1, thickness=-1) + return mask.astype(bool) + + +def mask_to_xyxy(mask: np.ndarray): + """Tight ``[x1, y1, x2, y2]`` bounding box of a bool mask, or None if empty.""" + ys, xs = np.where(mask) + if xs.size == 0: + return None + return [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())] + + +def mask_to_point(mask: np.ndarray): + """A single foreground point well inside the mask. + + Uses the distance transform's argmax so the point sits near the medial + axis rather than on an edge — a more stable positive prompt than a random + interior pixel. + """ + m = mask.astype(np.uint8) + if m.sum() == 0: + return None + dist = cv2.distanceTransform(m, cv2.DIST_L2, 3) + y, x = np.unravel_index(int(np.argmax(dist)), dist.shape) + return [float(x), float(y)] + + +# ── loss ──────────────────────────────────────────────────────────────────── + +def _focal_dice_loss(logits, target, focal_weight: float = 20.0): + """SAM's mask supervision: focal + dice, ≈20:1 (focal:dice). + + ``logits`` and ``target`` are ``(1, 1, H, W)`` on the same device; target + is float {0,1}. Matches the recipe used across the SAM fine-tuning + literature (focal for hard pixels, dice for region overlap). + """ + import torch + import torch.nn.functional as F + + prob = torch.sigmoid(logits) + # Focal (binary), gamma=2. + bce = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + p_t = prob * target + (1 - prob) * (1 - target) + focal = (bce * (1 - p_t).pow(2)).mean() + # Dice. + inter = (prob * target).sum() + dice = 1 - (2 * inter + 1) / (prob.sum() + target.sum() + 1) + return focal_weight * focal + dice + + +# ── dataset ────────────────────────────────────────────────────────────────── + +class SampleGroup: + """One image plus the specs for its instances, loaded lazily. + + ``image_loader`` returns an RGB ``uint8`` array (matching what + ``sam_utils._qimage_to_numpy`` feeds inference — channel-order consistency + between train and predict matters more than absolute order). ``specs`` are + raw annotations (``{"segmentation": [...]}`` or ``{"bbox": [x,y,w,h]}``) + rasterised to bool masks **at load time** using the actual image size, so + masks are never held in RAM between epochs and always match the image. + """ + + def __init__(self, image_loader, specs): + self._image_loader = image_loader + self.specs = specs + self.n_instances = len(specs) + + def load(self): + """Return ``(image_rgb, [{"mask": bool HxW}, ...])``.""" + image = self._image_loader() + h, w = image.shape[:2] + instances = [] + for spec in self.specs: + if spec.get("segmentation"): + mask = polygon_to_mask(spec["segmentation"], h, w) + elif spec.get("bbox"): + mask = bbox_to_mask(spec["bbox"], h, w) + else: + continue + if mask.any(): + instances.append({"mask": mask}) + return image, instances + + +# ── engine ─────────────────────────────────────────────────────────────────── + +class SAMFineTuner(QObject): + """Fine-tunes a SAM 2 mask decoder (optionally image encoder) on + user instances. Mirrors ``YOLOTrainer``'s signal/stop surface so the + controller and progress dialog wiring is identical.""" + + progress_signal = pyqtSignal(str) + + def __init__(self): + super().__init__() + self.stop_training = False + + def stop_training_signal(self): + self.stop_training = True + self.progress_signal.emit("Stopping after current step…") + + # -- model setup --------------------------------------------------------- + + def _build_predictor(self, base_model): + """Return a ready ``SAM2Predictor`` for ``base_model`` (a registry name + like ``"SAM 2 tiny"`` or a path to a ``.pt``). + + Forces predictor creation with one throwaway predict so ``set_image`` / + ``prompt_inference`` are usable. Pins the device via + ``resolve_torch_device`` so an incompatible GPU (which Ultralytics would + otherwise pick blindly and crash on) is honoured as CPU — the same + device decision SAM/DINO/YOLO inference already share.""" + from ultralytics import SAM + + from ..core.torch_utils import resolve_torch_device + + if base_model in MODEL_NAMES: + model_file = os.path.join(SAM_MODELS_DIR, MODEL_FILES[base_model]) + else: + model_file = base_model + if not os.path.exists(model_file): + raise FileNotFoundError(f"Base SAM weights not found: {model_file}") + + device, _ = resolve_torch_device() + model = SAM(model_file) + warm = (np.random.rand(64, 64, 3) * 255).astype(np.uint8) + # device= forces the warmup (and predictor model placement) onto the + # resolved device rather than Ultralytics' default cuda-if-present. + model(warm, bboxes=[[8, 8, 56, 56]], device=device, verbose=False) + return model, model.predictor, model_file + + @staticmethod + def _device_label(device) -> str: + """``cuda:0 (NVIDIA GeForce RTX 4070)`` — make it obvious the GPU is in + use (a bare ``cuda:0`` reads to users like the GPU wasn't detected).""" + try: + import torch + if str(device).startswith("cuda"): + idx = device.index if getattr(device, "index", None) is not None else 0 + return f"{device} ({torch.cuda.get_device_name(idx)})" + except Exception: + pass + return str(device) + + @staticmethod + def _apply_freeze(net, freeze_image_encoder: bool): + net.eval() + for p in net.parameters(): + p.requires_grad_(False) + for p in net.sam_mask_decoder.parameters(): + p.requires_grad_(True) + if not freeze_image_encoder: + for p in net.image_encoder.parameters(): + p.requires_grad_(True) + net.image_encoder.train() + trainable = [p for p in net.parameters() if p.requires_grad] + n = sum(p.numel() for p in trainable) + return trainable, n + + # -- training ------------------------------------------------------------ + + def train( + self, + base_model, + groups, + *, + epochs: int = 10, + lr: float = 1e-4, + batch_size: int = 1, + freeze_image_encoder: bool = True, + prompt_type: str = "bbox", + out_path: str, + ) -> dict: + """Fine-tune and save. Returns a small result dict; raises on failure. + + ``groups`` is an iterable of :class:`SampleGroup`. ``batch_size`` is a + gradient-accumulation count over **images** — all of an image's objects + are backpropagated together (one backward per image), so the optimizer + steps every ``batch_size`` images. + """ + import torch + from ultralytics.utils import ops + + groups = list(groups) + if not groups: + raise ValueError("No annotated instances to train on.") + if prompt_type not in ("bbox", "point"): + raise ValueError(f"Unknown prompt_type: {prompt_type}") + + model, pred, base_file = self._build_predictor(base_model) + net = pred.model + device = next(net.parameters()).device + trainable, n_trainable = self._apply_freeze(net, freeze_image_encoder) + self.progress_signal.emit( + f"Base: {os.path.basename(base_file)} | device: {self._device_label(device)} | " + f"images: {len(groups)} | trainable params: {n_trainable:,} | " + f"encoder {'TRAINED' if not freeze_image_encoder else 'frozen'}" + ) + + optimizer = torch.optim.AdamW(trainable, lr=lr, weight_decay=0.1) + self.stop_training = False + total_instances = sum(g.n_instances for g in groups) + if total_instances == 0: + raise ValueError("No annotated instances to train on.") + + for epoch in range(1, epochs + 1): + if self.stop_training: + break + random.shuffle(groups) + epoch_loss, seen, accum = 0.0, 0, 0 + optimizer.zero_grad() + + for group in groups: + if self.stop_training: + break + image, instances = group.load() + h, w = image.shape[:2] + im_t = self._set_image(pred, image, freeze_image_encoder) + + # Accumulate all of THIS image's instance losses and backward + # ONCE per image. When the image encoder is trainable, every + # instance's decoder graph hangs off the one shared encoder + # feature graph; a per-instance backward() would free that + # shared graph and make the next instance raise "backward + # through the graph a second time". One backward per image + # keeps the shared graph alive for exactly one pass. + inst_losses = [] + for inst in instances: + bbox = mask_to_xyxy(inst["mask"]) + if bbox is None: + continue + prompt = self._prompt_kwargs(prompt_type, inst["mask"], bbox) + with torch.enable_grad(): + pm, _ = pred.prompt_inference( + im_t, multimask_output=False, **prompt + ) + # Map the low-res logits back to the original image with the + # SAME transform inference uses (SAM2Predictor.postprocess → + # ops.scale_masks(padding=False)). SAM2 letterboxes the image + # (resize-min-ratio + pad bottom/right) and scale_masks crops + # that padding before upsampling. A naive interpolate over the + # full low-res mask instead bakes the padding region into the + # target, so the decoder learns masks shifted by the pad — the + # downward shift seen on non-square images (issue #73 testing). + logits = ops.scale_masks( + pm[:1].unsqueeze(0).float(), (h, w), padding=False + ) + target = torch.from_numpy(inst["mask"].astype(np.float32))[None, None].to(device) + inst_losses.append(_focal_dice_loss(logits, target)) + + del image # bound memory: encoder features recomputed per epoch + if not inst_losses: + continue + + # batch_size = number of IMAGES to accumulate before an + # optimizer step (gradient accumulation over images). + image_loss = torch.stack(inst_losses).mean() / max(1, batch_size) + image_loss.backward() + epoch_loss += image_loss.detach().item() * max(1, batch_size) + seen += 1 + accum += 1 + if accum >= batch_size: + optimizer.step() + optimizer.zero_grad() + accum = 0 + + if accum > 0: # flush partial accumulation + optimizer.step() + optimizer.zero_grad() + avg = epoch_loss / max(1, seen) + self.progress_signal.emit(f"Epoch {epoch}/{epochs} loss={avg:.4f}") + + result = self._save_and_verify(net, base_file, out_path) + result.update(stopped=self.stop_training, instances=total_instances) + self.progress_signal.emit( + f"Saved fine-tuned model: {out_path}" if not self.stop_training + else f"Stopped early — saved current state to {out_path}" + ) + return result + + def _set_image(self, pred, image: np.ndarray, freeze_image_encoder: bool): + """Preprocess ``image`` and compute encoder features into the predictor. + + Sets ``pred.batch`` explicitly so ``_prepare_prompts`` maps bbox/point + prompts from this image's original pixel size into model coordinates — + a stale ``batch`` from a different-sized image would mis-scale prompts. + """ + import torch + + pred.setup_source(image) + im_t = None + for batch in pred.dataset: + im_t = pred.preprocess(batch[1]) + break + # (paths, im0s, ...) — prompt_inference reads self.batch[1][0].shape for orig H,W. + pred.batch = (None, [image], None) + if freeze_image_encoder: + with torch.no_grad(): + pred.features = pred.get_im_features(im_t) + else: + pred.features = pred.get_im_features(im_t) + return im_t + + @staticmethod + def _prompt_kwargs(prompt_type: str, mask: np.ndarray, bbox): + if prompt_type == "bbox": + return {"bboxes": [bbox]} + pt = mask_to_point(mask) + return {"points": [pt], "labels": [1]} + + # -- checkpoint ---------------------------------------------------------- + + def _save_and_verify(self, net, base_file: str, out_path: str) -> dict: + """Save fine-tuned weights as ``{"model": state_dict}`` and prove they + reload through ``SAM(out_path)``. + + Ultralytics' ``_load_checkpoint`` reads only the nested ``"model"`` key + (a tensor dict) and rebuilds the architecture from the filename suffix, + so a pure state_dict is all we need — no need to ``torch.load`` (and + unpickle) the base file. Because ``net`` is the same ``SAM2Model`` class + Ultralytics instantiates, the keys match exactly, sidestepping the + "Unexpected key(s)" reload failures (facebookresearch/sam2#337). + """ + import torch + from ultralytics import SAM + + token = os.path.basename(base_file) + if not os.path.basename(out_path).endswith(token): + raise ValueError( + f"Fine-tuned checkpoint name must end with '{token}' so " + f"Ultralytics can pick the right architecture; got " + f"'{os.path.basename(out_path)}'. Use make_custom_filename()." + ) + + os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) + net_cpu_state = {k: v.detach().cpu() for k, v in net.state_dict().items()} + torch.save({"model": net_cpu_state}, out_path) + + # Round-trip verification: load + one forward. Failing here is loud by design. + verify = SAM(out_path) + verify( + (np.random.rand(64, 64, 3) * 255).astype(np.uint8), + bboxes=[[8, 8, 56, 56]], verbose=False, + ) + return {"out_path": out_path, "verified": True} diff --git a/src/digitalsreeni_image_annotator/ui/default_stylesheet.py b/src/digitalsreeni_image_annotator/ui/default_stylesheet.py index 13005a6..23efc93 100644 --- a/src/digitalsreeni_image_annotator/ui/default_stylesheet.py +++ b/src/digitalsreeni_image_annotator/ui/default_stylesheet.py @@ -45,6 +45,14 @@ } +QHeaderView::section { + background-color: #E0E0E0; + color: #333333; + border: 1px solid #CCCCCC; + padding: 4px; +} + + QLabel { color: #333333; } diff --git a/src/digitalsreeni_image_annotator/ui/sidebar.py b/src/digitalsreeni_image_annotator/ui/sidebar.py index 487e48e..5c85c2c 100644 --- a/src/digitalsreeni_image_annotator/ui/sidebar.py +++ b/src/digitalsreeni_image_annotator/ui/sidebar.py @@ -146,11 +146,11 @@ def build_sidebar(window): dino_browse_layout.setContentsMargins(0, 0, 0, 0) window.lbl_dino_custom = QLabel("No path set") window.lbl_dino_custom.setWordWrap(True) - # Use palette(text) so the colour follows the active stylesheet - # (light or dark) — hardcoded #555 used to render unreadable on - # dark mode. See "No Hardcoded Colors Rule" in CLAUDE.md. - # No font-size here — the caption follows the global ui_font_pt rule. - window.lbl_dino_custom.setStyleSheet("color:palette(text);") + # No inline colour: `palette(text)` resolves to a near-white role in light + # mode (rendering this caption unreadable). The global QLabel stylesheet + # rule gives a theme-correct colour in both modes — leave the property out + # so the global sheet wins. See "No Hardcoded Colors Rule" in CLAUDE.md. + # No font-size here either — the caption follows the global ui_font_pt rule. btn_dino_browse = QPushButton("Browse") # No fixed width — a 60px cap clipped the caption at large UI font # sizes (low-vision zoom); sizeHint tracks the active font. @@ -319,6 +319,20 @@ def build_image_list(window): window.image_list_label = QLabel("Images:") window.image_list_layout.addWidget(window.image_list_label) + # Annotation-status filter (upstream issue #27). Index order matters: + # ImageController.apply_image_filter maps 0=all, 1=without, 2=with. + window.image_filter_combo = QComboBox() + window.image_filter_combo.addItem("All images") + window.image_filter_combo.addItem("Without annotations") + window.image_filter_combo.addItem("With annotations") + window.image_filter_combo.setToolTip( + "Filter the image list by annotation status" + ) + window.image_filter_combo.currentIndexChanged.connect( + lambda _index: window.apply_image_filter() + ) + window.image_list_layout.addWidget(window.image_filter_combo) + window.image_list = QListWidget() window.image_list.itemClicked.connect(window.switch_image) window.image_list.currentRowChanged.connect( diff --git a/src/digitalsreeni_image_annotator/widgets/canvas_context.py b/src/digitalsreeni_image_annotator/widgets/canvas_context.py index cfc440d..5a95496 100644 --- a/src/digitalsreeni_image_annotator/widgets/canvas_context.py +++ b/src/digitalsreeni_image_annotator/widgets/canvas_context.py @@ -38,6 +38,12 @@ def is_class_visible(self, name: str) -> bool: def current_image_key(self): return self._mw.current_slice or self._mw.image_file_name + def has_annotation_selection(self) -> bool: + # The annotation list is the source of truth for what a Delete acts + # on; the canvas Delete must read it (not the possibly-stale red + # highlight) so it can't fire on an empty list selection. See ADR-022. + return bool(self._mw.annotation_list.selectedItems()) + def all_annotations(self) -> dict: return self._mw.all_annotations diff --git a/src/digitalsreeni_image_annotator/widgets/image_label.py b/src/digitalsreeni_image_annotator/widgets/image_label.py index ab89d8e..67d9856 100644 --- a/src/digitalsreeni_image_annotator/widgets/image_label.py +++ b/src/digitalsreeni_image_annotator/widgets/image_label.py @@ -29,6 +29,8 @@ from PyQt6.QtWidgets import QLabel, QMessageBox from .tools import EraserTool, PaintBrushTool, PolygonTool, RectangleTool +from ..core.constants import DEFAULT_FILL_OPACITY +from ..utils import calculate_area warnings.filterwarnings("ignore", category=UserWarning) @@ -44,6 +46,7 @@ class ImageLabel(QLabel): annotationsReplaced = pyqtSignal(str, dict) # eraser path: (image_key, per-class dict) annotationListUpdateRequested = pyqtSignal() # editing-mode exit refresh annotationSelected = pyqtSignal(object) # double-click selection + canvasSelectionChanged = pyqtSignal(object, str) # (list[annotation], mode); mode: replace|add|toggle deleteSelectionRequested = pyqtSignal() finishPolygonRequested = pyqtSignal() finishRectangleRequested = pyqtSignal() @@ -68,6 +71,12 @@ class ImageLabel(QLabel): zoomOutRequested = pyqtSignal() imageInfoChanged = pyqtSignal() + # Selection highlight: a semi-transparent selection-blue, drawn as the + # dashed bounding-box marquee (handles use the opaque variant). Class- + # colour-independent, so it never vanishes the way the old red-on-red + # highlight did. + _SELECTION_COLOR = QColor(0, 120, 215, 220) + def __init__(self, parent=None): super().__init__(parent) self.annotations = {} @@ -85,6 +94,11 @@ def __init__(self, parent=None): self.start_point = None self.end_point = None self.highlighted_annotations = [] + # Idle-mode mask selection (issue #75): drag a rubber band to box-select. + # selection_rect is (x0, y0, x1, y1) in image coords while a drag is live. + self.selection_origin = None + self.selecting = False + self.selection_rect = None self.setMouseTracking(True) self.setFocusPolicy(Qt.FocusPolicy.StrongFocus) self.original_pixmap = None @@ -97,7 +111,7 @@ def __init__(self, parent=None): self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None - self.fill_opacity = 0.3 + self.fill_opacity = DEFAULT_FILL_OPACITY self.drawing_rectangle = False self.current_rectangle = None self.bit_depth = None @@ -221,6 +235,11 @@ def reset_annotation_state(self): self.temp_point = None self.start_point = None self.end_point = None + # Drop any in-progress rubber-band so a stale rect can't render on + # the next image/slice (switch_image/switch_slice call through here). + self.selection_origin = None + self.selecting = False + self.selection_rect = None def clear_current_annotation(self): """Clear the current annotation.""" @@ -249,6 +268,9 @@ def paintEvent(self, event): # Polygon edit mode is modal; runs orthogonal to tool selection if self.editing_polygon: self.draw_editing_polygon(painter) + # Idle-mode rubber-band selection rectangle (issue #75) + if self.selection_rect is not None: + self.draw_selection_rect(painter) # SAM overlays (cross-cutting; not part of the tool handlers) if self.sam_box_active and self.sam_bbox: self.draw_sam_bbox(painter) @@ -413,6 +435,23 @@ def draw_sam_bbox(self, painter): painter.drawRect(QRectF(min(x1, x2), min(y1, y2), abs(x2 - x1), abs(y2 - y1))) painter.restore() + def draw_selection_rect(self, painter): + """Draw the idle-mode rubber-band selection rectangle (issue #75). + + A single dashed selection-blue rect with a faint fill — same restrained + style as the selection outline (not red, which clashes with class colours).""" + painter.save() + painter.translate(self.offset_x, self.offset_y) + painter.scale(self.zoom_factor, self.zoom_factor) + x0, y0, x1, y1 = self.selection_rect + rect = QRectF(min(x0, x1), min(y0, y1), abs(x1 - x0), abs(y1 - y0)) + fill = QColor(self._SELECTION_COLOR) + fill.setAlphaF(0.10) + painter.setBrush(QBrush(fill)) + painter.setPen(QPen(self._SELECTION_COLOR, self._pen_w(1), Qt.PenStyle.DashLine)) + painter.drawRect(rect) + painter.restore() + def clear_temp_sam_prediction(self): self.temp_sam_prediction = None self.update() @@ -448,6 +487,9 @@ def clear(self): self.start_point = None self.end_point = None self.highlighted_annotations.clear() + self.selection_origin = None + self.selecting = False + self.selection_rect = None self.original_pixmap = None self.scaled_pixmap = None self.editing_polygon = None @@ -476,13 +518,12 @@ def draw_annotations(self, painter): color = self.class_colors.get(class_name, QColor(Qt.GlobalColor.white)) for annotation in class_annotations: - if annotation in self.highlighted_annotations: - border_color = Qt.GlobalColor.red - fill_color = QColor(Qt.GlobalColor.red) - else: - border_color = color - fill_color = QColor(color) - + # Selection no longer recolours the mask (it used to turn red, + # which was invisible on a red-class mask). The mask always + # keeps its class colour; selection is drawn as a + # class-colour-independent overlay in a final pass below. + border_color = color + fill_color = QColor(color) fill_color.setAlphaF(self.fill_opacity) text_color = Qt.GlobalColor.white if self.dark_mode else Qt.GlobalColor.black @@ -552,8 +593,47 @@ def draw_annotations(self, painter): centroid, f"SAM: {self.temp_sam_prediction['score']:.2f}" ) + # Selection overlay — drawn LAST so it sits on top of every mask's + # fill, and in a class-colour-independent style so it's recognisable + # regardless of the selected mask's colour (issue #75 follow-up). + for annotation in self.highlighted_annotations: + self._draw_selection_overlay(painter, annotation) + painter.restore() + def _draw_selection_overlay(self, painter, annotation): + """Mark a selected annotation the way the sibling open-garden-planner + app does: a dashed selection-blue bounding box plus bright square + handles at the 4 corners and 4 edge midpoints. Class-colour-independent + and clearly visible regardless of the mask's own colour.""" + if not self._is_class_pickable(annotation.get("category_name")): + return # don't draw selection chrome over a hidden mask + bb = self._annotation_bbox(annotation) + if bb is None: + return + x0, y0, x1, y1 = bb + rect = QRectF(x0, y0, x1 - x0, y1 - y0) + + # Dashed bounding-box marquee. + painter.setBrush(Qt.BrushStyle.NoBrush) + painter.setPen(QPen(self._SELECTION_COLOR, self._pen_w(1.5), Qt.PenStyle.DashLine)) + painter.drawRect(rect) + + # Handle squares — opaque blue with a white casing so they read on any + # background; fixed on-screen size (zoom-compensated). Visual selection + # markers; resizing via handles is a separate feature (upstream #40). + cx, cy = (x0 + x1) / 2.0, (y0 + y1) / 2.0 + handles = [ + (x0, y0), (cx, y0), (x1, y0), + (x0, cy), (x1, cy), + (x0, y1), (cx, y1), (x1, y1), + ] + half = 4 * self.ui_scale / self.zoom_factor + painter.setPen(QPen(Qt.GlobalColor.white, self._pen_w(1), Qt.PenStyle.SolidLine)) + painter.setBrush(QBrush(QColor(0, 120, 215))) + for hx, hy in handles: + painter.drawRect(QRectF(hx - half, hy - half, 2 * half, 2 * half)) + def draw_editing_polygon(self, painter): """Draw the polygon being edited.""" painter.save() @@ -682,6 +762,13 @@ def mousePressEvent(self, event: QMouseEvent): self.drawing_sam_bbox = True elif self.editing_polygon: self.handle_editing_click(pos, event) + elif self._is_select_mode(): + # Idle-mode mask selection (issue #75): remember the press as + # the potential rubber-band origin; a click vs. drag is + # decided on move/release. + self.selection_origin = pos + self.selecting = False + self.selection_rect = None else: handler = self.active_tool_handler if handler is not None: @@ -716,6 +803,10 @@ def mouseMoveEvent(self, event: QMouseEvent): self.sam_bbox[3] = pos[1] elif self.editing_polygon: self.handle_editing_move(pos) + elif self._is_select_mode() and self.selection_origin is not None and ( + event.buttons() & Qt.MouseButton.LeftButton + ): + self._update_selection_drag(pos) else: handler = self.active_tool_handler if handler is not None: @@ -743,6 +834,8 @@ def mouseReleaseEvent(self, event: QMouseEvent): self.samPredictionApplyRequested.emit() elif self.editing_polygon: self.editing_point_index = None + elif self._is_select_mode() and self.selection_origin is not None: + self._finish_selection(pos, event) else: handler = self.active_tool_handler if handler is not None: @@ -815,6 +908,13 @@ def keyPressEvent(self, event: QKeyEvent): self.editing_point_index = None self.hover_point_index = None self.enableToolsRequested.emit() + elif self._is_select_mode() and ( + self.selecting or self.selection_origin is not None + ): + # Cancel an in-progress rubber band (selection unchanged). + self.selection_origin = None + self.selecting = False + self.selection_rect = None else: handler = self.active_tool_handler if handler is not None: @@ -827,6 +927,18 @@ def keyPressEvent(self, event: QKeyEvent): self.hover_point_index = None self.enableToolsRequested.emit() self.update() + elif ( + self._is_select_mode() + and self._ctx is not None + and self._ctx.has_annotation_selection() + ): + # Idle-mode canvas selection: delete the selected masks. + # Gate on the annotation-list selection (the controller's + # source of truth) rather than the red highlight, which a + # list rebuild (e.g. a sort) can leave stale — otherwise a + # canvas Delete would pop a spurious "nothing selected" + # warning. See ADR-022. + self.deleteSelectionRequested.emit() elif event.key() == Qt.Key.Key_Minus: if self.current_tool == "paint_brush": new_size = max(1, self._ctx.paint_brush_size() - 1) @@ -847,7 +959,127 @@ def keyPressEvent(self, event: QKeyEvent): print(f"Eraser size: {new_size}") self.update() + # --- Idle-mode mask selection (issue #75) --- + + def _is_select_mode(self): + """True when the canvas is idle (no drawing/SAM tool, not editing, + no temp review) — the only state where bare clicks/drags select + existing masks instead of drawing.""" + return ( + self.current_tool is None + and not self.editing_polygon + and not self.sam_box_active + and not self.sam_points_active + and not self.temp_annotations + and not self.temp_sam_prediction + ) + + @staticmethod + def _annotation_contains(annotation, pos): + """Hit-test a single annotation (segmentation polygon or bbox).""" + if "segmentation" in annotation: + seg = annotation["segmentation"] + points = [QPoint(int(x), int(y)) for x, y in zip(seg[0::2], seg[1::2])] + return len(points) >= 3 and ImageLabel.point_in_polygon(pos, points) + if "bbox" in annotation: + x, y, w, h = annotation["bbox"] + return x <= pos[0] <= x + w and y <= pos[1] <= y + h + return False + + @staticmethod + def _annotation_bbox(annotation): + """Axis-aligned bounds (x0, y0, x1, y1) of an annotation, or None.""" + if "segmentation" in annotation: + seg = annotation["segmentation"] + xs, ys = seg[0::2], seg[1::2] + if not xs or not ys: + return None + return (min(xs), min(ys), max(xs), max(ys)) + if "bbox" in annotation: + x, y, w, h = annotation["bbox"] + return (x, y, x + w, y + h) + return None + + def _is_class_pickable(self, class_name): + # No context (e.g. unit tests) → everything is pickable. + return self._ctx is None or self._ctx.is_class_visible(class_name) + + def annotation_at(self, pos): + """Smallest-area annotation containing pos, or None. Covers both + segmentation and bbox annotations and skips hidden classes. Smallest + wins so a mask nested inside another stays reachable (cf. + start_polygon_edit / upstream #33).""" + best = None + best_area = None + for class_name, annotations in self.annotations.items(): + if not self._is_class_pickable(class_name): + continue + for annotation in annotations: + if self._annotation_contains(annotation, pos): + area = calculate_area(annotation) + if best is None or area < best_area: + best = annotation + best_area = area + return best + + def annotations_in_rect(self, rect): + """All annotations whose bounds intersect the rubber-band rect. + rect is (x0, y0, x1, y1) in image coords (any corner order).""" + x0, y0, x1, y1 = rect + rx0, rx1 = min(x0, x1), max(x0, x1) + ry0, ry1 = min(y0, y1), max(y0, y1) + result = [] + for class_name, annotations in self.annotations.items(): + if not self._is_class_pickable(class_name): + continue + for annotation in annotations: + bb = self._annotation_bbox(annotation) + if bb is None: + continue + ax0, ay0, ax1, ay1 = bb + if ax0 <= rx1 and ax1 >= rx0 and ay0 <= ry1 and ay1 >= ry0: + result.append(annotation) + return result + + def _update_selection_drag(self, pos): + """Grow the rubber band once the drag clears the click threshold.""" + if self.selection_origin is None: + return + if not self.selecting: + threshold = 3.0 / max(self.zoom_factor, 1e-6) + if self.distance(pos, self.selection_origin) < threshold: + return + self.selecting = True + ox, oy = self.selection_origin + self.selection_rect = (ox, oy, pos[0], pos[1]) + + def _finish_selection(self, pos, event): + """Resolve a press→release in select mode into a selection change. + Shift makes it additive (drag) / toggling (click).""" + additive = bool(event.modifiers() & Qt.KeyboardModifier.ShiftModifier) + if self.selecting and self.selection_rect is not None: + anns = self.annotations_in_rect(self.selection_rect) + self.canvasSelectionChanged.emit(anns, "add" if additive else "replace") + else: + ann = self.annotation_at(pos) + if additive: + if ann is not None: # Shift+click on empty space keeps the selection + self.canvasSelectionChanged.emit([ann], "toggle") + else: + self.canvasSelectionChanged.emit( + [ann] if ann is not None else [], "replace" + ) + self.selection_origin = None + self.selecting = False + self.selection_rect = None + def start_polygon_edit(self, pos): + # Among all polygons containing the click, edit the smallest by + # area so an annotation fully nested inside another is reachable + # (upstream issue #33) instead of always grabbing the first/outer + # match. + best = None + best_area = None for class_name, annotations in self.annotations.items(): for annotation in annotations: if "segmentation" in annotation: @@ -859,11 +1091,16 @@ def start_polygon_edit(self, pos): ) ] if self.point_in_polygon(pos, points): - self.editing_polygon = annotation - self.current_tool = None - self.disableToolsRequested.emit() - self.resetToolButtonsRequested.emit() - return annotation + area = calculate_area(annotation) + if best is None or area < best_area: + best = annotation + best_area = area + if best is not None: + self.editing_polygon = best + self.current_tool = None + self.disableToolsRequested.emit() + self.resetToolButtonsRequested.emit() + return best return None def handle_editing_click(self, pos, event): diff --git a/tests/integration/test_canvas_selection_controller.py b/tests/integration/test_canvas_selection_controller.py new file mode 100644 index 0000000..db336a9 --- /dev/null +++ b/tests/integration/test_canvas_selection_controller.py @@ -0,0 +1,167 @@ +"""Canvas-selection ↔ annotation-list integration tests (bnsreenu #75). + +apply_canvas_selection must update image_label.highlighted_annotations, +mirror that onto the annotation list selection (so Delete / Merge / +Change-Class operate on the same set), and toggle the merge / change-class +buttons. The canvas Delete path then reuses delete_selected_annotations. + +One real offscreen ImageAnnotator; no model weights, no worker thread. +""" + +import copy + +import pytest +from PyQt6.QtCore import Qt +from PyQt6.QtWidgets import QMessageBox + + +@pytest.fixture +def window(qt_application): + from digitalsreeni_image_annotator.annotator_window import ImageAnnotator + + w = ImageAnnotator() + yield w + w.deleteLater() + + +def _square(x0, y0, side, number): + return { + "segmentation": [x0, y0, x0 + side, y0, x0 + side, y0 + side, x0, y0 + side], + "category_name": "cell", + "number": number, + } + + +def _seed(window, anns): + window.image_file_name = "img.png" + window.current_slice = None + window.all_annotations = {"img.png": {"cell": list(anns)}} + window.image_label.annotations = copy.deepcopy(window.all_annotations["img.png"]) + window.update_annotation_list() + + +def _selected_data(window): + return [ + item.data(Qt.ItemDataRole.UserRole) + for item in window.annotation_list.selectedItems() + ] + + +def test_replace_selects_one(window): + a1, a2, a3 = _square(0, 0, 10, 1), _square(50, 0, 10, 2), _square(100, 0, 10, 3) + _seed(window, [a1, a2, a3]) + + window.annotation_controller.apply_canvas_selection([a1], "replace") + + assert window.image_label.highlighted_annotations == [a1] + assert _selected_data(window) == [a1] + assert not window.merge_button.isEnabled() # needs ≥2 + assert window.change_class_button.isEnabled() # needs ≥1 + + +def test_add_then_toggle(window): + a1, a2, a3 = _square(0, 0, 10, 1), _square(50, 0, 10, 2), _square(100, 0, 10, 3) + _seed(window, [a1, a2, a3]) + ac = window.annotation_controller + + ac.apply_canvas_selection([a1], "replace") + ac.apply_canvas_selection([a2], "add") + assert window.image_label.highlighted_annotations == [a1, a2] + # The list mirrors by value-equality (PyQt round-trips UserRole dicts as + # copies), so compare by membership, not identity. + sel = _selected_data(window) + assert len(sel) == 2 and a1 in sel and a2 in sel + assert window.merge_button.isEnabled() + + # Toggling a1 off leaves only a2. + ac.apply_canvas_selection([a1], "toggle") + assert window.image_label.highlighted_annotations == [a2] + assert _selected_data(window) == [a2] + assert not window.merge_button.isEnabled() + + +def test_replace_empty_clears(window): + a1, a2 = _square(0, 0, 10, 1), _square(50, 0, 10, 2) + _seed(window, [a1, a2]) + ac = window.annotation_controller + + ac.apply_canvas_selection([a1, a2], "replace") + assert window.merge_button.isEnabled() + + ac.apply_canvas_selection([], "replace") + assert window.image_label.highlighted_annotations == [] + assert _selected_data(window) == [] + assert not window.merge_button.isEnabled() + assert not window.change_class_button.isEnabled() + + +def test_canvas_delete_removes_selected_set(window, monkeypatch): + a1, a2, a3 = _square(0, 0, 10, 1), _square(50, 0, 10, 2), _square(100, 0, 10, 3) + _seed(window, [a1, a2, a3]) + + monkeypatch.setattr( + QMessageBox, "question", + staticmethod(lambda *a, **k: QMessageBox.StandardButton.Yes), + ) + monkeypatch.setattr(QMessageBox, "information", staticmethod(lambda *a, **k: None)) + monkeypatch.setattr(window, "auto_save", lambda: None) + + window.annotation_controller.apply_canvas_selection([a1, a2], "replace") + window.delete_selected_annotations() + + remaining = window.image_label.annotations.get("cell", []) + assert a3 in remaining + assert a1 not in remaining and a2 not in remaining + assert window.all_annotations["img.png"]["cell"] == remaining + assert window.image_label.highlighted_annotations == [] + + +def test_canvas_delete_gated_on_list_selection(window, monkeypatch): + """Canvas Delete fires only when the annotation list actually has a + selection — not when only the red highlight is (stale) populated, e.g. + after a sort rebuilds the list. Otherwise it pops a spurious warning.""" + from PyQt6.QtCore import QEvent + from PyQt6.QtGui import QKeyEvent + + monkeypatch.setattr( + QMessageBox, "question", + staticmethod(lambda *a, **k: QMessageBox.StandardButton.Yes), + ) + monkeypatch.setattr(QMessageBox, "information", staticmethod(lambda *a, **k: None)) + monkeypatch.setattr(window, "auto_save", lambda: None) + + a1, a2, a3 = _square(0, 0, 10, 1), _square(50, 0, 10, 2), _square(100, 0, 10, 3) + _seed(window, [a1, a2, a3]) + il = window.image_label + il.current_tool = None + + def press_delete(): + il.keyPressEvent( + QKeyEvent( + QEvent.Type.KeyPress, + Qt.Key.Key_Delete, + Qt.KeyboardModifier.NoModifier, + ) + ) + + emitted = [] + il.deleteSelectionRequested.connect(lambda: emitted.append(True)) + + # In sync (canvas selection mirrored to the list) → Delete fires. + window.annotation_controller.apply_canvas_selection([a1], "replace") + assert il._ctx.has_annotation_selection() + press_delete() + assert emitted == [True] + assert a1 not in il.annotations.get("cell", []) + + # Construct the divergence the gate guards against: a stale red highlight + # with no list selection. The canvas Delete keys off the list (the + # controller's source of truth), so it must NOT fire here — no spurious + # "nothing selected" warning. + emitted.clear() + il.highlighted_annotations = [a2] + window.annotation_list.clearSelection() + assert il.highlighted_annotations # highlight is stale + assert not il._ctx.has_annotation_selection() # but the list isn't selected + press_delete() + assert emitted == [] diff --git a/tests/integration/test_image_filter_wiring.py b/tests/integration/test_image_filter_wiring.py new file mode 100644 index 0000000..e0fc2a7 --- /dev/null +++ b/tests/integration/test_image_filter_wiring.py @@ -0,0 +1,87 @@ +""" +Integration test for the image-filter re-apply wiring (upstream #27). + +The unit tests call apply_image_filter() directly; this test goes through +the real mutation path instead: ImageLabel.annotationsBatchSaved → +ImageAnnotator._on_annotations_batch_saved → save_current_annotations → +ClassController.update_slice_list_colors → apply_image_filter. It is the +test that fails if someone refactors the slice-color path and silently +detaches the filter from annotation mutations. + +Constructs one full ImageAnnotator (offscreen) — deliberately, despite +the runtime cost, because the coupling under test spans window, signals +and three controllers. +""" + +import pytest + + +FILTER_WITHOUT = 1 # combo index: "Without annotations" +FILTER_WITH = 2 # combo index: "With annotations" + + +@pytest.fixture +def window(qt_application): + from digitalsreeni_image_annotator.annotator_window import ImageAnnotator + + w = ImageAnnotator() + yield w + w.deleteLater() + + +def test_annotation_commit_path_reapplies_filter(window): + # Two regular images, neither annotated yet. currentRow stays -1 so + # the "never hide current row" exemption doesn't mask the assertion. + for name in ("a.png", "b.png"): + window.all_images.append({"file_name": name, "is_multi_slice": False}) + window.image_list.addItem(name) + + window.image_filter_combo.setCurrentIndex(FILTER_WITH) + assert window.image_list.isRowHidden(0) + assert window.image_list.isRowHidden(1) + + # Simulate finishing an annotation on a.png the way the canvas does: + # ImageLabel holds the in-progress annotations and emits the batch + # finalizer signal. No direct apply_image_filter / all_annotations + # manipulation here — that's the point of the test. + window.image_file_name = "a.png" + window.image_label.annotations = { + "cell": [{"segmentation": [0, 0, 10, 0, 10, 10], "category_name": "cell"}] + } + window.image_label.annotationsBatchSaved.emit() + + assert window.all_annotations["a.png"] # save path ran + assert not window.image_list.isRowHidden(0) # a.png now annotated + assert window.image_list.isRowHidden(1) # b.png still hidden + + +def test_hiding_current_row_keeps_canvas_and_fires_no_switch(window): + # Hiding the currently selected (non-matching) row must not change + # the displayed image or fire switch_image — the canvas stays on the + # worked-on image while its row leaves the list. + for name in ("annot.png", "plain.png"): + window.all_images.append({"file_name": name, "is_multi_slice": False}) + window.image_list.addItem(name) + window.all_annotations["annot.png"] = { + "cell": [{"segmentation": [0, 0, 1, 0, 1, 1], "category_name": "cell"}] + } + + # Pure counter (does NOT delegate to the real switch_image, which + # needs a loaded project): isolates whether the *filter* fires a + # switch. setCurrentRow below legitimately fires it once via + # currentRowChanged — that is product behavior and is cleared away. + calls = [] + window.switch_image = lambda item: calls.append(item) + + window.image_list.setCurrentRow(0) # select the annotated image + sentinel = object() + window.current_image = sentinel + calls.clear() + + # "Without annotations" must hide row 0 even though it is current. + window.image_filter_combo.setCurrentIndex(FILTER_WITHOUT) + + assert window.image_list.isRowHidden(0) # current row hidden + assert not window.image_list.isRowHidden(1) + assert window.current_image is sentinel # canvas unchanged + assert calls == [] # hiding the current row fired no switch_image diff --git a/tests/integration/test_sam_finetuning.py b/tests/integration/test_sam_finetuning.py new file mode 100644 index 0000000..787120b --- /dev/null +++ b/tests/integration/test_sam_finetuning.py @@ -0,0 +1,378 @@ +"""Tests for SAM 2 fine-tuning (issue bnsreenu#73). + +Most tests are CI-friendly: pure geometry/loss/dataset helpers that need no +model weights or GPU. The full train→save→reload round trip is gated behind +``SAM_TRAIN_E2E=1`` and the presence of cached weights, since it needs the +~75 MB ``sam2_t.pt`` and is realistically GPU-only. +""" + +import json +import os +import shutil +import tempfile + +import numpy as np +import pytest +from PyQt6.QtGui import QImage + +from src.digitalsreeni_image_annotator.training.sam_trainer import ( + SampleGroup, + bbox_to_mask, + list_custom_models, + make_custom_filename, + mask_to_point, + mask_to_xyxy, + polygon_to_mask, +) + + +@pytest.fixture +def temp_dir(): + d = tempfile.mkdtemp() + yield d + shutil.rmtree(d, ignore_errors=True) + + +# ── geometry helpers ───────────────────────────────────────────────────────── + +class TestGeometry: + def test_polygon_to_mask_fills_interior(self): + mask = polygon_to_mask([10, 10, 40, 10, 40, 40, 10, 40], 50, 50) + assert mask.dtype == bool and mask.shape == (50, 50) + assert mask[25, 25] and not mask[5, 5] + + def test_bbox_to_mask(self): + mask = bbox_to_mask([10, 10, 20, 20], 50, 50) # x,y,w,h + assert mask[15, 15] and not mask[45, 45] + + def test_mask_to_xyxy_tight(self): + mask = np.zeros((50, 50), bool) + mask[10:31, 5:26] = True + x1, y1, x2, y2 = mask_to_xyxy(mask) + assert (x1, y1, x2, y2) == (5, 10, 25, 30) + + def test_mask_to_xyxy_empty_is_none(self): + assert mask_to_xyxy(np.zeros((10, 10), bool)) is None + + def test_mask_to_point_inside(self): + mask = np.zeros((60, 60), bool) + mask[20:41, 20:41] = True + x, y = mask_to_point(mask) + assert mask[int(y), int(x)] # point lands inside the object + + def test_polygon_roundtrips_through_xyxy(self): + mask = polygon_to_mask([10, 10, 40, 10, 40, 40, 10, 40], 50, 50) + x1, y1, x2, y2 = mask_to_xyxy(mask) + assert x1 >= 9 and y1 >= 9 and x2 <= 41 and y2 <= 41 + + +# ── checkpoint naming (architecture-selection invariant) ───────────────────── + +class TestCustomNaming: + @pytest.mark.parametrize("base,token", [ + ("SAM 2 tiny", "sam2_t.pt"), + ("SAM 2.1 base", "sam2.1_b.pt"), + ("SAM 2.1 large", "sam2.1_l.pt"), + ]) + def test_filename_keeps_base_token(self, base, token): + # Ultralytics build_sam selects the architecture by ckpt.endswith(token); + # a fine-tuned name MUST keep that suffix or SAM(path) can't load it. + path = make_custom_filename(base, "my cool run!!") + assert os.path.basename(path).endswith(token) + + def test_filename_sanitises_label(self): + path = make_custom_filename("SAM 2 tiny", "a/b c:*?") + name = os.path.basename(path) + assert "/" not in name and ":" not in name and "*" not in name + + def test_filename_handles_empty_label(self): + path = make_custom_filename("SAM 2 tiny", "") + assert os.path.basename(path) == "finetuned_sam2_t.pt" + + def test_list_custom_models_returns_dict(self): + assert isinstance(list_custom_models(), dict) + + +# ── SampleGroup lazy rasterisation ─────────────────────────────────────────── + +class TestSampleGroup: + def test_load_rasterises_specs(self): + img = np.zeros((50, 50, 3), np.uint8) + specs = [ + {"segmentation": [10, 10, 40, 10, 40, 40, 10, 40]}, + {"bbox": [5, 5, 10, 10]}, + ] + group = SampleGroup(lambda: img.copy(), specs) + assert group.n_instances == 2 + out_img, instances = group.load() + assert out_img.shape == (50, 50, 3) + assert len(instances) == 2 + assert all(inst["mask"].dtype == bool for inst in instances) + + def test_load_drops_empty_masks(self): + img = np.zeros((50, 50, 3), np.uint8) + # Box entirely outside the image → no in-bounds pixels → dropped. + group = SampleGroup(lambda: img.copy(), [{"bbox": [200, 200, 10, 10]}]) + _, instances = group.load() + assert instances == [] + + +# ── loss ───────────────────────────────────────────────────────────────────── + +class TestLoss: + def test_focal_dice_lower_when_correct(self): + torch = pytest.importorskip("torch") + from src.digitalsreeni_image_annotator.training.sam_trainer import _focal_dice_loss + + target = torch.zeros(1, 1, 16, 16) + target[..., 4:12, 4:12] = 1.0 + good = (target * 12) - 6 # confident-correct logits + bad = (1 - target) * 12 - 6 # confident-wrong logits + l_good = _focal_dice_loss(good, target) + l_bad = _focal_dice_loss(bad, target) + assert torch.isfinite(l_good) and torch.isfinite(l_bad) + assert float(l_good) < float(l_bad) + + +# ── dataset producers ──────────────────────────────────────────────────────── + +class TestDatasetProducers: + def test_export_and_reload_folder_roundtrip(self, temp_dir): + from src.digitalsreeni_image_annotator.io.export_formats import export_sam_dataset + from src.digitalsreeni_image_annotator.training.sam_dataset import ( + build_groups_from_folder, + ) + + img = QImage(60, 60, QImage.Format.Format_RGB32) + img.fill(0xFF202020) + img_path = os.path.join(temp_dir, "img1.png") + img.save(img_path) + + all_annotations = { + "img1.png": {"cell": [{"segmentation": [10, 10, 40, 10, 40, 40, 10, 40]}]} + } + out_dir = os.path.join(temp_dir, "dataset") + _, manifest_path = export_sam_dataset( + all_annotations, {"cell": 1}, {"img1.png": img_path}, + slices=[], image_slices={}, output_dir=out_dir, + ) + assert os.path.exists(manifest_path) + with open(manifest_path) as f: + manifest = json.load(f) + assert len(manifest["images"]) == 1 + + groups = build_groups_from_folder(out_dir) + assert len(groups) == 1 + _, instances = groups[0].load() + assert len(instances) == 1 + + def test_build_from_project_regular_image(self, temp_dir): + from src.digitalsreeni_image_annotator.training.sam_dataset import ( + build_groups_from_project, + ) + + img = QImage(60, 60, QImage.Format.Format_RGB32) + img.fill(0xFF808080) + img_path = os.path.join(temp_dir, "img1.png") + img.save(img_path) + + groups = build_groups_from_project( + {"img1.png": {"cell": [{"bbox": [10, 10, 20, 20]}]}}, + {"img1.png": img_path}, slices=[], image_slices={}, + ) + assert len(groups) == 1 + image, instances = groups[0].load() + assert image.shape[:2] == (60, 60) and len(instances) == 1 + + def test_build_from_project_skips_unresolvable(self): + from src.digitalsreeni_image_annotator.training.sam_dataset import ( + build_groups_from_project, + ) + + groups = build_groups_from_project( + {"ghost.png": {"cell": [{"bbox": [1, 1, 5, 5]}]}}, + {}, slices=[], image_slices={}, + ) + assert groups == [] + + +# ── ultralytics API-drift guard ────────────────────────────────────────────── + +class TestUltralyticsAPI: + def test_sam2predictor_exposes_forward_methods(self): + """The native training loop reuses these SAM2Predictor methods; if an + Ultralytics upgrade removes/renames them, fail loudly here rather than + mid-training.""" + pytest.importorskip("ultralytics") + from ultralytics.models.sam.predict import SAM2Predictor + + for name in ("get_im_features", "prompt_inference", "_inference_features", + "_prepare_prompts", "set_image"): + assert hasattr(SAM2Predictor, name), f"SAM2Predictor.{name} missing" + + def test_ops_scale_masks_padding_false_geometry(self): + """The training loss maps logits back with the same letterbox-aware + transform inference uses (ops.scale_masks, padding=False). Guard the + *behavior*, not just the name: with padding=False the crop is + top-left-anchored, so foreground in the bottom padding band of a + non-square target is dropped while top content survives. A semantic + change in Ultralytics (not just a rename) trips this fast test rather + than only the GPU-gated e2e.""" + torch = pytest.importorskip("torch") + pytest.importorskip("ultralytics") + from ultralytics.utils import ops + + assert hasattr(ops, "scale_masks") + # Target 600x1000 (landscape) -> ~bottom 40% of a 256-tall mask is padding. + top = torch.zeros(1, 1, 256, 256) + top[..., 0:40, :] = 1.0 + bottom = torch.zeros(1, 1, 256, 256) + bottom[..., 215:256, :] = 1.0 + out_top = ops.scale_masks(top, (600, 1000), padding=False) + out_bottom = ops.scale_masks(bottom, (600, 1000), padding=False) + assert out_top.sum() > 0, "top content must survive the crop" + assert out_bottom.sum() == 0, "bottom padding band must be cropped (padding=False)" + + +# ── opt-in end-to-end (needs weights; realistically GPU) ───────────────────── + +@pytest.mark.skipif( + os.environ.get("SAM_TRAIN_E2E") != "1", + reason="set SAM_TRAIN_E2E=1 to run the full train→save→reload test (needs weights)", +) +def test_end_to_end_train_save_reload(temp_dir): + pytest.importorskip("ultralytics") + from src.digitalsreeni_image_annotator.training.sam_trainer import ( + SAM_MODELS_DIR, + SAMFineTuner, + ) + + if not os.path.exists(os.path.join(SAM_MODELS_DIR, "sam2_t.pt")): + pytest.skip("sam2_t.pt not cached") + + def make(seed): + rng = np.random.RandomState(seed) + img = (rng.rand(120, 120, 3) * 255).astype(np.uint8) + cy, cx = rng.randint(40, 80, 2) + yy, xx = np.ogrid[:120, :120] + mask = (yy - cy) ** 2 + (xx - cx) ** 2 < 25 ** 2 + img[mask] = [220, 30, 30] + x1, y1, x2, y2 = mask_to_xyxy(mask) + return SampleGroup( + lambda im=img: im.copy(), + [{"bbox": [x1, y1, x2 - x1, y2 - y1]}], + ) + + out = os.path.join(temp_dir, "e2e_sam2_t.pt") + res = SAMFineTuner().train( + "SAM 2 tiny", [make(i) for i in range(2)], + epochs=1, lr=1e-4, batch_size=1, freeze_image_encoder=True, + prompt_type="bbox", out_path=out, + ) + assert res["verified"] and os.path.exists(out) + + from ultralytics import SAM + SAM(out)((np.random.rand(120, 120, 3) * 255).astype(np.uint8), + bboxes=[[40, 40, 80, 80]], verbose=False) + + +@pytest.mark.skipif( + os.environ.get("SAM_TRAIN_E2E") != "1", + reason="set SAM_TRAIN_E2E=1 to run the encoder-path multi-instance test", +) +def test_encoder_path_multi_instance_per_image(temp_dir): + """Regression: encoder fine-tuning (freeze_image_encoder=False) on images + with >1 instance used to crash with 'backward through the graph a second + time' because all instances shared the one encoder feature graph. The + engine now backprops once per image.""" + pytest.importorskip("ultralytics") + from src.digitalsreeni_image_annotator.training.sam_trainer import ( + SAM_MODELS_DIR, + SAMFineTuner, + ) + + if not os.path.exists(os.path.join(SAM_MODELS_DIR, "sam2_t.pt")): + pytest.skip("sam2_t.pt not cached") + + def two_instance_image(seed): + rng = np.random.RandomState(seed) + img = (rng.rand(140, 140, 3) * 255).astype(np.uint8) + specs = [] + for cy, cx in [(40, 40), (95, 95)]: + yy, xx = np.ogrid[:140, :140] + m = (yy - cy) ** 2 + (xx - cx) ** 2 < 18 ** 2 + img[m] = [30, 220, 30] + x1, y1, x2, y2 = mask_to_xyxy(m) + specs.append({"bbox": [x1, y1, x2 - x1, y2 - y1]}) + return SampleGroup(lambda im=img: im.copy(), specs) + + out = os.path.join(temp_dir, "enc_sam2_t.pt") + res = SAMFineTuner().train( + "SAM 2 tiny", [two_instance_image(i) for i in range(2)], + epochs=1, lr=1e-5, batch_size=1, freeze_image_encoder=False, + prompt_type="bbox", out_path=out, + ) + assert res["verified"] and res["instances"] == 4 + + +@pytest.mark.skipif( + os.environ.get("SAM_TRAIN_E2E") != "1", + reason="set SAM_TRAIN_E2E=1 to run the landscape mask-shift regression", +) +def test_landscape_no_mask_shift(temp_dir): + """Regression for the downward mask shift (issue #73 GUI testing). + + SAM2 letterboxes (pad bottom/right) and inference crops the padding via + ops.scale_masks(padding=False). The training loss must use the SAME + transform; a naive interpolate baked the padding into the target and the + decoder learned masks shifted down. This must use a NON-square image (square + images have zero padding and so never exposed the bug). + """ + pytest.importorskip("ultralytics") + import cv2 + from ultralytics import SAM + + from src.digitalsreeni_image_annotator.training.sam_trainer import ( + SAM_MODELS_DIR, + SAMFineTuner, + ) + + if not os.path.exists(os.path.join(SAM_MODELS_DIR, "sam2_t.pt")): + pytest.skip("sam2_t.pt not cached") + + H, W = 600, 1000 # landscape -> SAM2 pads the bottom + + def disk(seed, cy, cx): + rng = np.random.RandomState(seed) + img = (rng.rand(H, W, 3) * 50).astype(np.uint8) + yy, xx = np.ogrid[:H, :W] + m = (yy - cy) ** 2 + (xx - cx) ** 2 < 60 ** 2 + img[m] = [230, 40, 40] + return img, m + + groups = [] + for i in range(4): + img, m = disk(i, 460 + (i % 2) * 40, 300 + i * 120) + cnts, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + groups.append(SampleGroup(lambda im=img: im.copy(), [{"segmentation": cnts[0].flatten().tolist()}])) + + out = os.path.join(temp_dir, "landscape_sam2_t.pt") + SAMFineTuner().train( + "SAM 2 tiny", groups, epochs=8, lr=1e-4, batch_size=2, + freeze_image_encoder=True, prompt_type="bbox", out_path=out, + ) + + # Evaluate on a fresh image, object near the bottom (worst case for the shift). + img, m = disk(99, 480, 520) + x1, y1, x2, y2 = mask_to_xyxy(m) + gt_cy = float(np.where(m)[0].mean()) + + r = SAM(out)(img, bboxes=[[x1, y1, x2, y2]], verbose=False) + pm = r[0].masks.data.cpu().numpy()[0].astype(bool) + ys, xs = np.where(pm) + assert xs.size > 0, "fine-tuned model produced an empty mask" + pred_cy = ys.mean() + iou = (pm & m).sum() / (pm | m).sum() + + assert abs(pred_cy - gt_cy) < 25, f"vertical shift {pred_cy - gt_cy:.1f}px (regression)" + assert iou > 0.7, f"IoU {iou:.3f} too low" diff --git a/tests/integration/test_sam_train_controller.py b/tests/integration/test_sam_train_controller.py new file mode 100644 index 0000000..5ce5cbd --- /dev/null +++ b/tests/integration/test_sam_train_controller.py @@ -0,0 +1,74 @@ +"""SAMTrainController UI-locking tests (issue bnsreenu#73). + +The lock/unlock of the SAM inference UI during a fine-tuning run is exactly the +kind of state machine that regresses silently. These build one real +ImageAnnotator (offscreen) and exercise the controller directly — no model +weights, no worker thread. +""" + +import pytest + + +@pytest.fixture +def window(qt_application): + from digitalsreeni_image_annotator.annotator_window import ImageAnnotator + + w = ImageAnnotator() + yield w + w.deleteLater() + + +def test_set_sam_ui_locked_toggles_widgets(window): + # The helper's contract is the toggle itself; the buttons' construction-time + # enabled state depends on whether an image is loaded, so assert relative + # to an explicit unlocked baseline rather than the initial state. + c = window.sam_train_controller + widgets = [window.sam_box_button, window.sam_points_button, window.sam_model_selector] + + c._set_sam_ui_locked(True) + assert not any(w.isEnabled() for w in widgets) + assert not c._menu.isEnabled() + + c._set_sam_ui_locked(False) + assert all(w.isEnabled() for w in widgets) + assert c._menu.isEnabled() + + +def test_launch_unlocks_ui_when_setup_raises(window, monkeypatch): + """If anything between locking and thread.start() raises, the SAM UI must + be restored — otherwise the tools stay dead until app restart.""" + from PyQt6.QtWidgets import QMessageBox + + import digitalsreeni_image_annotator.controllers.sam_train_controller as mod + from digitalsreeni_image_annotator.dialogs.sam_trainer_dialog import ( + SAMTrainConfigDialog, + ) + + c = window.sam_train_controller + monkeypatch.setattr(c, "_gpu_gate", lambda: True) + monkeypatch.setattr( + SAMTrainConfigDialog, "exec", + lambda self: SAMTrainConfigDialog.DialogCode.Accepted, + ) + monkeypatch.setattr( + SAMTrainConfigDialog, "get_config", + lambda self: { + "base_model": "SAM 2 tiny", "out_name": "t", "epochs": 1, + "lr": 1e-4, "batch_size": 1, "prompt_type": "bbox", + "freeze_image_encoder": True, + }, + ) + + class Boom: + def __init__(self, *a, **k): + raise RuntimeError("setup failed on purpose") + + monkeypatch.setattr(mod, "SAMFineTuner", Boom) + monkeypatch.setattr(QMessageBox, "critical", staticmethod(lambda *a, **k: None)) + + c._launch([object()]) # group content irrelevant — fails before use + + assert window.sam_box_button.isEnabled() + assert window.sam_points_button.isEnabled() + assert window.sam_model_selector.isEnabled() + assert c._menu.isEnabled() diff --git a/tests/integration/test_smoke.py b/tests/integration/test_smoke.py index 5d103c3..ba360df 100644 --- a/tests/integration/test_smoke.py +++ b/tests/integration/test_smoke.py @@ -47,6 +47,7 @@ def test_public_api_exports(): # Core helpers "digitalsreeni_image_annotator.core.constants", "digitalsreeni_image_annotator.core.annotation_utils", + "digitalsreeni_image_annotator.core.torch_utils", # UI "digitalsreeni_image_annotator.ui.default_stylesheet", "digitalsreeni_image_annotator.ui.soft_dark_stylesheet", diff --git a/tests/unit/test_canvas_selection.py b/tests/unit/test_canvas_selection.py new file mode 100644 index 0000000..c0453bc --- /dev/null +++ b/tests/unit/test_canvas_selection.py @@ -0,0 +1,244 @@ +""" +Unit tests for idle-mode canvas mask selection (bnsreenu issue #75). + +Covers the pure hit-testing helpers on ImageLabel (annotation_at / +annotations_in_rect) and the press→release gesture resolution +(_finish_selection) that emits canvasSelectionChanged. No main window, +no model — just the widget. +""" + +import pytest + +from PyQt6.QtCore import Qt +from PyQt6.QtGui import QColor, QImage, QPainter, QPixmap + +from src.digitalsreeni_image_annotator.widgets.image_label import ImageLabel + + +@pytest.fixture +def label(qtbot): + lbl = ImageLabel(None) + qtbot.addWidget(lbl) + return lbl + + +def _square(x0, y0, side, name): + return { + "segmentation": [x0, y0, x0 + side, y0, x0 + side, y0 + side, x0, y0 + side], + "category_name": name, + } + + +def _bbox(x, y, w, h, name): + return {"bbox": [x, y, w, h], "category_name": name} + + +class _FakeEvent: + """Minimal stand-in for QMouseEvent — only modifiers() is read.""" + + def __init__(self, shift=False): + self._shift = shift + + def modifiers(self): + if self._shift: + return Qt.KeyboardModifier.ShiftModifier + return Qt.KeyboardModifier.NoModifier + + +class _FakeCtx: + def __init__(self, hidden=()): + self._hidden = set(hidden) + + def is_class_visible(self, name): + return name not in self._hidden + + +# --- annotation_at --------------------------------------------------------- + +def test_annotation_at_smallest_of_nested(label): + outer = _square(0, 0, 100, "outer") # area 10000 + inner = _square(40, 40, 20, "inner") # area 400, fully inside outer + label.annotations = {"cell": [outer, inner]} + assert label.annotation_at((50, 50)) is inner # inside both → smallest + assert label.annotation_at((10, 10)) is outer # inside outer only + assert label.annotation_at((500, 500)) is None # empty space + + +def test_annotation_at_hits_bbox(label): + box = _bbox(0, 0, 100, 100, "box") + label.annotations = {"cell": [box]} + assert label.annotation_at((50, 50)) is box + assert label.annotation_at((150, 150)) is None + + +def test_annotation_at_smallest_across_seg_and_bbox(label): + big_box = _bbox(0, 0, 200, 200, "box") # area 40000 + small_seg = _square(40, 40, 20, "seg") # area 400 + label.annotations = {"cell": [big_box, small_seg]} + assert label.annotation_at((45, 45)) is small_seg + + +def test_annotation_at_skips_hidden_class(label): + visible = _square(0, 0, 100, "visible") + hidden = _square(40, 40, 20, "hidden") # smaller, but its class is hidden + label.annotations = {"visible": [visible], "hidden": [hidden]} + label.set_context(_FakeCtx(hidden={"hidden"})) + # Without the visibility guard, the smaller hidden square would win. + assert label.annotation_at((50, 50)) is visible + + +# --- annotations_in_rect --------------------------------------------------- + +def test_annotations_in_rect_returns_intersecting(label): + a = _square(0, 0, 20, "a") + b = _square(50, 50, 20, "b") + far = _square(500, 500, 20, "far") + bbox_in = _bbox(10, 10, 30, 30, "c") + label.annotations = {"cell": [a, b, far, bbox_in]} + hit = label.annotations_in_rect((0, 0, 80, 80)) + assert a in hit and b in hit and bbox_in in hit + assert far not in hit + + +def test_annotations_in_rect_any_corner_order(label): + a = _square(0, 0, 20, "a") + label.annotations = {"cell": [a]} + # Rect given bottom-right → top-left must still match. + assert a in label.annotations_in_rect((80, 80, 0, 0)) + + +def test_annotations_in_rect_skips_hidden(label): + a = _square(0, 0, 20, "a") + label.annotations = {"hidden": [a]} + label.set_context(_FakeCtx(hidden={"hidden"})) + assert label.annotations_in_rect((0, 0, 80, 80)) == [] + + +# --- _finish_selection gesture resolution ---------------------------------- + +@pytest.fixture +def captured(label): + events = [] + label.canvasSelectionChanged.connect(lambda anns, mode: events.append((anns, mode))) + return events + + +def test_click_on_mask_emits_replace(label, captured): + inner = _square(40, 40, 20, "inner") + label.annotations = {"cell": [inner]} + label.selection_origin = (45, 45) + label.selecting = False + label._finish_selection((45, 45), _FakeEvent(shift=False)) + assert captured == [([inner], "replace")] + # gesture state reset + assert label.selection_origin is None and label.selection_rect is None + + +def test_click_empty_emits_replace_empty(label, captured): + label.annotations = {"cell": [_square(40, 40, 20, "inner")]} + label.selection_origin = (300, 300) + label._finish_selection((300, 300), _FakeEvent(shift=False)) + assert captured == [([], "replace")] + + +def test_shift_click_mask_emits_toggle(label, captured): + inner = _square(40, 40, 20, "inner") + label.annotations = {"cell": [inner]} + label.selection_origin = (45, 45) + label._finish_selection((45, 45), _FakeEvent(shift=True)) + assert captured == [([inner], "toggle")] + + +def test_shift_click_empty_emits_nothing(label, captured): + label.annotations = {"cell": [_square(40, 40, 20, "inner")]} + label.selection_origin = (300, 300) + label._finish_selection((300, 300), _FakeEvent(shift=True)) + assert captured == [] + + +def test_drag_emits_add_when_shift(label, captured): + a = _square(0, 0, 20, "a") + b = _square(50, 50, 20, "b") + label.annotations = {"cell": [a, b]} + label.selection_origin = (0, 0) + label.selecting = True + label.selection_rect = (0, 0, 80, 80) + label._finish_selection((80, 80), _FakeEvent(shift=True)) + anns, mode = captured[-1] + assert mode == "add" + assert a in anns and b in anns + + +def test_drag_emits_replace_without_shift(label, captured): + a = _square(0, 0, 20, "a") + label.annotations = {"cell": [a]} + label.selection_origin = (0, 0) + label.selecting = True + label.selection_rect = (0, 0, 80, 80) + label._finish_selection((80, 80), _FakeEvent(shift=False)) + assert captured[-1][1] == "replace" + assert a in captured[-1][0] + + +def test_escape_cancels_rubber_band(label): + from PyQt6.QtCore import QEvent + from PyQt6.QtGui import QKeyEvent + + label.current_tool = None + label.selection_origin = (0, 0) + label.selecting = True + label.selection_rect = (0, 0, 50, 50) + label.keyPressEvent( + QKeyEvent(QEvent.Type.KeyPress, Qt.Key.Key_Escape, Qt.KeyboardModifier.NoModifier) + ) + assert label.selection_rect is None + assert label.selecting is False + assert label.selection_origin is None + + +# --- selection rendering: overlay, not recolor ---------------------------- + +def _setup_canvas(label, class_color): + label.set_context(_FakeCtx()) + px = QPixmap(100, 100) + px.fill(QColor("white")) + label.original_pixmap = px + label.class_colors = {"cell": QColor(class_color)} + + +def _render_center(label): + img = QImage(100, 100, QImage.Format.Format_RGB32) + img.fill(QColor("white")) + p = QPainter(img) + label.draw_annotations(p) + p.end() + return img.pixelColor(50, 50) + + +def test_selection_does_not_recolor_fill(label): + # A mask's interior fill must look identical selected vs. unselected — the + # old code turned it red, which was invisible on a red-class mask. + _setup_canvas(label, "#1F77B4") + mask = {"segmentation": [20, 20, 80, 20, 80, 80, 20, 80], "category_name": "cell"} + label.annotations = {"cell": [mask]} + + label.highlighted_annotations = [] + unselected = _render_center(label) + label.highlighted_annotations = [mask] + selected = _render_center(label) + assert selected == unselected # selection adds an overlay, never recolors + + +def test_selection_overlay_runs_for_seg_and_bbox(label): + # Red class = the worst case; outline + marquee must still render fine. + _setup_canvas(label, "#D62728") + seg = {"segmentation": [10, 10, 40, 10, 40, 40, 10, 40], "category_name": "cell"} + box = {"bbox": [50, 50, 30, 30], "category_name": "cell"} + label.annotations = {"cell": [seg, box]} + label.highlighted_annotations = [seg, box] + + img = QImage(100, 100, QImage.Format.Format_RGB32) + img.fill(QColor("white")) + p = QPainter(img) + label.draw_annotations(p) # must not raise + p.end() diff --git a/tests/unit/test_class_colors.py b/tests/unit/test_class_colors.py new file mode 100644 index 0000000..a326551 --- /dev/null +++ b/tests/unit/test_class_colors.py @@ -0,0 +1,35 @@ +"""Default class-colour palette (issue #75 follow-up). + +Red must no longer be the first auto-assigned class colour (it collided with +the selection highlight); it stays in the palette but at the back. +""" + +from src.digitalsreeni_image_annotator.core.constants import ( + DEFAULT_CLASS_COLORS, + default_class_color, +) + +_RED = "#D62728" + + +def test_first_default_color_is_not_red(): + assert default_class_color(0).upper() != _RED + assert default_class_color(0) == DEFAULT_CLASS_COLORS[0] + + +def test_palette_cycles_modulo_length(): + n = len(DEFAULT_CLASS_COLORS) + assert default_class_color(n) == default_class_color(0) + assert default_class_color(n + 3) == default_class_color(3) + + +def test_red_present_but_last(): + upper = [c.upper() for c in DEFAULT_CLASS_COLORS] + assert _RED in upper + assert upper[-1] == _RED + + +def test_all_entries_distinct_and_valid_hex(): + assert len(set(DEFAULT_CLASS_COLORS)) == len(DEFAULT_CLASS_COLORS) + for c in DEFAULT_CLASS_COLORS: + assert c.startswith("#") and len(c) == 7 diff --git a/tests/unit/test_image_filter.py b/tests/unit/test_image_filter.py new file mode 100644 index 0000000..e5948b3 --- /dev/null +++ b/tests/unit/test_image_filter.py @@ -0,0 +1,149 @@ +""" +Unit tests for the image-list annotation-status filter (upstream issue #27). + +Covers ImageController.image_has_annotations and apply_image_filter. +""" + +import pytest +from PyQt6.QtWidgets import QWidget, QListWidget, QComboBox + +from src.digitalsreeni_image_annotator.controllers.image_controller import ( + ImageController, +) + + +FILTER_ALL = 0 +FILTER_WITHOUT = 1 +FILTER_WITH = 2 + + +class FakeMainWindow(QWidget): + """Minimal stand-in for ImageAnnotator with the state the filter reads.""" + + +@pytest.fixture +def mw(qtbot): + window = FakeMainWindow() + qtbot.addWidget(window) + window.image_list = QListWidget(window) + window.image_filter_combo = QComboBox(window) + window.image_filter_combo.addItems( + ["All images", "Without annotations", "With annotations"] + ) + window.all_images = [] + window.all_annotations = {} + window.image_slices = {} + window.image_controller = ImageController(window) + return window + + +def _add_image(mw, file_name, is_multi_slice=False): + mw.all_images.append( + {"file_name": file_name, "is_multi_slice": is_multi_slice} + ) + mw.image_list.addItem(file_name) + + +class TestImageHasAnnotations: + def test_regular_image_without_annotations(self, mw): + _add_image(mw, "plain.png") + assert not mw.image_controller.image_has_annotations(mw.all_images[0]) + + def test_regular_image_with_empty_class_lists(self, mw): + _add_image(mw, "plain.png") + mw.all_annotations["plain.png"] = {"cell": []} + assert not mw.image_controller.image_has_annotations(mw.all_images[0]) + + def test_regular_image_with_annotations(self, mw): + _add_image(mw, "plain.png") + mw.all_annotations["plain.png"] = { + "cell": [{"segmentation": [0, 0, 1, 0, 1, 1]}] + } + assert mw.image_controller.image_has_annotations(mw.all_images[0]) + + def test_multi_slice_with_annotated_slice(self, mw): + _add_image(mw, "stack.tif", is_multi_slice=True) + mw.image_slices["stack"] = [("stack_T1_Z1", None), ("stack_T1_Z2", None)] + mw.all_annotations["stack_T1_Z2"] = { + "cell": [{"segmentation": [0, 0, 1, 0, 1, 1]}] + } + assert mw.image_controller.image_has_annotations(mw.all_images[0]) + + def test_multi_slice_without_annotated_slices(self, mw): + _add_image(mw, "stack.tif", is_multi_slice=True) + mw.image_slices["stack"] = [("stack_T1_Z1", None)] + mw.all_annotations["stack_T1_Z1"] = {"cell": []} + assert not mw.image_controller.image_has_annotations(mw.all_images[0]) + + def test_multi_slice_prefix_fallback_when_slices_not_loaded(self, mw): + # Project annotations exist under slice keys, but the slices were + # never extracted (e.g. load cancelled) — prefix fallback applies. + _add_image(mw, "stack.tif", is_multi_slice=True) + mw.all_annotations["stack_T1_Z5_C1"] = { + "cell": [{"segmentation": [0, 0, 1, 0, 1, 1]}] + } + assert mw.image_controller.image_has_annotations(mw.all_images[0]) + + def test_multi_slice_no_substring_false_positive(self, mw): + # "bee" must not match keys of "honeybee" (and vice versa). + _add_image(mw, "bee.tif", is_multi_slice=True) + mw.all_annotations["honeybee_T1_Z1"] = { + "cell": [{"segmentation": [0, 0, 1, 0, 1, 1]}] + } + assert not mw.image_controller.image_has_annotations(mw.all_images[0]) + + +class TestApplyImageFilter: + @pytest.fixture + def populated(self, mw): + _add_image(mw, "annotated.png") + _add_image(mw, "empty.png") + mw.all_annotations["annotated.png"] = { + "cell": [{"segmentation": [0, 0, 1, 0, 1, 1]}] + } + # Select the annotated image so the "never hide current" rule is + # exercised by a dedicated test, not by accident here. + mw.image_list.setCurrentRow(-1) + return mw + + def _hidden(self, mw): + return [ + mw.image_list.isRowHidden(i) for i in range(mw.image_list.count()) + ] + + def test_all_images_shows_everything(self, populated): + populated.image_filter_combo.setCurrentIndex(FILTER_ALL) + populated.image_controller.apply_image_filter() + assert self._hidden(populated) == [False, False] + + def test_without_annotations_hides_annotated(self, populated): + populated.image_filter_combo.setCurrentIndex(FILTER_WITHOUT) + populated.image_controller.apply_image_filter() + assert self._hidden(populated) == [True, False] + + def test_with_annotations_hides_unannotated(self, populated): + populated.image_filter_combo.setCurrentIndex(FILTER_WITH) + populated.image_controller.apply_image_filter() + assert self._hidden(populated) == [False, True] + + def test_current_row_is_hidden_when_not_matching(self, populated): + # The current row is not exempt: a non-matching selected image + # leaves the list (the canvas keeps showing it — see the wiring + # test for the switch_image / canvas-unchanged guarantee). + populated.image_list.setCurrentRow(0) # annotated.png + populated.image_filter_combo.setCurrentIndex(FILTER_WITHOUT) + populated.image_controller.apply_image_filter() + assert self._hidden(populated) == [True, False] + + def test_no_combo_is_a_noop(self, mw): + _add_image(mw, "plain.png") + del mw.image_filter_combo + mw.image_controller.apply_image_filter() # must not raise + assert not mw.image_list.isRowHidden(0) + + def test_switching_back_to_all_unhides(self, populated): + populated.image_filter_combo.setCurrentIndex(FILTER_WITH) + populated.image_controller.apply_image_filter() + populated.image_filter_combo.setCurrentIndex(FILTER_ALL) + populated.image_controller.apply_image_filter() + assert self._hidden(populated) == [False, False] diff --git a/tests/unit/test_image_list_sort.py b/tests/unit/test_image_list_sort.py new file mode 100644 index 0000000..3f93c5f --- /dev/null +++ b/tests/unit/test_image_list_sort.py @@ -0,0 +1,113 @@ +""" +Unit tests for alphabetical image-list sorting (upstream issue #60). + +sort_image_list must order the list case-insensitively, keep the +all_images model aligned with the list rows (positional invariant used +by COCO import), and never fire a spurious switch_image. +""" + +import pytest +from PyQt6.QtWidgets import QWidget, QListWidget, QComboBox + +from src.digitalsreeni_image_annotator.controllers.image_controller import ( + ImageController, +) + + +class FakeMainWindow(QWidget): + pass + + +@pytest.fixture +def mw(qtbot): + window = FakeMainWindow() + qtbot.addWidget(window) + window.image_list = QListWidget(window) + window.image_filter_combo = QComboBox(window) + window.image_filter_combo.addItems( + ["All images", "Without annotations", "With annotations"] + ) + window.all_images = [] + window.all_annotations = {} + window.image_slices = {} + window.image_paths = {} + window.is_loading_project = False + window.auto_save = lambda: None + window.image_controller = ImageController(window) + return window + + +def _populate(mw, names): + # Populate out of order, mimicking the model+view pairing that + # add_images_to_list produces before a sort. + for n in names: + mw.all_images.append({"file_name": n, "is_multi_slice": False}) + mw.image_list.addItem(n) + + +def _list_texts(mw): + return [mw.image_list.item(i).text() for i in range(mw.image_list.count())] + + +def test_sorts_alphabetically(mw): + _populate(mw, ["banana.png", "apple.png", "cherry.png"]) + mw.image_controller.sort_image_list() + assert _list_texts(mw) == ["apple.png", "banana.png", "cherry.png"] + + +def test_sort_is_case_insensitive(mw): + _populate(mw, ["Zebra.png", "apple.png", "Banana.png"]) + mw.image_controller.sort_image_list() + assert _list_texts(mw) == ["apple.png", "Banana.png", "Zebra.png"] + + +def test_model_and_view_stay_aligned(mw): + _populate(mw, ["d.png", "a.png", "c.png", "b.png"]) + mw.image_controller.sort_image_list() + assert _list_texts(mw) == [info["file_name"] for info in mw.all_images] + + +def test_sort_fires_no_switch_image(mw): + _populate(mw, ["b.png", "a.png"]) + calls = [] + mw.switch_image = lambda item: calls.append(item) + mw.image_controller.switch_image = lambda item: calls.append(item) + mw.image_list.setCurrentRow(0) + calls.clear() + mw.image_controller.sort_image_list() # no select_name, do_switch=False + assert calls == [] + + +def test_selection_preserved_across_sort(mw): + _populate(mw, ["b.png", "a.png", "c.png"]) + mw.image_list.setCurrentRow(0) # b.png + mw.image_controller.sort_image_list() + assert mw.image_list.currentItem().text() == "b.png" + + +def test_select_name_and_switch(mw): + _populate(mw, ["b.png", "a.png"]) + calls = [] + mw.switch_image = lambda item: calls.append(item.text()) + mw.image_controller.switch_image = lambda item: calls.append(item.text()) + mw.image_controller.sort_image_list(select_name="a.png", do_switch=True) + assert mw.image_list.currentItem().text() == "a.png" + assert calls == ["a.png"] + + +def test_project_load_path_ends_sorted(mw): + # Contract: during project load add_images_to_list does NOT sort per + # image (avoids O(n^2)); the list is rebuilt once afterwards via the + # update_ui -> update_image_list call. This guards a refactor of that + # call from silently leaving the post-load list unsorted. + mw.is_loading_project = True + mw.image_controller.add_images_to_list(["c.png", "a.png", "b.png"]) + # Not populated/sorted yet while loading. + assert mw.image_list.count() == 0 + + mw.is_loading_project = False + mw.image_controller.update_image_list() # what update_ui triggers + + texts = [mw.image_list.item(i).text() for i in range(mw.image_list.count())] + assert texts == ["a.png", "b.png", "c.png"] + assert texts == [info["file_name"] for info in mw.all_images] diff --git a/tests/unit/test_polygon_edit.py b/tests/unit/test_polygon_edit.py new file mode 100644 index 0000000..29822e9 --- /dev/null +++ b/tests/unit/test_polygon_edit.py @@ -0,0 +1,71 @@ +""" +Unit tests for nested-polygon edit selection (upstream issue #33). + +start_polygon_edit must enter edit mode on the *smallest* polygon that +contains the click, so an annotation fully nested inside another is +reachable instead of always grabbing the outer one. +""" + +import pytest + +from src.digitalsreeni_image_annotator.widgets.image_label import ImageLabel + + +@pytest.fixture +def label(qtbot): + lbl = ImageLabel(None) + qtbot.addWidget(lbl) + return lbl + + +def _square(x0, y0, side, name): + return { + "segmentation": [x0, y0, x0 + side, y0, x0 + side, y0 + side, x0, y0 + side], + "category_name": name, + } + + +@pytest.fixture +def nested(label): + outer = _square(0, 0, 100, "outer") # area 10000 + inner = _square(40, 40, 20, "inner") # area 400, fully inside outer + # Insert outer first so the old "return first match" behavior would + # have returned outer — the test fails unless smallest-area wins. + label.annotations = {"cell": [outer, inner]} + return label, outer, inner + + +def test_click_in_nested_region_selects_inner(nested): + label, outer, inner = nested + result = label.start_polygon_edit((50, 50)) # inside both + assert result is inner + assert label.editing_polygon is inner + + +def test_click_only_in_outer_selects_outer(nested): + label, outer, inner = nested + result = label.start_polygon_edit((10, 10)) # inside outer only + assert result is outer + assert label.editing_polygon is outer + + +def test_click_outside_all_returns_none(nested): + label, outer, inner = nested + label.editing_polygon = None + result = label.start_polygon_edit((500, 500)) + assert result is None + + +def test_insertion_order_does_not_matter(label): + # Inner listed first: result must still be the smallest, not the first. + outer = _square(0, 0, 100, "outer") + inner = _square(40, 40, 20, "inner") + label.annotations = {"cell": [inner, outer]} + assert label.start_polygon_edit((50, 50)) is inner + + +def test_bbox_only_annotation_is_ignored(label): + # start_polygon_edit only handles "segmentation"; bbox editing is #40. + bbox_ann = {"bbox": [0, 0, 100, 100], "category_name": "box"} + label.annotations = {"cell": [bbox_ann]} + assert label.start_polygon_edit((50, 50)) is None diff --git a/tests/unit/test_tiff_codec.py b/tests/unit/test_tiff_codec.py new file mode 100644 index 0000000..a5333b1 --- /dev/null +++ b/tests/unit/test_tiff_codec.py @@ -0,0 +1,86 @@ +""" +Unit tests for graceful handling of compressed TIFFs missing imagecodecs +(upstream issue #56). + +An LZW TIFF read raises ValueError when imagecodecs is absent; the app +must skip the file with a dialog instead of crashing, and must not leave +a half-added entry. +""" + +import pytest +from PyQt6.QtWidgets import QWidget, QListWidget, QComboBox + +import src.digitalsreeni_image_annotator.controllers.image_controller as ic_module +from src.digitalsreeni_image_annotator.controllers.image_controller import ( + ImageController, +) + +LZW_ERROR = " requires the 'imagecodecs' package" + + +class FakeMainWindow(QWidget): + pass + + +@pytest.fixture +def mw(qtbot): + window = FakeMainWindow() + qtbot.addWidget(window) + window.image_list = QListWidget(window) + window.image_filter_combo = QComboBox(window) + window.image_filter_combo.addItems( + ["All images", "Without annotations", "With annotations"] + ) + window.all_images = [] + window.all_annotations = {} + window.image_slices = {} + window.image_paths = {} + window.is_loading_project = False + window.auto_save = lambda: None + window.image_controller = ImageController(window) + return window + + +def test_lzw_tiff_without_codec_is_skipped_with_dialog(mw, monkeypatch): + dialogs = [] + monkeypatch.setattr( + ic_module.QMessageBox, "critical", + lambda *a, **k: dialogs.append(a), + ) + + def raise_lzw(_path): + raise ValueError(LZW_ERROR) + + mw.image_controller.load_multi_slice_image = raise_lzw + + # Must not raise. + mw.image_controller.add_images_to_list(["C:/data/scan.tif"]) + + assert len(dialogs) == 1 # one critical dialog shown + assert mw.all_images == [] # no half-added entry + assert mw.image_list.count() == 0 + assert "scan.tif" not in mw.image_paths + + +def test_is_missing_codec_error_matches(): + assert ImageController._is_missing_codec_error(ValueError(LZW_ERROR)) + assert ImageController._is_missing_codec_error( + ValueError("requires the 'imagecodecs' package") + ) + # A bare "compression" mention must NOT match — that would swallow + # unrelated errors behind a misleading "install imagecodecs" dialog. + assert not ImageController._is_missing_codec_error( + ValueError("unsupported compression scheme") + ) + + +def test_unrelated_value_error_is_reraised(mw, monkeypatch): + monkeypatch.setattr(ic_module.QMessageBox, "critical", lambda *a, **k: None) + + def raise_other(_path): + raise ValueError("corrupt dimension metadata") + + mw.image_controller.load_multi_slice_image = raise_other + + with pytest.raises(ValueError, match="corrupt dimension metadata"): + mw.image_controller.add_images_to_list(["C:/data/scan.tif"]) diff --git a/tests/unit/test_torch_utils.py b/tests/unit/test_torch_utils.py new file mode 100644 index 0000000..9f717d7 --- /dev/null +++ b/tests/unit/test_torch_utils.py @@ -0,0 +1,128 @@ +""" +Unit tests for core.torch_utils device resolution (upstream issue #57). + +A fake `torch` module is injected into sys.modules so the tests run the +same on machines with and without CUDA. +""" + +import sys +import types + +import pytest + +from src.digitalsreeni_image_annotator.core import torch_utils + + +@pytest.fixture(autouse=True) +def reset_cache(): + torch_utils._cached_result = None + torch_utils._warning_shown = False + yield + torch_utils._cached_result = None + torch_utils._warning_shown = False + + +@pytest.fixture +def fake_torch(monkeypatch): + """Install a configurable fake torch module; returns its cuda namespace.""" + cuda = types.SimpleNamespace( + is_available=lambda: False, + get_device_capability=lambda idx=0: (8, 6), + get_arch_list=lambda: ["sm_70", "sm_80", "sm_90", "compute_90"], + get_device_name=lambda idx=0: "Fake GPU", + ) + module = types.ModuleType("torch") + module.cuda = cuda + monkeypatch.setitem(sys.modules, "torch", module) + return cuda + + +def test_no_cuda_returns_cpu_without_warning(fake_torch): + assert torch_utils.resolve_torch_device() == ("cpu", None) + + +def test_supported_gpu_returns_cuda(fake_torch): + fake_torch.is_available = lambda: True + assert torch_utils.resolve_torch_device() == ("cuda", None) + + +def test_unsupported_pascal_gpu_falls_back_to_cpu(fake_torch): + fake_torch.is_available = lambda: True + fake_torch.get_device_capability = lambda idx=0: (6, 1) # GTX 1050 + device, warning = torch_utils.resolve_torch_device() + assert device == "cpu" + assert "sm_61" in warning + assert "sm_70" in warning + assert "Fake GPU" in warning + + +def test_oldest_supported_capability_is_accepted(fake_torch): + fake_torch.is_available = lambda: True + fake_torch.get_device_capability = lambda idx=0: (7, 0) + assert torch_utils.resolve_torch_device() == ("cuda", None) + + +def test_empty_arch_list_keeps_cuda(fake_torch): + # Defensive: if torch reports no compiled arches, don't second-guess it. + fake_torch.is_available = lambda: True + fake_torch.get_arch_list = lambda: [] + assert torch_utils.resolve_torch_device() == ("cuda", None) + + +def test_probe_failure_falls_back_to_cpu(fake_torch): + fake_torch.is_available = lambda: True + + def boom(idx=0): + raise RuntimeError("driver mismatch") + + fake_torch.get_device_capability = boom + device, warning = torch_utils.resolve_torch_device() + assert device == "cpu" + assert "driver mismatch" in warning + + +def test_missing_torch_returns_cpu(monkeypatch): + monkeypatch.setitem(sys.modules, "torch", None) # import torch → fails + assert torch_utils.resolve_torch_device() == ("cpu", None) + + +def test_result_is_cached(fake_torch): + fake_torch.is_available = lambda: True + assert torch_utils.resolve_torch_device() == ("cuda", None) + # Changing the fake afterwards must not change the cached decision. + fake_torch.is_available = lambda: False + assert torch_utils.resolve_torch_device() == ("cuda", None) + + +def test_parse_arch_list(): + assert torch_utils._parse_arch_list( + ["sm_70", "sm_80", "compute_90", "garbage", "sm_xx"] + ) == [70, 80] + + +def test_warning_dialog_shown_once(fake_torch, monkeypatch, qt_application): + fake_torch.is_available = lambda: True + fake_torch.get_device_capability = lambda idx=0: (6, 1) + + calls = [] + from PyQt6.QtWidgets import QMessageBox + monkeypatch.setattr( + QMessageBox, "warning", lambda *a, **k: calls.append(a) + ) + + torch_utils.maybe_warn_cpu_fallback(None) + torch_utils.maybe_warn_cpu_fallback(None) + assert len(calls) == 1 + + +def test_no_warning_dialog_when_device_ok(fake_torch, monkeypatch, qt_application): + fake_torch.is_available = lambda: True + + calls = [] + from PyQt6.QtWidgets import QMessageBox + monkeypatch.setattr( + QMessageBox, "warning", lambda *a, **k: calls.append(a) + ) + + torch_utils.maybe_warn_cpu_fallback(None) + assert calls == [] diff --git a/tests/unit/test_yaml_encoding.py b/tests/unit/test_yaml_encoding.py new file mode 100644 index 0000000..a66ddce --- /dev/null +++ b/tests/unit/test_yaml_encoding.py @@ -0,0 +1,33 @@ +""" +Regression guard for UTF-8 file encoding (upstream issue #44). + +Reading a COCO JSON that contains non-ASCII category names must succeed +regardless of the platform's default code page. Before the fix, open() +without encoding used cp1252 on Windows and crashed on these bytes. The +test writes genuine non-ASCII bytes (ensure_ascii=False) and asserts the +unicode survives the round-trip through import_coco_json's open(). +""" + +import json + +from src.digitalsreeni_image_annotator.io.import_formats import import_coco_json + +UNICODE_CLASS = "Zellkörper-Ü-中" # German + a CJK char + + +def test_import_coco_json_preserves_unicode_class(tmp_path): + coco = { + "images": [{"id": 1, "file_name": "img.png", "width": 10, "height": 10}], + "categories": [{"id": 1, "name": UNICODE_CLASS}], + "annotations": [ + {"id": 1, "image_id": 1, "category_id": 1, "bbox": [0, 0, 5, 5]} + ], + } + json_path = tmp_path / "annotations.json" + with open(json_path, "w", encoding="utf-8") as f: + json.dump(coco, f, ensure_ascii=False) + + imported_annotations, image_info = import_coco_json(str(json_path), {}) + + assert UNICODE_CLASS in imported_annotations["img.png"] + assert image_info[1]["file_name"] == "img.png"