diff --git a/CLAUDE.md b/CLAUDE.md index ae5b906..a7897e0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -22,7 +22,7 @@ python -m src.digitalsreeni_image_annotator.main Python 3.10+ | PyQt6 6.7+ | Ultralytics 8.3.27 (SAM 2) | NumPy | OpenCV | Shapely -**Test suite**: `tests/` (pytest + pytest-qt). 65 tests pass on PyQt6. +**Test suite**: `tests/` (pytest + pytest-qt). 94 tests pass on PyQt6. ## Documentation @@ -40,23 +40,44 @@ See [docs/README.md](docs/README.md) for full documentation index. ``` src/digitalsreeni_image_annotator/ -├── main.py # Entry point -├── annotator_window.py # ImageAnnotator - main window, project state -├── image_label.py # ImageLabel - display, mouse events, rendering -├── sam_utils.py # SAMUtils - SAM model management -├── utils.py # Utility functions -├── export_formats.py # COCO, YOLO, Pascal VOC exporters -├── import_formats.py # COCO, YOLO importers -└── [tool dialogs] # Standalone utility windows +├── main.py # Entry point +├── annotator_window.py # ImageAnnotator - thin orchestrator +├── app_settings.py # QSettings UI prefs: ui_font_pt, dark_mode (ADR-020) +├── utils.py # Utility functions (calculate_area, …) +├── __init__.py # Public API re-exports +│ +├── core/ # constants, annotation_utils, image_utils +├── controllers/ # 7 controllers (project, image, sam, dino, +│ # yolo, annotation, class) + io_controller +├── widgets/ +│ ├── image_label.py # ImageLabel canvas widget (dispatcher) +│ ├── canvas_context.py # CanvasContext read accessor (ADR-018) +│ └── tools/ # Per-tool handlers (ADR-019): rectangle, +│ # polygon, paint, eraser +├── inference/ # sam_utils.py, dino_utils.py +├── io/ # export_formats.py, import_formats.py +├── ui/ # menu_bar, sidebar, shortcuts, theme, stylesheets +└── dialogs/ # Standalone tool dialogs (statistics, + # splitter, augmenter, … 16 files) ``` ## Key Classes | Class | File | Responsibility | |-------|------|----------------| -| `ImageAnnotator` | annotator_window.py | Main window, state (`all_annotations`, `class_mapping`, etc.) | -| `ImageLabel` | image_label.py | Image display, zoom/pan, annotation interaction | -| `SAMUtils` | sam_utils.py | Load SAM models, run inference | +| `ImageAnnotator` | annotator_window.py | Thin orchestrator — holds controllers, wires signals, delegates almost everything | +| `ImageLabel` | widgets/image_label.py | Canvas display, zoom/pan, event dispatch to tool handlers | +| `CanvasContext` | widgets/canvas_context.py | Narrow read view of main-window state for ImageLabel (ADR-018) | +| `ToolHandler` (+ 4 subclasses) | widgets/tools/ | Per-tool mouse/key handling (rectangle, polygon, paint, eraser) (ADR-019) | +| `ProjectController` | controllers/project_controller.py | `.iap` save/load, auto-save, `is_loading_project` guard | +| `ImageController` | controllers/image_controller.py | TIFF/CZI loading, multi-dim slicing, image/slice switching | +| `AnnotationController` | controllers/annotation_controller.py | Annotation CRUD, sort, edit-mode, finish_polygon/rectangle | +| `ClassController` | controllers/class_controller.py | Class add/delete/rename/colour/visibility | +| `SAMController` | controllers/sam_controller.py | SAM model picker, debounce, in-flight guard (ADR-013) | +| `DINOController` | controllers/dino_controller.py | DINO single/batch detection, batch review, temp-class workflow | +| `YOLOController` | controllers/yolo_controller.py | YOLO training menu + prediction wiring | +| `SAMUtils` | inference/sam_utils.py | Load SAM models, run inference | +| `DINOUtils` | inference/dino_utils.py | Grounding-DINO model load + inference | See [Building Block View](docs/05_building_block_view.md) for detailed class documentation. @@ -68,7 +89,9 @@ See [Building Block View](docs/05_building_block_view.md) for detailed class doc 2. Set `image_label.current_tool` on click 3. Handle mouse events in `ImageLabel` (mousePressEvent, mouseMoveEvent) 4. Render in `ImageLabel.paintEvent()` -5. Call `main_window.add_annotation()` to commit +5. Commit via `self.annotationCommitted.emit(annotation_dict)` — the + orchestrator routes it to `AnnotationController.add_annotation_to_list` + (see ADR-018) ### Working with Annotations @@ -141,7 +164,7 @@ See [Runtime View](docs/06_runtime_view.md#multi-dimensional-image-loading) for | Dark mode contrast | No hardcoded `background:` / `color:` in widget `setStyleSheet(...)` | Hardcoded greys override `soft_dark_stylesheet.py` and punch bright boxes into the sidebar. Add a global rule first, then write the widget. See [No Hardcoded Colors Rule](docs/08_crosscutting_concepts.md#dark-mode--no-hardcoded-colors-rule). | | DINO review state | `image_label.temp_annotations` is a single field, **not** per-image — must be re-synced from `dino_batch_results` on every image/slice switch via `_refresh_dino_temp_for_current` | Otherwise the first image's masks bleed onto every subsequent slice during navigation. See [DINO Temp Annotations](docs/08_crosscutting_concepts.md#dino-temp-annotations--single-field-many-images). | | DINO batch over stacks | Use `_collect_dino_batch_work_items()` to flatten regular images + every loaded slice; don't iterate `self.all_images` directly | Multi-dim images appear in `all_images` as a single entry — slices live in `self.image_slices[base_name]` and were silently skipped. | -| DINO Enter/Escape during review | Application-wide `_DINOReviewEventFilter`, gated on pending temp_annotations + no modal + no text input | `QListWidget` consumes Enter for `itemActivated` before `ImageLabel.keyPressEvent` sees it. See [ADR-015](docs/09_architecture_decisions.md#adr-015-application-wide-event-filter-for-dino-review-shortcuts). | +| DINO Enter/Escape during review | Application-wide `DINOReviewEventFilter`, gated on pending temp_annotations + no modal + no text input | `QListWidget` consumes Enter for `itemActivated` before `ImageLabel.keyPressEvent` sees it. See [ADR-015](docs/09_architecture_decisions.md#adr-015-application-wide-event-filter-for-dino-review-shortcuts). | | Auto-accept dropdown | Honored by **both** `run_dino_detection_single` and `run_dino_detection_batch` | Easy to forget in the single path because the combo is labeled "batch". | | GPU model unload | `model.cpu()` → `gc.collect()` → `torch.cuda.empty_cache()` + `ipc_collect()` + `synchronize()` — full reclaim requires app restart due to per-process CUDA context | Setting refs to None alone leaves circular refs pinned and shows zero Task Manager drop. See [Releasing Model GPU Memory](docs/08_crosscutting_concepts.md#releasing-model-gpu-memory). | | Export image-path lookup | Exact-key match first, substring fallback only | `"bee.jpg" in "honeybee.jpg"` is True — substring-only matching writes the wrong file. See [Export Format Filename Matching](docs/08_crosscutting_concepts.md#export-format-filename-matching). | @@ -161,17 +184,18 @@ See [Runtime View](docs/06_runtime_view.md#multi-dimensional-image-loading) for | 6 | Commit: `feat: Description` or `fix: Description` | Clear, descriptive messages | | 7 | Push & create PR | `git push origin feature/branch` | -### Testing Checklist (Manual — No Automated Tests) +### Testing Checklist Before opening a PR, verify at minimum: -1. **Launch the app** — no import errors, main window renders -2. **Golden path** — perform the new feature's primary workflow end-to-end -3. **Edge cases** — empty state, cancel/escape, large images, missing model files -4. **Dark mode** — toggle and check rendering of new UI elements -5. **Save/load roundtrip** — if the feature touches `.iap` project files, save, close, reopen, verify state restored -6. **Adjacent features** — verify no regression in SAM, annotation tools, export formats -7. **Inference features** — if touching `sam_utils.py` or `dino_utils.py`, verify the model loads end-to-end (no silent load failure), returns masks/boxes, and the UI stays responsive during inference (timers, redraws, progress dialog cancels keep firing — see ADR-013) +1. **Smoke tests pass** — `pytest tests/integration/test_smoke.py -v`. This includes the AST-based `test_annotator_window_inline_imports_are_resolvable` which catches stale relative imports inside function bodies after any module move (see ADR-016). A launch that "looks clean" is NOT sufficient — inline imports fail only when the function is called at runtime. +2. **Launch the app** — no import errors, main window renders +3. **Golden path** — perform the new feature's primary workflow end-to-end +4. **Edge cases** — empty state, cancel/escape, large images, missing model files +5. **Dark mode** — toggle and check rendering of new UI elements +6. **Save/load roundtrip** — if the feature touches `.iap` project files, save, close, reopen, verify state restored +7. **Adjacent features** — verify no regression in SAM, annotation tools, export formats +8. **Inference features** — if touching `sam_utils.py` or `dino_utils.py`, verify the model loads end-to-end (no silent load failure), returns masks/boxes, and the UI stays responsive during inference (timers, redraws, progress dialog cancels keep firing — see ADR-013) ### arc42 Documentation Update Rules @@ -213,6 +237,8 @@ See [Risks and Technical Debt](docs/11_risks_and_technical_debt.md) for full lis | Global | Action | |--------|--------| | Ctrl+N/O/S | New/Open/Save Project | +| Ctrl+Shift+= / Ctrl+Shift+- | UI font bigger/smaller (8-24pt, persisted via QSettings) | +| Ctrl+Shift+0 | Reset UI font size | | F1 | Help | | Canvas | Action | diff --git a/README.md b/README.md index 104ab92..a541bb2 100644 --- a/README.md +++ b/README.md @@ -128,9 +128,8 @@ You should see `True` and your GPU name. For other platforms or driver combinati - To use SAM2-assisted annotation: - Select a model from the "Pick a SAM Model" dropdown. It's recommended to use smaller models like SAM2 tiny or SAM2 small. SAM2 large is not recommended as it may crash the application on systems with limited resources. - Note: When you select a model for the first time, the application needs to download it. This process may take a few seconds to a minute, depending on your internet connection speed. Subsequent uses of the same model will be faster as it will already be cached locally, in your working directory. - - Click the "SAM-Assisted" button to activate the tool. - - Draw a rectangle around objects of interest to allow SAM2 to automatically detect objects. - - Note that SAM2 provides various outputs with different scores, and only the top-scoring region will be displayed. If the desired result isn't achieved on the first try, draw again. + - Click the "SAM-box" button and draw a rectangle around an object of interest, or click the "SAM-points" button and left-click points inside the object (right-click adds negative points to exclude regions). + - SAM2 displays the top-scoring mask as a temporary prediction — press Enter to accept it or Esc to discard it. If the desired result isn't achieved on the first try, draw the box again or adjust the points. - For low-quality images where SAM2 may not auto-detect objects, manual tools may be necessary. - When SAM2 auto-detect partial objects, use polygon or paint brush tools to manually define the remaining region and use the Merge tool to combine both annotations into one. - When SAM2 over-annotates objects, extending the annotation beyond object's boundaries, use the Eraser tool to clean up the edges. diff --git a/docs/05_building_block_view.md b/docs/05_building_block_view.md index ab5359d..a957b10 100644 --- a/docs/05_building_block_view.md +++ b/docs/05_building_block_view.md @@ -28,10 +28,32 @@ ``` src/digitalsreeni_image_annotator/ ├── main.py # Entry point, initializes QApplication -├── annotator_window.py # ImageAnnotator - main window -├── image_label.py # ImageLabel - custom display widget -├── sam_utils.py # SAMUtils - SAM model management -└── utils.py # Utility functions + ├── annotator_window.py # ImageAnnotator - main window orchestrator + ├── app_settings.py # QSettings-backed UI prefs (font size, dark mode) — ADR-020 + ├── utils.py # Cross-cutting utilities + ├── core/ # Constants, annotation utils, image utils + │ ├── constants.py + │ └── annotation_utils.py + ├── widgets/ + │ ├── image_label.py # ImageLabel - canvas widget; dispatcher + │ ├── canvas_context.py # CanvasContext - narrow read view (ADR-018) + │ └── tools/ # Per-tool handlers (ADR-019) + │ ├── base.py # ToolHandler base + │ ├── rectangle_tool.py + │ ├── polygon_tool.py + │ ├── paint_tool.py + │ └── eraser_tool.py + ├── controllers/ # Project/Image/SAM/DINO/YOLO/Annotation/Class + ├── inference/ # sam_utils.py, dino_utils.py + │ ├── sam_utils.py + │ └── dino_utils.py + ├── io/ # export_formats.py, import_formats.py + │ ├── export_formats.py + │ └── import_formats.py + ├── ui/ # menu_bar, sidebar, theme, stylesheets + │ ├── default_stylesheet.py + │ └── soft_dark_stylesheet.py + └── dialogs/ # Standalone tool dialogs ``` ### ImageAnnotator (annotator_window.py) @@ -56,27 +78,51 @@ current_slice: str # Currently displayed slice - `export_annotations()`: Export to various formats - `import_annotations()`: Import from COCO/YOLO -### ImageLabel (image_label.py) +### ImageLabel (widgets/image_label.py) -**Responsibility**: Image display and annotation interaction +**Responsibility**: Canvas widget — image display, navigation +(zoom/pan), committed-annotation rendering, SAM bbox/points overlays, +DINO temp-annotation rendering, polygon edit mode (modal). Per-tool +mouse/key handling lives in `widgets/tools/*` (see ADR-019); ImageLabel +dispatches events to the active handler. **Key Attributes**: ```python -current_tool: str # Active annotation tool +current_tool: str # Active annotation tool (route via set_active_tool) zoom_factor: float # Current zoom level annotations: dict # Displayed annotations class_colors: dict # Class color mapping -temp_paint_mask: np.ndarray # Temporary paint strokes +temp_paint_mask: np.ndarray # In-progress paint stroke (owned by PaintBrushTool) +temp_eraser_mask: np.ndarray # In-progress eraser stroke (owned by EraserTool) +current_rectangle: list # In-progress rectangle (owned by RectangleTool) +current_annotation: list # In-progress polygon points (owned by PolygonTool) sam_positive_points: list # SAM positive points sam_negative_points: list # SAM negative points +editing_polygon: dict | None # Polygon being edited (modal sub-state) +_tools: dict[str, ToolHandler] # Per-tool handlers +_ctx: CanvasContext # Narrow read view of main-window state (ADR-018) ``` **Key Methods**: -- `mousePressEvent()`: Handle mouse clicks for annotation -- `mouseMoveEvent()`: Handle mouse dragging -- `paintEvent()`: Render image and annotations -- `zoom_in()`, `zoom_out()`: Zoom controls -- `start_painting()`, `start_erasing()`: Brush tools +- `mousePressEvent()` / `mouseMoveEvent()` / `mouseReleaseEvent()` / + `mouseDoubleClickEvent()`: Ctrl-modifier pan/zoom branches first, + then SAM/edit-mode branches, then dispatch to + `active_tool_handler.on_mouse_X()`. +- `keyPressEvent()`: Enter / Escape / Delete / brush-size keys. Modal + branches (DINO temp, sam_points, sam_box, editing_polygon) + consume first; otherwise routed to `handler.on_enter()` / + `on_escape()`. +- `paintEvent()`: image → committed annotations → editing polygon → + SAM overlays → all tool handlers' `paint_overlay()` → tool-size + indicator → DINO temp annotations. +- `set_active_tool(name)`: switches `current_tool` and gives the + previous handler a chance to clean up via `deactivate()`. +- `check_unsaved_changes()`: iterates handlers' `has_unsaved_state()` + and prompts the user. + +**Communication**: emits ~20 Qt signals connected to controller slots +in `ImageAnnotator._connect_image_label_signals` (ADR-018). Reads +main-window state through `CanvasContext`. ### SAMUtils (sam_utils.py) @@ -147,6 +193,33 @@ DINO's xyxy boxes feed directly into `SAMUtils.apply_sam_predictions_batch()`, which returns segmentation polygons (xywh bbox is derived from the polygon at export time — see [Cross-cutting Concepts](08_crosscutting_concepts.md)). +## Level 3: Controllers + +Seven `QObject` controllers plus an `io_controller` helper module +carve `ImageAnnotator` into single-responsibility owners that the +orchestrator delegates to. Each `QObject` controller holds `self.mw += main_window` and owns one slice of behaviour; the +`io_controller` is a thin module of UI-wrapper functions around the +pure `io/` formatters and does not need to hold state. The +orchestrator keeps pass-through methods so external call sites +(menus, signal wiring, the test harness) don't need to reach into +the controller graph. + +| Controller | Responsibility | +|------------|----------------| +| `ProjectController` | `.iap` save/load, auto-save, backup/restore, missing-image prompts, window-title sync. Owns the `is_loading_project` autosave guard (load/save round-trip safety, v0.8.12). | +| `ImageController` | Open / load / switch images and slices. TIFF + CZI loaders, the multi-dim `DimensionDialog`, the `[-ndim:]` axis-slice bug fix from the v0.9.0 era. | +| `AnnotationController` | Annotation CRUD, list sorting, highlight, edit-mode entry/exit, `finish_polygon`, `finish_rectangle`, `replace_annotations` (eraser path). Validates writes before mutating `all_annotations`. | +| `ClassController` | Class add / delete / rename / colour / visibility. `update_slice_list_colors`, `is_class_visible`. | +| `SAMController` | SAM box/points tool lifecycle, debounce timer, `_sam_inference_in_flight` re-entrancy guard (ADR-013), model picker. | +| `DINOController` | Single + batch detection, batch review navigation, temp-annotation accept/reject, custom-model browse, `DINOReviewEventFilter` ownership (ADR-015). | +| `YOLOController` | Training menu, `TrainingThread`, prediction dialog, result processing. | +| `io_controller` *(module-level functions, not a class)* | Thin UI wrappers around the pure `io/export_formats.py` and `io/import_formats.py` modules. | + +Communication: `ImageLabel` does not import controllers directly — +it emits Qt signals (ADR-018) that the orchestrator connects to +controller slots in `_connect_image_label_signals()`. + ## Level 3: Export/Import Subsystem ### Export Formats (export_formats.py) @@ -270,7 +343,9 @@ ImageAnnotator (main window) └── launches ──> Tool Dialogs (utilities) ImageLabel - ├── references ──> ImageAnnotator (callbacks) + ├── emits signals to ──> ImageAnnotator (writes; see ADR-018) + ├── reads via ──> CanvasContext (paint/eraser size, current class, + │ class_mapping, is_class_visible, scroll_area, …) └── uses ──> utils (area, bbox calculations) SAMUtils diff --git a/docs/06_runtime_view.md b/docs/06_runtime_view.md index 45c7125..42337ca 100644 --- a/docs/06_runtime_view.md +++ b/docs/06_runtime_view.md @@ -66,7 +66,7 @@ User presses Enter └─> update() to show final annotation ``` -## SAM-Assisted Annotation +## SAM-Assisted Annotation (SAM-box / SAM-points) ``` User selects SAM model diff --git a/docs/08_crosscutting_concepts.md b/docs/08_crosscutting_concepts.md index 42efbd5..7e4ea7e 100644 --- a/docs/08_crosscutting_concepts.md +++ b/docs/08_crosscutting_concepts.md @@ -221,6 +221,57 @@ which on Windows means barely-visible radio-button indicators and white-on-white headers (the dataset splitter radio buttons hit this before they were styled). +## UI Font Zoom (Low-Vision Mode) + +### Single Source of Truth: `ui_font_pt` + +All UI text size flows from one integer, `ImageAnnotator.ui_font_pt` +(8–24pt, default 10, clamped by `app_settings.clamp_font_pt`). The +Settings → Font Size presets (Small…XXL) jump to fixed values; +Ctrl+Shift+= / Ctrl+Shift+- step ±1pt; Ctrl+Shift+0 resets. Every +change goes through `theme.set_font_pt`, which clamps, re-applies the +theme, persists via QSettings and syncs the preset menu checkmarks +(no preset is checked at an in-between size). + +### Appended QSS Overrides, Not Templated Stylesheets + +`soft_dark_stylesheet.py` / `default_stylesheet.py` stay static +strings. `apply_theme_and_font` appends scaled rules *after* the +static sheet — later rules of equal specificity win in QSS — for the +body font, `.section-header` and checkbox/radio indicator sizes. The +overrides scale the legacy px values (14px header, 14px indicators, +8px radio radius, 11px/10px compact DINO panel) by `ui_font_pt / 10` +and stay in **px**, so at the default 10pt they reproduce the legacy +look exactly. Widgets that want smaller-than-body text (e.g. the DINO +threshold table / phrase panel) must not set their own `font-size` — +they get a type- or objectName-targeted rule in the appended block +instead, so "compact" still scales. Do not +hardcode `font-size` in widget `setStyleSheet(...)` calls: it overrides +the global rule and the widget stops scaling (same failure mode as the +No Hardcoded Colors rule below; the DINO sidebar captions hit this). + +### Canvas Overlay Scaling: `ui_scale` + +`apply_theme_and_font` pushes `ui_font_pt / 10.0` to +`ImageLabel.set_ui_scale`. Overlay sizes (annotation label fonts, SAM +point radii, pen widths, edit-point handles, hit-test tolerances) use +the helpers `ImageLabel._pen_w(base)` / `_overlay_font(base)`, which +multiply by `ui_scale` and divide by `zoom_factor` — UI zoom and image +zoom stay orthogonal: overlays grow with the font setting but remain +constant-size on screen across image zoom. At the default 10pt, +`ui_scale == 1.0` and rendering is pixel-identical to the legacy code. +Exception: the SAM point-marker radii are drawn under +`painter.scale(zoom)` without zoom compensation (pre-existing +behaviour) and only multiply by `ui_scale`. + +### Persistence via QSettings + +`app_settings.py` stores `ui/font_pt` and `ui/dark_mode` in +`QSettings("DigitalSreeni", "ImageAnnotator")` (registry under HKCU on +Windows). These are per-user preferences, deliberately *not* part of +the `.iap` project file. All functions take an optional `QSettings` +instance so tests inject an INI-backed temp file. + ## Thread Safety for YOLO Training ### Training Thread @@ -313,7 +364,11 @@ def generate_slice_name(filename, t, z, c, s): | Ctrl+O | Open Project | | Ctrl+S | Save Project | | Ctrl+W | Close Project | -| Ctrl+Shift+S | Annotation Statistics | +| Ctrl+Shift+S | Save Project As | +| Ctrl+Alt+S | Annotation Statistics | +| Ctrl+Shift+= (or Ctrl++) | Increase UI font size | +| Ctrl+Shift+- (or Ctrl+-) | Decrease UI font size | +| Ctrl+Shift+0 | Reset UI font size | | F1 | Help Window | ### Canvas Shortcuts @@ -360,7 +415,7 @@ Consequences this codebase has tripped over: slice_list / image_list / a button — `QListWidget` consumes Enter for itemActivated before `ImageLabel.keyPressEvent` ever sees it. Solved with an application-wide event filter - (`_DINOReviewEventFilter`) that fires only while + (`DINOReviewEventFilter`) that fires only while `temp_annotations` has DINO items and skips modal dialogs and text inputs. Setting `image_label.setFocus()` synchronously inside `_show_dino_batch_review` was not enough — Qt's focus handling @@ -444,3 +499,54 @@ if image_path is None: The substring fallback is kept for backward compatibility with old projects that may have stored normalised image names (e.g. without extension); new code should prefer the exact-key path. + +## Canvas Decoupling — Signals + CanvasContext + +`ImageLabel` (the canvas widget) does **not** hold a reference to +`ImageAnnotator`. Communication is split: + +- **Writes** (committing an annotation, requesting a SAM prediction, + asking for tools to be re-enabled, etc.) leave the widget as Qt + `pyqtSignal` emissions. The signal block at the top of `ImageLabel` + documents every outbound interaction. `ImageAnnotator` connects + each signal to the right controller slot once, in + `_connect_image_label_signals` (called at the end of + `ImageAnnotator.__init__`). +- **Reads** (`paint_brush_size`, `current_class`, `class_mapping`, + `is_class_visible`, `scroll_area`, etc.) go through a + `CanvasContext` object passed in via + `image_label.set_context(CanvasContext(self))`. + `CanvasContext` wraps the main window rather than copying state, + so updates made by controllers are visible on the next read. + +**Why both mechanisms.** Signals are inherently one-way (fire and +forget); a synchronous read like "is this class visible" needs a +return value, which signals don't provide. Trying to express reads +as request/response signals adds latency and ordering bugs. The +`CanvasContext` accessor list is small (~10 methods) and stable. + +**Rules for adding traffic in either direction**: + +- New write from canvas → orchestrator: declare a `pyqtSignal` on + `ImageLabel`, add a slot on a controller, wire it in + `_connect_image_label_signals`. Do not add a back-reference to + `ImageAnnotator`. +- New read from canvas → orchestrator: add a method on + `CanvasContext`. Do not expose `_ctx._mw` directly. + +**Synchronous-emit ordering**. Qt's default `AutoConnection` runs the +slot synchronously when the sender and receiver share a thread (true +for everything on the GUI thread). Code that emits a signal and then +reads state expected to be updated by it is correct — the slot has +already run by the time `.emit()` returns. This is load-bearing for +`accept_temp_annotations`, where `classRequested` must complete +before the subsequent class lookup. + +**Batch save signal**. Paint commits and accept-temp commits emit +`annotationCommitted` per annotation but `annotationsBatchSaved` only +once at the end. The single batch save preserves O(1) `.iap` writes +per user action; replacing it with a per-annotation save would turn +paint commits into O(N). See ADR-018. + +See ADR-018 in `09_architecture_decisions.md` for the rationale and +the full pattern. diff --git a/docs/09_architecture_decisions.md b/docs/09_architecture_decisions.md index 755f6dd..ec1306b 100644 --- a/docs/09_architecture_decisions.md +++ b/docs/09_architecture_decisions.md @@ -262,7 +262,7 @@ **Context**: ADR-011 introduced a subprocess hop for every SAM and DINO inference call to work around a PyQt5 + Torch DLL load-order conflict on Windows + Python 3.14. The workaround cost a fresh `python sam_worker.py` / `dino_worker.py` spawn per inference (~1-2 s warm latency, model reloaded from disk on every call) plus a temp-PNG marshal of the image. -Migrating the GUI from PyQt5 to PyQt6 (same PR) eliminates the DLL conflict — verified by `tools/check_pyqt6_torch_coexistence.py` importing PyQt6 → torch → transformers → ultralytics cleanly in one process on Windows+Py3.14 (the original failure case) and the Linux/macOS test matrix. +Migrating the GUI from PyQt5 to PyQt6 (same PR) was expected to eliminate the DLL conflict — initially verified by `tools/check_pyqt6_torch_coexistence.py` importing PyQt6 packages → torch cleanly. However, further testing (see [ADR-017](#adr-017-eager-torch-import-in-mainpy-before-qapplication-creation)) discovered that the conflict resurfaces when Qt's **platform plugin** is loaded before torch, which happens inside `QApplication()`. The practical workaround is to import torch eagerly before creating the QApplication. **Decision**: Run SAM and DINO inference directly inside the main Python process. Keep the model objects on the `SAMUtils` / `DINOUtils` singletons so they persist across calls. Wrap each inference in a short-lived `QThread` to keep the UI thread responsive; the public API blocks the caller via a nested `QEventLoop` so call sites in `annotator_window.py` stay synchronous-looking. @@ -297,7 +297,7 @@ Migrating the GUI from PyQt5 to PyQt6 (same PR) eliminates the DLL conflict — **Status**: Accepted **Context**: The project shipped on PyQt5 5.15+ (ADR-001) from inception. Two pressures combined to motivate a migration: -1. The PyQt5 + Torch DLL load-order conflict on Windows + Python 3.14 (ADR-011) forced an entire subprocess isolation layer (`sam_worker.py`, `dino_worker.py`, `check_worker_isolation.py`) that added ~1-2 s latency per inference. The conflict only manifests on PyQt5 — Qt6's packaging reshuffle eliminates it. +1. The PyQt5 + Torch DLL load-order conflict on Windows + Python 3.14 (ADR-011) forced an entire subprocess isolation layer. It was hypothesised that Qt6's packaging would eliminate the conflict entirely, but real-world testing (see [ADR-017](#adr-017-eager-torch-import-in-mainpy-before-qapplication-creation)) showed the conflict persists when Qt's platform plugin is loaded before torch, regardless of whether PyQt5 or PyQt6 is the binding. The migration still removes PyQt5-specific issues (XCB plugin paths, enum namespacing drift). 2. PyQt5 is in maintenance mode. PyQt6 is the actively developed line, gets new Qt6.x features, and has better Linux native integration (XCB plugin paths in particular). **Decision**: Migrate the GUI binding from PyQt5 (`>=5.15.0`) to PyQt6 (`>=6.7.0`). Land in a single PR alongside the subprocess-removal work (ADR-013), gated behind `tools/check_pyqt6_torch_coexistence.py` to confirm the DLL conflict is actually gone on Windows + Python 3.14. @@ -316,7 +316,7 @@ Migrating the GUI from PyQt5 to PyQt6 (same PR) eliminates the DLL conflict — - ✅ All `.exec_()` call sites in `src/` migrated to `.exec()` in the v0.9.0 fix-pack — the PyQt5 alias is gone from this codebase. **Verification**: -- `tools/check_pyqt6_torch_coexistence.py` imports PyQt6 → torch → torchvision → transformers → ultralytics in that order. Run before merging on the Windows + Python 3.14 target. +- `tools/check_pyqt6_torch_coexistence.py` tests both import orders. The production order (torch first, then `QApplication`) must pass. The Qt-first order is the known-failing case and is checked only to document the environment. Run before merging on the Windows + Python 3.14 target. - 65 tests pass on the new binding under `QT_QPA_PLATFORM=offscreen`. - Full app constructs and renders headlessly; snake-game easter egg validates the `QDesktopWidget` → `QGuiApplication.primaryScreen()` replacement. @@ -352,7 +352,7 @@ Three options were considered: when DINO temp_annotations are pending, and only when the focused widget is not a text input and no modal dialog is active. -**Decision**: Option 3. Implement `_DINOReviewEventFilter`, install it +**Decision**: Option 3. Implement `DINOReviewEventFilter`, install it on `QApplication.instance()` once at startup, and gate the interception on three conditions: pending DINO temp_annotations, no active modal widget, focus not on `QLineEdit`/`QTextEdit`. @@ -371,13 +371,374 @@ no active modal widget, focus not on `QLineEdit`/`QTextEdit`. multiple top-level filters. **Related**: -- Implementation: `annotator_window.py` (`_DINOReviewEventFilter` - class, `installEventFilter` call in `__init__`). +- Implementation: `DINOReviewEventFilter` class in + `controllers/dino_controller.py` (moved there in Phase 4b); + `installEventFilter` call in `ui/shortcuts.py:install_event_filters`, + invoked from `ImageAnnotator.__init__` (moved there in Phase 8). - Cross-cuts: documented in [Cross-cutting Concepts → DINO Temp Annotations](08_crosscutting_concepts.md#dino-temp-annotations--single-field-many-images). --- +## ADR-018: Decouple ImageLabel from ImageAnnotator via Signals + CanvasContext + +**Status**: Accepted (Phase 6 of the modular refactor) + +**Context**: Before Phase 6, `ImageLabel.set_main_window(main_window)` +injected the orchestrator into the canvas widget, and the widget poked +~50 sites on `main_window` directly — both reading state +(`paint_brush_size`, `class_mapping`, `current_class`, `scroll_area`, +`current_slice`, `image_file_name`) and mutating it +(`all_annotations[name] = …`, `add_class(…)`, +`update_annotation_list()`, `save_current_annotations()`, +`update_slice_list_colors()`, `schedule_sam_prediction()`, +`zoom_in()`, `enable_tools()`, etc.). The coupling made: +- ImageLabel impossible to test in isolation without a + whole-`ImageAnnotator` fixture. +- Every controller extraction (Phases 3–5) leak through `main_window` + delegation pass-throughs, because deleting them would break the + widget. +- The Phase 7 per-tool split (paint / eraser / polygon / rectangle + handler classes) impractical, because each handler would need the + same `main_window` reference and would multiply the coupling. + +Three options were considered: + +1. **Protocol / duck-typed callback object** — pass a small protocol + with the methods ImageLabel needs. Strict, type-safe, but writes + are still synchronous direct calls; the widget still knows the + exact method names on the orchestrator. +2. **Defer the fix** — leave `main_window` for one more phase, accept + the debt. Cheapest, but each subsequent refactor pays the cost. +3. **Qt signals for every write + a narrow read accessor object** — + ImageLabel emits typed signals; the orchestrator connects each to + a controller slot during `__init__`. Reads go through a + `CanvasContext` object with method-style accessors. + +**Decision**: Option 3. ImageLabel declares ~20 `pyqtSignal`s covering +annotation lifecycle, SAM, class, tool/UI state, navigation, and +batch finalisation. Reads go via a `CanvasContext` instance passed in +through `set_context(ctx)`. The previous `set_main_window` / +`self.main_window` field is removed entirely. + +The connection block lives in `ImageAnnotator._connect_image_label_signals`, +called once at the end of `__init__` after every controller exists. +`CanvasContext` wraps the main window rather than copying state, so +the source of truth stays on `ImageAnnotator` and controllers see +their writes reflected on the next read. + +**Consequences**: +- ✅ ImageLabel has zero `main_window` references; signals form the + documented public write surface at the top of the class. +- ✅ ImageLabel is now testable in isolation by connecting signals + to stub slots; no controller fixture needed. +- ✅ Phase 7 (per-tool handlers) can carve `mousePressEvent` / + `mouseMoveEvent` etc. without each handler needing the orchestrator. +- ✅ Signal connections are explicit and grep-able — searching for + `il.annotationCommitted.connect` finds the single wiring site. +- ⚠️ Two parallel mechanisms (signals for writes, `CanvasContext` for + reads) need to be kept in step. The widget's signal block and + `_connect_image_label_signals` must stay in sync; a missing + connection is a silent no-op write. +- ⚠️ Signal connections rely on Qt's default `AutoConnection` semantics, + which is synchronous within a single thread. Consumers that depend + on a write taking effect before the next read (e.g. `classRequested` + emit followed by `_ctx.class_id(name)` read) must stay on the GUI + thread. +- ⚠️ The synchronous batch-save signal (`annotationsBatchSaved`) + preserves the original O(1)-save-per-batch behaviour. Replacing it + with per-annotation save would silently turn paint commits into + O(N) saves. Future refactors must keep the batch boundary. + +**Pattern for adding a new ImageLabel → orchestrator interaction**: + +1. Add a `pyqtSignal()` to `ImageLabel`. +2. Add a slot method on a controller (or main window) with matching + signature. +3. Wire it in `_connect_image_label_signals`. +4. Replace the previous direct call site in ImageLabel with + `self..emit()`. + +**Pattern for adding a new read accessor**: + +1. Add a method on `CanvasContext` returning the value. +2. Use `self._ctx.()` at the read site in ImageLabel. + +**Related**: +- Implementation: `widgets/canvas_context.py`, + `widgets/image_label.py` (signal block lines 42–70), + `annotator_window.py:_connect_image_label_signals`. +- Cross-cuts: documented in + [Cross-cutting Concepts → Canvas Decoupling](08_crosscutting_concepts.md#canvas-decoupling--signals--canvascontext). +- Predecessor pattern: ADR-015 (DINO event filter) showed that + ImageLabel can't reliably observe global keyboard state without + help; ADR-018 generalises "explicit interaction surface, narrow + read surface" to all canvas ↔ orchestrator traffic. + +--- + +## ADR-019: Per-Tool Handler Classes inside ImageLabel + +**Status**: Accepted (Phase 7 of the modular refactor) + +**Context**: After Phase 6, `ImageLabel` no longer held a back-reference +to `ImageAnnotator`, but it still embedded four distinct annotation +tools (polygon, rectangle, paint_brush, eraser) as if/elif branches +spread across six event methods (`mousePressEvent`, `mouseMoveEvent`, +`mouseReleaseEvent`, `mouseDoubleClickEvent`, `keyPressEvent`, +`paintEvent`). Each tool also owned helper methods on the widget +(`start_painting`, `commit_paint_annotation`, `commit_eraser_changes`, +`finish_polygon`, `cancel_current_annotation`, …). Adding a new tool +meant touching all six event methods plus the widget's helper layer, +and the file had reached ~1,240 LOC. + +Three options were considered: + +1. **Keep tools as if/elif branches** — cheapest, but the widget keeps + accruing every new tool's behaviour. +2. **Per-tool widget subclass** (one `QWidget` per tool, swap on tool + change) — too heavy: tool switches would require teardown of the + pixmap, scroll context, zoom factor, and the SAM/DINO/edit-mode + sub-states that cut across tool selection. +3. **Per-tool handler classes** with a thin dispatcher on the widget. + Plain Python objects (not QObjects); the widget keeps a + `_tools: dict[str, ToolHandler]` and routes events to + `active_tool_handler`. Tools emit through the widget's existing + Phase 6 signals. + +**Decision**: Option 3. Each tool becomes a subclass of `ToolHandler` +in `widgets/tools/`. The contract: + +- Event hooks return `True` when consumed: `on_mouse_press`, + `on_mouse_move`, `on_mouse_release`, `on_double_click`, `on_enter`, + `on_escape`. +- `paint_overlay(painter)` renders in-progress state (paint mask, + eraser mask, polygon-in-progress, rectangle preview). +- `has_unsaved_state()` / `commit()` / `discard()` participate in + the widget's `check_unsaved_changes()` dialog. +- `deactivate()` runs when the user switches away from this tool; + default is no-op (matches the pre-Phase-7 "silently drop temp state + mid-stroke" behaviour). + +**Deliberate non-decision: state ownership.** Tool handlers contain +only *behaviour*; the temp-state fields (`current_rectangle`, +`current_annotation`, `temp_paint_mask`, `temp_eraser_mask`, +`drawing_polygon`, `drawing_rectangle`, `is_painting`, `is_erasing`) +remain on `ImageLabel`. Reason: `AnnotationController.finish_rectangle` +and `finish_polygon` (Phase 5a) read `mw.image_label.current_rectangle` +and `mw.image_label.current_annotation` directly. Moving the state +onto the handlers would have required a parallel controller refactor. +Handlers mutate `self.label.X` for those fields; pure-tool state +(e.g. future tool-internal counters) can live on the handler. See +the architectural-smell note below. + +**What stays on `ImageLabel` (intentional non-extraction)**: + +- Navigation (zoom, pan, offset, scaled pixmap) — cross-cutting. +- SAM bbox / points state — activates from any tool via the SAM-box / + SAM-points toggles, cuts across the main tools. +- Polygon edit mode (`editing_polygon`, `handle_editing_click`, + `handle_editing_move`, `draw_editing_polygon`) — modal state + orthogonal to tool selection; sets `current_tool = None` while + active. Promoting this to a handler would tangle the modal flow. +- DINO `temp_annotations` + `accept_temp_annotations` — + cross-cutting; already touched by ADR-015's event filter. +- `draw_tool_size_indicator` — small enough that splitting it across + paint/eraser handlers buys nothing. + +**`paintEvent` overlay pass**. Iterates **all** handlers' +`paint_overlay()`, not just the active one. Reason: pre-Phase-7 the +temp paint mask, temp eraser mask, and polygon-in-progress rendered +whenever their state was populated, regardless of `current_tool`. +Each handler's `paint_overlay` short-circuits when its state is empty, +so the iteration is cheap and the user can switch tools mid-stroke +without losing visual feedback. + +**Consequences**: +- ✅ `image_label.py` shrinks from 1,239 to ~960 LOC. Adding a new + tool now means: create one file in `widgets/tools/`, register it + in `_tools`, wire a button in `annotator_window.py`. No event-method + edits. +- ✅ Each tool can be unit-tested by instantiating the handler with + a stub `label` carrying signals and `_ctx` — no controller fixture + needed. +- ✅ Phase 6's signal contract (ADR-018) is unchanged: handlers emit + via `self.label..emit(...)`. +- ⚠️ **State leak across the widget boundary.** Handlers reach into + `self.label.X` for state. The contract drifts toward "handler is a + namespaced function bag." Mitigation: revisit if/when controllers + are updated to ask the handler (e.g. `polygon_tool.points()`) + instead of reading the widget's field. +- ⚠️ `deactivate()` is no-op by default. If you make it + `discard()` later, audit the three call sites that still write + `current_tool = None` directly (`ImageLabel.clear()`, + `ImageLabel.start_polygon_edit`, three locations in + `SAMController`) — they bypass `set_active_tool` and therefore the + hook. +- ⚠️ `check_unsaved_changes` now iterates all handlers, not just + paint/eraser. Polygon participates via `has_unsaved_state() = len > 2` + (sub-3-point polygons are silently discarded on switch — they + can't be saved anyway). + +**Pattern for adding a new mouse-driven tool**: + +1. Create `widgets/tools/foo_tool.py` with `class FooTool(ToolHandler):`. +2. Override the event hooks you need; emit via + `self.label..emit(...)` and read via `self.label._ctx.X()`. +3. Register in `ImageLabel.__init__`'s `_tools = {…, "foo": FooTool(self)}`. +4. Add a button in `ui/sidebar.py:build_sidebar` next to the existing + tool buttons, register it in `window.tool_group`, and connect + `clicked` to `window.toggle_tool`. Then add a branch in + `ImageAnnotator.toggle_tool` that calls + `self.image_label.set_active_tool("foo")` for that button (since + Phase 8 the UI building lives in `ui/sidebar.py`, not on the + orchestrator). + +**Related**: +- Implementation: `widgets/tools/base.py`, + `widgets/tools/{rectangle,polygon,paint,eraser}_tool.py`, + `widgets/image_label.py:set_active_tool`, + `widgets/image_label.py:paintEvent` overlay-iteration block. +- Predecessor: ADR-018 (Phase 6 signal decoupling) made this safe by + removing the `main_window` reference; handlers don't need an + orchestrator handle. +- Cross-cuts: documented in + [Cross-cutting Concepts → Canvas Decoupling](08_crosscutting_concepts.md#canvas-decoupling--signals--canvascontext) + (extended to describe the tool dispatcher). + +--- + +## ADR-016: Static AST Inspection of Inline Imports as Quality Gate for Refactor PRs + +**Status**: Accepted + +**Context**: During Phase 1 of the modular refactoring (2025-06-10), 25 modules were moved into `core/`, `dialogs/`, `inference/`, `io/`, `ui/`, `widgets/` subpackages. The smoke tests (`test_smoke.py`) verified that every module could be imported at top-level. All 30 smoke tests passed. However, four stale **inline imports** inside method bodies were missed: + +```python +# annotator_window.py — inside function bodies, NOT top-level +from .dino_utils import GDINO_MODEL_PATHS # moved to .inference.dino_utils +from .annotation_statistics import ... # moved to .dialogs.annotation_statistics +from .project_details import ... # moved to .dialogs.project_details +from .project_search import ... # moved to .dialogs.project_search +``` + +These imports were deferred until the specific UI action triggered the function (e.g. picking a DINO model from the dropdown). The smoke tests, which only import modules, never execute function bodies and therefore never resolved the inline `from .dino_utils` reference. The bug surfaced only in manual QA when selecting a DINO model. + +**Decision**: Add a static AST analysis test (`test_annotator_window_inline_imports_are_resolvable`) that parses `annotator_window.py`, extracts every bare relative import (`from .module`), and asserts the module still exists in the package root. The test fails with the exact line number for any stale import, preventing silent runtime-only regressions from reaching CI. + +**Rationale**: +- Top-level import rewrites are mechanical and easy to verify via module import. +- Inline imports inside method bodies are invisible to module-level import tests. +- Manual QA is the fallback for behaviour, not for mechanical import correctness. +- AST inspection is cheap (~1 ms), zero false positives for this codebase, and runs in every CI build along with smoke tests. + +**Consequences**: +- 🛑 Regression now impossible: the 30th smoke test would have failed the PR before merge. +- 🔧 No runtime cost — purely static analysis. +- ⚠️ Only covers `annotator_window.py`. If other files use the same inline-import pattern, the test should be generalized (or each file that contains inline imports gets its own AST check). In this codebase, `annotator_window.py` is the only file with significant inline imports. +- ⚠️ Doesn't catch dynamic imports (`__import__`, `importlib.import_module`), but we don't use those. + +**Related**: +- Implementation: `tests/integration/test_smoke.py` (`test_annotator_window_inline_imports_are_resolvable`). +- Cross-cuts: `CLAUDE.md` "Testing Checklist" updated to reference this test as a mandatory CI gate. + +--- + +## ADR-017: Eager Torch Import in `main.py` before `QApplication` Creation + +**Status**: Accepted + +**Context**: ADR-011 and ADR-014 both discussed a DLL load-order conflict on Windows when PyQt and PyTorch share a process. The conflict was first observed with PyQt5 (ADR-011) and later claimed to be resolved by migrating to PyQt6 (ADR-014): + +> "Qt6's packaging reshuffle eliminates it." — ADR-014 +> +> "...verified by `tools/check_pyqt6_torch_coexistence.py` importing PyQt6 → torch → transformers → ultralytics cleanly in one process..." — ADR-013 + +This claim was based on testing at the time, but it tested the **wrong order**: importing PyQt6 *packages* before torch works even in Qt5. The actual failure mode is triggered only when Qt's **native platform plugin** is loaded, which happens inside `QApplication.__init__()`, not at `import PyQt6`. The earlier verification script did not call `QApplication()`, so it never exercised the real failure path. + +Real-world testing with `torch 2.11.0+cu126 + PyQt6 6.10.2 + Python 3.14.2` on Windows 11 shows the conflict **still surfaces** when Qt's platform DLLs (e.g. `qwindows.dll`) are loaded BEFORE torch's `c10.dll`. The error is `OSError: [WinError 1114] A dynamic link library (DLL) initialization routine failed`. + +**Root cause analysis**: Qt and torch both ship native DLLs that load into the same process. On Windows the DLL load order and address-space layout matter. When Qt's platform plugin claims certain memory slots or loads conflicting CRT libraries before torch does, torch's `c10.dll` init fails. The conflict is NOT between PyQt5 and torch per se — it is between Qt platform plugins and torch, regardless of whether the binding is PyQt5 or PyQt6. + +**Decision**: Two complementary changes: + +1. In `main.py`, eagerly `import torch` (with an `ImportError` fallback) **before** importing `QApplication` and creating the app. This ensures torch's DLLs claim their slot first. +2. In `__init__.py`, replace eager toplevel imports of `annotator_window`, `image_label`, and `sam_utils` with a `__getattr__`-based lazy loader. The package init runs before `main.py` when launched via the `sreeni` console script (`digitalsreeni_image_annotator.main:main`). If `__init__.py` eagerly imports modules that transitively import PyQt6 (e.g. `annotator_window`), Qt loads first and the `import torch` in `main.py` crashes with the same WinError 1114. Lazy loading defers the Qt import until someone actually accesses `pkg.ImageAnnotator`, which only happens after the torch-first guard has run. + +**Verification**: +- `tools/check_pyqt6_torch_coexistence.py` now tests both orders: + 1. `torch` → `QApplication` (production order) — **PASS**. + 2. `QApplication` → `torch` (the claimed-safe order) — **FAIL** on Windows with torch 2.11.0. +- Exit code 0 means production order works; exit code 1 means even torch-first fails and subprocess isolation (ADR-011) must be restored. +- Smoke test `test_public_api_exports` passes: `__getattr__` correctly resolves all five public names. + +**Consequences**: +- ✅ SAM and DINO model loading works on Windows + Python 3.14 + PyQt6 without subprocess overhead. +- ✅ App startup cost is negligible — torch import adds ~0.5-1 s before the splash window appears, which is acceptable for a desktop annotation tool. +- ⚠️ `tests/integration/test_smoke.py` cannot import `main.py` because the pytest-qt test process already has Qt loaded; importing torch afterward triggers the same WinError 1114. `main.py` is therefore **excluded** from the module-import list and is validated by CLI smoke tests instead. +- ⚠️ Future Qt upgrades may change DLL packaging and make this unnecessary, but `check_pyqt6_torch_coexistence.py` will detect that automatically. +- ⚠️ Any new public name added to `__init__.py` must also be wired through `__getattr__` or it will transitively pull in PyQt6 and break the torch-first guard. + +**Related**: +- Supersedes (in spirit): ADR-014's claim that PyQt6 eliminates the conflict. +- Unblocks: ADR-013 in-process inference on the affected Windows environment. +- Implementation: `src/digitalsreeni_image_annotator/main.py`. +- Gate: `tools/check_pyqt6_torch_coexistence.py`. + +## ADR-020: App-Global UI Preferences via QSettings; Canvas Overlays Scale with `ui_font_pt` + +**Status**: Accepted + +**Context**: The low-vision accessibility feature (continuous UI font +zoom, 8–24pt) needed (a) the chosen size to survive app restarts and +(b) canvas overlay elements — annotation labels, SAM point markers, +pen widths — to grow with the setting. UI preferences were previously +reset on every launch, and the `.iap` project file was the only +persistence mechanism in the app. + +**Decision**: +1. Introduce the app's first QSettings usage + (`QSettings("DigitalSreeni", "ImageAnnotator")`, module + `app_settings.py`) for `ui/font_pt` and `ui/dark_mode`. These are + per-user preferences, so they do **not** go into the `.iap` file — + a project opened by a different user must not impose a font size. +2. A single integer `ImageAnnotator.ui_font_pt` is the source of + truth; the named presets and the step shortcuts both funnel + through `theme.set_font_pt` (clamp → apply → persist → menu sync). +3. Canvas overlay sizes derive from `ui_scale = ui_font_pt / 10.0` + (10 = the legacy default, so the default renders pixel-identical + to the pre-feature code). `ImageLabel` receives the value via a + plain setter from `apply_theme_and_font`, not via CanvasContext — + consistent with the existing direct `image_label.setFont` call, + and avoids a paint-before-context-set window. + +**Alternatives considered**: +- Storing prefs in the `.iap` file — rejected: project files are + shared artifacts; accessibility settings are personal. +- Templating the static stylesheets per font size — rejected: + appended QSS override rules (later rules win at equal specificity) + achieve the same with zero churn in the two stylesheet strings. + +**Consequences**: +- ✅ Font size and dark mode persist across restarts. +- ✅ Tests stay hermetic: every `app_settings` function accepts an + injectable `QSettings` (INI temp file) instance. +- ⚠️ Any new scalable UI metric should use `ImageLabel._pen_w` / + `_overlay_font` or the appended-override block in + `theme.apply_theme_and_font` — hardcoded px values won't follow the + setting (see "UI Font Zoom" in `08_crosscutting_concepts.md`). +- ⚠️ Deliberately-compact widgets (DINO threshold table / phrase + panel) don't hardcode their small font inline; the appended block + owns it via type/objectName selectors (`ClassThresholdTable`, + `PhraseEditorPanel …`, `#dino_phrase_hint`) so compact ≠ unscaled. + Follow that pattern for new compact widgets. +- ⚠️ Known debt: `dino_merge_dialog.py` still carries hardcoded + `font-size:Npx` tokens and a `color:#444` dark-mode contrast issue, + so it doesn't scale. Tracked, not an oversight; fix when that + dialog is next touched. + +--- + ## Decisions Under Consideration ### Consider pytest-qt for Utility Testing diff --git a/docs/11_risks_and_technical_debt.md b/docs/11_risks_and_technical_debt.md index a2df15d..0a999a8 100644 --- a/docs/11_risks_and_technical_debt.md +++ b/docs/11_risks_and_technical_debt.md @@ -85,26 +85,52 @@ ## Technical Debt -### No Automated Tests +### Low Test Coverage of Interactive Paths -**Debt Level**: High +**Debt Level**: Medium -**Description**: Zero unit tests, integration tests, or UI tests +**Description**: A pytest + pytest-qt suite of 94 tests now exists +(boot smoke, coordinate conversions, export-format round-trips, +utility functions). Coverage is ~15% by line — the gap is the +canvas event flow (mouse events → tool handler → signal emission → +controller slot) and the SAM/DINO/YOLO inference paths. **Impact**: -- High risk of regressions -- Refactoring is dangerous -- Manual testing burden -- Slow development velocity +- Phase 6/7/8 refactors had to lean on manual QA checklists for the + canvas flow because no automated test exercises it end-to-end. +- Inference paths are exercised only via the smoke boot, not under + real model loads (those would slow CI prohibitively). -**Effort to Resolve**: High (months) +**Effort to Resolve**: Medium **Priority**: Medium **Plan**: -1. Add unit tests for utility functions first (low-hanging fruit) -2. Add integration tests for export/import -3. Consider pytest-qt for critical UI flows +1. Per-tool unit tests under `widgets/tools/` — each handler can be + tested by instantiating with a stub `label` carrying signals + and a fake `CanvasContext`, then feeding `QMouseEvent`s. +2. Integration test that loads a tiny project, draws a polygon, + asserts the `.iap` round-trip restores state. +3. Mock SAMUtils / DINOUtils inference returns to exercise the + controller signal paths without needing model weights. + +--- + +### Limited Coverage — Inline Imports Not Caught by Module Tests + +**Debt Level**: Medium + +**Description**: Smoke tests verify modules import cleanly at top-level, but inline `from .module` imports inside function bodies are deferred and only fail when the function is called. Phase 1 modular refactoring moved 25 modules; four stale inline imports (`from .dino_utils`, `.annotation_statistics`, `.project_details`, `.project_search`) were missed and only surfaced in manual QA. + +**Impact**: +- Subpackage refactor PRs require functional QA paths (not just module import CI) to verify inline imports +- Silent regressions until user clicks the specific button/dialog that triggers the stale import + +**Mitigation**: +- Added AST-based static smoke test (ADR-016) that parses `annotator_window.py` and asserts every bare relative import resolves to an existing module in the package root +- The test now catches inline import drift in CI before merge + +**Future Action**: Extend the AST check to any other file that uses inline deferred imports (currently only `annotator_window.py` has them). --- @@ -159,29 +185,18 @@ return None --- -### Tight Coupling Between ImageAnnotator and ImageLabel - -**Debt Level**: Medium +### Tight Coupling Between ImageAnnotator and ImageLabel — Resolved (Phase 6) -**Description**: ImageLabel has `main_window` reference and calls methods directly +**Status**: Resolved. `ImageLabel.main_window` and `set_main_window()` +were removed; every write path is now a `pyqtSignal` emission and every +read goes through a narrow `CanvasContext` accessor. -**Examples**: -```python -# In ImageLabel -self.main_window.add_annotation(polygon) -self.main_window.update_annotation_list() -``` - -**Impact**: -- Hard to test ImageLabel independently -- Changes ripple between classes -- Circular dependency concerns - -**Effort to Resolve**: Medium (refactor to signals/slots) - -**Priority**: Low +**Pattern**: see `widgets/canvas_context.py` and +`ImageAnnotator._connect_image_label_signals`. ImageLabel emits ~20 +signals (annotation lifecycle, SAM, class, tool/UI state, navigation); +the orchestrator wires each to the matching controller slot. -**Plan**: Refactor to Qt signals for loose coupling +**ADR**: see ADR-018 in `09_architecture_decisions.md`. --- diff --git a/docs/12_glossary.md b/docs/12_glossary.md index 3cc8309..03df404 100644 --- a/docs/12_glossary.md +++ b/docs/12_glossary.md @@ -68,6 +68,21 @@ You Only Look Once - object detection format. Uses `.txt` files with normalized ### Z-Stack A series of 2D images taken at different focal depths (Z positions), used in microscopy to capture 3D structure. +### CanvasContext +Narrow read-only view of main-window state exposed to `ImageLabel`. Introduced in Phase 6 (ADR-018) to replace the old `image_label.main_window` back-reference. Method-style accessors (`paint_brush_size()`, `current_class()`, `is_class_visible(name)`, `scroll_area()`, …) so future state migrations can re-route reads without touching the widget. Constructed once in `ImageAnnotator.__init__` and passed via `image_label.set_context(ctx)`. + +### Controller +Architectural pattern used across `controllers/*`. A controller is a `QObject` subclass holding `self.mw = main_window` that owns a single responsibility cluster carved out of the old monolithic `ImageAnnotator` — project I/O, image loading, annotations, classes, SAM, DINO, or YOLO. The orchestrator delegates to the controllers via thin pass-through methods, keeping external entry points (menu actions, signal connections) stable across refactors. Seven controllers exist as of Phase 8. + +### ToolHandler +Base class for per-tool mouse / key / paint behaviour inside `ImageLabel`. Plain Python object (not a `QObject`); holds a back-reference to the widget for signal emission and `CanvasContext` reads. Subclasses (`RectangleTool`, `PolygonTool`, `PaintBrushTool`, `EraserTool`) live in `widgets/tools/` and are dispatched to by `ImageLabel.active_tool_handler`. Introduced in Phase 7 (ADR-019). + +### Tool subclasses (`RectangleTool`, `PolygonTool`, `PaintBrushTool`, `EraserTool`) +Concrete `ToolHandler` implementations, one per mouse-driven annotation tool. Each overrides the event hooks defined on the base class (`on_mouse_press`, `on_mouse_move`, `on_mouse_release`, `on_double_click`, `on_enter`, `on_escape`, `paint_overlay`, `deactivate`) and participates in the `has_unsaved_state()` / `commit()` / `discard()` contract used by the `check_unsaved_changes` dialog. + +### UI builders (`build_menu_bar`, `build_sidebar`, `build_image_area`, `build_image_list`) +Functions under `ui/` that construct widget trees at startup. Each takes the `ImageAnnotator` instance as `window`, attaches widgets as `window.X = QWidget(...)` so other modules can read them, and wires signals to `window.` delegate methods. Replaced the equivalent `setup_*` methods on `ImageAnnotator` in Phase 8. + ## Acronyms | Acronym | Full Term | @@ -110,12 +125,24 @@ A series of 2D images taken at different focal depths (Z positions), used in mic | Class | Module | Description | |-------|--------|-------------| -| `ImageAnnotator` | annotator_window.py | Main application window (QMainWindow) | -| `ImageLabel` | image_label.py | Custom QLabel for image display and interaction | -| `SAMUtils` | sam_utils.py | SAM model loading and inference | -| `DimensionDialog` | annotator_window.py | Dialog for assigning dimensions to stacks | -| `TrainingThread` | annotator_window.py | Background thread for YOLO training | -| `YOLOTrainer` | yolo_trainer.py | YOLO model training and prediction | +| `ImageAnnotator` | annotator_window.py | Thin orchestrator (QMainWindow). Holds controllers, wires signals, delegates almost everything. | +| `ImageLabel` | widgets/image_label.py | Canvas widget — display, zoom/pan, event dispatch to tool handlers. | +| `CanvasContext` | widgets/canvas_context.py | Narrow read view of main-window state for ImageLabel (ADR-018). | +| `ToolHandler` | widgets/tools/base.py | Base class for per-tool mouse/key handlers (ADR-019). | +| `RectangleTool` / `PolygonTool` / `PaintBrushTool` / `EraserTool` | widgets/tools/ | Per-tool handler subclasses. | +| `ProjectController` | controllers/project_controller.py | `.iap` save/load, auto-save, `is_loading_project` guard. | +| `ImageController` | controllers/image_controller.py | TIFF/CZI loading, multi-dim slicing, image/slice switching. | +| `AnnotationController` | controllers/annotation_controller.py | Annotation CRUD, sort, edit-mode, finish_polygon/rectangle. | +| `ClassController` | controllers/class_controller.py | Class add/delete/rename/colour/visibility. | +| `SAMController` | controllers/sam_controller.py | SAM model picker + debounce + ADR-013 re-entrancy guard. | +| `DINOController` | controllers/dino_controller.py | DINO single + batch detection, batch review, temp-class workflow. | +| `YOLOController` | controllers/yolo_controller.py | YOLO training menu + prediction wiring. | +| `SAMUtils` | inference/sam_utils.py | SAM model loading and inference. | +| `DINOUtils` | inference/dino_utils.py | Grounding-DINO model loading and inference. | +| `DimensionDialog` | controllers/image_controller.py | Dialog for assigning dimensions to multi-dim stacks. | +| `TrainingThread` | controllers/yolo_controller.py | Background thread for YOLO training. | +| `YOLOTrainer` | dialogs/yolo_trainer.py | YOLO model training and prediction dialog. | +| `DINOReviewEventFilter` | controllers/dino_controller.py | App-wide Enter/Escape filter during DINO review (ADR-015). | ## Data Structure Keys diff --git a/src/digitalsreeni_image_annotator/__init__.py b/src/digitalsreeni_image_annotator/__init__.py index 452fbaf..252927b 100644 --- a/src/digitalsreeni_image_annotator/__init__.py +++ b/src/digitalsreeni_image_annotator/__init__.py @@ -10,9 +10,35 @@ __version__ = "0.9.0" __author__ = "Dr. Sreenivas Bhattiprolu" -from .annotator_window import ImageAnnotator -from .image_label import ImageLabel -from .utils import calculate_area, calculate_bbox -from .sam_utils import SAMUtils +# Lazy loading — importing this package must NOT pull in PyQt6, because +# main.py needs to import torch BEFORE Qt loads (ADR-017). The modules +# below transitively import PyQt6. Deferring them to __getattr__ keeps +# ``import digitalsreeni_image_annotator`` cheap and Qt-free. +__all__ = [ + "ImageAnnotator", + "ImageLabel", + "calculate_area", + "calculate_bbox", + "SAMUtils", +] -__all__ = ['ImageAnnotator', 'ImageLabel', 'calculate_area', 'calculate_bbox', 'SAMUtils'] # Add 'SAMUtils' to this list \ No newline at end of file + +def __getattr__(name): + if name == "ImageAnnotator": + from .annotator_window import ImageAnnotator + return ImageAnnotator + if name == "ImageLabel": + from .widgets.image_label import ImageLabel + return ImageLabel + if name == "SAMUtils": + from .inference.sam_utils import SAMUtils + return SAMUtils + if name == "calculate_area": + from .utils import calculate_area + return calculate_area + if name == "calculate_bbox": + from .utils import calculate_bbox + return calculate_bbox + raise AttributeError( + f"module {__name__!r} has no attribute {name!r}" + ) diff --git a/src/digitalsreeni_image_annotator/annotator_window.py b/src/digitalsreeni_image_annotator/annotator_window.py index ffc94ee..aa6ad0e 100644 --- a/src/digitalsreeni_image_annotator/annotator_window.py +++ b/src/digitalsreeni_image_annotator/annotator_window.py @@ -1,194 +1,54 @@ -import copy -import json import os -import shutil -import traceback import warnings -from datetime import datetime - -import cv2 -import numpy as np -import shapely -from czifile import CziFile -from PyQt6.QtCore import QEvent, QObject, Qt, QThread, QTimer, pyqtSignal -from PyQt6.QtGui import ( - QAction, - QColor, - QFont, - QIcon, - QImage, - QKeySequence, - QPalette, - QPixmap, - QShortcut, -) + +from PyQt6.QtCore import Qt, QTimer +from PyQt6.QtGui import QPixmap from PyQt6.QtWidgets import ( - QAbstractItemView, QApplication, - QButtonGroup, - QCheckBox, - QColorDialog, - QComboBox, QDialog, - QDialogButtonBox, - QDoubleSpinBox, QFileDialog, - QGridLayout, QHBoxLayout, - QInputDialog, - QLabel, QLineEdit, - QListWidget, - QListWidgetItem, QMainWindow, QMenu, QMessageBox, - QProgressBar, - QProgressDialog, - QPushButton, - QScrollArea, - QSlider, QTextEdit, - QVBoxLayout, QWidget, ) -from shapely.geometry import MultiPolygon, Point, Polygon -from shapely.ops import unary_union -from shapely.validation import make_valid -from tifffile import TiffFile - -from .annotation_statistics import show_annotation_statistics -from .coco_json_combiner import show_coco_json_combiner -from .dino_phrase_editor import ClassThresholdTable, PhraseEditorPanel -from .dino_utils import DINOUtils -from .dataset_splitter import DatasetSplitterTool -from .default_stylesheet import default_stylesheet -from .dicom_converter import DicomConverter -from .dino_merge_dialog import show_dino_merge_dialog -from .export_formats import ( - export_coco_json, - export_labeled_images, - export_pascal_voc_bbox, - export_pascal_voc_both, - export_semantic_labels, - export_yolo_v4, - export_yolo_v5plus, -) -from .help_window import HelpWindow -from .image_augmenter import show_image_augmenter -from .image_label import ImageLabel -from .image_patcher import show_image_patcher -from .import_formats import ( - import_coco_json, - import_yolo_v4, - import_yolo_v5plus, - process_import_format, -) -from .sam_utils import InferenceBusyError, SAMUtils -from .slice_registration import SliceRegistrationTool -from .snake_game import SnakeGame -from .soft_dark_stylesheet import soft_dark_stylesheet -from .stack_interpolator import StackInterpolator -from .stack_to_slices import show_stack_to_slices -from .utils import calculate_area, calculate_bbox -from .yolo_trainer import LoadPredictionModelDialog, TrainingInfoDialog, YOLOTrainer - -warnings.filterwarnings("ignore", category=UserWarning) - - -class TrainingThread(QThread): - progress_update = pyqtSignal(str) - finished = pyqtSignal(object) - def __init__(self, yolo_trainer, epochs, imgsz): - super().__init__() - self.yolo_trainer = yolo_trainer - self.epochs = epochs - self.imgsz = imgsz +from .app_settings import load_ui_prefs +from .controllers import io_controller +from .controllers.annotation_controller import AnnotationController +from .controllers.class_controller import ClassController +from .controllers.dino_controller import DINOController +from .controllers.image_controller import ImageController +from .controllers.project_controller import ProjectController +from .controllers.sam_controller import SAMController +from .controllers.yolo_controller import YOLOController +from .core import image_utils +from .ui import theme +from .ui.menu_bar import build_menu_bar +from .ui.shortcuts import install_event_filters, install_shortcuts +from .ui.sidebar import build_image_area, build_image_list, build_sidebar +from .dialogs.annotation_statistics import show_annotation_statistics +from .dialogs.coco_json_combiner import show_coco_json_combiner +from .dialogs.dino_phrase_editor import ClassThresholdTable, PhraseEditorPanel +from .inference.dino_utils import DINOUtils, GDINO_MODEL_PATHS +from .dialogs.dataset_splitter import DatasetSplitterTool +from .dialogs.dicom_converter import DicomConverter +from .dialogs.dino_merge_dialog import show_dino_merge_dialog +from .dialogs.help_window import HelpWindow +from .dialogs.image_augmenter import show_image_augmenter +from .widgets.canvas_context import CanvasContext +from .widgets.image_label import ImageLabel +from .dialogs.image_patcher import show_image_patcher +from .inference.sam_utils import SAMUtils +from .dialogs.slice_registration import SliceRegistrationTool +from .dialogs.snake_game import SnakeGame +from .dialogs.stack_interpolator import StackInterpolator +from .dialogs.stack_to_slices import show_stack_to_slices - def run(self): - try: - results = self.yolo_trainer.train_model( - epochs=self.epochs, imgsz=self.imgsz - ) - self.finished.emit(results) - except Exception as e: - self.finished.emit(str(e)) - - -class DimensionDialog(QDialog): - def __init__(self, shape, file_name, parent=None, default_dimensions=None): - super().__init__(parent) - self.setWindowTitle("Assign Dimensions") - layout = QVBoxLayout(self) - - # Add file name label - file_name_label = QLabel(f"File: {file_name}") - file_name_label.setWordWrap(True) - layout.addWidget(file_name_label) - - # Add dimension assignment widgets - dim_widget = QWidget() - dim_layout = QGridLayout(dim_widget) - self.combos = [] - self.shape = shape - dimensions = ["T", "Z", "C", "S", "H", "W"] - for i, dim in enumerate(shape): - dim_layout.addWidget(QLabel(f"Dimension {i} (size {dim}):"), i, 0) - combo = QComboBox() - combo.addItems(dimensions) - if default_dimensions and i < len(default_dimensions): - combo.setCurrentText(default_dimensions[i]) - dim_layout.addWidget(combo, i, 1) - self.combos.append(combo) - layout.addWidget(dim_widget) - - self.button = QPushButton("OK") - self.button.clicked.connect(self.accept) - layout.addWidget(self.button) - - self.setMinimumWidth(300) - - def get_dimensions(self): - return [combo.currentText() for combo in self.combos] - - -class _DINOReviewEventFilter(QObject): - """Application-wide event filter that lets Enter / Escape accept or - reject pending DINO temp_annotations regardless of which widget has - focus. Without this, clicking a slice/image entry in a list moves - focus there and Enter is consumed by the list's itemActivated - handler before it can reach ImageLabel.keyPressEvent. - - Suppressed when a modal dialog is active or focus is on a text-input - widget so we don't break dialog default-button behaviour or - in-cell editing. - """ - - def __init__(self, main_window: "ImageAnnotator"): - super().__init__(main_window) - self.main_window = main_window - - def eventFilter(self, obj, event): - if event.type() != QEvent.Type.KeyPress: - return False - key = event.key() - if key not in (Qt.Key.Key_Return, Qt.Key.Key_Enter, Qt.Key.Key_Escape): - return False - app = QApplication.instance() - if app is None or app.activeModalWidget() is not None: - return False - focused = app.focusWidget() - if isinstance(focused, (QLineEdit, QTextEdit)): - return False - temp = self.main_window.image_label.temp_annotations - if not temp or not any(a.get("source") == "dino" for a in temp): - return False - if key in (Qt.Key.Key_Return, Qt.Key.Key_Enter): - self.main_window.accept_dino_results() - else: - self.main_window.reject_dino_results() - return True +warnings.filterwarnings("ignore", category=UserWarning) class ImageAnnotator(QMainWindow): @@ -198,23 +58,20 @@ def __init__(self): self.is_loading_project = False self.backup_project_path = None + self.project_controller = ProjectController(self) + self.image_controller = ImageController(self) + self.setWindowTitle("Image Annotator") self.setGeometry(100, 100, 1400, 800) - self.central_widget = QWidget() - self.setCentralWidget(self.central_widget) - self.layout = QHBoxLayout(self.central_widget) - - self.create_menu_bar() - - # Initialize image_label early + # Initialize image_label early — setup_ui's sidebar/image-area + # builders expect it to exist. self.image_label = ImageLabel() self.image_label.sam_box_active = False self.image_label.sam_points_active = False self.image_label.sam_positive_points = [] self.image_label.sam_negative_points = [] - self.image_label.set_main_window(self) # Initialize attributes self.current_image = None @@ -259,29 +116,34 @@ def __init__(self): # pumping inside _run_sync. See apply_sam_prediction(). self._sam_inference_in_flight = False - # Create sam_magic_wand_button - self.sam_magic_wand_button = QPushButton("Magic Wand") - self.sam_magic_wand_button.setCheckable(True) - self.sam_magic_wand_button.setEnabled(False) # Initially disable the button - - # Initialize tool group - self.tool_group = QButtonGroup(self) - self.tool_group.setExclusive(False) - - # Font size control + self.sam_controller = SAMController(self) + self.dino_controller = DINOController(self) + self.yolo_controller = YOLOController(self) + self.annotation_controller = AnnotationController(self) + self.class_controller = ClassController(self) + + # CanvasContext gives ImageLabel a narrow read view of main-window + # state. All write paths from the canvas leave as Qt signals + # connected to controllers below. + self.image_label.set_context(CanvasContext(self)) + self._connect_image_label_signals() + + # Font size control. Presets are named entry points into the + # continuous 8-24pt range; `ui_font_pt` (int) is the single + # source of truth — see theme.set_font_pt. self.font_sizes = { "Small": 8, "Medium": 10, "Large": 12, "XL": 14, "XXL": 16, - } # Also, add the options in create_menu_bar method - self.current_font_size = "Medium" + } # When adding a new option here, also add it to the Font Size submenu in ui/menu_bar.build_menu_bar. - # Dark mode control. Default on — matches the look most users - # expect from a 2025-era desktop annotation tool; toggle with - # Settings → Toggle Dark Mode (Ctrl+D). - self.dark_mode = True + # UI prefs persist app-globally via QSettings (not in the .iap + # project file). Dark mode defaults on — matches the look most + # users expect from a 2025-era desktop annotation tool; toggle + # with Settings → Toggle Dark Mode (Ctrl+D). + self.ui_font_pt, self.dark_mode = load_ui_prefs() # Default annotations sorting self.current_sort_method = "class" # Default sorting method @@ -297,691 +159,143 @@ def __init__(self): # Apply theme and font (this includes stylesheet and font size application) self.apply_theme_and_font() - # Connect sam_magic_wand_button - self.sam_magic_wand_button.clicked.connect(self.toggle_tool) - - self.class_list.itemChanged.connect(self.toggle_class_visibility) - # YOLO Trainer self.yolo_trainer = None self.setup_yolo_menu() - # F2 → Snake game (Easter egg). Registered as a global QShortcut - # so it fires regardless of which widget has focus — putting it - # in keyPressEvent didn't work because QTableWidget (DINO - # threshold table) and other focusable children consume F2 - # before it bubbles up to the main window. - self._snake_shortcut = QShortcut(QKeySequence("F2"), self) - self._snake_shortcut.setContext(Qt.ShortcutContext.ApplicationShortcut) - self._snake_shortcut.activated.connect(self.launch_snake_game) - - # Enter/Escape for DINO temp_annotations need to work even when - # focus is on slice_list / image_list / a button — none of which - # forward the key to ImageLabel.keyPressEvent. Application-wide - # event filter intercepts these keys but only when DINO results - # are pending review, and skips modal dialogs + text inputs. - self._dino_review_filter = _DINOReviewEventFilter(self) - QApplication.instance().installEventFilter(self._dino_review_filter) - - # Start in maximized mode - self.showMaximized() + install_shortcuts(self) + install_event_filters(self) # Start in maximized mode self.showMaximized() + def _connect_image_label_signals(self): + """Wire ImageLabel events to controller slots. ImageLabel does not + hold a main_window reference any more — every write path is a + Qt signal connected here.""" + il = self.image_label + ac = self.annotation_controller + cc = self.class_controller + sc = self.sam_controller + + # Annotation lifecycle + il.annotationCommitted.connect(ac.add_annotation_to_list) + il.annotationsBatchSaved.connect(self._on_annotations_batch_saved) + il.annotationsReplaced.connect(ac.replace_annotations) + il.annotationListUpdateRequested.connect(ac.update_annotation_list) + il.annotationSelected.connect(ac.select_annotation_in_list) + il.deleteSelectionRequested.connect(ac.delete_selected_annotations) + il.finishPolygonRequested.connect(ac.finish_polygon) + il.finishRectangleRequested.connect(ac.finish_rectangle) + + # Class + il.classRequested.connect(cc.add_class) + + # SAM + il.samPredictionRequested.connect(sc.schedule_sam_prediction) + il.samPredictionApplyRequested.connect(sc.apply_sam_prediction) + il.samPredictionAccepted.connect(sc.accept_sam_prediction) + il.samPointsCleared.connect(sc.cancel_sam_debounce) + + # Tool / UI state + il.enableToolsRequested.connect(self.enable_tools) + il.disableToolsRequested.connect(self.disable_tools) + il.resetToolButtonsRequested.connect(self.reset_tool_buttons) + il.toolSizeChanged.connect(self._on_tool_size_changed) + + # Navigation / info + il.zoomInRequested.connect(self.zoom_in) + il.zoomOutRequested.connect(self.zoom_out) + il.imageInfoChanged.connect(self.update_image_info) + + def _on_tool_size_changed(self, tool: str, size: int) -> None: + if tool == "paint": + self.paint_brush_size = size + elif tool == "eraser": + self.eraser_size = size + + def _on_annotations_batch_saved(self) -> None: + self.annotation_controller.save_current_annotations() + self.class_controller.update_slice_list_colors() + def setup_ui(self): - # Initialize the main layout + # Initialize the main layout. tool_group is created inside + # build_sidebar (it needs to register the tool buttons). self.central_widget = QWidget() self.setCentralWidget(self.central_widget) self.layout = QHBoxLayout(self.central_widget) - # Initialize tool group - self.tool_group = QButtonGroup(self) - self.tool_group.setExclusive(False) - - # Setup UI components - self.setup_sidebar() - self.setup_image_area() - self.setup_image_list() + build_menu_bar(self) + build_sidebar(self) + build_image_area(self) + build_image_list(self) self.setup_slice_list() self.update_ui_for_current_tool() def update_window_title(self): - base_title = "Image Annotator" - if hasattr(self, "current_project_file"): - project_name = os.path.basename(self.current_project_file) - project_name = os.path.splitext(project_name)[ - 0 - ] # Remove the file extension - self.setWindowTitle(f"{base_title} - {project_name}") - else: - self.setWindowTitle(base_title) + return self.project_controller.update_window_title() def new_project(self): - self.remove_all_temp_annotations() # Remove temp annotations from the previous project - project_file, _ = QFileDialog.getSaveFileName( - self, "Create New Project", "", "Image Annotator Project (*.iap)" - ) - if project_file: - # Ensure the file has the correct extension - if not project_file.lower().endswith(".iap"): - project_file += ".iap" - - self.current_project_file = project_file - self.current_project_dir = os.path.dirname(project_file) - - # Create the images directory - images_dir = os.path.join(self.current_project_dir, "images") - os.makedirs(images_dir, exist_ok=True) - - # Clear existing data without showing messages - self.clear_all(new_project=True, show_messages=False) - - # Prompt for initial project notes - notes, ok = QInputDialog.getMultiLineText( - self, "Project Notes", "Enter initial project notes:" - ) - if ok: - self.project_notes = notes - else: - self.project_notes = "" - - self.project_creation_date = datetime.now().isoformat() - - # Save the empty project without showing a message - self.save_project(show_message=False) - - # Keep only this message - self.show_info( - "New Project", f"New project created at {self.current_project_file}" - ) - self.initialize_yolo_trainer() - self.update_window_title() + return self.project_controller.new_project() def show_project_search(self): - from .project_search import show_project_search + from .dialogs.project_search import show_project_search show_project_search(self) def open_project(self): - print("open_project method called") # Debug print - self.remove_all_temp_annotations() # Remove temp annotations from the previous project - project_file, _ = QFileDialog.getOpenFileName( - self, "Open Project", "", "Image Annotator Project (*.iap)" - ) - print(f"Selected project file: {project_file}") # Debug print - if project_file: - try: - self.backup_project_before_open(project_file) - self.open_specific_project(project_file) - except Exception as e: - self.restore_project_from_backup() - QMessageBox.critical( - self, - "Error", - f"An error occurred while opening the project: {str(e)}\n" - f"The project file has been restored from backup.", - ) - else: - print("No project file selected") # Debug print + return self.project_controller.open_project() def backup_project_before_open(self, project_file): - """Create a backup of the project file before opening it.""" - import os - import shutil - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - backup_dir = os.path.join(os.path.dirname(project_file), ".project_backups") - os.makedirs(backup_dir, exist_ok=True) - - self.backup_project_path = os.path.join( - backup_dir, f"{os.path.basename(project_file)}.{timestamp}.backup" - ) - shutil.copy2(project_file, self.backup_project_path) + return self.project_controller.backup_project_before_open(project_file) def restore_project_from_backup(self): - """Restore the project file from its backup if available.""" - if self.backup_project_path and os.path.exists(self.backup_project_path): - try: - shutil.copy2(self.backup_project_path, self.current_project_file) - print(f"Project restored from backup: {self.backup_project_path}") - except Exception as e: - print(f"Failed to restore from backup: {str(e)}") + return self.project_controller.restore_project_from_backup() def open_specific_project(self, project_file): - print(f"Opening specific project: {project_file}") # Debug print - if os.path.exists(project_file): - try: - self.is_loading_project = True # Set loading flag - - with open(project_file, "r") as f: - project_data = json.load(f) - - self.clear_all(show_messages=False) - self.current_project_file = project_file - self.current_project_dir = os.path.dirname(project_file) - - # Load project notes and metadata - self.project_notes = project_data.get("notes", "") - self.project_creation_date = project_data.get("creation_date", "") - self.last_modified = project_data.get("last_modified", "") - - # Parse dates - if self.project_creation_date: - self.project_creation_date = datetime.fromisoformat( - self.project_creation_date - ).strftime("%Y-%m-%d %H:%M:%S") - if self.last_modified: - self.last_modified = datetime.fromisoformat( - self.last_modified - ).strftime("%Y-%m-%d %H:%M:%S") - - # Load all data without triggering auto-saves - self.load_project_data(project_data) - - # Now save once after everything is loaded - self.is_loading_project = False # Clear loading flag - # Reveal the phrase editor if any classes exist — the - # per-class selectRow inside add_class was skipped during - # load (see add_class). Selecting row 0 is enough; the - # user can switch rows freely afterwards. - if self.dino_class_table.rowCount() > 0: - self.dino_class_table.selectRow(0) - self.save_project(show_message=False) # Save once after loading - - self.initialize_yolo_trainer() - self.update_window_title() - - print(f"Project opened successfully: {project_file}") - QMessageBox.information( - self, - "Project Opened", - f"Project opened successfully: {os.path.basename(project_file)}", - ) - - except Exception as e: - self.is_loading_project = False # Make sure to clear flag on error - raise e - else: - print(f"Project file not found: {project_file}") - QMessageBox.critical( - self, "Error", f"Project file not found: {project_file}" - ) + return self.project_controller.open_specific_project(project_file) def load_project_data(self, project_data): - """Load project data without triggering auto-saves.""" - # Load classes - self.class_mapping.clear() - self.image_label.class_colors.clear() - for class_info in project_data.get("classes", []): - self.add_class(class_info["name"], QColor(class_info["color"])) - - # Load images - self.all_images = project_data.get("images", []) - self.image_paths = project_data.get("image_paths", {}) - - # Load all annotations first - self.all_annotations.clear() - for image_info in project_data["images"]: - if image_info.get("is_multi_slice", False): - for slice_info in image_info.get("slices", []): - self.all_annotations[slice_info["name"]] = slice_info["annotations"] - else: - self.all_annotations[image_info["file_name"]] = image_info.get( - "annotations", {} - ) - - # Handle missing images - missing_images = [] - for image_info in project_data["images"]: - image_path = os.path.join( - self.current_project_dir, "images", image_info["file_name"] - ) - - if not os.path.exists(image_path): - missing_images.append(image_info["file_name"]) - continue - - # Update image_paths - self.image_paths[image_info["file_name"]] = image_path - - if image_info.get("is_multi_slice", False): - dimensions = image_info.get("dimensions", []) - shape = image_info.get("shape", []) - self.load_multi_slice_image(image_path, dimensions, shape) - else: - self.add_images_to_list([image_path]) - - # Restore DINO configuration if present. Classes were created above - # via add_class(), so the threshold table already has rows for them; - # we just push the saved values into the existing widgets. Filter - # out any keys that reference classes no longer in the project - # (hand-edited .iap, class deleted between sessions) so stale state - # doesn't get round-tripped on the next save. - dino_cfg = project_data.get("dino_config", {}) - valid_classes = set(self.class_mapping.keys()) - - phrases = dino_cfg.get("phrases", {}) - if phrases: - kept = {k: v for k, v in phrases.items() if k in valid_classes} - for orphan in phrases.keys() - kept.keys(): - print(f" Skipped saved DINO phrases for unknown class " - f"'{orphan}' — class is not in the current project.") - self.dino_phrase_panel.set_phrases(kept) - - for cls_name, thr in dino_cfg.get("thresholds", {}).items(): - ok = self.dino_class_table.set_thresholds( - cls_name, - thr.get("box", 0.25), - thr.get("txt", 0.25), - thr.get("nms", 0.50), - ) - if not ok: - print(f" Skipped saved DINO thresholds for unknown class " - f"'{cls_name}' — class is not in the current project.") - - # Update UI - self.update_ui() - - # Handle missing images if any - if missing_images: - self.handle_missing_images(missing_images) - - # Select the first image if available - if self.image_list.count() > 0: - self.image_list.setCurrentRow(0) - first_item = self.image_list.item(0) - if first_item: - self.switch_image(first_item) - - # Select the first class if available - if self.class_list.count() > 0: - self.class_list.setCurrentRow(0) - self.on_class_selected() + return self.project_controller.load_project_data(project_data) def handle_missing_images(self, missing_images): - message = "The following images have annotations but were not found in the project directory:\n\n" - message += "\n".join(missing_images[:10]) # Show first 10 missing images - if len(missing_images) > 10: - message += f"\n... and {len(missing_images) - 10} more." - message += "\n\nWould you like to locate these images now?" - - reply = QMessageBox.question( - self, - "Missing Images", - message, - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.Yes, - ) - - if reply == QMessageBox.StandardButton.Yes: - self.load_missing_images(missing_images) - else: - self.remove_missing_images(missing_images) + return self.project_controller.handle_missing_images(missing_images) def remove_missing_images(self, missing_images): - for image_name in missing_images: - # Remove from all_images - self.all_images = [ - img for img in self.all_images if img["file_name"] != image_name - ] - - # Remove from image_paths - self.image_paths.pop(image_name, None) - - # Remove from all_annotations - self.all_annotations.pop(image_name, None) - - # If it's a multi-slice image, remove all related slices - base_name = os.path.splitext(image_name)[0] - if base_name in self.image_slices: - for slice_name, _ in self.image_slices[base_name]: - self.all_annotations.pop(slice_name, None) - del self.image_slices[base_name] - - self.update_ui() - QMessageBox.information( - self, - "Images Removed", - f"{len(missing_images)} missing images and their annotations have been removed from the project.", - ) + return self.project_controller.remove_missing_images(missing_images) def prompt_load_missing_images(self, missing_images): - message = "The following images have annotations but were not found in the project directory:\n\n" - message += "\n".join(missing_images[:10]) # Show first 10 missing images - if len(missing_images) > 10: - message += f"\n... and {len(missing_images) - 10} more." - message += "\n\nWould you like to locate these images now?" - - reply = QMessageBox.question( - self, - "Load Missing Images", - message, - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.Yes, - ) - - if reply == QMessageBox.StandardButton.Yes: - self.load_missing_images(missing_images) + return self.project_controller.prompt_load_missing_images(missing_images) def load_missing_images(self, missing_images): - files, _ = QFileDialog.getOpenFileNames( - self, - "Select Missing Images", - "", - "Image Files (*.png *.jpg *.bmp *.tif *.tiff *.czi)", - ) - if files: - images_loaded = 0 - for file_path in files: - file_name = os.path.basename(file_path) - if file_name in missing_images: - dst_path = os.path.join( - self.current_project_dir, "images", file_name - ) - shutil.copy2(file_path, dst_path) - self.image_paths[file_name] = dst_path - - # Add the image to all_images if it's not already there - if not any( - img["file_name"] == file_name for img in self.all_images - ): - self.all_images.append( - { - "file_name": file_name, - "height": 0, - "width": 0, - "id": len(self.all_images) + 1, - "is_multi_slice": False, - } - ) - images_loaded += 1 - missing_images.remove(file_name) - - self.update_image_list() - if images_loaded > 0: - self.image_list.setCurrentRow(0) # Select the first image - self.switch_image(self.image_list.item(0)) # Display the first image - QMessageBox.information( - self, - "Images Loaded", - f"Successfully copied and loaded {images_loaded} out of {len(files)} selected images.", - ) - - # If there are still missing images, prompt again - if missing_images: - self.prompt_load_missing_images(missing_images) + return self.project_controller.load_missing_images(missing_images) def update_image_list(self): - self.image_list.clear() - for image_info in self.all_images: - self.image_list.addItem(image_info["file_name"]) + return self.image_controller.update_image_list() def select_class(self, index): - if 0 <= index < self.class_list.count(): - item = self.class_list.item(index) - self.class_list.setCurrentItem(item) - self.current_class = item.text() - print(f"Selected class: {self.current_class}") - else: - print("Invalid class index") + return self.class_controller.select_class(index) def close_project(self): - if hasattr(self, "current_project_file"): - reply = QMessageBox.question( - self, - "Close Project", - "Do you want to save the current project before closing?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No | QMessageBox.StandardButton.Cancel, - ) - - if reply == QMessageBox.StandardButton.Yes: - self.remove_all_temp_annotations() # Remove temp annotations before saving - self.save_project(show_message=False) # Save without showing a message - elif reply == QMessageBox.StandardButton.Cancel: - return # User cancelled the operation - - # Clear all data - self.clear_all(new_project=True, show_messages=False) - - # Reset project-related attributes - if hasattr(self, "current_project_file"): - del self.current_project_file - if hasattr(self, "current_project_dir"): - del self.current_project_dir - - # Update the window title - self.update_window_title() + return self.project_controller.close_project() def delete_selected_class(self): - selected_items = self.class_list.selectedItems() - if not selected_items: - QMessageBox.warning( - self, "No Selection", "Please select a class to delete." - ) - return - - class_name = selected_items[0].text() - reply = QMessageBox.question( - self, - "Delete Class", - f"Are you sure you want to delete the class '{class_name}'?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - if reply == QMessageBox.StandardButton.Yes: - self.delete_class( - class_name - ) # Sreeni note: Implement this method to handle class deletion + return self.class_controller.delete_selected_class() def check_missing_images(self): - missing_images = [ - img["file_name"] - for img in self.all_images - if img["file_name"] not in self.image_paths - or not os.path.exists(self.image_paths[img["file_name"]]) - ] - if missing_images: - self.prompt_load_missing_images(missing_images) + return self.project_controller.check_missing_images() def convert_to_serializable(self, obj): - if isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, list): - return [self.convert_to_serializable(item) for item in obj] - elif isinstance(obj, dict): - return { - key: self.convert_to_serializable(value) for key, value in obj.items() - } - else: - return obj + return image_utils.convert_to_serializable(obj) def save_project(self, show_message=True): - if not hasattr(self, "current_project_file") or not self.current_project_file: - self.current_project_file, _ = QFileDialog.getSaveFileName( - self, "Save Project", "", "Image Annotator Project (*.iap)" - ) - if not self.current_project_file: - return # User cancelled the save dialog - - self.current_project_dir = os.path.dirname(self.current_project_file) - - # Check if images are in the correct directory structure - images_dir = os.path.join(self.current_project_dir, "images") - os.makedirs(images_dir, exist_ok=True) - - images_to_copy = [] - for file_name, src_path in self.image_paths.items(): - dst_path = os.path.join(images_dir, file_name) - if os.path.abspath(src_path) != os.path.abspath(dst_path): - if not os.path.exists(dst_path): - images_to_copy.append((file_name, src_path, dst_path)) - - if images_to_copy: - reply = QMessageBox.question( - self, - "Image Directory Structure", - f"The project structure requires all images to be in an 'images' subdirectory. " - f"{len(images_to_copy)} images need to be copied to the correct location. " - f"Do you want to copy these images?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.Yes, - ) - - if reply == QMessageBox.StandardButton.Yes: - for file_name, src_path, dst_path in images_to_copy: - try: - shutil.copy2(src_path, dst_path) - self.image_paths[file_name] = dst_path - except Exception as e: - QMessageBox.warning( - self, "Copy Failed", f"Failed to copy {file_name}: {str(e)}" - ) - return - else: - QMessageBox.warning( - self, - "Save Cancelled", - "Project cannot be saved without the correct directory structure.", - ) - return - - # Prepare image data - images_data = [] - for image_info in self.all_images: - file_name = image_info["file_name"] - image_data = { - "file_name": file_name, - "width": image_info["width"], - "height": image_info["height"], - "is_multi_slice": image_info["is_multi_slice"], - } - - if image_data["is_multi_slice"]: - base_name_without_ext = os.path.splitext(file_name)[0] - image_data["slices"] = [] - for slice_name, _ in self.image_slices.get(base_name_without_ext, []): - slice_data = { - "name": slice_name, - "annotations": self.convert_to_serializable( - self.all_annotations.get(slice_name, {}) - ), - } - image_data["slices"].append(slice_data) - - image_data["dimensions"] = self.convert_to_serializable( - self.image_dimensions.get(base_name_without_ext, []) - ) - image_data["shape"] = self.convert_to_serializable( - self.image_shapes.get(base_name_without_ext, []) - ) - else: - image_data["annotations"] = {} - for class_name, annotations in self.all_annotations.get( - file_name, {} - ).items(): - image_data["annotations"][class_name] = [ - ann.copy() for ann in annotations - ] - - images_data.append(image_data) - - # Create project data - project_data = { - "classes": [ - {"name": name, "color": color.name()} - for name, color in self.image_label.class_colors.items() - ], - "images": images_data, - "image_paths": { - k: v for k, v in self.image_paths.items() if os.path.exists(v) - }, - "notes": getattr(self, "project_notes", ""), - "creation_date": getattr( - self, "project_creation_date", datetime.now().isoformat() - ), - "last_modified": datetime.now().isoformat(), - } - - # Persist DINO configuration by snapshotting the widgets that own it. - dino_cfg = { - "phrases": self.dino_phrase_panel.get_all_phrases(), - "thresholds": self.dino_class_table.get_thresholds_dict(), - } - if dino_cfg["phrases"] or dino_cfg["thresholds"]: - project_data["dino_config"] = dino_cfg - - # Save project data - with open(self.current_project_file, "w") as f: - json.dump(self.convert_to_serializable(project_data), f, indent=2) - - if show_message: - self.show_info( - "Project Saved", f"Project saved to {self.current_project_file}" - ) - - # Update the window title - self.update_window_title() - - # Update image_paths to reflect the correct locations - for file_name in self.image_paths.keys(): - self.image_paths[file_name] = os.path.join(images_dir, file_name) + return self.project_controller.save_project(show_message=show_message) def save_project_as(self): - new_project_file, _ = QFileDialog.getSaveFileName( - self, "Save Project As", "", "Image Annotator Project (*.iap)" - ) - if new_project_file: - # Ensure the file has the correct extension - if not new_project_file.lower().endswith(".iap"): - new_project_file += ".iap" - - # Store the original project file - original_project_file = getattr(self, "current_project_file", None) - - # Set the new project file as the current one - self.current_project_file = new_project_file - self.current_project_dir = os.path.dirname(new_project_file) - - # Save the project with the new name - self.save_project(show_message=False) - - # Update the window title - self.update_window_title() - - # Show a success message - QMessageBox.information( - self, "Project Saved As", f"Project saved as:\n{new_project_file}" - ) - - # If this was originally a new unsaved project, update the original project file - if original_project_file is None: - self.current_project_file = new_project_file + return self.project_controller.save_project_as() def auto_save(self): - if self.is_loading_project: - return # Skip auto-save during project loading - - if not hasattr(self, "current_project_file"): - reply = QMessageBox.question( - self, - "No Project", - "You need to save the project before auto-saving. Would you like to save now?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.Yes, - ) - if reply == QMessageBox.StandardButton.Yes: - self.save_project() - else: - return - - if hasattr(self, "current_project_file"): - self.save_project(show_message=False) - print("Project auto-saved.") + return self.project_controller.auto_save() def show_project_details(self): if not hasattr(self, "current_project_file"): @@ -990,8 +304,8 @@ def show_project_details(self): ) return - from .annotation_statistics import AnnotationStatisticsDialog - from .project_details import ProjectDetailsDialog + from .dialogs.annotation_statistics import AnnotationStatisticsDialog + from .dialogs.project_details import ProjectDetailsDialog # Generate annotation statistics stats_dialog = AnnotationStatisticsDialog(self) @@ -1010,500 +324,48 @@ def show_project_details(self): print("No changes made to project details.") def load_multi_slice_image(self, image_path, dimensions=None, shape=None): + return self.image_controller.load_multi_slice_image(image_path, dimensions, shape) - file_name = os.path.basename(image_path) - base_name = os.path.splitext(file_name)[0] - print(f"Loading multi-slice image: {image_path}") - print(f"Base name: {base_name}") - - if dimensions and shape: - print(f"Using stored dimensions: {dimensions}") - print(f"Using stored shape: {shape}") - self.image_dimensions[base_name] = dimensions - self.image_shapes[base_name] = shape - if image_path.lower().endswith((".tif", ".tiff")): - self.load_tiff(image_path, dimensions, shape) - elif image_path.lower().endswith(".czi"): - self.load_czi(image_path, dimensions, shape) - else: - print("No stored dimensions or shape, loading as new image") - if image_path.lower().endswith((".tif", ".tiff")): - self.load_tiff(image_path) - elif image_path.lower().endswith(".czi"): - self.load_czi(image_path) - - print(f"Loaded multi-slice image: {file_name}") - print(f"Dimensions: {self.image_dimensions.get(base_name, 'Not found')}") - print(f"Shape: {self.image_shapes.get(base_name, 'Not found')}") - print(f"Number of slices: {len(self.slices)}") - - if self.slices: - self.current_image = self.slices[0][1] - self.current_slice = self.slices[0][0] - - self.update_slice_list() - self.slice_list.setCurrentRow(0) - self.activate_slice(self.current_slice) - print(f"Activated first slice: {self.current_slice}") - else: - print("No slices were loaded") - self.current_image = None - self.current_slice = None - - self.update_slice_list() - self.image_label.update() - - # print(f"Loaded slices: {[slice_name for slice_name, _ in self.slices]}") - - def activate_sam_magic_wand(self): - # Uncheck all other tools - for button in self.tool_group.buttons(): - if button != self.sam_magic_wand_button: - button.setChecked(False) - - # Set the current tool - self.image_label.current_tool = "sam_magic_wand" - self.image_label.sam_magic_wand_active = True - self.image_label.setCursor(Qt.CursorShape.CrossCursor) - - # Update UI based on the current tool - self.update_ui_for_current_tool() - - # If a class is not selected, select the first one (if available) - if self.current_class is None and self.class_list.count() > 0: - self.class_list.setCurrentRow(0) - self.current_class = self.class_list.currentItem().text() - elif self.class_list.count() == 0: - QMessageBox.warning( - self, - "No Class Selected", - "Please add a class before using annotation tools.", - ) - self.sam_magic_wand_button.setChecked(False) - self.deactivate_sam_magic_wand() - - def deactivate_sam_magic_wand(self): - self.image_label.current_tool = None - self.image_label.sam_magic_wand_active = False - self.sam_magic_wand_button.setChecked(False) - self.sam_magic_wand_button.setEnabled(False) # Disable the button - self.image_label.setCursor(Qt.CursorShape.ArrowCursor) - - # Clear any SAM-related temporary data - self.image_label.sam_bbox = None - self.image_label.drawing_sam_bbox = False - self.image_label.temp_sam_prediction = None - - # Update UI based on the current tool - self.update_ui_for_current_tool() - - def toggle_sam_assisted(self): - if not self.current_sam_model: - QMessageBox.warning( - self, - "No SAM Model Selected", - "Please pick a SAM model before using the SAM-Assisted tool.", - ) - self.sam_magic_wand_button.setChecked(False) - return - - if self.sam_magic_wand_button.isChecked(): - self.activate_sam_magic_wand() - else: - self.deactivate_sam_magic_wand() - - self.image_label.clear_temp_sam_prediction() # Clear temporary prediction - - def toggle_sam_magic_wand(self): - if self.sam_magic_wand_button.isChecked(): - if self.current_class is None: - QMessageBox.warning( - self, - "No Class Selected", - "Please select a class before using SAM2 Magic Wand.", - ) - self.sam_magic_wand_button.setChecked(False) - return - self.image_label.setCursor(Qt.CursorShape.CrossCursor) - self.image_label.sam_magic_wand_active = True - else: - self.image_label.setCursor(Qt.CursorShape.ArrowCursor) - self.image_label.sam_magic_wand_active = False - self.image_label.sam_bbox = None - - self.image_label.clear_temp_sam_prediction() # Clear temporary prediction + def deactivate_sam_tools(self): + return self.sam_controller.deactivate_sam_tools() def schedule_sam_prediction(self): - """Restart the debounce timer; inference fires 1s after last click.""" - self.sam_inference_timer.stop() - self.sam_inference_timer.start(1000) + return self.sam_controller.schedule_sam_prediction() def apply_sam_prediction(self): - # Re-entry guard: if a previous SAM call is still in flight, the - # event-loop pump inside _run_sync can deliver this timer fire - # before the first call returns. Bail and rely on the user - # clicking again (which restarts the debounce) to issue a fresh - # inference with the up-to-date point set. - if self._sam_inference_in_flight: - return - self._sam_inference_in_flight = True - try: - try: - if self.image_label.current_tool == "sam_box": - if self.image_label.sam_bbox is None: - print("SAM bbox is None") - return - x1, y1, x2, y2 = self.image_label.sam_bbox - bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] - prediction = self.sam_utils.apply_sam_prediction(self.current_image, bbox) - self.image_label.sam_bbox = None - elif self.image_label.current_tool == "sam_points": - # Always use all points! - pos_points = self.image_label.sam_positive_points - neg_points = self.image_label.sam_negative_points - print( - f"[SAM-POINTS] Predicting with {len(pos_points)} positive points: {pos_points} " - f"and {len(neg_points)} negative points: {neg_points}" - ) - if not pos_points: - print("No positive points for SAM-points") - return - prediction = self.sam_utils.apply_sam_points( - self.current_image, - pos_points, - neg_points, - ) - else: - return - except InferenceBusyError: - # Re-entry safety net from sam_utils. The call-site flag - # above should catch this first, but if a different - # caller drives inference concurrently we just skip — - # the user keeps interacting; their next click will - # restart the debounce. - return - except Exception as exc: - traceback.print_exc() - QMessageBox.critical( - self, - "SAM Error", - f"SAM inference failed:\n\n{exc}\n\n" - "See the log for details.", - ) - return - - if prediction: - temp_annotation = { - "segmentation": prediction["segmentation"], - "category_id": self.class_mapping[self.current_class], - "category_name": self.current_class, - "score": prediction["score"], - } - self.image_label.temp_sam_prediction = temp_annotation - self.image_label.update() - elif prediction is None: - QMessageBox.information( - self, - "SAM", - "No mask matches the given constraints. " - "Try adjusting the box or point positions." - ) - else: - print("Failed to generate prediction") - - # Only clear box/points for box mode, not for points mode! - if self.image_label.current_tool == "sam_box": - self.image_label.sam_bbox = None - self.image_label.update() - finally: - self._sam_inference_in_flight = False + return self.sam_controller.apply_sam_prediction() def accept_sam_prediction(self): - if self.image_label.temp_sam_prediction: - new_annotation = self.image_label.temp_sam_prediction - self.image_label.annotations.setdefault( - new_annotation["category_name"], [] - ).append(new_annotation) - self.add_annotation_to_list(new_annotation) - self.save_current_annotations() - self.update_slice_list_colors() - self.image_label.temp_sam_prediction = None - # --- Clear points after accepting - self.image_label.sam_positive_points = [] - self.image_label.sam_negative_points = [] - self.image_label.update() - print("SAM prediction accepted, points cleared, and added to annotations.") + return self.sam_controller.accept_sam_prediction() def setup_slice_list(self): - self.slice_list = QListWidget() - self.slice_list.itemClicked.connect(self.switch_slice) - self.image_list_layout.addWidget(QLabel("Slices:")) - self.image_list_layout.addWidget(self.slice_list) + return self.image_controller.setup_slice_list() def open_images(self): - file_names, _ = QFileDialog.getOpenFileNames( - self, - "Open Images", - "", - "Image Files (*.png *.jpg *.bmp *.tif *.tiff *.czi)", - ) - if file_names: - self.image_list.clear() - self.image_paths.clear() - self.all_images.clear() - self.slice_list.clear() - self.slices.clear() - self.current_stack = None - self.current_slice = None - self.add_images_to_list(file_names) + return self.image_controller.open_images() def convert_to_8bit_rgb(self, image_array): - if image_array.ndim == 2: - # Grayscale image - image_8bit = self.normalize_array(image_array) - return np.stack((image_8bit,) * 3, axis=-1) - elif image_array.ndim == 3: - if image_array.shape[2] == 3: - # Already RGB, just normalize - return self.normalize_array(image_array) - elif image_array.shape[2] > 3: - # Multi-channel image, use first three channels - rgb_array = image_array[:, :, :3] - return self.normalize_array(rgb_array) - - raise ValueError(f"Unsupported image shape: {image_array.shape}") + return image_utils.convert_to_8bit_rgb(image_array) def add_images_to_list(self, file_names): - first_added_item = None - for file_name in file_names: - base_name = os.path.basename(file_name) - if base_name not in self.image_paths: - image_info = { - "file_name": base_name, - "height": 0, - "width": 0, - "id": len(self.all_images) + 1, - "is_multi_slice": False, - } - - # Detect multi-slice images and set dimensions - if file_name.lower().endswith((".tif", ".tiff", ".czi")): - self.load_multi_slice_image(file_name) - base_name_without_ext = os.path.splitext(base_name)[0] - if ( - base_name_without_ext in self.image_slices - and self.image_slices[base_name_without_ext] - ): - first_slice_name, first_slice = self.image_slices[ - base_name_without_ext - ][0] - image_info["height"] = first_slice.height() - image_info["width"] = first_slice.width() - image_info["is_multi_slice"] = True - image_info["dimensions"] = self.image_dimensions.get( - base_name_without_ext, [] - ) - image_info["shape"] = self.image_shapes.get( - base_name_without_ext, [] - ) - else: - # For regular images - image = QImage(file_name) - image_info["height"] = image.height() - image_info["width"] = image.width() - - self.all_images.append(image_info) - item = QListWidgetItem(base_name) - self.image_list.addItem(item) - if first_added_item is None: - first_added_item = item - - # Update image_paths with the original file path - self.image_paths[base_name] = file_name - - if first_added_item: - self.image_list.setCurrentItem(first_added_item) - self.switch_image(first_added_item) - - if not self.is_loading_project: - self.auto_save() + return self.image_controller.add_images_to_list(file_names) def update_all_images(self, new_image_info): - for info in new_image_info: - if not any( - img["file_name"] == info["file_name"] for img in self.all_images - ): - self.all_images.append(info) + return self.image_controller.update_all_images(new_image_info) def closeEvent(self, event): + # check_unsaved_changes prompts and commits/discards as the + # user chooses; returns False on Cancel. if not self.image_label.check_unsaved_changes(): event.ignore() return event.accept() - if ( - self.image_label.temp_paint_mask is not None - or self.image_label.temp_eraser_mask is not None - ): - reply = QMessageBox.question( - self, - "Unsaved Changes", - "You have unsaved changes. Do you want to save them before closing?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No | QMessageBox.StandardButton.Cancel, - ) - if reply == QMessageBox.StandardButton.Yes: - if self.image_label.temp_paint_mask is not None: - self.image_label.commit_paint_annotation() - if self.image_label.temp_eraser_mask is not None: - self.image_label.commit_eraser_changes() - elif reply == QMessageBox.StandardButton.Cancel: - event.ignore() - return - - # Perform any other cleanup or saving operations here - event.accept() - def switch_slice(self, item): - if item is None: - return - if not self.image_label.check_unsaved_changes(): - return - - # Check for unsaved changes - if ( - self.image_label.temp_paint_mask is not None - or self.image_label.temp_eraser_mask is not None - ): - reply = QMessageBox.question( - self, - "Unsaved Changes", - "You have unsaved changes. Do you want to save them?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No | QMessageBox.StandardButton.Cancel, - ) - if reply == QMessageBox.StandardButton.Yes: - if self.image_label.temp_paint_mask is not None: - self.image_label.commit_paint_annotation() - if self.image_label.temp_eraser_mask is not None: - self.image_label.commit_eraser_changes() - elif reply == QMessageBox.StandardButton.Cancel: - return - else: - self.image_label.discard_paint_annotation() - self.image_label.discard_eraser_changes() - - self.save_current_annotations() - self.image_label.clear_temp_sam_prediction() - - slice_name = item.text() - for name, qimage in self.slices: - if name == slice_name: - self.current_image = qimage - self.current_slice = name - self.display_image() - self.load_image_annotations() - self.update_annotation_list() - self.clear_highlighted_annotation() - self.image_label.highlighted_annotations.clear() # Add this line - self.image_label.reset_annotation_state() - self.image_label.clear_current_annotation() - self.update_image_info() - break - - # Ensure the UI is updated - self.image_label.update() - self.update_slice_list_colors() - - # Reset zoom level to default (1.0) - self.set_zoom(1.0) - - # Sync DINO temp_annotations to the new slice (carry over masks - # from the previous slice was a reported bug). - self._refresh_dino_temp_for_current() + return self.image_controller.switch_slice(item) def switch_image(self, item): - if item is None: - return - if not self.image_label.check_unsaved_changes(): - return - - # Store the current item before checking temp annotations - current_item = self.image_list.currentItem() - - if not self.check_temp_annotations(): - # If the user chooses not to discard temp annotations, revert the selection - self.image_list.setCurrentItem(current_item) - return - - self.save_current_annotations() - self.image_label.clear_temp_sam_prediction() - self.image_label.exit_editing_mode() - - file_name = item.text() - print(f"\nSwitching to image: {file_name}") - - image_info = next( - (img for img in self.all_images if img["file_name"] == file_name), None - ) - - if image_info: - self.image_file_name = file_name - image_path = self.image_paths.get(file_name) - - if not image_path: - image_path = os.path.join(self.current_project_dir, "images", file_name) - - if image_path and os.path.exists(image_path): - if image_info.get("is_multi_slice", False): - base_name = os.path.splitext(file_name)[0] - if base_name in self.image_slices: - self.slices = self.image_slices[base_name] - if self.slices: - self.current_image = self.slices[0][1] - self.current_slice = self.slices[0][0] - self.update_slice_list() - self.activate_slice(self.current_slice) - else: - self.load_multi_slice_image( - image_path, - image_info.get("dimensions"), - image_info.get("shape"), - ) - else: - self.load_regular_image(image_path) - self.display_image() - self.clear_slice_list() - - self.load_image_annotations() - self.update_annotation_list() - self.clear_highlighted_annotation() - - self.image_label.highlighted_annotations.clear() - self.image_label.update() - self.image_label.reset_annotation_state() - self.image_label.clear_current_annotation() - self.update_image_info() - - self.adjust_zoom_to_fit() - else: - self.current_image = None - self.image_label.clear() - self.load_image_annotations() - self.update_annotation_list() - self.update_image_info() - - self.image_list.setCurrentItem(item) - self.image_label.update() - self.update_slice_list_colors() - else: - self.current_image = None - self.current_slice = None - self.image_label.clear() - self.update_image_info() - self.clear_slice_list() - - # Sync DINO temp_annotations to the new image (mask carry-over - # bug from single-image review and batch review). - self._refresh_dino_temp_for_current() + return self.image_controller.switch_image(item) def adjust_zoom_to_fit(self): if not self.current_image: @@ -1520,428 +382,51 @@ def adjust_zoom_to_fit(self): self.set_zoom(zoom_factor) def activate_current_slice(self): - if self.current_slice: - # Ensure the current slice is selected in the slice list - items = self.slice_list.findItems(self.current_slice, Qt.MatchFlag.MatchExactly) - if items: - self.slice_list.setCurrentItem(items[0]) - - # Load annotations for the current slice - self.load_image_annotations() - - # Update the image label - self.image_label.update() - - # Update the annotation list - self.update_annotation_list() + return self.image_controller.activate_current_slice() def load_image(self, image_path): - extension = os.path.splitext(image_path)[1].lower() - if extension in [".tif", ".tiff"]: - self.load_tiff(image_path) - elif extension == ".czi": - self.load_czi(image_path) - else: - self.load_regular_image(image_path) + return self.image_controller.load_image(image_path) - def load_tiff( - self, image_path, dimensions=None, shape=None, force_dimension_dialog=False - ): - print(f"Loading TIFF file: {image_path}") - axes_hint = None - with TiffFile(image_path) as tif: - print(f"TIFF tags: {tif.pages[0].tags}") - - # Try to access metadata if available - try: - metadata = tif.pages[0].tags["ImageDescription"].value - print(f"TIFF metadata: {metadata}") - except KeyError: - print("No ImageDescription metadata found") - - # Try to read axis labels from the tifffile series. ImageJ / - # OME-TIFF stores axes like "TZCYX" — we can prefill the - # dimension dialog with the right labels so the user just - # clicks OK instead of guessing per axis. Map tifffile's - # axes vocabulary (T,Z,C,S,Y,X) to the app's (T,Z,C,S,H,W). - try: - series_axes = tif.series[0].axes if tif.series else None - if series_axes: - axis_map = { - "T": "T", "Z": "Z", "C": "C", "S": "S", - "Y": "H", "X": "W", - } - mapped = [axis_map.get(a) for a in series_axes] - if all(a is not None for a in mapped): - axes_hint = mapped - print(f"TIFF series axes: {series_axes} → dimension hint: {axes_hint}") - else: - unknown = [a for a in series_axes if axis_map.get(a) is None] - print(f"TIFF series axes had unknown labels {unknown}, no hint applied") - except Exception as e: - print(f"Could not read TIFF series axes: {e}") - - # Check if it's a multi-page TIFF - if len(tif.pages) > 1: - print(f"Multi-page TIFF detected. Number of pages: {len(tif.pages)}") - # Read all pages into a 3D array - image_array = tif.asarray() - else: - print("Single-page TIFF detected.") - image_array = tif.pages[0].asarray() - - print(f"Image array shape: {image_array.shape}") - print(f"Image array dtype: {image_array.dtype}") - print(f"Image min: {image_array.min()}, max: {image_array.max()}") - - if dimensions and shape and not force_dimension_dialog: - # Use stored dimensions and shape - print(f"Using stored dimensions: {dimensions}") - print(f"Using stored shape: {shape}") - image_array = image_array.reshape(shape) - else: - # Process as before for new images or when forcing dimension dialog - print("Processing as new image or forcing dimension dialog.") - dimensions = None - - self.process_multidimensional_image( - image_array, image_path, dimensions, force_dimension_dialog, - axes_hint=axes_hint, - ) - - def load_czi( - self, image_path, dimensions=None, shape=None, force_dimension_dialog=False - ): - print(f"Loading CZI file: {image_path}") - with CziFile(image_path) as czi: - image_array = czi.asarray() - print(f"CZI array shape: {image_array.shape}") - print(f"CZI array dtype: {image_array.dtype}") - print(f"CZI array min: {image_array.min()}, max: {image_array.max()}") - - if dimensions and shape and not force_dimension_dialog: - # Use stored dimensions and shape - print(f"Using stored dimensions: {dimensions}") - print(f"Using stored shape: {shape}") - image_array = image_array.reshape(shape) - else: - # Process as before for new images or when forcing dimension dialog - print("Processing as new image or forcing dimension dialog.") - dimensions = None + def load_tiff(self, image_path, dimensions=None, shape=None, force_dimension_dialog=False): + return self.image_controller.load_tiff(image_path, dimensions, shape, force_dimension_dialog) - self.process_multidimensional_image( - image_array, image_path, dimensions, force_dimension_dialog - ) + def load_czi(self, image_path, dimensions=None, shape=None, force_dimension_dialog=False): + return self.image_controller.load_czi(image_path, dimensions, shape, force_dimension_dialog) def load_regular_image(self, image_path): - self.current_image = QImage(image_path) - self.slices = [] - self.slice_list.clear() - self.current_slice = None + return self.image_controller.load_regular_image(image_path) def process_multidimensional_image( self, image_array, image_path, dimensions=None, force_dimension_dialog=False, axes_hint=None, ): - file_name = os.path.basename(image_path) - base_name = os.path.splitext(file_name)[0] - print(f"Processing file: {file_name}") - print(f"Image array shape: {image_array.shape}") - print(f"Image array dtype: {image_array.dtype}") - - if dimensions is None or force_dimension_dialog: - if image_array.ndim > 2: - # Prefer the loader's metadata-derived hint (e.g. ImageJ - # TIFF axes='TZCYX'). Fall back to a hand-crafted default - # that covers ndim 3..6 so a user clicking OK without - # tweaking the combos gets a sensible result. The earlier - # `default_dimensions[-ndim:]` slice silently degraded for - # ndim≥5: one axis ended up unset and inherited the combo - # box's first item ("T"), producing 2560 wrong slices for - # a 5D TZCYX file. - if axes_hint and len(axes_hint) == image_array.ndim: - default_dimensions = list(axes_hint) - print(f"Applying axes hint as default dims: {default_dimensions}") - else: - if axes_hint and len(axes_hint) != image_array.ndim: - print( - f"Ignoring axes hint (length {len(axes_hint)} " - f"vs ndim {image_array.ndim})" - ) - ndim_defaults = { - 3: ["Z", "H", "W"], - 4: ["T", "Z", "H", "W"], - 5: ["T", "Z", "C", "H", "W"], - 6: ["T", "Z", "C", "S", "H", "W"], - } - # ndim ≥ 7 falls into the generic case: pad with - # "T" at the front so H / W are still the last two - # axes — that way "click OK" still produces a - # sensible 2D slice even on exotic inputs. - default_dimensions = ndim_defaults.get( - image_array.ndim, - ["T"] * max(0, image_array.ndim - 2) + ["H", "W"], - ) - - # Show a progress dialog - progress = QProgressDialog( - "Assigning dimensions...", "Cancel", 0, 100, self - ) - progress.setWindowModality(Qt.WindowModality.WindowModal) - progress.setMinimumDuration(0) - progress.setValue(10) - QApplication.processEvents() - - while True: - dialog = DimensionDialog( - image_array.shape, file_name, self, default_dimensions - ) - # Qt6 no longer shows the "?" help button by default; - # the old WindowContextHelpButtonHint clear is gone. - progress.setValue(50) - QApplication.processEvents() - if dialog.exec(): - dimensions = dialog.get_dimensions() - print(f"Assigned dimensions: {dimensions}") - if "H" in dimensions and "W" in dimensions: - self.image_dimensions[base_name] = dimensions - break - else: - QMessageBox.warning( - self, - "Invalid Dimensions", - "You must assign both H and W dimensions.", - ) - else: - progress.close() - return - progress.setValue(100) - progress.close() - else: - dimensions = ["H", "W"] - self.image_dimensions[base_name] = dimensions + return self.image_controller.process_multidimensional_image( + image_array, image_path, dimensions, force_dimension_dialog, axes_hint=axes_hint + ) - self.image_shapes[base_name] = image_array.shape - print(f"Final assigned dimensions: {self.image_dimensions[base_name]}") - print(f"Image shape: {self.image_shapes[base_name]}") + def create_slices(self, image_array, dimensions, image_path): + return self.image_controller.create_slices(image_array, dimensions, image_path) - if self.image_dimensions[base_name]: - self.create_slices( - image_array, self.image_dimensions[base_name], image_path - ) - else: - rgb_image = self.convert_to_8bit_rgb(image_array) - self.current_image = self.array_to_qimage(rgb_image) - self.slices = [] - self.slice_list.clear() - - if self.slices: - self.current_image = self.slices[0][1] - self.current_slice = self.slices[0][0] - self.slice_list.setCurrentRow(0) - self.load_image_annotations() - self.image_label.update() + def add_slice_to_list(self, slice_name): + return self.image_controller.add_slice_to_list(slice_name) - self.update_image_info() - - # Update UI - self.update_slice_list() - self.update_annotation_list() - self.image_label.update() - - def create_slices(self, image_array, dimensions, image_path): - base_name = os.path.splitext(os.path.basename(image_path))[0] - slices = [] - self.slice_list.clear() - - print(f"Creating slices for {base_name}") - print(f"Dimensions: {dimensions}") - print(f"Image array shape: {image_array.shape}") - - # Create and show progress dialog - progress = QProgressDialog("Loading slices...", "Cancel", 0, 100, self) - progress.setWindowModality(Qt.WindowModality.WindowModal) - progress.setMinimumDuration(0) # Show immediately - - # Handle 2D images - if image_array.ndim == 2: - progress.setValue(50) # Update progress - QApplication.processEvents() # Allow GUI to update - normalized_array = self.normalize_array(image_array) - qimage = self.array_to_qimage(normalized_array) - slice_name = f"{base_name}" - slices.append((slice_name, qimage)) - self.add_slice_to_list(slice_name) - else: - # For 3D or higher dimensional arrays - slice_indices = [ - i for i, dim in enumerate(dimensions) if dim not in ["H", "W"] - ] - - total_slices = np.prod([image_array.shape[i] for i in slice_indices]) - for idx, _ in enumerate( - np.ndindex(tuple(image_array.shape[i] for i in slice_indices)) - ): - if progress.wasCanceled(): - break - - full_idx = [slice(None)] * len(dimensions) - for i, val in zip(slice_indices, _): - full_idx[i] = val - - slice_array = image_array[tuple(full_idx)] - rgb_slice = self.convert_to_8bit_rgb(slice_array) - qimage = self.array_to_qimage(rgb_slice) - - slice_name = f"{base_name}_{'_'.join([f'{dimensions[i]}{val+1}' for i, val in zip(slice_indices, _)])}" - slices.append((slice_name, qimage)) - - self.add_slice_to_list(slice_name) - - # Update progress - progress_value = int((idx + 1) / total_slices * 100) - progress.setValue(progress_value) - QApplication.processEvents() # Allow GUI to update - - progress.setValue(100) # Ensure progress reaches 100% - - self.image_slices[base_name] = slices - self.slices = slices - - if slices: - self.current_image = slices[0][1] - self.current_slice = slices[0][0] - self.slice_list.setCurrentRow(0) - - self.activate_slice(self.current_slice) - - slice_info = f"Total slices: {len(slices)}" - for dim, size in zip(dimensions, image_array.shape): - if dim not in ["H", "W"]: - slice_info += f", {dim}: {size}" - self.update_image_info(additional_info=slice_info) - else: - print("No slices were created") - - print(f"Created {len(slices)} slices for {base_name}") - return slices - - def add_slice_to_list(self, slice_name): - item = QListWidgetItem(slice_name) - - if self.dark_mode: - # Dark mode - item.setBackground( - QColor(40, 40, 40) - ) # Very dark gray background for all items - if slice_name in self.all_annotations: - # Muted steel-blue + light text; the prior light-blue - # (173, 216, 230) bg + dark-gray text was painfully - # bright on a dark sidebar. - item.setForeground(QColor(235, 235, 235)) - item.setBackground(QColor(58, 95, 140)) - else: - item.setForeground(QColor(200, 200, 200)) # Light gray text - else: - # Light mode - item.setBackground( - QColor(240, 240, 240) - ) # Very light gray background for all items - if slice_name in self.all_annotations: - item.setForeground(QColor(255, 255, 255)) # White text - item.setBackground(QColor(70, 130, 180)) # Medium-dark blue background - else: - item.setForeground(QColor(0, 0, 0)) # Black text - - self.slice_list.addItem(item) - - def normalize_array(self, array): - # print(f"Normalizing array. Shape: {array.shape}, dtype: {array.dtype}") - # print(f"Array min: {array.min()}, max: {array.max()}, mean: {array.mean()}") - - array_float = array.astype(np.float32) - - if array.dtype == np.uint16: - array_normalized = (array_float - array.min()) / (array.max() - array.min()) - elif array.dtype == np.uint8: - # For 8-bit images, use a simple contrast stretching - p_low, p_high = np.percentile( - array_float, (0, 100) - ) # Change these to 1, 99 or something to stretch the contrast for visualizing 8 bit images - array_normalized = np.clip(array_float, p_low, p_high) - array_normalized = (array_normalized - p_low) / (p_high - p_low) - else: - array_normalized = (array_float - array.min()) / (array.max() - array.min()) - - # Apply gamma correction - gamma = 1.0 # Adjust this value to fine-tune brightness (> 1 for darker, < 1 for brighter) - array_normalized = np.power(array_normalized, gamma) - - return (array_normalized * 255).astype(np.uint8) + def normalize_array(self, array): + return image_utils.normalize_array(array) def adjust_contrast(self, image, low_percentile=1, high_percentile=99): - if image.dtype != np.uint8: - p_low, p_high = np.percentile(image, (low_percentile, high_percentile)) - image_adjusted = np.clip(image, p_low, p_high) - image_adjusted = (image_adjusted - p_low) / (p_high - p_low) - return (image_adjusted * 255).astype(np.uint8) - return image + return image_utils.adjust_contrast(image, low_percentile, high_percentile) def activate_slice(self, slice_name): - self.current_slice = slice_name - self.image_file_name = slice_name - self.load_image_annotations() - self.update_annotation_list() - - for name, qimage in self.slices: - if name == slice_name: - self.current_image = qimage - self.display_image() - break - - self.image_label.update() - - items = self.slice_list.findItems(slice_name, Qt.MatchFlag.MatchExactly) - if items: - self.slice_list.setCurrentItem(items[0]) + return self.image_controller.activate_slice(slice_name) def array_to_qimage(self, array): - if array.ndim == 2: - height, width = array.shape - return QImage(array.data, width, height, width, QImage.Format.Format_Grayscale8) - elif array.ndim == 3 and array.shape[2] == 3: - height, width, _ = array.shape - bytes_per_line = 3 * width - return QImage( - array.data, width, height, bytes_per_line, QImage.Format.Format_RGB888 - ) - else: - raise ValueError( - f"Unsupported array shape {array.shape} for conversion to QImage" - ) + return image_utils.array_to_qimage(array) def update_slice_list(self): - self.slice_list.clear() - for slice_name, _ in self.slices: - item = QListWidgetItem(slice_name) - if slice_name in self.all_annotations: - item.setForeground(QColor(Qt.GlobalColor.green)) - else: - item.setForeground( - QColor(Qt.GlobalColor.black) if not self.dark_mode else QColor(Qt.GlobalColor.white) - ) - self.slice_list.addItem(item) - - # Select the current slice - if self.current_slice: - items = self.slice_list.findItems(self.current_slice, Qt.MatchFlag.MatchExactly) - if items: - self.slice_list.setCurrentItem(items[0]) + return self.image_controller.update_slice_list() def clear_slice_list(self): - self.slice_list.clear() - self.slices = [] - self.current_slice = None + return self.image_controller.clear_slice_list() def reset_tool_buttons(self): for button in self.tool_group.buttons(): @@ -1998,11 +483,7 @@ def keyPressEvent(self, event): super().keyPressEvent(event) def has_visible_temp_classes(self): - for i in range(self.class_list.count()): - item = self.class_list.item(i) - if item.text().startswith("Temp-") and item.checkState() == Qt.CheckState.Checked: - return True - return False + return self.dino_controller.has_visible_temp_classes() def launch_snake_game(self): # print("Launching Snake game") @@ -2012,650 +493,43 @@ def launch_snake_game(self): self.snake_game.setFocus() def import_annotations(self): - if not self.image_label.check_unsaved_changes(): - return - print("Starting import_annotations") - import_format = self.import_format_selector.currentText() - print(f"Import format: {import_format}") - - if import_format == "COCO JSON": - file_name, _ = QFileDialog.getOpenFileName( - self, "Import COCO JSON Annotations", "", "JSON Files (*.json)" - ) - if not file_name: - print("No file selected, returning") - return - - print(f"Selected file: {file_name}") - json_dir = os.path.dirname(file_name) - images_dir = os.path.join(json_dir, "images") - imported_annotations, image_info = import_coco_json( - file_name, self.class_mapping - ) - - elif import_format in ["YOLO (v4 and earlier)", "YOLO (v5+)"]: - yaml_file, _ = QFileDialog.getOpenFileName( - self, "Select YOLO Dataset YAML", "", "YAML Files (*.yaml *.yml)" - ) - if not yaml_file: - print("No YAML file selected, returning") - return - - print(f"Selected YAML file: {yaml_file}") - try: - imported_annotations, image_info = process_import_format( - import_format, yaml_file, self.class_mapping - ) - yaml_dir = os.path.dirname(yaml_file) - if import_format == "YOLO (v4 and earlier)": - images_dir = os.path.join(yaml_dir, "train", "images") - else: # YOLO (v5+) - images_dir = os.path.join( - yaml_dir, "images", "train" - ) # Preferring train over val - except ValueError as e: - QMessageBox.warning(self, "Import Error", str(e)) - return - - else: - QMessageBox.warning( - self, - "Unsupported Format", - f"The selected format '{import_format}' is not implemented for import.", - ) - return - - print( - f"JSON/YOLO directory: {json_dir if import_format == 'COCO JSON' else os.path.dirname(yaml_file)}" - ) - print(f"Images directory: {images_dir}") - print(f"Imported annotations count: {len(imported_annotations)}") - print(f"Image info count: {len(image_info)}") - - images_loaded = 0 - images_not_found = [] - - for info in image_info.values(): - print(f"Processing image: {info['file_name']}") - image_path = os.path.join(images_dir, info["file_name"]) - - if os.path.exists(image_path): - print(f"Image found at: {image_path}") - self.image_paths[info["file_name"]] = image_path - self.all_images.append( - { - "file_name": info["file_name"], - "height": info["height"], - "width": info["width"], - "id": info["id"], - "is_multi_slice": False, - } - ) - images_loaded += 1 - else: - print(f"Image not found at: {image_path}") - images_not_found.append(info["file_name"]) - - print(f"Images loaded: {images_loaded}") - print(f"Images not found: {len(images_not_found)}") - - if images_not_found: - message = f"The following {len(images_not_found)} images were not found in the 'images' directory:\n\n" - message += "\n".join(images_not_found[:10]) - if len(images_not_found) > 10: - message += f"\n... and {len(images_not_found) - 10} more." - message += "\n\nDo you want to proceed and ignore annotations for these missing images?" - reply = QMessageBox.question( - self, - "Missing Images", - message, - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - - if reply == QMessageBox.StandardButton.No: - print("Import cancelled due to missing images") - QMessageBox.information( - self, - "Import Cancelled", - "Import cancelled. Please ensure all images are in the 'images' directory and try again.", - ) - return - - # Update annotations (only for found images) - for image_name, annotations in imported_annotations.items(): - if image_name not in self.image_paths: - continue - self.all_annotations[image_name] = {} - for category_name, category_annotations in annotations.items(): - self.all_annotations[image_name][category_name] = [] - for i, ann in enumerate(category_annotations, start=1): - new_ann = { - "segmentation": ann.get("segmentation"), - "bbox": ann.get("bbox"), - "category_id": ann["category_id"], - "category_name": category_name, - "number": i, - "type": ann.get("type", "polygon"), - } - self.all_annotations[image_name][category_name].append(new_ann) - - # Update class mapping and colors - for annotations in self.all_annotations.values(): - for category_name in annotations.keys(): - if category_name not in self.class_mapping: - new_id = len(self.class_mapping) + 1 - self.class_mapping[category_name] = new_id - self.image_label.class_colors[category_name] = QColor( - Qt.GlobalColor(new_id % 16 + 7) - ) - - print("Updating UI") - # Update UI - self.update_class_list() - self.update_image_list() - self.update_annotation_list() - - # Highlight and display the first image - if self.image_list.count() > 0: - self.image_list.setCurrentRow(0) - self.switch_image(self.image_list.item(0)) - - # Select the first class if available - if self.class_list.count() > 0: - self.class_list.setCurrentRow(0) - self.on_class_selected() - - self.image_label.update() - - message = f"Annotations have been imported successfully from {file_name if import_format == 'COCO JSON' else yaml_file}.\n" - message += f"{images_loaded} images were loaded from the 'images' directory.\n" - if images_not_found: - message += ( - f"Annotations for {len(images_not_found)} missing images were ignored." - ) - - print("Import complete, showing message") - QMessageBox.information(self, "Import Complete", message) - self.auto_save() # Auto-save after importing annotations + return io_controller.import_annotations(self) def export_annotations(self): - if not self.image_label.check_unsaved_changes(): - return - export_format = self.export_format_selector.currentText() - - supported_formats = [ - "COCO JSON", - "YOLO (v4 and earlier)", - "YOLO (v5+)", - "Labeled Images", - "Semantic Labels", - "Pascal VOC (BBox)", - "Pascal VOC (BBox + Segmentation)", - ] - - if export_format not in supported_formats: - QMessageBox.warning( - self, - "Unsupported Format", - f"The selected format '{export_format}' is not implemented.", - ) - return - - if export_format == "COCO JSON": - file_name, _ = QFileDialog.getSaveFileName( - self, "Export COCO JSON Annotations", "", "JSON Files (*.json)" - ) - else: - file_name = QFileDialog.getExistingDirectory( - self, f"Select Output Directory for {export_format} Export" - ) - - if not file_name: - return - - self.save_current_annotations() - - if export_format == "COCO JSON": - output_dir = os.path.dirname(file_name) - json_filename = os.path.basename(file_name) - json_file, images_dir = export_coco_json( - self.all_annotations, - self.class_mapping, - self.image_paths, - self.slices, - self.image_slices, - output_dir, - json_filename, - ) - message = ( - "Annotations have been exported successfully in COCO JSON format.\n" - ) - message += f"JSON file: {json_file}\nImages directory: {images_dir}" - - elif export_format == "YOLO (v4 and earlier)": - labels_dir, yaml_path = export_yolo_v4( - self.all_annotations, - self.class_mapping, - self.image_paths, - self.slices, - self.image_slices, - file_name, - ) - message = "Annotations have been exported successfully in YOLO (v4 and earlier) format.\n" - message += f"Labels: {labels_dir}\nYAML: {yaml_path}" - - elif export_format == "YOLO (v5+)": - output_dir, yaml_path = export_yolo_v5plus( - self.all_annotations, - self.class_mapping, - self.image_paths, - self.slices, - self.image_slices, - file_name, - ) - message = ( - "Annotations have been exported successfully in YOLO (v5+) format.\n" - ) - message += f"Output directory: {output_dir}\nYAML: {yaml_path}" - - elif export_format == "Labeled Images": - labeled_images_dir = export_labeled_images( - self.all_annotations, - self.class_mapping, - self.image_paths, - self.slices, - self.image_slices, - file_name, - ) - message = f"Labeled images have been exported successfully.\nLabeled Images: {labeled_images_dir}\n" - message += f"A class summary has been saved in: {os.path.join(labeled_images_dir, 'class_summary.txt')}" - - elif export_format == "Semantic Labels": - semantic_labels_dir = export_semantic_labels( - self.all_annotations, - self.class_mapping, - self.image_paths, - self.slices, - self.image_slices, - file_name, - ) - message = f"Semantic labels have been exported successfully.\nSemantic Labels: {semantic_labels_dir}\n" - message += f"A class-pixel mapping has been saved in: {os.path.join(semantic_labels_dir, 'class_pixel_mapping.txt')}" - - elif export_format == "Pascal VOC (BBox)": - voc_dir = export_pascal_voc_bbox( - self.all_annotations, - self.class_mapping, - self.image_paths, - self.slices, - self.image_slices, - file_name, - ) - message = "Annotations have been exported successfully in Pascal VOC format (BBox only).\n" - message += f"Pascal VOC Annotations: {voc_dir}" - - elif export_format == "Pascal VOC (BBox + Segmentation)": - voc_dir = export_pascal_voc_both( - self.all_annotations, - self.class_mapping, - self.image_paths, - self.slices, - self.image_slices, - file_name, - ) - message = "Annotations have been exported successfully in Pascal VOC format (BBox + Segmentation).\n" - message += f"Pascal VOC Annotations: {voc_dir}" - - QMessageBox.information(self, "Export Complete", message) + return io_controller.export_annotations(self) def save_slices(self, directory): - slices_saved = False - for image_file, image_slices in self.image_slices.items(): - for slice_name, qimage in image_slices: - if ( - slice_name in self.all_annotations - and self.all_annotations[slice_name] - ): - file_path = os.path.join(directory, f"{slice_name}.png") - qimage.save(file_path, "PNG") - slices_saved = True - - return slices_saved + return io_controller.save_slices(self, directory) def create_coco_annotation(self, ann, image_id, annotation_id): - coco_ann = { - "id": annotation_id, - "image_id": image_id, - "category_id": ann["category_id"], - "area": calculate_area(ann), - "iscrowd": 0, - } - - if "segmentation" in ann: - coco_ann["segmentation"] = [ann["segmentation"]] - coco_ann["bbox"] = calculate_bbox(ann["segmentation"]) - elif "bbox" in ann: - coco_ann["bbox"] = ann["bbox"] - - return coco_ann + return self.annotation_controller.create_coco_annotation(ann, image_id, annotation_id) def update_all_annotation_lists(self): - for image_name in self.all_annotations.keys(): - self.update_annotation_list(image_name) - self.update_annotation_list() # Update for the current image/slice + return self.annotation_controller.update_all_annotation_lists() def update_annotation_list(self, image_name=None): - self.annotation_list.clear() - current_name = image_name or self.current_slice or self.image_file_name - annotations = self.all_annotations.get(current_name, {}) - for class_name, class_annotations in annotations.items(): - if not class_name.startswith( - "Temp-" - ): # Only show non-temporary annotations - color = self.image_label.class_colors.get(class_name, QColor(Qt.GlobalColor.white)) - for annotation in class_annotations: - number = annotation.get("number", 0) - area = calculate_area(annotation) - item_text = f"{class_name} - {number:<3} Area: {area:.2f}" - item = QListWidgetItem(item_text) - item.setData(Qt.ItemDataRole.UserRole, annotation) - item.setForeground(color) - self.annotation_list.addItem(item) - - # Force the annotation list to repaint - self.annotation_list.repaint() + return self.annotation_controller.update_annotation_list(image_name) def update_slice_list_colors(self): - # Set the background color of the entire list widget - if self.dark_mode: - self.slice_list.setStyleSheet( - "QListWidget { background-color: rgb(40, 40, 40); }" - ) - else: - self.slice_list.setStyleSheet( - "QListWidget { background-color: rgb(240, 240, 240); }" - ) - - for i in range(self.slice_list.count()): - item = self.slice_list.item(i) - slice_name = item.text() - - if self.dark_mode: - # Dark mode (annotated colors match add_slice_to_list — - # muted steel-blue, light text; not the prior glaring - # light-blue bg) - if slice_name in self.all_annotations and any( - self.all_annotations[slice_name].values() - ): - item.setForeground(QColor(235, 235, 235)) - item.setBackground(QColor(58, 95, 140)) - else: - item.setForeground(QColor(200, 200, 200)) # Light gray text - item.setBackground(QColor(40, 40, 40)) # Very dark gray background - else: - # Light mode - if slice_name in self.all_annotations and any( - self.all_annotations[slice_name].values() - ): - item.setForeground(QColor(255, 255, 255)) # White text - item.setBackground( - QColor(70, 130, 180) - ) # Medium-dark blue background - else: - item.setForeground(QColor(0, 0, 0)) # Black text - item.setBackground( - QColor(240, 240, 240) - ) # Very light gray background - - # Force the list to repaint - self.slice_list.repaint() + return self.class_controller.update_slice_list_colors() def update_annotation_list_colors(self, class_name=None, color=None): - for i in range(self.annotation_list.count()): - item = self.annotation_list.item(i) - annotation = item.data(Qt.ItemDataRole.UserRole) - # Update only the item for the specific class if class_name is provided - if class_name is None or annotation["category_name"] == class_name: - item_color = ( - color - if class_name - else self.image_label.class_colors.get( - annotation["category_name"], QColor(Qt.GlobalColor.white) - ) - ) - item.setForeground(item_color) + return self.annotation_controller.update_annotation_list_colors(class_name, color) def load_image_annotations(self): - # print(f"Loading annotations for: {self.current_slice or self.image_file_name}") - self.image_label.annotations.clear() - current_name = self.current_slice or self.image_file_name - # print(f"Current name for annotations: {current_name}") - # print(f"All annotations keys: {list(self.all_annotations.keys())}") - if current_name in self.all_annotations: - self.image_label.annotations = copy.deepcopy( - self.all_annotations[current_name] - ) - # print(f"Loaded annotations: {self.image_label.annotations}") - else: - print(f"No annotations found for {current_name}") - self.image_label.update() + return self.annotation_controller.load_image_annotations() def save_current_annotations(self): - if self.current_slice: - current_name = self.current_slice - elif self.image_file_name: - current_name = self.image_file_name - else: - # print("Error: No current slice or image file name set") - return + return self.annotation_controller.save_current_annotations() - # print(f"Saving annotations for: {current_name}") - if self.image_label.annotations: - self.all_annotations[current_name] = self.image_label.annotations.copy() - # print(f"Saved {len(self.image_label.annotations)} annotations for {current_name}") - elif current_name in self.all_annotations: - del self.all_annotations[current_name] - # print(f"Removed annotations for {current_name}") - - self.update_slice_list_colors() - - # print(f"All annotations now: {self.all_annotations.keys()}") - # print(f"Current slice: {self.current_slice}") - # print(f"Current image_file_name: {self.image_file_name}") - - def setup_class_list(self): - """Set up the class list widget.""" - self.class_list = QListWidget() - self.class_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) - self.class_list.customContextMenuRequested.connect(self.show_class_context_menu) - self.class_list.itemClicked.connect(self.on_class_selected) - self.sidebar_layout.addWidget(QLabel("Classes:")) - self.sidebar_layout.addWidget(self.class_list) - - def setup_tool_buttons(self): - """Set up the tool buttons with grouped manual and automated tools.""" - self.tool_group = QButtonGroup(self) - self.tool_group.setExclusive(False) - - # Create a widget for manual tools - manual_tools_widget = QWidget() - manual_layout = QVBoxLayout(manual_tools_widget) - manual_layout.setSpacing(5) - - manual_label = QLabel("Manual Tools") - manual_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - manual_layout.addWidget(manual_label) - - manual_buttons_layout = QHBoxLayout() - self.polygon_button = QPushButton("Polygon") - self.polygon_button.setCheckable(True) - self.rectangle_button = QPushButton("Rectangle") - self.rectangle_button.setCheckable(True) - manual_buttons_layout.addWidget(self.polygon_button) - manual_buttons_layout.addWidget(self.rectangle_button) - manual_layout.addLayout(manual_buttons_layout) - - self.tool_group.addButton(self.polygon_button) - self.tool_group.addButton(self.rectangle_button) - self.polygon_button.clicked.connect(self.toggle_tool) - self.rectangle_button.clicked.connect(self.toggle_tool) - - # Create a widget for automated tools - automated_tools_widget = QWidget() - automated_layout = QVBoxLayout(automated_tools_widget) - automated_layout.setSpacing(5) - - automated_label = QLabel("Automated Tools") - automated_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - automated_layout.addWidget(automated_label) - - automated_buttons_layout = QHBoxLayout() - self.sam_magic_wand_button = QPushButton("Magic Wand") - self.sam_magic_wand_button.setCheckable(True) - automated_buttons_layout.addWidget(self.sam_magic_wand_button) - automated_layout.addLayout(automated_buttons_layout) - - self.tool_group.addButton(self.sam_magic_wand_button) - self.sam_magic_wand_button.clicked.connect(self.toggle_tool) - - # Add the grouped tools to the sidebar layout - self.sidebar_layout.addWidget(manual_tools_widget) - self.sidebar_layout.addWidget(automated_tools_widget) - - # Set a fixed size for all buttons to make them smaller - for button in [ - self.polygon_button, - self.rectangle_button, - self.load_sam2_button, - self.sam_magic_wand_button, - ]: - button.setFixedSize(100, 30) - - def setup_annotation_list(self): - """Set up the annotation list widget.""" - self.annotation_list = QListWidget() - self.annotation_list.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - self.annotation_list.itemSelectionChanged.connect( - self.update_highlighted_annotations - ) - - def create_menu_bar(self): - menu_bar = self.menuBar() - - # Project Menu - project_menu = menu_bar.addMenu("&Project") - - new_project_action = QAction("&New Project", self) - new_project_action.setShortcut(QKeySequence.StandardKey.New) - new_project_action.triggered.connect(self.new_project) - project_menu.addAction(new_project_action) - - open_project_action = QAction("&Open Project", self) - open_project_action.setShortcut(QKeySequence.StandardKey.Open) - open_project_action.triggered.connect(self.open_project) - project_menu.addAction(open_project_action) - - save_project_action = QAction("&Save Project", self) - save_project_action.setShortcut(QKeySequence.StandardKey.Save) - save_project_action.triggered.connect(self.save_project) - project_menu.addAction(save_project_action) - - save_project_as_action = QAction("Save Project &As...", self) - save_project_as_action.setShortcut(QKeySequence("Ctrl+Shift+S")) - save_project_as_action.triggered.connect(self.save_project_as) - project_menu.addAction(save_project_as_action) - - close_project_action = QAction("&Close Project", self) - close_project_action.setShortcut(QKeySequence("Ctrl+W")) - close_project_action.triggered.connect(self.close_project) - project_menu.addAction(close_project_action) - - project_details_action = QAction("Project &Details", self) - project_details_action.setShortcut(QKeySequence("Ctrl+I")) - project_details_action.triggered.connect(self.show_project_details) - project_menu.addAction(project_details_action) - - search_projects_action = QAction("&Search Projects", self) - search_projects_action.setShortcut(QKeySequence("Ctrl+F")) - search_projects_action.triggered.connect(self.show_project_search) - project_menu.addAction(search_projects_action) - - # Settings Menu - settings_menu = menu_bar.addMenu("&Settings") - - font_size_menu = settings_menu.addMenu("&Font Size") - for size in ["Small", "Medium", "Large", "XL", "XXL"]: - action = QAction(size, self) - action.triggered.connect(lambda checked, s=size: self.change_font_size(s)) - font_size_menu.addAction(action) - - toggle_dark_mode_action = QAction("Toggle &Dark Mode", self) - toggle_dark_mode_action.setShortcut(QKeySequence("Ctrl+D")) - toggle_dark_mode_action.triggered.connect(self.toggle_dark_mode) - settings_menu.addAction(toggle_dark_mode_action) - - # Tools Menu - tools_menu = menu_bar.addMenu("&Tools") - - annotation_stats_action = QAction("Annotation Statistics", self) - annotation_stats_action.triggered.connect(self.show_annotation_statistics) - annotation_stats_action.setShortcut(QKeySequence("Ctrl+Alt+S")) - tools_menu.addAction(annotation_stats_action) - - coco_json_combiner_action = QAction("COCO JSON Combiner", self) - coco_json_combiner_action.triggered.connect(self.show_coco_json_combiner) - tools_menu.addAction(coco_json_combiner_action) - - dataset_splitter_action = QAction("Dataset Splitter", self) - dataset_splitter_action.triggered.connect(self.open_dataset_splitter) - tools_menu.addAction(dataset_splitter_action) - - dino_merge_action = QAction("Merge COCO for Training", self) - dino_merge_action.triggered.connect(self.show_dino_merge_dialog) - tools_menu.addAction(dino_merge_action) - - stack_to_slices_action = QAction("Stack to Slices", self) - stack_to_slices_action.triggered.connect(self.show_stack_to_slices) - tools_menu.addAction(stack_to_slices_action) - - image_patcher_action = QAction("Image Patcher", self) - image_patcher_action.triggered.connect(self.show_image_patcher) - tools_menu.addAction(image_patcher_action) - - image_augmenter_action = QAction("Image Augmenter", self) - image_augmenter_action.triggered.connect(self.show_image_augmenter) - tools_menu.addAction(image_augmenter_action) - - slice_registration_action = QAction("Slice Registration", self) - slice_registration_action.triggered.connect(self.show_slice_registration) - tools_menu.addAction(slice_registration_action) - - stack_interpolator_action = QAction("Stack Interpolator", self) - stack_interpolator_action.triggered.connect(self.show_stack_interpolator) - tools_menu.addAction(stack_interpolator_action) - - dicom_converter_action = QAction("DICOM Converter", self) - dicom_converter_action.triggered.connect(self.show_dicom_converter) - tools_menu.addAction(dicom_converter_action) + def change_font_size(self, size): + theme.change_font_size(self, size) - tools_menu.addSeparator() - - unload_models_action = QAction("Unload AI Models (Free GPU Memory)", self) - unload_models_action.triggered.connect(self.unload_ai_models) - tools_menu.addAction(unload_models_action) - - # Help Menu - help_menu = menu_bar.addMenu("&Help") - - help_action = QAction("&Show Help", self) - help_action.setShortcut(QKeySequence.StandardKey.HelpContents) - help_action.triggered.connect(self.show_help) - help_menu.addAction(help_action) + def step_font_size(self, delta): + theme.step_font_pt(self, delta) - def change_font_size(self, size): - self.current_font_size = size - self.apply_theme_and_font() + def reset_font_size(self): + theme.reset_font_pt(self) def unload_ai_models(self): """Drop cached SAM/DINO model objects to free GPU/CPU memory. @@ -2685,1095 +559,87 @@ def unload_ai_models(self): "Re-select a SAM/DINO model to use AI tools again.", ) - def setup_sidebar(self): - self.sidebar = QWidget() - self.sidebar_layout = QVBoxLayout(self.sidebar) - self.layout.addWidget(self.sidebar, 1) - - # Helper function to create section headers - def create_section_header(text): - label = QLabel(text) - label.setProperty("class", "section-header") - label.setAlignment(Qt.AlignmentFlag.AlignLeft) - return label - - # Import functionality - self.import_button = QPushButton("Import Annotations with Images") - self.import_button.clicked.connect(self.import_annotations) - self.sidebar_layout.addWidget(self.import_button) - - self.import_format_selector = QComboBox() - self.import_format_selector.addItem("COCO JSON") - self.import_format_selector.addItem("YOLO (v4 and earlier)") # Modified name - self.import_format_selector.addItem("YOLO (v5+)") # New format - - self.sidebar_layout.addWidget(self.import_format_selector) - - # Add spacing - self.sidebar_layout.addSpacing(20) - - self.add_images_button = QPushButton("Add New Images") - self.add_images_button.clicked.connect(self.add_images) - self.sidebar_layout.addWidget(self.add_images_button) - - self.add_class_button = QPushButton("Add Classes") - self.add_class_button.clicked.connect(lambda: self.add_class()) - self.sidebar_layout.addWidget(self.add_class_button) - - # Class list (without the "Classes" header) - self.class_list = QListWidget() - self.class_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) - self.class_list.customContextMenuRequested.connect(self.show_class_context_menu) - self.class_list.itemClicked.connect(self.on_class_selected) - self.sidebar_layout.addWidget(self.class_list) - - # Annotation section - self.sidebar_layout.addWidget(create_section_header("Annotation")) - annotation_widget = QWidget() - annotation_layout = QVBoxLayout(annotation_widget) - - # Manual tools subsection - manual_widget = QWidget() - manual_layout = QVBoxLayout(manual_widget) - - button_layout_top = QHBoxLayout() - self.polygon_button = QPushButton("Polygon") - self.polygon_button.setCheckable(True) - self.rectangle_button = QPushButton("Rectangle") - self.rectangle_button.setCheckable(True) - button_layout_top.addWidget(self.polygon_button) - button_layout_top.addWidget(self.rectangle_button) - - button_layout_bottom = QHBoxLayout() - self.paint_brush_button = QPushButton("Paint Brush") - self.paint_brush_button.setCheckable(True) - self.eraser_button = QPushButton("Eraser") - self.eraser_button.setCheckable(True) - button_layout_bottom.addWidget(self.paint_brush_button) - button_layout_bottom.addWidget(self.eraser_button) - - manual_layout.addLayout(button_layout_top) - manual_layout.addLayout(button_layout_bottom) - - annotation_layout.addWidget(manual_widget) - - # SAM-Assisted tools subsection - sam_widget = QWidget() - sam_layout = QVBoxLayout(sam_widget) - - # --- Replace the old SAM-Assisted button block with this: --- - sam_buttons_layout = QHBoxLayout() - - self.sam_box_button = QPushButton("SAM-box") - self.sam_box_button.setCheckable(True) - self.sam_box_button.clicked.connect(self.toggle_sam_box) - - self.sam_points_button = QPushButton("SAM-points") - self.sam_points_button.setCheckable(True) - self.sam_points_button.clicked.connect(self.toggle_sam_points) - - sam_buttons_layout.addWidget(self.sam_box_button) - sam_buttons_layout.addWidget(self.sam_points_button) - sam_layout.addLayout(sam_buttons_layout) - # ------------------------------------------------------------ - - # Add SAM model selector - self.sam_model_selector = QComboBox() - self.sam_model_selector.addItem("Pick a SAM Model") - self.sam_model_selector.addItems(list(self.sam_utils.sam_models.keys())) - self.sam_model_selector.currentTextChanged.connect(self.change_sam_model) - sam_layout.addWidget(self.sam_model_selector) - - annotation_layout.addWidget(sam_widget) - - # --- LLM-Assisted Detection (DINO) subsection --- - dino_widget = QWidget() - dino_layout = QVBoxLayout(dino_widget) - - self.dino_model_selector = QComboBox() - self.dino_model_selector.addItem("Pick a DINO Model") - self.dino_model_selector.addItem("grounding-dino-base") - self.dino_model_selector.addItem("grounding-dino-tiny") - self.dino_model_selector.addItem("Custom / fine-tuned (browse)") - self.dino_model_selector.currentTextChanged.connect(self._on_dino_model_changed) - dino_layout.addWidget(self.dino_model_selector) - - # Custom model browse row (hidden by default) - self.dino_browse_row = QWidget() - dino_browse_layout = QHBoxLayout(self.dino_browse_row) - dino_browse_layout.setContentsMargins(0, 0, 0, 0) - self.lbl_dino_custom = QLabel("No path set") - self.lbl_dino_custom.setWordWrap(True) - self.lbl_dino_custom.setStyleSheet("font-size:10px;color:#555;") - btn_dino_browse = QPushButton("Browse") - btn_dino_browse.setFixedWidth(60) - btn_dino_browse.clicked.connect(self.browse_dino_model) - dino_browse_layout.addWidget(self.lbl_dino_custom, 1) - dino_browse_layout.addWidget(btn_dino_browse) - self.dino_browse_row.setVisible(False) - dino_layout.addWidget(self.dino_browse_row) - - self.lbl_dino_status = QLabel("No DINO model loaded") - self.lbl_dino_status.setWordWrap(True) - # No hardcoded background — let the active stylesheet (light or - # dark) provide it via QLabel rules. Hardcoded #f5f5f5 used to - # punch a bright rectangle into the dark sidebar. - self.lbl_dino_status.setStyleSheet( - "font-size:11px;padding:4px;border-radius:3px;" - "border:1px solid palette(mid);") - dino_layout.addWidget(self.lbl_dino_status) - - # Threshold table - self.dino_class_table = ClassThresholdTable() - self.dino_class_table.itemSelectionChanged.connect(self.on_dino_class_row_changed) - dino_layout.addWidget(self.dino_class_table) - - # Phrase editor - self.dino_phrase_panel = PhraseEditorPanel() - dino_layout.addWidget(self.dino_phrase_panel) - - # Detect buttons - det_btn_layout = QHBoxLayout() - self.btn_detect_single = QPushButton("Detect Current Image") - self.btn_detect_single.clicked.connect(self.run_dino_detection_single) - self.btn_detect_single.setEnabled(False) - det_btn_layout.addWidget(self.btn_detect_single) - - self.btn_detect_batch = QPushButton("Detect All Images") - self.btn_detect_batch.clicked.connect(self.run_dino_detection_batch) - self.btn_detect_batch.setEnabled(False) - det_btn_layout.addWidget(self.btn_detect_batch) - dino_layout.addLayout(det_btn_layout) - - # Batch mode - self.dino_batch_mode = QComboBox() - self.dino_batch_mode.addItem("Review before accepting") - self.dino_batch_mode.addItem("Auto-accept all detections") - dino_layout.addWidget(self.dino_batch_mode) - - annotation_layout.addWidget(dino_widget) - # --- END DINO section --- - - # Add tool group - self.tool_group = QButtonGroup(self) - self.tool_group.setExclusive(False) - self.tool_group.addButton(self.polygon_button) - self.tool_group.addButton(self.rectangle_button) - self.tool_group.addButton(self.paint_brush_button) - self.tool_group.addButton(self.eraser_button) - self.tool_group.addButton(self.sam_box_button) - self.tool_group.addButton(self.sam_points_button) - - self.polygon_button.clicked.connect(self.toggle_tool) - self.rectangle_button.clicked.connect(self.toggle_tool) - self.paint_brush_button.clicked.connect(self.toggle_tool) - self.eraser_button.clicked.connect(self.toggle_tool) - self.sam_magic_wand_button.clicked.connect(self.toggle_tool) - - # Annotations list subsection - annotation_layout.addWidget(QLabel("Annotations")) - self.annotation_list = QListWidget() - self.annotation_list.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) - self.annotation_list.itemSelectionChanged.connect( - self.update_highlighted_annotations - ) - annotation_layout.addWidget(self.annotation_list) - - # Create a horizontal layout for the sort buttons - sort_button_layout = QHBoxLayout() - - self.sort_by_class_button = QPushButton("Sort by Class") - self.sort_by_class_button.clicked.connect(self.sort_annotations_by_class) - sort_button_layout.addWidget(self.sort_by_class_button) - - self.sort_by_area_button = QPushButton("Sort by Area") - self.sort_by_area_button.clicked.connect(self.sort_annotations_by_area) - sort_button_layout.addWidget(self.sort_by_area_button) - - # Add the sort button layout to the annotation layout - annotation_layout.addLayout(sort_button_layout) - - # Delete and Merge annotation buttons - self.delete_button = QPushButton("Delete") - self.delete_button.clicked.connect(self.delete_selected_annotations) - self.merge_button = QPushButton("Merge") - self.merge_button.clicked.connect(self.merge_annotations) - self.change_class_button = QPushButton("Change Class") - self.change_class_button.clicked.connect(self.change_annotation_class) - - # Create a horizontal layout for the other buttons - button_layout = QHBoxLayout() - button_layout.addWidget(self.delete_button) - button_layout.addWidget(self.merge_button) - button_layout.addWidget(self.change_class_button) - - # Add the button layout to the annotation layout - annotation_layout.addLayout(button_layout) - - # Add export format selector - self.export_format_selector = QComboBox() - self.export_format_selector.addItem("COCO JSON") - self.export_format_selector.addItem("YOLO (v4 and earlier)") # Modified name - self.export_format_selector.addItem("YOLO (v5+)") # New format - self.export_format_selector.addItem("Labeled Images") - self.export_format_selector.addItem("Semantic Labels") - self.export_format_selector.addItem("Pascal VOC (BBox)") - self.export_format_selector.addItem("Pascal VOC (BBox + Segmentation)") - - annotation_layout.addWidget(QLabel("Export Format:")) - annotation_layout.addWidget(self.export_format_selector) - - self.export_button = QPushButton("Export Annotations") - self.export_button.clicked.connect(self.export_annotations) - annotation_layout.addWidget(self.export_button) - - # Add the annotation widget to the sidebar - self.sidebar_layout.addWidget(annotation_widget) - def toggle_sam_box(self): - if self.sam_box_button.isChecked(): - self.sam_points_button.setChecked(False) - self.image_label.current_tool = "sam_box" - self.image_label.sam_box_active = True - self.image_label.sam_points_active = False - self.image_label.setCursor(Qt.CursorShape.CrossCursor) - else: - self.image_label.current_tool = None - self.image_label.sam_box_active = False - self.image_label.setCursor(Qt.CursorShape.ArrowCursor) - self.update_ui_for_current_tool() + return self.sam_controller.toggle_sam_box() def toggle_sam_points(self): - if self.sam_points_button.isChecked(): - self.sam_box_button.setChecked(False) - self.image_label.current_tool = "sam_points" - self.image_label.sam_points_active = True - self.image_label.sam_box_active = False - self.image_label.setCursor(Qt.CursorShape.CrossCursor) - self.image_label.sam_positive_points = [] - self.image_label.sam_negative_points = [] - else: - self.sam_inference_timer.stop() - self.image_label.current_tool = None - self.image_label.sam_points_active = False - self.image_label.setCursor(Qt.CursorShape.ArrowCursor) - self.image_label.sam_positive_points = [] - self.image_label.sam_negative_points = [] - self.update_ui_for_current_tool() + return self.sam_controller.toggle_sam_points() def sort_annotations_by_class(self): - current_name = self.current_slice or self.image_file_name - if current_name not in self.all_annotations: - QMessageBox.information( - self, - "No Annotations", - "There are no annotations to sort for this image.", - ) - return - - annotations = self.all_annotations[current_name] - sorted_annotations = [] - for class_name in sorted(annotations.keys()): - if not class_name.startswith("Temp-"): # Skip temporary classes - class_annotations = sorted( - annotations[class_name], key=lambda x: x.get("number", 0) - ) - sorted_annotations.extend(class_annotations) - - self.update_annotation_list_with_sorted(sorted_annotations) + return self.annotation_controller.sort_annotations_by_class() def sort_annotations_by_area(self): - current_name = self.current_slice or self.image_file_name - if current_name not in self.all_annotations: - QMessageBox.information( - self, - "No Annotations", - "There are no annotations to sort for this image.", - ) - return - - annotations = self.all_annotations[current_name] - sorted_annotations = [] - for class_name in annotations.keys(): - if not class_name.startswith("Temp-"): # Skip temporary classes - class_annotations = sorted( - annotations[class_name], - key=lambda x: calculate_area(x), - reverse=True, - ) - sorted_annotations.extend(class_annotations) - - self.update_annotation_list_with_sorted(sorted_annotations) + return self.annotation_controller.sort_annotations_by_area() def update_annotation_list_with_sorted(self, sorted_annotations): - self.annotation_list.clear() - for annotation in sorted_annotations: - class_name = annotation["category_name"] - if not class_name.startswith("Temp-"): # Only add non-temporary annotations - number = annotation.get("number", 0) - area = calculate_area(annotation) - item_text = f"{class_name} - {number:<3} Area: {area:.2f}" - item = QListWidgetItem(item_text) - item.setData(Qt.ItemDataRole.UserRole, annotation) - color = self.image_label.class_colors.get(class_name, QColor(Qt.GlobalColor.white)) - item.setForeground(color) - self.annotation_list.addItem(item) - - self.image_label.update() + return self.annotation_controller.update_annotation_list_with_sorted(sorted_annotations) def change_sam_model(self, model_name): - try: - self.sam_utils.change_sam_model(model_name) - except Exception as e: - QMessageBox.critical( - self, - "SAM Model Error", - f"Failed to load SAM model '{model_name}':\n\n{str(e)}\n\n" - "Check that the model weights are downloadable and that torch " - "is correctly installed for your platform / GPU." - ) - self.sam_model_selector.setCurrentIndex(0) - return - - self.current_sam_model = self.sam_utils.current_sam_model - - if model_name != "Pick a SAM Model": - # Enable the SAM Magic Wand button - self.sam_magic_wand_button.setEnabled(True) - - # Activate the SAM Magic Wand tool - self.sam_magic_wand_button.setChecked(True) - self.activate_sam_magic_wand() - - print(f"Changed SAM model to: {model_name}") - else: - # Disable and deactivate the SAM Magic Wand button - self.sam_magic_wand_button.setEnabled(False) - self.sam_magic_wand_button.setChecked(False) - self.deactivate_sam_magic_wand() - print("SAM model unset") + return self.sam_controller.change_sam_model(model_name) # --- DINO / LLM-Assisted Detection Methods --- def _resolve_dino_model_path(self, model_name: str) -> str | None: """Return the canonical local path for a preset DINO model, or None if unknown.""" - from .dino_utils import GDINO_MODEL_PATHS - # GDINO_MODEL_PATHS now returns absolute paths from models_base_dir(). - return GDINO_MODEL_PATHS.get(model_name) + return self.dino_controller._resolve_dino_model_path(model_name) def _on_dino_model_changed(self, text): - """Selection → ready state. Downloads happen lazily on first Detect.""" - self.dino_browse_row.setVisible(text == "Custom / fine-tuned (browse)") + return self.dino_controller._on_dino_model_changed(text) - if text == "Pick a DINO Model": - self.dino_model_loaded = False - self.lbl_dino_status.setText("No DINO model loaded") - self.btn_detect_single.setEnabled(False) - self.btn_detect_batch.setEnabled(False) - return - - if text == "Custom / fine-tuned (browse)": - if self.dino_custom_model_path and os.path.exists(self.dino_custom_model_path): - self.dino_model_loaded = True - self.lbl_dino_status.setText( - f"Ready: {os.path.basename(self.dino_custom_model_path)}" - ) - self.btn_detect_single.setEnabled(True) - self.btn_detect_batch.setEnabled(True) - else: - self.dino_model_loaded = False - self.lbl_dino_status.setText("Browse for a custom model folder") - self.btn_detect_single.setEnabled(False) - self.btn_detect_batch.setEnabled(False) - return - - # Standard preset (grounding-dino-base/tiny) - self.dino_model_loaded = True - self.btn_detect_single.setEnabled(True) - self.btn_detect_batch.setEnabled(True) - model_path = self._resolve_dino_model_path(text) - if model_path and os.path.exists(model_path): - self.lbl_dino_status.setText(f"Ready: {text}") - else: - self.lbl_dino_status.setText(f"{text} — will download on first detection") - - def _ensure_dino_model_downloaded(self, model_name: str) -> bool: - """If the preset model isn't on disk yet, download it. Returns success.""" - if model_name in ("Pick a DINO Model", "Custom / fine-tuned (browse)"): - return True # Custom path is validated elsewhere; no download for it. - model_path = self._resolve_dino_model_path(model_name) - if model_path and os.path.exists(model_path): - return True - - # huggingface_hub is the only way to fetch the weights. Surface the - # actionable install hint if it's missing rather than the generic - # "Could not download" message. - try: - import huggingface_hub # noqa: F401 - except ImportError: - QMessageBox.critical( - self, "Missing Dependency", - f"Cannot download {model_name}: the huggingface_hub package " - "is not installed.\n\nRun:\n pip install huggingface_hub", - ) - return False - - self.lbl_dino_status.setText(f"Downloading {model_name}...") - QApplication.processEvents() - try: - downloaded = self.dino_utils.download_model(model_name) - except Exception as e: - QMessageBox.critical(self, "Download Failed", f"{model_name}:\n{e}") - return False - if not downloaded: - QMessageBox.critical( - self, "Download Failed", - f"Could not download {model_name} from Hugging Face Hub.", - ) - return False - return True + def _ensure_dino_model_downloaded(self, model_name): + return self.dino_controller._ensure_dino_model_downloaded(model_name) def browse_dino_model(self): - path = QFileDialog.getExistingDirectory(self, "Select DINO Model Folder") - if path: - self.dino_custom_model_path = path - self.lbl_dino_custom.setText(os.path.basename(path)) - # Refresh ready state now that a path is set. - self._on_dino_model_changed(self.dino_model_selector.currentText()) + return self.dino_controller.browse_dino_model() def on_dino_class_row_changed(self): - name = self.dino_class_table.selected_class_name() - self.dino_phrase_panel.set_active_class(name) - - def _build_dino_class_configs(self) -> list[dict]: - """Build class_configs from threshold table + phrase panel.""" - configs = [] - for cfg in self.dino_class_table.get_class_configs(): - phrases = self.dino_phrase_panel.get_phrases_for(cfg["name"]) - configs.append({ - "name": cfg["name"], - "phrases": phrases, - "box_thr": cfg["box_thr"], - "txt_thr": cfg["txt_thr"], - "nms_thr": cfg["nms_thr"], - }) - return configs + return self.dino_controller.on_dino_class_row_changed() - def run_dino_detection_single(self): - if not self.dino_model_loaded: - QMessageBox.warning(self, "No DINO Model", - "Please pick a DINO model first.") - return - if not self.sam_utils.current_sam_model: - QMessageBox.warning( - self, "No SAM Model", - "DINO produces bounding boxes; SAM is needed to convert them " - "into segmentation masks. Please pick a SAM model first.", - ) - return - if not self.current_image or self.current_image.isNull(): - QMessageBox.warning(self, "No Image", - "Please load an image first.") - return + def _build_dino_class_configs(self): + return self.dino_controller._build_dino_class_configs() - model_name = self.dino_model_selector.currentText() - class_configs = self._build_dino_class_configs() - if not class_configs: - QMessageBox.warning(self, "No Classes", - "Please add at least one class with phrases.") - return + def run_dino_detection_single(self): + return self.dino_controller.run_dino_detection_single() - self.btn_detect_single.setEnabled(False) - self.btn_detect_batch.setEnabled(False) + def run_dino_detection_batch(self): + return self.dino_controller.run_dino_detection_batch() - if not self._ensure_dino_model_downloaded(model_name): - self.btn_detect_single.setEnabled(True) - self.btn_detect_batch.setEnabled(True) - return - self.lbl_dino_status.setText("Detecting...") - QApplication.processEvents() + def _collect_dino_batch_work_items(self): + return self.dino_controller._collect_dino_batch_work_items() - print(f"[DINO] detect_single: model={model_name!r} class_configs={class_configs}") - try: - results = self.dino_utils.detect( - self.current_image, class_configs, - model_name=model_name, - custom_model_path=self.dino_custom_model_path, - ) - except Exception as e: - traceback.print_exc() - QMessageBox.critical(self, "DINO Error", str(e)) - self.btn_detect_single.setEnabled(True) - self.btn_detect_batch.setEnabled(True) - self.lbl_dino_status.setText("Detection failed.") - return + def _commit_dino_results(self, image_name, dino_results, sam_results): + return self.dino_controller._commit_dino_results(image_name, dino_results, sam_results) - self.btn_detect_single.setEnabled(True) - self.btn_detect_batch.setEnabled(True) + def _store_dino_batch_results(self, image_name, dino_results, sam_results): + return self.dino_controller._store_dino_batch_results(image_name, dino_results, sam_results) - if results is None: - print("[DINO] detect_single: results=None (model resolution failure)") - self.lbl_dino_status.setText("No detections.") - return + def _show_dino_batch_review(self): + return self.dino_controller._show_dino_batch_review() - print(f"[DINO] detect_single: got {len(results)} result(s)") - if results: - for i, r in enumerate(results[:3]): - print(f"[DINO] result[{i}] class={r['class_name']!r} score={r['score']:.3f} bbox={r['bbox']}") + def _navigate_to_image_or_slice(self, name): + return self.dino_controller._navigate_to_image_or_slice(name) - if not results: - self.lbl_dino_status.setText("No detections found.") - return + def _refresh_dino_temp_for_current(self): + return self.dino_controller._refresh_dino_temp_for_current() - self.lbl_dino_status.setText(f"{len(results)} detection(s). Running SAM...") - QApplication.processEvents() + def accept_dino_results(self): + return self.dino_controller.accept_dino_results() - # Batch SAM segmentation. Wrap in try/except for the same reason - # as the DINO call above — sam_utils raises on model load - # failure / CUDA OOM / re-entry now, instead of returning None. - bboxes = [r["bbox"] for r in results] - print(f"[SAM] batch call: {len(bboxes)} bbox(es), first 3 = {bboxes[:3]}") - try: - sam_results = self.sam_utils.apply_sam_predictions_batch( - self.current_image, bboxes - ) - except Exception as e: - traceback.print_exc() - QMessageBox.critical(self, "SAM Error", str(e)) - self.lbl_dino_status.setText("SAM segmentation failed.") - return + def reject_dino_results(self): + return self.dino_controller.reject_dino_results() - if sam_results is None: - print("[SAM] batch returned None (no SAM model loaded)") - QMessageBox.warning(self, "SAM Error", - "Failed to segment detections with SAM.") - self.lbl_dino_status.setText("SAM segmentation failed.") - return + def apply_theme_and_font(self): + theme.apply_theme_and_font(self) - n_errors = sum(1 for s in sam_results if "error" in s) - n_ok = sum(1 for s in sam_results if "segmentation" in s) - print(f"[SAM] batch returned {len(sam_results)} result(s): {n_ok} ok, {n_errors} error(s)") - - # Honor the batch-mode dropdown for the single-image case too: - # "Auto-accept" means commit straight to annotations without - # showing the temp-review overlay. The dropdown name is "batch" - # historically but it controls both paths. - image_name = self.current_slice or self.image_file_name - auto_accept = ( - self.dino_batch_mode.currentText() == "Auto-accept all detections" - ) - if auto_accept: - self._commit_dino_results(image_name, results, sam_results) - n_committed = sum(1 for s in sam_results if "error" not in s) - self.image_label.temp_annotations = [] - self.image_label.update() - self.update_annotation_list() - # Refresh slice list so the freshly-annotated slice picks - # up the highlight color; review-mode's accept_dino_results - # already does this, the auto-accept path didn't. - self.update_slice_list_colors() - self.auto_save() - self.lbl_dino_status.setText( - f"Loaded: {model_name} | {n_committed} mask(s) auto-accepted" - ) - print(f"[DINO] auto-accept: committed {n_committed} mask(s) to {image_name}") - return + def toggle_dark_mode(self): + theme.toggle_dark_mode(self) - # Review mode — build temp annotations and let user accept/reject - temp_annotations = [] - for r, s in zip(results, sam_results): - if "error" in s: - print(f"[SAM] failed for {r['class_name']}: {s['error']}") - continue - temp_annotations.append({ - "segmentation": s["segmentation"], - "category_name": r["class_name"], - "score": r["score"], - "source": "dino", - "temp": True, - }) - - self.image_label.temp_annotations = temp_annotations - # Defer setFocus until after the click event chain settles — - # synchronous setFocus often loses to whatever widget is still - # processing the original click. - QTimer.singleShot(0, self.image_label.setFocus) - self.image_label.update() - self.lbl_dino_status.setText( - f"Loaded: {model_name} | {len(temp_annotations)} mask(s) ready" - ) - print(f"[DINO] detection complete: {len(results)} boxes, {len(temp_annotations)} masks attached to canvas") + def apply_stylesheet(self): + theme.apply_stylesheet(self) - def run_dino_detection_batch(self): - if not self.dino_model_loaded: - QMessageBox.warning(self, "No DINO Model", - "Please pick a DINO model first.") - return - if not self.sam_utils.current_sam_model: - QMessageBox.warning( - self, "No SAM Model", - "DINO produces bounding boxes; SAM is needed to convert them " - "into segmentation masks. Please pick a SAM model first.", - ) - return - if not self.all_images: - QMessageBox.warning(self, "No Images", - "Please load images first.") - return - - model_name = self.dino_model_selector.currentText() - class_configs = self._build_dino_class_configs() - if not class_configs: - QMessageBox.warning(self, "No Classes", - "Please add at least one class with phrases.") - return - - if not self._ensure_dino_model_downloaded(model_name): - return - - auto_accept = self.dino_batch_mode.currentText() == "Auto-accept all detections" - - # Build a flat list of (display_name, qimage) work items covering - # both regular images (loaded from disk) and multi-dim image - # slices (already QImages in memory). Slices live in - # self.image_slices[base_name], indexed by their slice_name - # (e.g. "stack_T1_Z1_C1"). The earlier implementation only - # iterated self.all_images and skipped multi-slice entries with - # a console warning, leaving slice-based projects unable to use - # Detect All. - work_items = self._collect_dino_batch_work_items() - if not work_items: - QMessageBox.information( - self, "Detect All Images", - "No images or slices available to process." - ) - return - total = len(work_items) - - progress = QProgressDialog("Running LLM Detection...", "Cancel", 0, total, self) - progress.setWindowModality(Qt.WindowModality.WindowModal) - progress.setMinimumDuration(0) - - for idx, (image_name, qimage) in enumerate(work_items): - if progress.wasCanceled(): - break - progress.setValue(idx) - QApplication.processEvents() - - try: - results = self.dino_utils.detect( - qimage, class_configs, - model_name=model_name, - custom_model_path=self.dino_custom_model_path, - ) - except Exception as e: - print(f" DINO failed for {image_name}: {e}") - continue - - if not results: - continue - - bboxes = [r["bbox"] for r in results] - try: - sam_results = self.sam_utils.apply_sam_predictions_batch(qimage, bboxes) - except Exception as e: - print(f" SAM failed for {image_name}: {e}") - continue - if sam_results is None: - continue - - if auto_accept: - self._commit_dino_results(image_name, results, sam_results) - else: - # Store for later review - self._store_dino_batch_results(image_name, results, sam_results) - - progress.setValue(total) - progress.close() - - if auto_accept: - QMessageBox.information( - self, "Batch Detection Complete", - "Detections have been saved to annotations." - ) - self.update_annotation_list() - # Multi-dim stacks commonly auto-accept across dozens of - # slices; the slice list must show which ones gained - # annotations or the user can't tell what happened. - self.update_slice_list_colors() - self.auto_save() - else: - self._show_dino_batch_review() - - def _collect_dino_batch_work_items(self): - """Return a flat ``[(name, QImage), …]`` list for batch DINO. - - Regular images are loaded from disk via PIL → QImage. Multi-dim - images contribute one entry per slice from ``self.image_slices``; - slices that haven't been materialised yet (the parent image was - never opened in this session) are skipped with a console log. - """ - from PIL import Image as PILImage - items = [] - for img_info in self.all_images: - file_name = img_info["file_name"] - if img_info.get("is_multi_slice", False): - base_name = os.path.splitext(file_name)[0] - slices = self.image_slices.get(base_name, []) - if not slices: - print(f" Skipping multi-slice image '{file_name}': " - "no slices loaded (open the image first to " - "materialise its slices).") - continue - for slice_name, qimage in slices: - items.append((slice_name, qimage)) - else: - image_path = self.image_paths.get(file_name) - if not image_path or not os.path.exists(image_path): - print(f" Skipping '{file_name}': missing image path.") - continue - try: - pil_img = PILImage.open(image_path).convert("RGB") - qimage = QImage( - pil_img.tobytes(), - pil_img.width, - pil_img.height, - pil_img.width * 3, - QImage.Format.Format_RGB888, - ) - items.append((file_name, qimage)) - except Exception as e: - print(f" Skipping '{file_name}': failed to load ({e}).") - print(f"[DINO] batch work items: {len(items)} total") - return items - - def _commit_dino_results(self, image_name, dino_results, sam_results): - """Commit DINO+SAM results to annotations for a single image. - - If image_name is the currently-displayed image, route through - image_label.annotations so the canvas reflects the change and the - next save_current_annotations() doesn't overwrite the additions. - Otherwise write directly to the project-level cache. - """ - current_image = self.current_slice or self.image_file_name - is_current = image_name == current_image - - if is_current: - target = self.image_label.annotations - else: - if image_name not in self.all_annotations: - self.all_annotations[image_name] = {} - target = self.all_annotations[image_name] - - for r, s in zip(dino_results, sam_results): - if "error" in s: - continue - class_name = r["class_name"] - # DINO only returns labels that came from class_configs (which the - # parent built from the class table), so this should never trigger. - # Skip with a warning rather than auto-creating a class mid-batch - # (which would fan out auto_save() per new class). - if class_name not in self.class_mapping: - print(f" Skipping DINO result for unknown class '{class_name}'") - continue - existing = target.get(class_name, []) - number = max((a.get("number", 0) for a in existing), default=0) + 1 - ann = { - "segmentation": s["segmentation"], - "category_id": self.class_mapping[class_name], - "category_name": class_name, - "score": r["score"], - "source": "dino", - "number": number, - } - target.setdefault(class_name, []).append(ann) - - if is_current: - # Sync image_label.annotations -> all_annotations[current] for save. - self.save_current_annotations() - self.image_label.update() - - def _store_dino_batch_results(self, image_name, dino_results, sam_results): - """Store results for batch review mode.""" - valid = [] - for r, s in zip(dino_results, sam_results): - if "error" not in s: - valid.append({ - "segmentation": s["segmentation"], - "category_name": r["class_name"], - "score": r["score"], - "source": "dino", - "temp": True, - }) - self.dino_batch_results[image_name] = valid - - def _show_dino_batch_review(self): - """Navigate to first image with batch results for review. - - If the next entry refers to an image/slice that's no longer in - the project (e.g. the source was removed between detection and - review), pop the orphan and try the next entry so the user - doesn't get stuck with un-reviewable results. - """ - if not self.dino_batch_results: - QMessageBox.information(self, "Batch Detection", - "No detections found in any image.") - return - # Drain orphans up front. Navigate to the entry: it may be a - # regular image (key in image_list) or a slice (key in some - # image_slices[base_name]). _navigate_to_image_or_slice handles - # both. After the switch, switch_image / switch_slice's tail - # call to _refresh_dino_temp_for_current copies - # dino_batch_results[first] into image_label.temp_annotations - # and defers setFocus on the canvas — nothing to repeat here. - while self.dino_batch_results: - first = next(iter(self.dino_batch_results)) - if self._navigate_to_image_or_slice(first): - return - print(f"[DINO] dropping orphan batch result for {first!r} " - "(no matching image or slice in project)") - self.dino_batch_results.pop(first, None) - # Drained all entries without a single navigable target. - QMessageBox.warning( - self, "Batch Detection", - "Detections were produced but none of them map to an image " - "or slice still in the project. Results discarded.", - ) - - def _navigate_to_image_or_slice(self, name: str) -> bool: - """Switch the UI to a regular image or a slice by name. - - Returns True if a match was found and the switch was issued. - Used by batch-review navigation, which mixes regular image - names and slice names in ``dino_batch_results``. - """ - # Regular image — match in image_list directly - for i in range(self.image_list.count()): - item = self.image_list.item(i) - if item and item.text() == name: - self.image_list.setCurrentRow(i) - self.switch_image(item) - return True - # Slice — find which multi-dim image contains it, switch to - # that parent image first, then activate the specific slice - # via slice_list. - for base_name, slices in self.image_slices.items(): - if not any(s_name == name for s_name, _ in slices): - continue - # Find the parent file in image_list. The file_name in the - # list includes the extension (e.g. "stack.tif") while - # base_name is the stem ("stack"), so match by stripping - # the extension and comparing for equality. - for i in range(self.image_list.count()): - item = self.image_list.item(i) - if not item: - continue - file_name = item.text() - if os.path.splitext(file_name)[0] == base_name: - self.image_list.setCurrentRow(i) - self.switch_image(item) - # switch_image populates slice_list. Now find the slice. - for s_i in range(self.slice_list.count()): - s_item = self.slice_list.item(s_i) - if s_item and s_item.text() == name: - self.slice_list.setCurrentRow(s_i) - self.switch_slice(s_item) - return True - break - return False - return False - - def _refresh_dino_temp_for_current(self): - """Sync ``image_label.temp_annotations`` to whatever the - currently-displayed image/slice has stored in - ``dino_batch_results``. Called from switch_slice / switch_image. - - Why this exists: ``temp_annotations`` is a single field on - ``ImageLabel``, not a per-image cache. Without this sync, masks - from the previously-viewed image bleed onto every slice the - user navigates to. During a batch review the user expects each - image to show its own pending detections; outside batch review, - switching simply discards the pending overlay. - """ - new_image = self.current_slice or self.image_file_name - pending = self.dino_batch_results.get(new_image, []) if new_image else [] - if pending: - # Re-stamp the "temp" flag in case it was stripped by a - # previous accept path; this list also feeds the paintEvent - # which expects dicts with "segmentation" + "category_name". - self.image_label.temp_annotations = list(pending) - self.lbl_dino_status.setText( - f"Review: {new_image} ({len(pending)} detection(s))" - ) - QTimer.singleShot(0, self.image_label.setFocus) - else: - if self.image_label.temp_annotations: - print("[DINO] temp annotations cleared on switch " - f"(no pending batch results for {new_image!r})") - self.image_label.temp_annotations = [] - self.image_label.update() - - def accept_dino_results(self): - """Accept current temp_annotations (called from keyPressEvent).""" - if not self.image_label.temp_annotations: - return - image_name = self.current_slice or self.image_file_name - - for ann in self.image_label.temp_annotations: - class_name = ann["category_name"] - # DINO only returns labels from class_configs (built from the - # class table), so unknown classes should never reach this point. - # Skip with a warning rather than auto-creating mid-accept. - if class_name not in self.class_mapping: - print(f" Skipping DINO result for unknown class '{class_name}'") - continue - new_ann = { - "segmentation": ann["segmentation"], - "category_id": self.class_mapping[class_name], - "category_name": class_name, - "score": ann.get("score", 0.0), - "source": "dino", - } - # Append to the live image_label dict; save_current_annotations() - # below syncs it into self.all_annotations. add_annotation_to_list - # assigns the per-class "number" used for display. - self.image_label.annotations.setdefault(class_name, []).append(new_ann) - self.add_annotation_to_list(new_ann) - - self.image_label.temp_annotations = [] - # Clear batch results if reviewing - self.dino_batch_results.pop(image_name, None) - if self.dino_batch_results: - self._show_dino_batch_review() - self.save_current_annotations() - self.update_slice_list_colors() - self.image_label.update() - self.lbl_dino_status.setText("Results accepted.") - print("DINO results accepted.") - - def reject_dino_results(self): - """Discard current temp_annotations.""" - self.image_label.temp_annotations = [] - image_name = self.current_slice or self.image_file_name - self.dino_batch_results.pop(image_name, None) - if self.dino_batch_results: - self._show_dino_batch_review() - self.image_label.update() - self.lbl_dino_status.setText("Results discarded.") - print("DINO results discarded.") - - # --- END DINO Methods --- - - def setup_font_size_selector(self): - font_size_label = QLabel("Font Size:") - self.font_size_selector = QComboBox() - self.font_size_selector.addItems(["Small", "Medium", "Large"]) - self.font_size_selector.setCurrentText("Medium") - self.font_size_selector.currentTextChanged.connect(self.on_font_size_changed) - - self.sidebar_layout.addWidget(font_size_label) - self.sidebar_layout.addWidget(self.font_size_selector) - - def on_font_size_changed(self, size): - self.current_font_size = size - self.apply_theme_and_font() - - def apply_theme_and_font(self): - font_size = self.font_sizes[self.current_font_size] - if self.dark_mode: - style = soft_dark_stylesheet - else: - style = default_stylesheet - - # Combine the theme stylesheet with font size - combined_style = f"{style}\nQWidget {{ font-size: {font_size}pt; }}" - self.setStyleSheet(combined_style) - - # Apply font size to all widgets - for widget in self.findChildren(QWidget): - font = widget.font() - font.setPointSize(font_size) - widget.setFont(font) - - self.image_label.setFont(QFont("Arial", font_size)) - self.update() - - def toggle_dark_mode(self): - self.dark_mode = not self.dark_mode - self.apply_theme_and_font() - - # Update slice list colors - self.update_slice_list_colors() - - # Update other UI elements if necessary - self.update_class_list() - self.update_annotation_list() - - # Force a repaint of the main window - self.repaint() - - def apply_stylesheet(self): - if self.dark_mode: - self.setStyleSheet(soft_dark_stylesheet) - else: - self.setStyleSheet(default_stylesheet) - - def update_ui_colors(self): - # Update colors for elements that need to retain their functionality - self.update_annotation_list_colors() - self.update_slice_list_colors() - self.image_label.update() - - def setup_image_area(self): - """Set up the main image area.""" - self.image_widget = QWidget() - self.image_layout = QVBoxLayout(self.image_widget) - self.layout.addWidget(self.image_widget, 3) - - self.scroll_area = QScrollArea() - self.scroll_area.setWidgetResizable(True) - self.scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) - self.scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) - - # Use the already initialized image_label - self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.scroll_area.setWidget(self.image_label) - - self.image_layout.addWidget(self.scroll_area) - - self.zoom_slider = QSlider(Qt.Orientation.Horizontal) - self.zoom_slider.setMinimum(10) - self.zoom_slider.setMaximum(500) - self.zoom_slider.setValue(100) - self.zoom_slider.setTickPosition(QSlider.TickPosition.TicksBelow) - self.zoom_slider.setTickInterval(50) - self.zoom_slider.valueChanged.connect(self.zoom_image) - self.image_layout.addWidget(self.zoom_slider) - self.image_info_label = QLabel() - self.image_layout.addWidget(self.image_info_label) - - def setup_image_list(self): - """Set up the image list area.""" - self.image_list_widget = QWidget() - self.image_list_layout = QVBoxLayout(self.image_list_widget) - self.layout.addWidget(self.image_list_widget, 1) - - self.image_list_label = QLabel("Images:") - self.image_list_layout.addWidget(self.image_list_label) - - self.image_list = QListWidget() - self.image_list.itemClicked.connect(self.switch_image) - self.image_list.currentRowChanged.connect( - lambda row: self.switch_image(self.image_list.currentItem()) - ) - self.image_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) - self.image_list.customContextMenuRequested.connect(self.show_image_context_menu) - self.image_list_layout.addWidget(self.image_list) - - self.clear_all_button = QPushButton("Clear All Images and Annotations") - self.clear_all_button.clicked.connect(self.clear_all) - self.image_list_layout.addWidget(self.clear_all_button) + def update_ui_colors(self): + theme.update_ui_colors(self) ########## ### Tools ########## I love useful image processing tools :) def open_dataset_splitter(self): @@ -3830,7 +696,7 @@ def show_dicom_converter(self): # update the show_help method: def show_help(self): self.help_window = HelpWindow( - dark_mode=self.dark_mode, font_size=self.font_sizes[self.current_font_size] + dark_mode=self.dark_mode, font_size=self.ui_font_pt ) self.help_window.show_centered(self) @@ -3901,12 +767,9 @@ def clear_all(self, new_project=False, show_messages=True): self.zoom_slider.setValue(100) # Reset tools - self.image_label.current_tool = None + self.image_label.set_active_tool(None) self.polygon_button.setChecked(False) self.rectangle_button.setChecked(False) - self.sam_magic_wand_button.setChecked(False) - self.sam_magic_wand_button.setEnabled(False) # Disable the SAM-Assisted button - self.image_label.sam_magic_wand_active = False # Deactivate SAM magic wand # Reset SAM-related attributes self.image_label.sam_bbox = None @@ -3980,580 +843,40 @@ def show_image_context_menu(self, position): self.redefine_dimensions(file_name) def is_multi_dimensional(self, file_name): - return file_name.lower().endswith((".tif", ".tiff", ".czi")) + return self.image_controller.is_multi_dimensional(file_name) def predict_single_image(self, file_name): - if self.is_multi_dimensional(file_name): - return # Do nothing for multi-dimensional images - - if not self.yolo_trainer or not self.yolo_trainer.model: - QMessageBox.warning( - self, - "No Model", - "Please load a YOLO model first from the YOLO > Prediction Settings > Load Model menu.", - ) - return - - # Deactivate SAM tool before prediction - self.deactivate_sam_magic_wand() - - image_path = self.image_paths[file_name] - try: - results = self.yolo_trainer.predict(image_path) - self.process_yolo_results(results, file_name) - except Exception as e: - QMessageBox.warning( - self, - "Prediction Error", - f"An error occurred during prediction: {str(e)}\n\n" - "This might be due to a mismatch between the model and the YAML file classes. " - "Please check that the YAML file corresponds to the loaded model.", - ) + return self.yolo_controller.predict_single_image(file_name) def redefine_dimensions(self, file_name): - file_path = self.image_paths.get(file_name) - if not file_path or not file_path.lower().endswith((".tif", ".tiff", ".czi")): - return # Exit the method if it's not a TIFF or CZI file - - reply = QMessageBox.warning( - self, - "Redefine Dimensions", - "Redefining dimensions will cause all associated annotations to be lost. " - "Do you want to continue?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - - if reply == QMessageBox.StandardButton.Yes: - # Remove existing annotations for this file - base_name = os.path.splitext(file_name)[0] - - print(f"Removing annotations for image: {base_name}") - # print(f"Current annotations: {list(self.all_annotations.keys())}") - - # Create a list of keys to remove, using a more specific matching condition - keys_to_remove = [ - key - for key in self.all_annotations.keys() - if key == base_name - or ( - key.startswith(f"{base_name}_") - and not key.startswith(f"{base_name}_8bit") - ) - ] - - print(f"Keys to remove: {keys_to_remove}") - - # Remove the annotations - for key in keys_to_remove: - del self.all_annotations[key] - - # print(f"Annotations after removal: {list(self.all_annotations.keys())}") - - # Remove existing slices - if base_name in self.image_slices: - del self.image_slices[base_name] - - # Clear current image if it's the one being redefined - if self.image_file_name == file_name: - self.current_image = None - self.image_label.clear() - - # Reload the image with new dimension dialog - if file_path.lower().endswith((".tif", ".tiff")): - self.load_tiff(file_path, force_dimension_dialog=True) - elif file_path.lower().endswith(".czi"): - self.load_czi(file_path, force_dimension_dialog=True) - - # Update UI - self.update_slice_list() - self.update_annotation_list() - self.image_label.update() - - # print(f"Final annotations: {list(self.all_annotations.keys())}") - - QMessageBox.information( - self, - "Dimensions Redefined", - "The dimensions have been redefined and the image reloaded. " - "All previous annotations for this image have been removed.", - ) + return self.image_controller.redefine_dimensions(file_name) def remove_image(self): - current_item = self.image_list.currentItem() - if current_item: - file_name = current_item.text() - - # Remove from all data structures - self.image_list.takeItem(self.image_list.row(current_item)) - self.image_paths.pop(file_name, None) - self.all_images = [ - img for img in self.all_images if img["file_name"] != file_name - ] - - # Remove annotations - self.all_annotations.pop(file_name, None) - - # Handle multi-dimensional images - base_name = os.path.splitext(file_name)[0] - if base_name in self.image_slices: - # Remove slices - for slice_name, _ in self.image_slices[base_name]: - self.all_annotations.pop(slice_name, None) - del self.image_slices[base_name] - - # Clear slice list - self.slice_list.clear() - - # Clear current image and slice if it was the removed image - if self.image_file_name == file_name: - self.current_image = None - self.image_file_name = "" - self.current_slice = None - self.image_label.clear() - self.annotation_list.clear() - - # Switch to another image if available - if self.image_list.count() > 0: - next_item = self.image_list.item(0) - self.image_list.setCurrentItem(next_item) - self.switch_image(next_item) - else: - # No images left - self.current_image = None - self.image_file_name = "" - self.current_slice = None - self.image_label.clear() - self.annotation_list.clear() - self.slice_list.clear() - - # Update UI - self.update_ui() - self.auto_save() # Auto-save after removing an image + return self.image_controller.remove_image() def load_annotations(self): - file_name, _ = QFileDialog.getOpenFileName( - self, "Load Annotations", "", "JSON Files (*.json)" - ) - if file_name: - with open(file_name, "r") as f: - self.loaded_json = json.load(f) - - # Load categories - self.class_list.clear() - self.image_label.class_colors.clear() - self.class_mapping.clear() - for category in self.loaded_json["categories"]: - class_name = category["name"] - self.class_mapping[class_name] = category["id"] - - # Assign a color if not already assigned - if class_name not in self.image_label.class_colors: - color = QColor( - Qt.GlobalColor(len(self.image_label.class_colors) % 16 + 7) - ) - self.image_label.class_colors[class_name] = color - - # Add item to class list with color indicator - item = QListWidgetItem(class_name) - self.update_class_item_color( - item, self.image_label.class_colors[class_name] - ) - self.class_list.addItem(item) - - # Create a mapping of image IDs to file names - image_id_to_filename = { - img["id"]: img["file_name"] for img in self.loaded_json["images"] - } - - # Load image information - json_images = {img["file_name"]: img for img in self.loaded_json["images"]} - - # Update existing images and add new ones from JSON - updated_all_images = [] - for i in range(self.image_list.count()): - item = self.image_list.item(i) - file_name = item.text() - if file_name in json_images: - updated_image = self.all_images[i].copy() - updated_image.update(json_images[file_name]) - updated_all_images.append(updated_image) - del json_images[file_name] - else: - updated_all_images.append(self.all_images[i]) - - # Add remaining images from JSON - for img in json_images.values(): - updated_all_images.append(img) - self.image_list.addItem(img["file_name"]) - - self.all_images = updated_all_images - - # Load annotations - self.all_annotations.clear() - for annotation in self.loaded_json["annotations"]: - image_id = annotation["image_id"] - file_name = image_id_to_filename.get(image_id) - if file_name: - if file_name not in self.all_annotations: - self.all_annotations[file_name] = {} - - category = next( - ( - cat - for cat in self.loaded_json["categories"] - if cat["id"] == annotation["category_id"] - ), - None, - ) - if category: - category_name = category["name"] - if category_name not in self.all_annotations[file_name]: - self.all_annotations[file_name][category_name] = [] - - ann = { - "category_id": annotation["category_id"], - "category_name": category_name, - } - - if "segmentation" in annotation: - ann["segmentation"] = annotation["segmentation"][0] - ann["type"] = "polygon" - elif "bbox" in annotation: - ann["bbox"] = annotation["bbox"] - ann["type"] = "bbox" - - # Add number field if it's missing - if "number" not in ann: - ann["number"] = ( - len(self.all_annotations[file_name][category_name]) + 1 - ) - - self.all_annotations[file_name][category_name].append(ann) - - # Check for missing images - missing_images = [ - img["file_name"] - for img in self.loaded_json["images"] - if img["file_name"] not in self.image_paths - ] - if missing_images: - self.show_warning( - "Missing Images", - "The following images are missing:\n" + "\n".join(missing_images), - ) - - # Reload the current image if it exists, otherwise load the first image - if self.image_file_name and self.image_file_name in self.all_annotations: - self.switch_image( - self.image_list.findItems(self.image_file_name, Qt.MatchFlag.MatchExactly)[0] - ) - elif self.all_images: - self.switch_image(self.image_list.item(0)) - - self.image_label.highlighted_annotations = [] # Clear existing highlights - self.update_annotation_list() # This will repopulate the annotation list - self.image_label.update() # Force a redraw of the image label + return self.annotation_controller.load_annotations() def clear_highlighted_annotation(self): - self.image_label.highlighted_annotation = None - self.image_label.update() + return self.annotation_controller.clear_highlighted_annotation() def update_highlighted_annotations(self): - selected_items = self.annotation_list.selectedItems() - self.image_label.highlighted_annotations = [ - item.data(Qt.ItemDataRole.UserRole) for item in selected_items - ] - self.image_label.update() # Force a redraw of the image label - - # Enable/disable merge and change class buttons based on selection - self.merge_button.setEnabled(len(selected_items) >= 2) - self.change_class_button.setEnabled(len(selected_items) > 0) + return self.annotation_controller.update_highlighted_annotations() def renumber_annotations(self): - current_name = self.current_slice or self.image_file_name - if current_name in self.all_annotations: - for class_name, annotations in self.all_annotations[current_name].items(): - for i, ann in enumerate(annotations, start=1): - ann["number"] = i - self.update_annotation_list() + return self.annotation_controller.renumber_annotations() def delete_selected_annotations(self): - selected_items = self.annotation_list.selectedItems() - if not selected_items: - QMessageBox.warning( - self, "No Selection", "Please select an annotation to delete." - ) - return - - reply = QMessageBox.question( - self, - "Delete Annotations", - f"Are you sure you want to delete {len(selected_items)} annotation(s)?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - if reply == QMessageBox.StandardButton.Yes: - # Create a list of annotations to remove - annotations_to_remove = [] - for item in selected_items: - annotation = item.data(Qt.ItemDataRole.UserRole) - annotations_to_remove.append((annotation["category_name"], annotation)) - - # Remove annotations from image_label.annotations - for category_name, annotation in annotations_to_remove: - if category_name in self.image_label.annotations: - if annotation in self.image_label.annotations[category_name]: - self.image_label.annotations[category_name].remove(annotation) - - # Update all_annotations - current_name = self.current_slice or self.image_file_name - self.all_annotations[current_name] = self.image_label.annotations - - # Sort and update the annotation list based on the current sorting method - if self.current_sort_method == "area": - self.sort_annotations_by_area() - else: - self.sort_annotations_by_class() - - self.image_label.highlighted_annotations.clear() - self.image_label.update() - - # Update slice list colors - self.update_slice_list_colors() - - QMessageBox.information( - self, - "Annotations Deleted", - f"{len(selected_items)} annotation(s) have been deleted.", - ) - self.auto_save() # Auto-save after deleting annotations + return self.annotation_controller.delete_selected_annotations() def merge_annotations(self): - if self.image_label.editing_polygon is not None: - QMessageBox.warning( - self, - "Edit Mode Active", - "Please exit the annotation edit mode before merging annotations.", - ) - return - - selected_items = self.annotation_list.selectedItems() - if len(selected_items) < 2: - QMessageBox.warning( - self, - "Not Enough Annotations", - "Please select at least two annotations to merge.", - ) - return - - class_name = selected_items[0].data(Qt.ItemDataRole.UserRole)["category_name"] - if not all( - item.data(Qt.ItemDataRole.UserRole)["category_name"] == class_name - for item in selected_items - ): - QMessageBox.warning( - self, - "Mixed Classes", - "All selected annotations must be from the same class.", - ) - return - - polygons = [] - original_annotations = [] - for item in selected_items: - annotation = item.data(Qt.ItemDataRole.UserRole) - original_annotations.append(annotation) - if "segmentation" in annotation: - points = zip( - annotation["segmentation"][0::2], annotation["segmentation"][1::2] - ) - polygon = Polygon(points) - if not polygon.is_valid: - polygon = polygon.buffer(0) - polygons.append(polygon) - - def are_all_polygons_connected(polygons): - if len(polygons) < 2: - return True - - connected = set([0]) # Start with the first polygon - to_check = set(range(1, len(polygons))) - - while to_check: - newly_connected = set() - for i in connected: - for j in to_check: - if polygons[i].intersects(polygons[j]) or polygons[i].touches( - polygons[j] - ): - newly_connected.add(j) - - if not newly_connected: - return ( - False # If no new connections found, they're not all connected - ) - - connected.update(newly_connected) - to_check -= newly_connected - - return True # All polygons are connected - - if not are_all_polygons_connected(polygons): - QMessageBox.warning( - self, - "Disconnected Polygons", - "Not all selected annotations are connected. Please select only connected annotations to merge.", - ) - return - - try: - merged_polygon = unary_union(polygons) - except Exception as e: - QMessageBox.warning( - self, - "Merge Error", - f"Unable to merge the selected annotations due to an error: {str(e)}", - ) - return - - new_annotation = { - "segmentation": [], - "category_id": self.class_mapping[class_name], - "category_name": class_name, - } - - if isinstance(merged_polygon, Polygon): - new_annotation["segmentation"] = [ - coord for point in merged_polygon.exterior.coords for coord in point - ] - elif isinstance(merged_polygon, MultiPolygon): - largest_polygon = max(merged_polygon.geoms, key=lambda p: p.area) - new_annotation["segmentation"] = [ - coord for point in largest_polygon.exterior.coords for coord in point - ] - - # Ask user about keeping original annotations - msg_box = QMessageBox(self) - msg_box.setWindowTitle("Merge Annotations") - msg_box.setText("Do you want to keep the original annotations?") - msg_box.setIcon(QMessageBox.Icon.Question) - - keep_button = msg_box.addButton("Keep", QMessageBox.ButtonRole.YesRole) - delete_button = msg_box.addButton("Delete", QMessageBox.ButtonRole.NoRole) - cancel_button = msg_box.addButton("Cancel", QMessageBox.ButtonRole.RejectRole) - - msg_box.setDefaultButton(cancel_button) - msg_box.setEscapeButton(cancel_button) - - msg_box.exec() - - if msg_box.clickedButton() == cancel_button: - return - - if msg_box.clickedButton() == delete_button: - for annotation in original_annotations: - if annotation in self.image_label.annotations[class_name]: - self.image_label.annotations[class_name].remove(annotation) - - self.image_label.annotations.setdefault(class_name, []).append(new_annotation) - - current_name = self.current_slice or self.image_file_name - self.all_annotations[current_name] = self.image_label.annotations - - self.renumber_annotations() - self.update_annotation_list() - self.save_current_annotations() - self.update_slice_list_colors() - self.image_label.update() - - QMessageBox.information( - self, "Merge Complete", "Annotations have been merged successfully." - ) - self.auto_save() # Auto-save after merging annotations + return self.annotation_controller.merge_annotations() def delete_selected_image(self): - current_item = self.image_list.currentItem() - if current_item: - file_name = current_item.text() - reply = QMessageBox.question( - self, - "Delete Image", - f"Are you sure you want to delete the image '{file_name}'?\n\n" - "This will remove the image and all its associated annotations.", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - - if reply == QMessageBox.StandardButton.Yes: - # Remove from all data structures - self.image_list.takeItem(self.image_list.row(current_item)) - self.image_paths.pop(file_name, None) - self.all_images = [ - img for img in self.all_images if img["file_name"] != file_name - ] - - # Remove annotations - self.all_annotations.pop(file_name, None) - - # Handle multi-dimensional images - base_name = os.path.splitext(file_name)[0] - if base_name in self.image_slices: - # Remove slices - for slice_name, _ in self.image_slices[base_name]: - self.all_annotations.pop(slice_name, None) - del self.image_slices[base_name] - - # Clear slice list - self.slice_list.clear() - - # Clear current image and slice if it was the removed image - if self.image_file_name == file_name: - self.current_image = None - self.image_file_name = "" - self.current_slice = None - self.image_label.clear() - self.annotation_list.clear() - - # Switch to another image if available - if self.image_list.count() > 0: - next_item = self.image_list.item(0) - self.image_list.setCurrentItem(next_item) - self.switch_image(next_item) - else: - # No images left - self.current_image = None - self.image_file_name = "" - self.current_slice = None - self.image_label.clear() - self.annotation_list.clear() - self.slice_list.clear() - - # Update UI - self.update_ui() - - QMessageBox.information( - self, "Image Deleted", f"The image '{file_name}' has been deleted." - ) + return self.image_controller.delete_selected_image() def display_image(self): - if self.current_image: - if isinstance(self.current_image, QImage): - pixmap = QPixmap.fromImage(self.current_image) - elif isinstance(self.current_image, QPixmap): - pixmap = self.current_image - else: - print(f"Unexpected image type: {type(self.current_image)}") - return - - if not pixmap.isNull(): - self.image_label.setPixmap(pixmap) - self.image_label.adjustSize() - else: - print("Error: Null pixmap") - else: - self.image_label.clear() - print("No current image to display") + return self.image_controller.display_image() def update_ui(self): self.update_image_list() @@ -4564,217 +887,22 @@ def update_ui(self): self.update_image_info() def add_class(self, class_name=None, color=None): - if not self.image_label.check_unsaved_changes(): - return - - if class_name is None: - while True: - class_name, ok = QInputDialog.getText( - self, "Add Class", "Enter class name:" - ) - if not ok: - print("Class addition cancelled") - return - if not class_name.strip(): - QMessageBox.warning( - self, - "Invalid Input", - "Please enter a class name or press Cancel.", - ) - continue - if class_name in self.class_mapping: - QMessageBox.warning( - self, - "Duplicate Class", - f"The class '{class_name}' already exists. Please choose a different name.", - ) - continue - break - else: - # For programmatic addition (e.g., from YOLO predictions) - if class_name in self.class_mapping: - print(f"Class '{class_name}' already exists. Skipping addition.") - return - - if not isinstance(class_name, str): - print( - f"Warning: class_name is not a string. Converting {class_name} to string." - ) - class_name = str(class_name) - - if color is None: - color = QColor(Qt.GlobalColor(len(self.image_label.class_colors) % 16 + 7)) - elif isinstance(color, str): - color = QColor(color) - - print(f"Adding class: {class_name}, color: {color.name()}") - - self.image_label.class_colors[class_name] = color - self.class_mapping[class_name] = len(self.class_mapping) + 1 - - try: - item = QListWidgetItem(class_name) - - # Create a color indicator - pixmap = QPixmap(16, 16) - pixmap.fill(color) - item.setIcon(QIcon(pixmap)) - - # Set visibility state - item.setData(Qt.ItemDataRole.UserRole, True) - - # Set checkbox - item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable) - item.setCheckState(Qt.CheckState.Checked) - - self.class_list.addItem(item) - - self.class_list.setCurrentItem(item) - self.current_class = class_name - print(f"Class added successfully: {class_name}") - - # Sync DINO phrase/threshold state. Select the newly added - # row so the phrase editor below the table reveals itself — - # it hides by default and only becomes visible when a row is - # selected (set_active_class). Skip the row-select during - # project load: classes are added in a loop and we don't want - # N row-selection signals firing during bulk restoration; the - # caller will select an appropriate row after load completes. - row_added = self.dino_class_table.add_class(class_name) - self.dino_phrase_panel.on_class_added(class_name) - if row_added and not self.is_loading_project: - self.dino_class_table.selectRow(self.dino_class_table.rowCount() - 1) - - if not self.is_loading_project: - self.auto_save() - except Exception as e: - print(f"Error adding class: {e}") - traceback.print_exc() + return self.class_controller.add_class(class_name, color) def update_class_item_color(self, item, color): - pixmap = QPixmap(16, 16) - pixmap.fill(color) - item.setIcon(QIcon(pixmap)) + return self.class_controller.update_class_item_color(item, color) def update_class_list(self): - self.class_list.clear() - for class_name, color in self.image_label.class_colors.items(): - item = QListWidgetItem(class_name) - - # Create a color indicator - pixmap = QPixmap(16, 16) - pixmap.fill(color) - item.setIcon(QIcon(pixmap)) - - # Store the visibility state - item.setData( - Qt.ItemDataRole.UserRole, self.image_label.class_visibility.get(class_name, True) - ) - - # Set checkbox - item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable) - item.setCheckState(Qt.CheckState.Checked if item.data(Qt.ItemDataRole.UserRole) else Qt.CheckState.Unchecked) - - self.class_list.addItem(item) - - # Re-select the current class if it exists - if self.current_class: - items = self.class_list.findItems(self.current_class, Qt.MatchFlag.MatchExactly) - if items: - self.class_list.setCurrentItem(items[0]) - elif self.class_list.count() > 0: - # If no class is selected, select the first one - self.class_list.setCurrentItem(self.class_list.item(0)) - - print(f"Updated class list with {self.class_list.count()} items") + return self.class_controller.update_class_list() def update_class_selection(self): - for i in range(self.class_list.count()): - item = self.class_list.item(i) - if item.text() == self.current_class: - item.setSelected(True) - else: - item.setSelected(False) + return self.class_controller.update_class_selection() def toggle_class_visibility(self, item): - class_name = item.text() - is_visible = item.checkState() == Qt.CheckState.Checked - self.image_label.set_class_visibility(class_name, is_visible) - item.setData(Qt.ItemDataRole.UserRole, is_visible) - self.image_label.update() + return self.class_controller.toggle_class_visibility(item) def change_annotation_class(self): - selected_items = self.annotation_list.selectedItems() - if not selected_items: - QMessageBox.warning( - self, - "No Selection", - "Please select one or more annotations to change class.", - ) - return - - class_dialog = QDialog(self) - class_dialog.setWindowTitle("Change Class") - layout = QVBoxLayout(class_dialog) - - class_combo = QComboBox() - for class_name in self.class_mapping.keys(): - class_combo.addItem(class_name) - layout.addWidget(class_combo) - - button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) - button_box.accepted.connect(class_dialog.accept) - button_box.rejected.connect(class_dialog.reject) - layout.addWidget(button_box) - - if class_dialog.exec() == QDialog.DialogCode.Accepted: - new_class = class_combo.currentText() - current_name = self.current_slice or self.image_file_name - - # Get the current maximum number for the new class - max_number = max( - [ - ann.get("number", 0) - for ann in self.image_label.annotations.get(new_class, []) - ] - + [0] - ) - - for item in selected_items: - annotation = item.data(Qt.ItemDataRole.UserRole) - old_class = annotation["category_name"] - - # Remove from old class - self.image_label.annotations[old_class].remove(annotation) - if not self.image_label.annotations[old_class]: - del self.image_label.annotations[old_class] - - # Add to new class with updated number - annotation["category_name"] = new_class - annotation["category_id"] = self.class_mapping[new_class] - max_number += 1 - annotation["number"] = max_number - if new_class not in self.image_label.annotations: - self.image_label.annotations[new_class] = [] - self.image_label.annotations[new_class].append(annotation) - - # Update all_annotations - self.all_annotations[current_name] = self.image_label.annotations - - # Renumber all annotations for consistency - self.renumber_annotations() - - self.update_annotation_list() - self.image_label.update() - self.save_current_annotations() - self.update_slice_list_colors() - self.auto_save() - - QMessageBox.information( - self, - "Class Changed", - f"Selected annotations have been changed to class '{new_class}'.", - ) + return self.annotation_controller.change_annotation_class() def toggle_tool(self): if not self.image_label.check_unsaved_changes(): @@ -4782,7 +910,7 @@ def toggle_tool(self): sender = self.sender() if sender is None: - sender = self.sam_magic_wand_button + return if not self.current_class: QMessageBox.warning( @@ -4804,13 +932,6 @@ def toggle_tool(self): other_buttons = [btn for btn in self.tool_group.buttons() if btn != sender] - # Deactivate SAM if we're switching to a different tool - if ( - sender != self.sam_magic_wand_button - and self.image_label.sam_magic_wand_active - ): - self.deactivate_sam_magic_wand() - if sender.isChecked(): # Uncheck all other buttons for btn in other_buttons: @@ -4818,22 +939,17 @@ def toggle_tool(self): # Set the current tool based on the checked button if sender == self.polygon_button: - self.image_label.current_tool = "polygon" + self.image_label.set_active_tool("polygon") elif sender == self.rectangle_button: - self.image_label.current_tool = "rectangle" - elif sender == self.sam_magic_wand_button: - self.image_label.current_tool = "sam_magic_wand" - self.activate_sam_magic_wand() + self.image_label.set_active_tool("rectangle") elif sender == self.paint_brush_button: - self.image_label.current_tool = "paint_brush" + self.image_label.set_active_tool("paint_brush") self.image_label.setFocus() # Set focus on the image label elif sender == self.eraser_button: - self.image_label.current_tool = "eraser" + self.image_label.set_active_tool("eraser") self.image_label.setFocus() # Set focus on the image label else: - self.image_label.current_tool = None - if sender == self.sam_magic_wand_button: - self.deactivate_sam_magic_wand() + self.image_label.set_active_tool(None) # Update UI based on the current tool self.update_ui_for_current_tool() @@ -4860,12 +976,6 @@ def update_ui_for_current_tool(self): # Update button states self.polygon_button.setChecked(self.image_label.current_tool == "polygon") self.rectangle_button.setChecked(self.image_label.current_tool == "rectangle") - self.sam_magic_wand_button.setChecked( - self.image_label.current_tool == "sam_magic_wand" - ) - - # Enable/disable SAM button based on model availability - self.sam_magic_wand_button.setEnabled(self.current_sam_model is not None) # Disable all tools if no class is selected tools_enabled = ( @@ -4876,320 +986,44 @@ def update_ui_for_current_tool(self): button.setEnabled(tools_enabled) # Update cursor based on the current tool - if ( - self.image_label.current_tool == "sam_magic_wand" - and self.sam_magic_wand_button.isEnabled() - ): + if self.image_label.current_tool in ("sam_box", "sam_points"): self.image_label.setCursor(Qt.CursorShape.CrossCursor) else: self.image_label.setCursor(Qt.CursorShape.ArrowCursor) def on_class_selected(self, current=None, previous=None): - if not self.image_label.check_unsaved_changes(): - return - - if current is None: - current = self.class_list.currentItem() - - if current: - self.current_class = current.text() - print(f"Class selected: {self.current_class}") - - if self.current_class.startswith("Temp-"): - self.disable_annotation_tools() - else: - self.enable_annotation_tools() - else: - self.current_class = None - self.disable_annotation_tools() + return self.class_controller.on_class_selected(current, previous) def disable_annotation_tools(self): for button in self.tool_group.buttons(): button.setChecked(False) button.setEnabled(False) - self.image_label.current_tool = None + self.image_label.set_active_tool(None) def enable_annotation_tools(self): for button in self.tool_group.buttons(): button.setEnabled(True) def show_class_context_menu(self, position): - menu = QMenu() - rename_action = menu.addAction("Rename Class") - change_color_action = menu.addAction("Change Color") - delete_action = menu.addAction("Delete Class") - - item = self.class_list.itemAt(position) - if item: - action = menu.exec(self.class_list.mapToGlobal(position)) - - if action == rename_action: - self.rename_class(item) - elif action == change_color_action: - self.change_class_color(item) - elif action == delete_action: - self.delete_class(item) - else: - QMessageBox.warning( - self, "No Selection", "Please select a class to perform actions." - ) + return self.class_controller.show_class_context_menu(position) def change_class_color(self, item): - class_name = item.text() - current_color = self.image_label.class_colors.get(class_name, QColor(Qt.GlobalColor.white)) - color = QColorDialog.getColor( - current_color, self, f"Select Color for {class_name}" - ) - - if color.isValid(): - self.image_label.class_colors[class_name] = color - - # Update the color indicator - pixmap = QPixmap(16, 16) - pixmap.fill(color) - item.setIcon(QIcon(pixmap)) - - self.update_annotation_list_colors(class_name, color) - self.image_label.update() - self.auto_save() # Auto-save after changing class color + return self.class_controller.change_class_color(item) def rename_class(self, item): - old_name = item.text() - new_name, ok = QInputDialog.getText( - self, "Rename Class", "Enter new class name:", text=old_name - ) - if ok and new_name and new_name != old_name: - # Update class mapping - if old_name in self.class_mapping: - old_id = self.class_mapping[old_name] - self.class_mapping[new_name] = old_id - del self.class_mapping[old_name] - else: - print(f"Warning: Class '{old_name}' not found in class_mapping") - return - - # Update class colors - if old_name in self.image_label.class_colors: - self.image_label.class_colors[new_name] = ( - self.image_label.class_colors.pop(old_name) - ) - else: - print(f"Warning: Class '{old_name}' not found in class_colors") - return - - # Update annotations for all images and slices - for image_name, image_annotations in self.all_annotations.items(): - if old_name in image_annotations: - image_annotations[new_name] = image_annotations.pop(old_name) - for annotation in image_annotations[new_name]: - annotation["category_name"] = new_name - - # Update current image annotations - if old_name in self.image_label.annotations: - self.image_label.annotations[new_name] = ( - self.image_label.annotations.pop(old_name) - ) - for annotation in self.image_label.annotations[new_name]: - annotation["category_name"] = new_name - - # Update current class if it's the renamed one - if self.current_class == old_name: - self.current_class = new_name - - # Update annotation list for all images and slices - self.update_all_annotation_lists() - - # Update class list - item.setText(new_name) - - # Update the image label - self.image_label.update() - self.auto_save() # Auto-save after renaming a class - - print(f"Class renamed from '{old_name}' to '{new_name}'") + return self.class_controller.rename_class(item) def delete_class(self, item=None): - if item is None: - item = self.class_list.currentItem() - - if item is None: - QMessageBox.warning( - self, "No Selection", "Please select a class to delete." - ) - return - - class_name = item.text() - - # Show confirmation dialog - reply = QMessageBox.question( - self, - "Delete Class", - f"Are you sure you want to delete the class '{class_name}'?\n\n" - "This will remove all annotations associated with this class.", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - - if reply == QMessageBox.StandardButton.Yes: - # Proceed with deletion - # Remove class color - self.image_label.class_colors.pop(class_name, None) - - # Remove class from mapping - self.class_mapping.pop(class_name, None) - - # Remove annotations for this class from all images - for image_annotations in self.all_annotations.values(): - image_annotations.pop(class_name, None) - - # Remove annotations for this class from current image - self.image_label.annotations.pop(class_name, None) - - # Sync DINO state - self.dino_class_table.remove_class(class_name) - self.dino_phrase_panel.on_class_removed(class_name) - - # Update annotation list - self.update_annotation_list() - - # Remove class from list - row = self.class_list.row(item) - self.class_list.takeItem(row) - - # Update current_class - if self.current_class == class_name: - self.current_class = None - if self.class_list.count() > 0: - self.class_list.setCurrentRow(0) - self.on_class_selected(self.class_list.item(0)) - else: - self.disable_annotation_tools() - - self.image_label.update() - - # Inform the user - QMessageBox.information( - self, "Class Deleted", f"The class '{class_name}' has been deleted." - ) - self.auto_save() # Auto-save after deleting a class - else: - # User cancelled the operation - QMessageBox.information( - self, "Deletion Cancelled", "The class deletion was cancelled." - ) + return self.class_controller.delete_class(item) def finish_polygon(self): - if ( - self.image_label.current_tool == "polygon" - and len(self.image_label.current_annotation) > 2 - ): - if self.current_class is None: - QMessageBox.warning( - self, - "No Class Selected", - "Please select a class before finishing the annotation.", - ) - return - - # Create a polygon from the current annotation - polygon = Polygon(self.image_label.current_annotation) - - # Define the image boundary as a rectangle - image_boundary = Polygon( - [ - (0, 0), - (self.current_image.width(), 0), - (self.current_image.width(), self.current_image.height()), - (0, self.current_image.height()), - ] - ) - - # Intersect the polygon with the image boundary - clipped_polygon = polygon.intersection(image_boundary) - - if clipped_polygon.is_empty: - QMessageBox.warning( - self, - "Invalid Annotation", - "The annotation is completely outside the image boundaries.", - ) - self.image_label.clear_current_annotation() - self.image_label.update() - return - - # Convert the clipped polygon to a segmentation format - if isinstance(clipped_polygon, Polygon): - segmentation = [ - coord - for point in clipped_polygon.exterior.coords - for coord in point - ] - elif isinstance(clipped_polygon, MultiPolygon): - largest_polygon = max(clipped_polygon.geoms, key=lambda p: p.area) - segmentation = [ - coord - for point in largest_polygon.exterior.coords - for coord in point - ] - else: - QMessageBox.warning( - self, "Invalid Annotation", "The annotation could not be processed." - ) - return - - new_annotation = { - "segmentation": segmentation, - "category_id": self.class_mapping[self.current_class], - "category_name": self.current_class, - } - self.image_label.annotations.setdefault(self.current_class, []).append( - new_annotation - ) - self.add_annotation_to_list(new_annotation) - self.image_label.clear_current_annotation() - self.image_label.drawing_polygon = False # Reset the drawing_polygon flag - self.image_label.reset_annotation_state() - self.image_label.update() - - # Save the current annotations - self.save_current_annotations() - - # Update the slice list colors - self.update_slice_list_colors() - self.auto_save() # Auto-save after adding a polygon annotation - - def highlight_annotation(self, item): - self.image_label.highlighted_annotation = item.data(Qt.ItemDataRole.UserRole) - self.image_label.update() + return self.annotation_controller.finish_polygon() def delete_annotation(self): - current_item = self.annotation_list.currentItem() - if current_item: - annotation = current_item.data(Qt.ItemDataRole.UserRole) - category_name = annotation["category_name"] - self.image_label.annotations[category_name].remove(annotation) - self.annotation_list.takeItem(self.annotation_list.row(current_item)) - self.image_label.highlighted_annotation = None - self.image_label.update() + return self.annotation_controller.delete_annotation() def add_annotation_to_list(self, annotation): - class_name = annotation["category_name"] - color = self.image_label.class_colors.get(class_name, QColor(Qt.GlobalColor.white)) - annotations = self.image_label.annotations.get(class_name, []) - number = max([ann.get("number", 0) for ann in annotations] + [0]) + 1 - annotation["number"] = number - area = calculate_area(annotation) - item_text = f"{class_name} - {number:<3} Area: {area:.2f}" - - item = QListWidgetItem(item_text) - item.setData(Qt.ItemDataRole.UserRole, annotation) - item.setForeground(color) - self.annotation_list.addItem(item) - - # Clear the current selection - self.annotation_list.clearSelection() - self.image_label.highlighted_annotations.clear() - self.image_label.update() + return self.annotation_controller.add_annotation_to_list(annotation) def zoom_in(self): new_zoom = min(self.image_label.zoom_factor + 0.1, 5.0) @@ -5218,683 +1052,84 @@ def enable_tools(self): self.rectangle_button.setEnabled(True) def finish_rectangle(self): - if self.image_label.current_rectangle: - x1, y1, x2, y2 = self.image_label.current_rectangle - - # Create a rectangle polygon from the annotation - rectangle = Polygon([(x1, y1), (x2, y1), (x2, y2), (x1, y2)]) - - # Define the image boundary as a rectangle - image_boundary = Polygon( - [ - (0, 0), - (self.current_image.width(), 0), - (self.current_image.width(), self.current_image.height()), - (0, self.current_image.height()), - ] - ) - - # Intersect the rectangle with the image boundary - clipped_rectangle = rectangle.intersection(image_boundary) - - if clipped_rectangle.is_empty: - QMessageBox.warning( - self, - "Invalid Annotation", - "The annotation is completely outside the image boundaries.", - ) - self.image_label.current_rectangle = None - self.image_label.update() - return - - # Convert the clipped rectangle to a segmentation format - if isinstance(clipped_rectangle, Polygon): - segmentation = [ - coord - for point in clipped_rectangle.exterior.coords - for coord in point - ] - elif isinstance(clipped_rectangle, MultiPolygon): - largest_polygon = max(clipped_rectangle.geoms, key=lambda p: p.area) - segmentation = [ - coord - for point in largest_polygon.exterior.coords - for coord in point - ] - else: - QMessageBox.warning( - self, "Invalid Annotation", "The annotation could not be processed." - ) - return - - new_annotation = { - "segmentation": segmentation, - "category_id": self.class_mapping[self.current_class], - "category_name": self.current_class, - } - self.image_label.annotations.setdefault(self.current_class, []).append( - new_annotation - ) - self.add_annotation_to_list(new_annotation) - self.image_label.start_point = None - self.image_label.end_point = None - self.image_label.current_rectangle = None - self.image_label.update() - - # Save the current annotations - self.save_current_annotations() - - # Update the slice list colors - self.update_slice_list_colors() - self.auto_save() + return self.annotation_controller.finish_rectangle() def enter_edit_mode(self, annotation): - self.editing_mode = True - self.disable_tools() - - QMessageBox.information( - self, - "Edit Mode", - "You are now in edit mode. Click and drag points to move them, Shift+Click to delete points, or click on edges to add new points.", - ) + return self.annotation_controller.enter_edit_mode(annotation) def exit_edit_mode(self): - self.editing_mode = False - self.enable_tools() - - self.image_label.editing_polygon = None - self.image_label.editing_point_index = None - self.image_label.hover_point_index = None - self.update_annotation_list() - self.image_label.update() + return self.annotation_controller.exit_edit_mode() def highlight_annotation_in_list(self, annotation): - for i in range(self.annotation_list.count()): - item = self.annotation_list.item(i) - if item.data(Qt.ItemDataRole.UserRole) == annotation: - self.annotation_list.setCurrentItem(item) - break + return self.annotation_controller.highlight_annotation_in_list(annotation) def select_annotation_in_list(self, annotation): - for i in range(self.annotation_list.count()): - item = self.annotation_list.item(i) - if item.data(Qt.ItemDataRole.UserRole) == annotation: - self.annotation_list.setCurrentItem(item) - break + return self.annotation_controller.select_annotation_in_list(annotation) ################################################################ def setup_yolo_menu(self): - yolo_menu = self.menuBar().addMenu("&YOLO (beta)") - - # Training submenu - training_submenu = yolo_menu.addMenu("Training") - - load_pretrained_action = QAction("Load Pre-trained Model", self) - load_pretrained_action.triggered.connect(self.load_yolo_model) - training_submenu.addAction(load_pretrained_action) - - prepare_data_action = QAction("Prepare YOLO Dataset", self) - prepare_data_action.triggered.connect(self.prepare_yolo_dataset) - training_submenu.addAction(prepare_data_action) - - load_yaml_action = QAction("Load Dataset YAML", self) - load_yaml_action.triggered.connect(self.load_yolo_yaml) - training_submenu.addAction(load_yaml_action) - - train_action = QAction("Train Model", self) - train_action.triggered.connect(self.show_train_dialog) - training_submenu.addAction(train_action) - - save_model_action = QAction("Save Model", self) - save_model_action.triggered.connect(self.save_yolo_model) - training_submenu.addAction(save_model_action) - - # Prediction Settings submenu - prediction_submenu = yolo_menu.addMenu("Prediction Settings") - - load_model_action = QAction("Load Model", self) - load_model_action.triggered.connect(self.load_prediction_model) - prediction_submenu.addAction(load_model_action) - - set_threshold_action = QAction("Set Confidence Threshold", self) - set_threshold_action.triggered.connect(self.set_confidence_threshold) - prediction_submenu.addAction(set_threshold_action) + return self.yolo_controller.setup_yolo_menu() def load_yolo_model(self): - if not hasattr(self, "current_project_dir"): - QMessageBox.warning( - self, "No Project", "Please open or create a project first." - ) - return - - if not self.yolo_trainer: - self.initialize_yolo_trainer() - - if self.yolo_trainer.load_model(): - QMessageBox.information( - self, "Model Loaded", "YOLO model loaded successfully." - ) - else: - QMessageBox.warning(self, "Load Cancelled", "Model loading was cancelled.") + return self.yolo_controller.load_yolo_model() def prepare_yolo_dataset(self): - if not hasattr(self, "current_project_file"): - QMessageBox.warning( - self, "No Project", "Please open or create a project first." - ) - return - - if not self.yolo_trainer: - self.initialize_yolo_trainer() - - try: - yaml_path = self.yolo_trainer.prepare_dataset() - QMessageBox.information( - self, - "Dataset Prepared", - f"YOLO dataset prepared successfully. YAML file: {yaml_path}", - ) - except Exception as e: - QMessageBox.critical( - self, - "Error", - f"An error occurred while preparing the dataset: {str(e)}", - ) + return self.yolo_controller.prepare_yolo_dataset() def load_yolo_yaml(self): - if not hasattr(self, "current_project_file"): - QMessageBox.warning( - self, "No Project", "Please open or create a project first." - ) - return - - if not self.yolo_trainer: - self.initialize_yolo_trainer() - - try: - if self.yolo_trainer.load_yaml(): - QMessageBox.information( - self, "YAML Loaded", "Dataset YAML loaded successfully." - ) - else: - QMessageBox.warning( - self, "Load Cancelled", "YAML loading was cancelled." - ) - except Exception as e: - QMessageBox.critical( - self, - "Error", - f"An error occurred while loading the YAML file: {str(e)}", - ) + return self.yolo_controller.load_yolo_yaml() def save_yolo_model(self): - if not hasattr(self, "current_project_file"): - QMessageBox.warning( - self, "No Project", "Please open or create a project first." - ) - return - - if not self.yolo_trainer or not self.yolo_trainer.model: - QMessageBox.warning( - self, "No Model", "Please train or load a YOLO model first." - ) - return - - try: - if self.yolo_trainer.save_model(): - QMessageBox.information( - self, "Model Saved", "YOLO model saved successfully." - ) - else: - QMessageBox.warning( - self, "Save Cancelled", "Model saving was cancelled." - ) - except Exception as e: - QMessageBox.critical( - self, "Error", f"An error occurred while saving the model: {str(e)}" - ) + return self.yolo_controller.save_yolo_model() def load_prediction_model(self): - if not hasattr(self, "current_project_file"): - QMessageBox.warning( - self, "No Project", "Please open or create a project first." - ) - return - - if not self.yolo_trainer: - self.initialize_yolo_trainer() - - dialog = LoadPredictionModelDialog(self) - if dialog.exec() == QDialog.DialogCode.Accepted: - model_path = dialog.model_path - yaml_path = dialog.yaml_path - if model_path and yaml_path: - try: - result, message = self.yolo_trainer.load_prediction_model( - model_path, yaml_path - ) - if result: - QMessageBox.information( - self, - "Model Loaded", - "YOLO model and YAML file loaded successfully for prediction.", - ) - if message: - QMessageBox.warning(self, "Class Mismatch Warning", message) - else: - QMessageBox.critical( - self, - "Error Loading Model", - f"Could not load the model or YAML file: {message}", - ) - except Exception as e: - QMessageBox.critical(self, "Error", f"An error occurred: {str(e)}") - else: - QMessageBox.warning( - self, - "Files Required", - "Both model and YAML files are required for prediction.", - ) + return self.yolo_controller.load_prediction_model() def show_train_dialog(self): - if not self.yolo_trainer: - QMessageBox.warning( - self, "No Project", "Please open or create a project first." - ) - return - if not self.yolo_trainer.model: - QMessageBox.warning( - self, "No Model", "Please load a pre-trained model first." - ) - return - if not self.yolo_trainer.yaml_path: - QMessageBox.warning( - self, "No Dataset", "Please prepare or load a dataset YAML first." - ) - return - - dialog = QDialog(self) - dialog.setWindowTitle("Train YOLO Model") - layout = QVBoxLayout() - - epochs_label = QLabel("Number of Epochs:") - epochs_input = QLineEdit("100") - layout.addWidget(epochs_label) - layout.addWidget(epochs_input) - - imgsz_label = QLabel("Image Size:") - imgsz_input = QLineEdit("640") - layout.addWidget(imgsz_label) - layout.addWidget(imgsz_input) - - button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) - button_box.accepted.connect(dialog.accept) - button_box.rejected.connect(dialog.reject) - layout.addWidget(button_box) - - dialog.setLayout(layout) - - if dialog.exec() == QDialog.DialogCode.Accepted: - epochs = int(epochs_input.text()) - imgsz = int(imgsz_input.text()) - self.start_training(epochs, imgsz) + return self.yolo_controller.show_train_dialog() def initialize_yolo_trainer(self): - if hasattr(self, "current_project_dir"): - self.yolo_trainer = YOLOTrainer(self.current_project_dir, self) - else: - QMessageBox.warning( - self, "No Project", "Please open or create a project first." - ) + return self.yolo_controller.initialize_yolo_trainer() def start_training(self, epochs, imgsz): - if not hasattr(self, "training_dialog"): - self.training_dialog = TrainingInfoDialog(self) - self.training_dialog.show() - - self.yolo_trainer.progress_signal.connect(self.training_dialog.update_info) - self.yolo_trainer.set_progress_callback(self.training_dialog.update_info) - self.training_dialog.stop_signal.connect(self.yolo_trainer.stop_training_signal) - - self.training_thread = TrainingThread(self.yolo_trainer, epochs, imgsz) - self.training_thread.finished.connect(self.training_finished) - self.training_thread.start() + return self.yolo_controller.start_training(epochs, imgsz) def training_finished(self, results): - self.training_dialog.stop_button.setEnabled(True) - self.training_dialog.stop_button.setText("Stop Training") - self.yolo_trainer.progress_signal.disconnect(self.training_dialog.update_info) - self.training_dialog.stop_signal.disconnect( - self.yolo_trainer.stop_training_signal - ) - - if isinstance(results, str): - QMessageBox.critical( - self, "Training Error", f"An error occurred during training: {results}" - ) - else: - QMessageBox.information( - self, "Training Complete", "YOLO model training completed successfully." - ) + return self.yolo_controller.training_finished(results) def set_confidence_threshold(self): - if not hasattr(self, "current_project_file"): - QMessageBox.warning( - self, "No Project", "Please open or create a project first." - ) - return - - if not self.yolo_trainer: - self.initialize_yolo_trainer() - - current_threshold = self.yolo_trainer.conf_threshold - new_threshold, ok = QInputDialog.getDouble( - self, - "Set Confidence Threshold", - "Enter confidence threshold (0-1):", - current_threshold, - 0, - 1, - 2, - ) - if ok: - self.yolo_trainer.set_conf_threshold(new_threshold) - QMessageBox.information( - self, - "Threshold Updated", - f"Confidence threshold set to {new_threshold}", - ) + return self.yolo_controller.set_confidence_threshold() def show_predict_dialog(self): - if not self.yolo_trainer or not self.yolo_trainer.model: - QMessageBox.warning(self, "No Model", "Please load a YOLO model first.") - return - - dialog = QDialog(self) - dialog.setWindowTitle("Predict with YOLO Model") - layout = QVBoxLayout() - - image_list = QListWidget() - for image_name in self.image_paths.keys(): - image_list.addItem(image_name) - layout.addWidget(QLabel("Select images for prediction:")) - layout.addWidget(image_list) - - conf_label = QLabel("Confidence Threshold:") - conf_input = QDoubleSpinBox() - conf_input.setRange(0, 1) - conf_input.setSingleStep(0.01) - conf_input.setValue(self.yolo_trainer.conf_threshold) - layout.addWidget(conf_label) - layout.addWidget(conf_input) - - button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Cancel) - predict_button = QPushButton("Predict") - button_box.addButton(predict_button, QDialogButtonBox.ButtonRole.AcceptRole) - button_box.accepted.connect(dialog.accept) - button_box.rejected.connect(dialog.reject) - layout.addWidget(button_box) - - dialog.setLayout(layout) - - if dialog.exec() == QDialog.DialogCode.Accepted: - selected_images = [item.text() for item in image_list.selectedItems()] - conf = conf_input.value() - self.yolo_trainer.set_conf_threshold(conf) - self.run_predictions(selected_images) + return self.yolo_controller.show_predict_dialog() def run_predictions(self, selected_images): - for image_name in selected_images: - image_path = self.image_paths[image_name] - results = self.yolo_trainer.predict(image_path) - self.process_yolo_results(results, image_name) + return self.yolo_controller.run_predictions(selected_images) def process_yolo_results(self, results, image_name): - image_path = self.image_paths[image_name] - image = cv2.imread(image_path) - if image is None: - QMessageBox.warning(self, "Error", f"Failed to load image: {image_name}") - return - original_height, original_width = image.shape[:2] - - temp_annotations = {} - - try: - results, input_size, original_size = ( - results # Unpack the results, input size, and original size - ) - input_height, input_width = input_size - orig_height, orig_width = original_size - - scale_x = original_width / orig_width - scale_y = original_height / orig_height - - for result in results: - boxes = result.boxes - masks = result.masks - - if masks is None: - print(f"No masks found for {image_name}") - continue - - for mask, box in zip(masks, boxes): - try: - class_id = int(box.cls) - class_name = self.yolo_trainer.class_names[class_id] - score = float(box.conf) - - mask_array = mask.data.cpu().numpy()[0] - # Resize mask to original image size - mask_array = cv2.resize(mask_array, (orig_width, orig_height)) - contours, _ = cv2.findContours( - (mask_array > 0.5).astype(np.uint8), - cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_SIMPLE, - ) - - if contours: - epsilon = 0.005 * cv2.arcLength(contours[0], True) - approx = cv2.approxPolyDP(contours[0], epsilon, True) - polygon = approx.flatten().tolist() - - # Scale the polygon coordinates - scaled_polygon = [] - for i in range(0, len(polygon), 2): - x = polygon[i] * scale_x - y = polygon[i + 1] * scale_y - scaled_polygon.extend([x, y]) - - temp_class_name = f"Temp-{class_name}" - if temp_class_name not in temp_annotations: - temp_annotations[temp_class_name] = [] - - temp_annotation = { - "segmentation": scaled_polygon, - "category_name": temp_class_name, - "score": score, - "temp": True, - } - temp_annotations[temp_class_name].append(temp_annotation) - except IndexError: - QMessageBox.warning( - self, - "Class Mismatch", - "There is a mismatch between the model and the YAML file classes. " - "Please check that the YAML file corresponds to the loaded model.", - ) - return - - except Exception as e: - QMessageBox.warning( - self, - "Prediction Error", - f"An error occurred during prediction: {str(e)}\n\n" - "This might be due to a mismatch between the model and the YAML file classes. " - "Please check that the YAML file corresponds to the loaded model.", - ) - return - - self.add_temp_classes(temp_annotations) - self.update_class_list() - self.image_label.update() - - if temp_annotations: - total_predictions = sum(len(anns) for anns in temp_annotations.values()) - QMessageBox.information( - self, - "Review Predictions", - f"Found {total_predictions} predictions for {len(temp_annotations)} classes.\n" - "Use class visibility checkboxes to review.\n" - "Press Enter to accept or Esc to reject visible predictions.", - ) - else: - QMessageBox.information( - self, "No Predictions", "No predictions were found for this image." - ) - - # Deactivate SAM tool - self.deactivate_sam_magic_wand() + return self.yolo_controller.process_yolo_results(results, image_name) def add_temp_classes(self, temp_annotations): - for temp_class_name, annotations in temp_annotations.items(): - if temp_class_name not in self.image_label.class_colors: - color = QColor( - Qt.GlobalColor(len(self.image_label.class_colors) % 16 + 7) - ) - self.image_label.class_colors[temp_class_name] = color - self.image_label.annotations[temp_class_name] = annotations - - self.update_class_list() + return self.dino_controller.add_temp_classes(temp_annotations) def verify_current_class(self): - if self.current_class is None or self.current_class not in self.class_mapping: - if self.class_list.count() > 0: - self.class_list.setCurrentRow(0) - self.on_class_selected(self.class_list.item(0)) - else: - self.current_class = None - self.disable_annotation_tools() + return self.dino_controller.verify_current_class() def accept_visible_temp_classes(self): - visible_temp_classes = [ - item.text() - for item in self.class_list.findItems("Temp-*", Qt.MatchFlag.MatchWildcard) - if item.checkState() == Qt.CheckState.Checked - ] - - for temp_class_name in visible_temp_classes: - permanent_class_name = temp_class_name[5:] # Remove "Temp-" prefix - if permanent_class_name not in self.image_label.annotations: - self.add_class( - permanent_class_name, self.image_label.class_colors[temp_class_name] - ) - - # Get the current maximum number for this class - current_max = max( - [ - ann.get("number", 0) - for ann in self.image_label.annotations.get( - permanent_class_name, [] - ) - ] - + [0] - ) - - for annotation in self.image_label.annotations[temp_class_name]: - current_max += 1 - annotation["category_name"] = permanent_class_name - annotation["number"] = current_max - self.image_label.annotations.setdefault( - permanent_class_name, [] - ).append(annotation) - - del self.image_label.annotations[temp_class_name] - del self.image_label.class_colors[temp_class_name] - - self.update_class_list() - current_name = self.current_slice or self.image_file_name - self.all_annotations[current_name] = self.image_label.annotations - self.update_annotation_list() - self.image_label.update() - self.save_current_annotations() - - # Select the first primary class - self.select_first_primary_class() - self.verify_current_class() - - QMessageBox.information( - self, - "Annotations Accepted", - "Temporary annotations have been accepted and added to the permanent classes.", - ) + return self.dino_controller.accept_visible_temp_classes() def select_first_primary_class(self): - for i in range(self.class_list.count()): - item = self.class_list.item(i) - if not item.text().startswith("Temp-"): - self.class_list.setCurrentItem(item) - self.on_class_selected(item) - break + return self.dino_controller.select_first_primary_class() def reject_visible_temp_classes(self): - visible_temp_classes = [ - item.text() - for item in self.class_list.findItems("Temp-*", Qt.MatchFlag.MatchWildcard) - if item.checkState() == Qt.CheckState.Checked - ] - - for temp_class_name in visible_temp_classes: - if temp_class_name in self.image_label.annotations: - del self.image_label.annotations[temp_class_name] - if temp_class_name in self.image_label.class_colors: - del self.image_label.class_colors[temp_class_name] - - self.update_class_list() - self.image_label.update() + return self.dino_controller.reject_visible_temp_classes() def is_class_visible(self, class_name): - items = self.class_list.findItems(class_name, Qt.MatchFlag.MatchExactly) - if items: - return items[0].checkState() == Qt.CheckState.Checked - return False + return self.class_controller.is_class_visible(class_name) def check_temp_annotations(self): - temp_classes = [ - class_name - for class_name in self.image_label.annotations.keys() - if class_name.startswith("Temp-") - ] - if temp_classes: - reply = QMessageBox.question( - self, - "Temporary Annotations", - "There are temporary annotations that will be discarded. Do you want to continue?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, - QMessageBox.StandardButton.No, - ) - if reply == QMessageBox.StandardButton.Yes: - for temp_class in temp_classes: - del self.image_label.annotations[temp_class] - del self.image_label.class_colors[temp_class] - self.update_class_list() - self.update_annotation_list() - return True - return False - return True + return self.dino_controller.check_temp_annotations() def remove_all_temp_annotations(self): - for image_name in list(self.all_annotations.keys()): - for class_name in list(self.all_annotations[image_name].keys()): - if class_name.startswith("Temp-"): - del self.all_annotations[image_name][class_name] - if not self.all_annotations[image_name]: - del self.all_annotations[image_name] - - for class_name in list(self.image_label.class_colors.keys()): - if class_name.startswith("Temp-"): - del self.image_label.class_colors[class_name] - - self.update_class_list() - self.update_annotation_list() - self.image_label.update() + return self.dino_controller.remove_all_temp_annotations() diff --git a/src/digitalsreeni_image_annotator/app_settings.py b/src/digitalsreeni_image_annotator/app_settings.py new file mode 100644 index 0000000..089eeb0 --- /dev/null +++ b/src/digitalsreeni_image_annotator/app_settings.py @@ -0,0 +1,54 @@ +"""App-global UI preferences persisted via QSettings. + +First (and so far only) QSettings usage in the app — see ADR in +docs/09_architecture_decisions.md. UI preferences (font size, dark +mode) are per-user, not per-project, so they live here rather than in +the .iap project file. On Windows this writes to the registry under +HKCU\\Software\\DigitalSreeni\\ImageAnnotator. + +All functions accept an optional QSettings instance so tests can pass +an INI-backed temp file instead of touching the real registry. +""" + +from PyQt6.QtCore import QSettings + +FONT_PT_MIN = 8 +FONT_PT_MAX = 24 +FONT_PT_DEFAULT = 10 + +_KEY_FONT_PT = "ui/font_pt" +_KEY_DARK_MODE = "ui/dark_mode" + + +def clamp_font_pt(pt) -> int: + """Coerce any stored/passed value to a usable point size. + + QSettings round-trips values as strings on some backends, and a + hand-edited registry/INI can contain garbage — fall back to the + default rather than crash at startup. + """ + try: + pt = int(pt) + except (TypeError, ValueError): + return FONT_PT_DEFAULT + return max(FONT_PT_MIN, min(FONT_PT_MAX, pt)) + + +def _settings() -> QSettings: + return QSettings("DigitalSreeni", "ImageAnnotator") + + +def load_ui_prefs(settings=None) -> tuple[int, bool]: + """Return (font_pt, dark_mode), with defaults (10, True).""" + if settings is None: + settings = _settings() + font_pt = clamp_font_pt(settings.value(_KEY_FONT_PT, FONT_PT_DEFAULT)) + dark_mode = settings.value(_KEY_DARK_MODE, True, type=bool) + return font_pt, dark_mode + + +def save_ui_prefs(font_pt, dark_mode, settings=None) -> None: + if settings is None: + settings = _settings() + settings.setValue(_KEY_FONT_PT, clamp_font_pt(font_pt)) + settings.setValue(_KEY_DARK_MODE, bool(dark_mode)) diff --git a/src/digitalsreeni_image_annotator/controllers/__init__.py b/src/digitalsreeni_image_annotator/controllers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/digitalsreeni_image_annotator/controllers/annotation_controller.py b/src/digitalsreeni_image_annotator/controllers/annotation_controller.py new file mode 100644 index 0000000..00e52e5 --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/annotation_controller.py @@ -0,0 +1,829 @@ +"""Annotation CRUD + UI list management controller. + +Extracted from `ImageAnnotator`. Owns the annotation list widget +plumbing, the per-image annotation cache sync (`load_image_annotations` +/ `save_current_annotations`), polygon and rectangle commit paths, +merge/delete/change-class workflows, sort & renumber, the COCO-load +path, and the edit-mode lifecycle. + +This is the cluster `ImageLabel` mutates most directly (via +`main_window.add_annotation_to_list(...)` etc.). Phase 5 keeps the +delegation pattern on `ImageAnnotator`; Phase 6 will replace +`ImageLabel`'s `main_window.*` calls with Qt signals targeting these +controller methods. + +State stays on the main window: +- `all_annotations` (dict[image_name, dict[class_name, list[ann]]]) +- `image_label.annotations` (per-image working copy) +- `editing_mode`, `loaded_json`, `current_sort_method` +- All Qt widgets (`annotation_list`, `merge_button`, `change_class_button`) +""" + +import copy +import json + +from PyQt6.QtCore import Qt, QObject +from PyQt6.QtGui import QColor +from PyQt6.QtWidgets import ( + QComboBox, + QDialog, + QDialogButtonBox, + QFileDialog, + QListWidgetItem, + QMessageBox, + QVBoxLayout, +) +from shapely.geometry import MultiPolygon, Polygon +from shapely.ops import unary_union + +from ..utils import calculate_area, calculate_bbox + + +class AnnotationController(QObject): + def __init__(self, main_window): + super().__init__(main_window) + self.mw = main_window + + # --- COCO conversion helper --- + + def create_coco_annotation(self, ann, image_id, annotation_id): + coco_ann = { + "id": annotation_id, + "image_id": image_id, + "category_id": ann["category_id"], + "area": calculate_area(ann), + "iscrowd": 0, + } + + if "segmentation" in ann: + coco_ann["segmentation"] = [ann["segmentation"]] + coco_ann["bbox"] = calculate_bbox(ann["segmentation"]) + elif "bbox" in ann: + coco_ann["bbox"] = ann["bbox"] + + return coco_ann + + # --- List widget updates --- + + def update_all_annotation_lists(self): + for image_name in self.mw.all_annotations.keys(): + self.update_annotation_list(image_name) + self.update_annotation_list() + + def update_annotation_list(self, image_name=None): + self.mw.annotation_list.clear() + current_name = image_name or self.mw.current_slice or self.mw.image_file_name + annotations = self.mw.all_annotations.get(current_name, {}) + for class_name, class_annotations in annotations.items(): + if not class_name.startswith("Temp-"): + color = self.mw.image_label.class_colors.get( + class_name, QColor(Qt.GlobalColor.white) + ) + for annotation in class_annotations: + number = annotation.get("number", 0) + area = calculate_area(annotation) + item_text = f"{class_name} - {number:<3} Area: {area:.2f}" + item = QListWidgetItem(item_text) + item.setData(Qt.ItemDataRole.UserRole, annotation) + item.setForeground(color) + self.mw.annotation_list.addItem(item) + + self.mw.annotation_list.repaint() + + def update_annotation_list_colors(self, class_name=None, color=None): + for i in range(self.mw.annotation_list.count()): + item = self.mw.annotation_list.item(i) + annotation = item.data(Qt.ItemDataRole.UserRole) + if class_name is None or annotation["category_name"] == class_name: + item_color = ( + color + if class_name + else self.mw.image_label.class_colors.get( + annotation["category_name"], QColor(Qt.GlobalColor.white) + ) + ) + item.setForeground(item_color) + + def update_annotation_list_with_sorted(self, sorted_annotations): + self.mw.annotation_list.clear() + for annotation in sorted_annotations: + class_name = annotation["category_name"] + if not class_name.startswith("Temp-"): + number = annotation.get("number", 0) + area = calculate_area(annotation) + item_text = f"{class_name} - {number:<3} Area: {area:.2f}" + item = QListWidgetItem(item_text) + item.setData(Qt.ItemDataRole.UserRole, annotation) + color = self.mw.image_label.class_colors.get( + class_name, QColor(Qt.GlobalColor.white) + ) + item.setForeground(color) + self.mw.annotation_list.addItem(item) + + self.mw.image_label.update() + + # --- Per-image annotation cache sync --- + + def load_image_annotations(self): + self.mw.image_label.annotations.clear() + current_name = self.mw.current_slice or self.mw.image_file_name + if current_name in self.mw.all_annotations: + self.mw.image_label.annotations = copy.deepcopy( + self.mw.all_annotations[current_name] + ) + else: + print(f"No annotations found for {current_name}") + self.mw.image_label.update() + + def save_current_annotations(self): + if self.mw.current_slice: + current_name = self.mw.current_slice + elif self.mw.image_file_name: + current_name = self.mw.image_file_name + else: + return + + if self.mw.image_label.annotations: + self.mw.all_annotations[current_name] = ( + self.mw.image_label.annotations.copy() + ) + elif current_name in self.mw.all_annotations: + del self.mw.all_annotations[current_name] + + self.mw.update_slice_list_colors() + + def replace_annotations(self, image_key: str, annotations: dict) -> None: + """Replace the full per-class annotation dict for one image. + Used by the eraser path which has already cut polygons in + ImageLabel.annotations. Triggers list refresh, save, and slice + colour update atomically.""" + self.mw.all_annotations[image_key] = annotations + self.update_annotation_list() + self.save_current_annotations() + self.mw.class_controller.update_slice_list_colors() + + # --- Sorting --- + + def sort_annotations_by_class(self): + current_name = self.mw.current_slice or self.mw.image_file_name + if current_name not in self.mw.all_annotations: + QMessageBox.information( + self.mw, + "No Annotations", + "There are no annotations to sort for this image.", + ) + return + + annotations = self.mw.all_annotations[current_name] + sorted_annotations = [] + for class_name in sorted(annotations.keys()): + if not class_name.startswith("Temp-"): + class_annotations = sorted( + annotations[class_name], key=lambda x: x.get("number", 0) + ) + sorted_annotations.extend(class_annotations) + + self.update_annotation_list_with_sorted(sorted_annotations) + + def sort_annotations_by_area(self): + current_name = self.mw.current_slice or self.mw.image_file_name + if current_name not in self.mw.all_annotations: + QMessageBox.information( + self.mw, + "No Annotations", + "There are no annotations to sort for this image.", + ) + return + + annotations = self.mw.all_annotations[current_name] + sorted_annotations = [] + for class_name in annotations.keys(): + if not class_name.startswith("Temp-"): + class_annotations = sorted( + annotations[class_name], + key=lambda x: calculate_area(x), + reverse=True, + ) + sorted_annotations.extend(class_annotations) + + self.update_annotation_list_with_sorted(sorted_annotations) + + # --- COCO JSON load (independent of project save/load) --- + + def load_annotations(self): + file_name, _ = QFileDialog.getOpenFileName( + self.mw, "Load Annotations", "", "JSON Files (*.json)" + ) + if not file_name: + return + + with open(file_name, "r") as f: + self.mw.loaded_json = json.load(f) + + self.mw.class_list.clear() + self.mw.image_label.class_colors.clear() + self.mw.class_mapping.clear() + for category in self.mw.loaded_json["categories"]: + class_name = category["name"] + self.mw.class_mapping[class_name] = category["id"] + + if class_name not in self.mw.image_label.class_colors: + color = QColor( + Qt.GlobalColor(len(self.mw.image_label.class_colors) % 16 + 7) + ) + self.mw.image_label.class_colors[class_name] = color + + item = QListWidgetItem(class_name) + self.mw.update_class_item_color( + item, self.mw.image_label.class_colors[class_name] + ) + self.mw.class_list.addItem(item) + + image_id_to_filename = { + img["id"]: img["file_name"] for img in self.mw.loaded_json["images"] + } + + json_images = {img["file_name"]: img for img in self.mw.loaded_json["images"]} + + updated_all_images = [] + for i in range(self.mw.image_list.count()): + item = self.mw.image_list.item(i) + file_name = item.text() + if file_name in json_images: + updated_image = self.mw.all_images[i].copy() + updated_image.update(json_images[file_name]) + updated_all_images.append(updated_image) + del json_images[file_name] + else: + updated_all_images.append(self.mw.all_images[i]) + + for img in json_images.values(): + updated_all_images.append(img) + self.mw.image_list.addItem(img["file_name"]) + + self.mw.all_images = updated_all_images + + self.mw.all_annotations.clear() + for annotation in self.mw.loaded_json["annotations"]: + image_id = annotation["image_id"] + file_name = image_id_to_filename.get(image_id) + if file_name: + if file_name not in self.mw.all_annotations: + self.mw.all_annotations[file_name] = {} + + category = next( + ( + cat + for cat in self.mw.loaded_json["categories"] + if cat["id"] == annotation["category_id"] + ), + None, + ) + if category: + category_name = category["name"] + if category_name not in self.mw.all_annotations[file_name]: + self.mw.all_annotations[file_name][category_name] = [] + + ann = { + "category_id": annotation["category_id"], + "category_name": category_name, + } + + if "segmentation" in annotation: + ann["segmentation"] = annotation["segmentation"][0] + ann["type"] = "polygon" + elif "bbox" in annotation: + ann["bbox"] = annotation["bbox"] + ann["type"] = "bbox" + + if "number" not in ann: + ann["number"] = ( + len(self.mw.all_annotations[file_name][category_name]) + 1 + ) + + self.mw.all_annotations[file_name][category_name].append(ann) + + missing_images = [ + img["file_name"] + for img in self.mw.loaded_json["images"] + if img["file_name"] not in self.mw.image_paths + ] + if missing_images: + self.mw.show_warning( + "Missing Images", + "The following images are missing:\n" + "\n".join(missing_images), + ) + + if self.mw.image_file_name and self.mw.image_file_name in self.mw.all_annotations: + self.mw.switch_image( + self.mw.image_list.findItems( + self.mw.image_file_name, Qt.MatchFlag.MatchExactly + )[0] + ) + elif self.mw.all_images: + self.mw.switch_image(self.mw.image_list.item(0)) + + self.mw.image_label.highlighted_annotations = [] + self.update_annotation_list() + self.mw.image_label.update() + + # --- Highlighting / selection --- + + def clear_highlighted_annotation(self): + self.mw.image_label.highlighted_annotations.clear() + self.mw.image_label.update() + + def update_highlighted_annotations(self): + selected_items = self.mw.annotation_list.selectedItems() + self.mw.image_label.highlighted_annotations = [ + item.data(Qt.ItemDataRole.UserRole) for item in selected_items + ] + self.mw.image_label.update() + + self.mw.merge_button.setEnabled(len(selected_items) >= 2) + self.mw.change_class_button.setEnabled(len(selected_items) > 0) + + def highlight_annotation_in_list(self, annotation): + for i in range(self.mw.annotation_list.count()): + item = self.mw.annotation_list.item(i) + if item.data(Qt.ItemDataRole.UserRole) == annotation: + self.mw.annotation_list.setCurrentItem(item) + break + + def select_annotation_in_list(self, annotation): + for i in range(self.mw.annotation_list.count()): + item = self.mw.annotation_list.item(i) + if item.data(Qt.ItemDataRole.UserRole) == annotation: + self.mw.annotation_list.setCurrentItem(item) + break + + # --- Annotation numbering --- + + def renumber_annotations(self): + current_name = self.mw.current_slice or self.mw.image_file_name + if current_name in self.mw.all_annotations: + for class_name, annotations in self.mw.all_annotations[ + current_name + ].items(): + for i, ann in enumerate(annotations, start=1): + ann["number"] = i + self.update_annotation_list() + + # --- Delete / merge / change-class --- + + def delete_annotation(self): + current_item = self.mw.annotation_list.currentItem() + if current_item: + annotation = current_item.data(Qt.ItemDataRole.UserRole) + category_name = annotation["category_name"] + self.mw.image_label.annotations[category_name].remove(annotation) + self.mw.annotation_list.takeItem( + self.mw.annotation_list.row(current_item) + ) + self.mw.image_label.highlighted_annotations.clear() + self.mw.image_label.update() + + def delete_selected_annotations(self): + selected_items = self.mw.annotation_list.selectedItems() + if not selected_items: + QMessageBox.warning( + self.mw, "No Selection", "Please select an annotation to delete." + ) + return + + reply = QMessageBox.question( + self.mw, + "Delete Annotations", + f"Are you sure you want to delete {len(selected_items)} annotation(s)?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + annotations_to_remove = [] + for item in selected_items: + annotation = item.data(Qt.ItemDataRole.UserRole) + annotations_to_remove.append((annotation["category_name"], annotation)) + + for category_name, annotation in annotations_to_remove: + if category_name in self.mw.image_label.annotations: + if annotation in self.mw.image_label.annotations[category_name]: + self.mw.image_label.annotations[category_name].remove( + annotation + ) + + current_name = self.mw.current_slice or self.mw.image_file_name + self.mw.all_annotations[current_name] = self.mw.image_label.annotations + + if self.mw.current_sort_method == "area": + self.sort_annotations_by_area() + else: + self.sort_annotations_by_class() + + self.mw.image_label.highlighted_annotations.clear() + self.mw.image_label.update() + + self.mw.update_slice_list_colors() + + QMessageBox.information( + self.mw, + "Annotations Deleted", + f"{len(selected_items)} annotation(s) have been deleted.", + ) + self.mw.auto_save() + + def merge_annotations(self): + if self.mw.image_label.editing_polygon is not None: + QMessageBox.warning( + self.mw, + "Edit Mode Active", + "Please exit the annotation edit mode before merging annotations.", + ) + return + + selected_items = self.mw.annotation_list.selectedItems() + if len(selected_items) < 2: + QMessageBox.warning( + self.mw, + "Not Enough Annotations", + "Please select at least two annotations to merge.", + ) + return + + class_name = selected_items[0].data(Qt.ItemDataRole.UserRole)["category_name"] + if not all( + item.data(Qt.ItemDataRole.UserRole)["category_name"] == class_name + for item in selected_items + ): + QMessageBox.warning( + self.mw, + "Mixed Classes", + "All selected annotations must be from the same class.", + ) + return + + polygons = [] + original_annotations = [] + for item in selected_items: + annotation = item.data(Qt.ItemDataRole.UserRole) + original_annotations.append(annotation) + if "segmentation" in annotation: + points = zip( + annotation["segmentation"][0::2], annotation["segmentation"][1::2] + ) + polygon = Polygon(points) + if not polygon.is_valid: + polygon = polygon.buffer(0) + polygons.append(polygon) + + def are_all_polygons_connected(polygons): + if len(polygons) < 2: + return True + + connected = set([0]) + to_check = set(range(1, len(polygons))) + + while to_check: + newly_connected = set() + for i in connected: + for j in to_check: + if polygons[i].intersects(polygons[j]) or polygons[i].touches( + polygons[j] + ): + newly_connected.add(j) + + if not newly_connected: + return False + + connected.update(newly_connected) + to_check -= newly_connected + + return True + + if not are_all_polygons_connected(polygons): + QMessageBox.warning( + self.mw, + "Disconnected Polygons", + "Not all selected annotations are connected. Please select only connected annotations to merge.", + ) + return + + try: + merged_polygon = unary_union(polygons) + except Exception as e: + QMessageBox.warning( + self.mw, + "Merge Error", + f"Unable to merge the selected annotations due to an error: {str(e)}", + ) + return + + new_annotation = { + "segmentation": [], + "category_id": self.mw.class_mapping[class_name], + "category_name": class_name, + } + + if isinstance(merged_polygon, Polygon): + new_annotation["segmentation"] = [ + coord for point in merged_polygon.exterior.coords for coord in point + ] + elif isinstance(merged_polygon, MultiPolygon): + largest_polygon = max(merged_polygon.geoms, key=lambda p: p.area) + new_annotation["segmentation"] = [ + coord for point in largest_polygon.exterior.coords for coord in point + ] + + msg_box = QMessageBox(self.mw) + msg_box.setWindowTitle("Merge Annotations") + msg_box.setText("Do you want to keep the original annotations?") + msg_box.setIcon(QMessageBox.Icon.Question) + + keep_button = msg_box.addButton("Keep", QMessageBox.ButtonRole.YesRole) + delete_button = msg_box.addButton("Delete", QMessageBox.ButtonRole.NoRole) + cancel_button = msg_box.addButton("Cancel", QMessageBox.ButtonRole.RejectRole) + + msg_box.setDefaultButton(cancel_button) + msg_box.setEscapeButton(cancel_button) + + msg_box.exec() + + if msg_box.clickedButton() == cancel_button: + return + + if msg_box.clickedButton() == delete_button: + for annotation in original_annotations: + if annotation in self.mw.image_label.annotations[class_name]: + self.mw.image_label.annotations[class_name].remove(annotation) + + self.mw.image_label.annotations.setdefault(class_name, []).append(new_annotation) + + current_name = self.mw.current_slice or self.mw.image_file_name + self.mw.all_annotations[current_name] = self.mw.image_label.annotations + + self.renumber_annotations() + self.update_annotation_list() + self.save_current_annotations() + self.mw.update_slice_list_colors() + self.mw.image_label.update() + + QMessageBox.information( + self.mw, "Merge Complete", "Annotations have been merged successfully." + ) + self.mw.auto_save() + + def change_annotation_class(self): + selected_items = self.mw.annotation_list.selectedItems() + if not selected_items: + QMessageBox.warning( + self.mw, + "No Selection", + "Please select one or more annotations to change class.", + ) + return + + class_dialog = QDialog(self.mw) + class_dialog.setWindowTitle("Change Class") + layout = QVBoxLayout(class_dialog) + + class_combo = QComboBox() + for class_name in self.mw.class_mapping.keys(): + class_combo.addItem(class_name) + layout.addWidget(class_combo) + + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) + button_box.accepted.connect(class_dialog.accept) + button_box.rejected.connect(class_dialog.reject) + layout.addWidget(button_box) + + if class_dialog.exec() == QDialog.DialogCode.Accepted: + new_class = class_combo.currentText() + current_name = self.mw.current_slice or self.mw.image_file_name + + max_number = max( + [ + ann.get("number", 0) + for ann in self.mw.image_label.annotations.get(new_class, []) + ] + + [0] + ) + + for item in selected_items: + annotation = item.data(Qt.ItemDataRole.UserRole) + old_class = annotation["category_name"] + + self.mw.image_label.annotations[old_class].remove(annotation) + if not self.mw.image_label.annotations[old_class]: + del self.mw.image_label.annotations[old_class] + + annotation["category_name"] = new_class + annotation["category_id"] = self.mw.class_mapping[new_class] + max_number += 1 + annotation["number"] = max_number + if new_class not in self.mw.image_label.annotations: + self.mw.image_label.annotations[new_class] = [] + self.mw.image_label.annotations[new_class].append(annotation) + + self.mw.all_annotations[current_name] = self.mw.image_label.annotations + + self.renumber_annotations() + + self.update_annotation_list() + self.mw.image_label.update() + self.save_current_annotations() + self.mw.update_slice_list_colors() + self.mw.auto_save() + + QMessageBox.information( + self.mw, + "Class Changed", + f"Selected annotations have been changed to class '{new_class}'.", + ) + + # --- Commit paths for the drawing tools --- + + def finish_polygon(self): + if ( + self.mw.image_label.current_tool == "polygon" + and len(self.mw.image_label.current_annotation) > 2 + ): + if self.mw.current_class is None: + QMessageBox.warning( + self.mw, + "No Class Selected", + "Please select a class before finishing the annotation.", + ) + return + + polygon = Polygon(self.mw.image_label.current_annotation) + + image_boundary = Polygon( + [ + (0, 0), + (self.mw.current_image.width(), 0), + (self.mw.current_image.width(), self.mw.current_image.height()), + (0, self.mw.current_image.height()), + ] + ) + + clipped_polygon = polygon.intersection(image_boundary) + + if clipped_polygon.is_empty: + QMessageBox.warning( + self.mw, + "Invalid Annotation", + "The annotation is completely outside the image boundaries.", + ) + self.mw.image_label.clear_current_annotation() + self.mw.image_label.update() + return + + if isinstance(clipped_polygon, Polygon): + segmentation = [ + coord + for point in clipped_polygon.exterior.coords + for coord in point + ] + elif isinstance(clipped_polygon, MultiPolygon): + largest_polygon = max(clipped_polygon.geoms, key=lambda p: p.area) + segmentation = [ + coord + for point in largest_polygon.exterior.coords + for coord in point + ] + else: + QMessageBox.warning( + self.mw, + "Invalid Annotation", + "The annotation could not be processed.", + ) + return + + new_annotation = { + "segmentation": segmentation, + "category_id": self.mw.class_mapping[self.mw.current_class], + "category_name": self.mw.current_class, + } + self.mw.image_label.annotations.setdefault( + self.mw.current_class, [] + ).append(new_annotation) + self.add_annotation_to_list(new_annotation) + self.mw.image_label.clear_current_annotation() + self.mw.image_label.drawing_polygon = False + self.mw.image_label.reset_annotation_state() + self.mw.image_label.update() + + self.save_current_annotations() + + self.mw.update_slice_list_colors() + self.mw.auto_save() + + def finish_rectangle(self): + if self.mw.image_label.current_rectangle: + x1, y1, x2, y2 = self.mw.image_label.current_rectangle + + rectangle = Polygon([(x1, y1), (x2, y1), (x2, y2), (x1, y2)]) + + image_boundary = Polygon( + [ + (0, 0), + (self.mw.current_image.width(), 0), + (self.mw.current_image.width(), self.mw.current_image.height()), + (0, self.mw.current_image.height()), + ] + ) + + clipped_rectangle = rectangle.intersection(image_boundary) + + if clipped_rectangle.is_empty: + QMessageBox.warning( + self.mw, + "Invalid Annotation", + "The annotation is completely outside the image boundaries.", + ) + self.mw.image_label.current_rectangle = None + self.mw.image_label.update() + return + + if isinstance(clipped_rectangle, Polygon): + segmentation = [ + coord + for point in clipped_rectangle.exterior.coords + for coord in point + ] + elif isinstance(clipped_rectangle, MultiPolygon): + largest_polygon = max(clipped_rectangle.geoms, key=lambda p: p.area) + segmentation = [ + coord + for point in largest_polygon.exterior.coords + for coord in point + ] + else: + QMessageBox.warning( + self.mw, + "Invalid Annotation", + "The annotation could not be processed.", + ) + return + + new_annotation = { + "segmentation": segmentation, + "category_id": self.mw.class_mapping[self.mw.current_class], + "category_name": self.mw.current_class, + } + self.mw.image_label.annotations.setdefault( + self.mw.current_class, [] + ).append(new_annotation) + self.add_annotation_to_list(new_annotation) + self.mw.image_label.start_point = None + self.mw.image_label.end_point = None + self.mw.image_label.current_rectangle = None + self.mw.image_label.update() + + self.save_current_annotations() + + self.mw.update_slice_list_colors() + self.mw.auto_save() + + def add_annotation_to_list(self, annotation): + class_name = annotation["category_name"] + color = self.mw.image_label.class_colors.get( + class_name, QColor(Qt.GlobalColor.white) + ) + annotations = self.mw.image_label.annotations.get(class_name, []) + number = max([ann.get("number", 0) for ann in annotations] + [0]) + 1 + annotation["number"] = number + area = calculate_area(annotation) + item_text = f"{class_name} - {number:<3} Area: {area:.2f}" + + item = QListWidgetItem(item_text) + item.setData(Qt.ItemDataRole.UserRole, annotation) + item.setForeground(color) + self.mw.annotation_list.addItem(item) + + self.mw.annotation_list.clearSelection() + self.mw.image_label.highlighted_annotations.clear() + self.mw.image_label.update() + + # --- Edit mode --- + + def enter_edit_mode(self, annotation): + self.mw.editing_mode = True + self.mw.disable_tools() + + QMessageBox.information( + self.mw, + "Edit Mode", + "You are now in edit mode. Click and drag points to move them, Shift+Click to delete points, or click on edges to add new points.", + ) + + def exit_edit_mode(self): + self.mw.editing_mode = False + self.mw.enable_tools() + + self.mw.image_label.editing_polygon = None + self.mw.image_label.editing_point_index = None + self.mw.image_label.hover_point_index = None + self.update_annotation_list() + self.mw.image_label.update() diff --git a/src/digitalsreeni_image_annotator/controllers/class_controller.py b/src/digitalsreeni_image_annotator/controllers/class_controller.py new file mode 100644 index 0000000..d45b2c4 --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/class_controller.py @@ -0,0 +1,422 @@ +"""Class management controller (add / delete / rename / colour / +visibility) plus the slice-list colouring driven by per-slice +annotations. + +Extracted from `ImageAnnotator`. Owns the class list widget plumbing, +context menu, programmatic and interactive class addition (with DINO +phrase-panel + threshold-table sync), and the slice-list colouring +that highlights annotated slices. + +State stays on the main window (consistent with prior phases): +- `class_mapping` (dict[name, id]) +- `image_label.class_colors`, `image_label.class_visibility` +- `current_class` +- `class_list`, `slice_list` widgets +- DINO widgets (`dino_class_table`, `dino_phrase_panel`) +""" + +import traceback + +from PyQt6.QtCore import Qt, QObject +from PyQt6.QtGui import QColor, QIcon, QPixmap +from PyQt6.QtWidgets import ( + QColorDialog, + QInputDialog, + QListWidgetItem, + QMenu, + QMessageBox, +) + + +class ClassController(QObject): + def __init__(self, main_window): + super().__init__(main_window) + self.mw = main_window + + def select_class(self, index): + if 0 <= index < self.mw.class_list.count(): + item = self.mw.class_list.item(index) + self.mw.class_list.setCurrentItem(item) + self.mw.current_class = item.text() + print(f"Selected class: {self.mw.current_class}") + else: + print("Invalid class index") + + def delete_selected_class(self): + selected_items = self.mw.class_list.selectedItems() + if not selected_items: + QMessageBox.warning( + self.mw, "No Selection", "Please select a class to delete." + ) + return + + class_name = selected_items[0].text() + reply = QMessageBox.question( + self.mw, + "Delete Class", + f"Are you sure you want to delete the class '{class_name}'?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + self.delete_class(class_name) + + def update_slice_list_colors(self): + if self.mw.dark_mode: + self.mw.slice_list.setStyleSheet( + "QListWidget { background-color: rgb(40, 40, 40); }" + ) + else: + self.mw.slice_list.setStyleSheet( + "QListWidget { background-color: rgb(240, 240, 240); }" + ) + + for i in range(self.mw.slice_list.count()): + item = self.mw.slice_list.item(i) + slice_name = item.text() + + if self.mw.dark_mode: + if slice_name in self.mw.all_annotations and any( + self.mw.all_annotations[slice_name].values() + ): + item.setForeground(QColor(235, 235, 235)) + item.setBackground(QColor(58, 95, 140)) + else: + item.setForeground(QColor(200, 200, 200)) + item.setBackground(QColor(40, 40, 40)) + else: + if slice_name in self.mw.all_annotations and any( + self.mw.all_annotations[slice_name].values() + ): + item.setForeground(QColor(255, 255, 255)) + item.setBackground(QColor(70, 130, 180)) + else: + item.setForeground(QColor(0, 0, 0)) + item.setBackground(QColor(240, 240, 240)) + + self.mw.slice_list.repaint() + + def add_class(self, class_name=None, color=None): + if not self.mw.image_label.check_unsaved_changes(): + return + + if class_name is None: + while True: + class_name, ok = QInputDialog.getText( + self.mw, "Add Class", "Enter class name:" + ) + if not ok: + print("Class addition cancelled") + return + if not class_name.strip(): + QMessageBox.warning( + self.mw, + "Invalid Input", + "Please enter a class name or press Cancel.", + ) + continue + if class_name in self.mw.class_mapping: + QMessageBox.warning( + self.mw, + "Duplicate Class", + f"The class '{class_name}' already exists. Please choose a different name.", + ) + continue + break + else: + if class_name in self.mw.class_mapping: + print(f"Class '{class_name}' already exists. Skipping addition.") + return + + if not isinstance(class_name, str): + print( + f"Warning: class_name is not a string. Converting {class_name} to string." + ) + class_name = str(class_name) + + if color is None: + color = QColor( + Qt.GlobalColor(len(self.mw.image_label.class_colors) % 16 + 7) + ) + elif isinstance(color, str): + color = QColor(color) + + print(f"Adding class: {class_name}, color: {color.name()}") + + self.mw.image_label.class_colors[class_name] = color + self.mw.class_mapping[class_name] = len(self.mw.class_mapping) + 1 + + try: + item = QListWidgetItem(class_name) + + pixmap = QPixmap(16, 16) + pixmap.fill(color) + item.setIcon(QIcon(pixmap)) + + item.setData(Qt.ItemDataRole.UserRole, True) + + item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable) + item.setCheckState(Qt.CheckState.Checked) + + self.mw.class_list.addItem(item) + + self.mw.class_list.setCurrentItem(item) + self.mw.current_class = class_name + print(f"Class added successfully: {class_name}") + + # DINO phrase/threshold sync. Skip the row-select during + # project load (classes are added in a loop and we don't + # want N row-selection signals firing during bulk restoration). + row_added = self.mw.dino_class_table.add_class(class_name) + self.mw.dino_phrase_panel.on_class_added(class_name) + if row_added and not self.mw.is_loading_project: + self.mw.dino_class_table.selectRow( + self.mw.dino_class_table.rowCount() - 1 + ) + + if not self.mw.is_loading_project: + self.mw.auto_save() + except Exception as e: + print(f"Error adding class: {e}") + traceback.print_exc() + + def update_class_item_color(self, item, color): + pixmap = QPixmap(16, 16) + pixmap.fill(color) + item.setIcon(QIcon(pixmap)) + + def update_class_list(self): + self.mw.class_list.clear() + for class_name, color in self.mw.image_label.class_colors.items(): + item = QListWidgetItem(class_name) + + pixmap = QPixmap(16, 16) + pixmap.fill(color) + item.setIcon(QIcon(pixmap)) + + item.setData( + Qt.ItemDataRole.UserRole, + self.mw.image_label.class_visibility.get(class_name, True), + ) + + item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable) + item.setCheckState( + Qt.CheckState.Checked + if item.data(Qt.ItemDataRole.UserRole) + else Qt.CheckState.Unchecked + ) + + self.mw.class_list.addItem(item) + + if self.mw.current_class: + items = self.mw.class_list.findItems( + self.mw.current_class, Qt.MatchFlag.MatchExactly + ) + if items: + self.mw.class_list.setCurrentItem(items[0]) + elif self.mw.class_list.count() > 0: + self.mw.class_list.setCurrentItem(self.mw.class_list.item(0)) + + print(f"Updated class list with {self.mw.class_list.count()} items") + + def update_class_selection(self): + for i in range(self.mw.class_list.count()): + item = self.mw.class_list.item(i) + if item.text() == self.mw.current_class: + item.setSelected(True) + else: + item.setSelected(False) + + def toggle_class_visibility(self, item): + class_name = item.text() + is_visible = item.checkState() == Qt.CheckState.Checked + self.mw.image_label.set_class_visibility(class_name, is_visible) + item.setData(Qt.ItemDataRole.UserRole, is_visible) + self.mw.image_label.update() + + def on_class_selected(self, current=None, previous=None): + if not self.mw.image_label.check_unsaved_changes(): + return + + if current is None: + current = self.mw.class_list.currentItem() + + if current: + self.mw.current_class = current.text() + print(f"Class selected: {self.mw.current_class}") + + if self.mw.current_class.startswith("Temp-"): + self.mw.disable_annotation_tools() + else: + self.mw.enable_annotation_tools() + else: + self.mw.current_class = None + self.mw.disable_annotation_tools() + + def show_class_context_menu(self, position): + menu = QMenu() + rename_action = menu.addAction("Rename Class") + change_color_action = menu.addAction("Change Color") + delete_action = menu.addAction("Delete Class") + + item = self.mw.class_list.itemAt(position) + if item: + action = menu.exec(self.mw.class_list.mapToGlobal(position)) + + if action == rename_action: + self.rename_class(item) + elif action == change_color_action: + self.change_class_color(item) + elif action == delete_action: + self.delete_class(item) + else: + QMessageBox.warning( + self.mw, + "No Selection", + "Please select a class to perform actions.", + ) + + def change_class_color(self, item): + class_name = item.text() + current_color = self.mw.image_label.class_colors.get( + class_name, QColor(Qt.GlobalColor.white) + ) + color = QColorDialog.getColor( + current_color, self.mw, f"Select Color for {class_name}" + ) + + if color.isValid(): + self.mw.image_label.class_colors[class_name] = color + + pixmap = QPixmap(16, 16) + pixmap.fill(color) + item.setIcon(QIcon(pixmap)) + + self.mw.update_annotation_list_colors(class_name, color) + self.mw.image_label.update() + self.mw.auto_save() + + def rename_class(self, item): + old_name = item.text() + new_name, ok = QInputDialog.getText( + self.mw, "Rename Class", "Enter new class name:", text=old_name + ) + if ok and new_name and new_name != old_name: + if old_name in self.mw.class_mapping: + old_id = self.mw.class_mapping[old_name] + self.mw.class_mapping[new_name] = old_id + del self.mw.class_mapping[old_name] + else: + print(f"Warning: Class '{old_name}' not found in class_mapping") + return + + if old_name in self.mw.image_label.class_colors: + self.mw.image_label.class_colors[new_name] = ( + self.mw.image_label.class_colors.pop(old_name) + ) + else: + print(f"Warning: Class '{old_name}' not found in class_colors") + return + + for image_name, image_annotations in self.mw.all_annotations.items(): + if old_name in image_annotations: + image_annotations[new_name] = image_annotations.pop(old_name) + for annotation in image_annotations[new_name]: + annotation["category_name"] = new_name + + if old_name in self.mw.image_label.annotations: + self.mw.image_label.annotations[new_name] = ( + self.mw.image_label.annotations.pop(old_name) + ) + for annotation in self.mw.image_label.annotations[new_name]: + annotation["category_name"] = new_name + + if self.mw.current_class == old_name: + self.mw.current_class = new_name + + self.mw.update_all_annotation_lists() + + item.setText(new_name) + + self.mw.image_label.update() + self.mw.auto_save() + + print(f"Class renamed from '{old_name}' to '{new_name}'") + + def delete_class(self, item=None): + if item is None: + item = self.mw.class_list.currentItem() + + if item is None: + QMessageBox.warning( + self.mw, "No Selection", "Please select a class to delete." + ) + return + + # delete_selected_class calls self.delete_class(class_name) with a + # string instead of a QListWidgetItem — handle both. The + # show_class_context_menu / Delete key path passes a QListWidgetItem, + # while delete_selected_class passes the class name string. + if isinstance(item, str): + class_name = item + row_items = self.mw.class_list.findItems(class_name, Qt.MatchFlag.MatchExactly) + list_item = row_items[0] if row_items else None + else: + class_name = item.text() + list_item = item + + reply = QMessageBox.question( + self.mw, + "Delete Class", + f"Are you sure you want to delete the class '{class_name}'?\n\n" + "This will remove all annotations associated with this class.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + + if reply == QMessageBox.StandardButton.Yes: + self.mw.image_label.class_colors.pop(class_name, None) + self.mw.class_mapping.pop(class_name, None) + + for image_annotations in self.mw.all_annotations.values(): + image_annotations.pop(class_name, None) + + self.mw.image_label.annotations.pop(class_name, None) + + self.mw.dino_class_table.remove_class(class_name) + self.mw.dino_phrase_panel.on_class_removed(class_name) + + self.mw.update_annotation_list() + + if list_item is not None: + row = self.mw.class_list.row(list_item) + self.mw.class_list.takeItem(row) + + if self.mw.current_class == class_name: + self.mw.current_class = None + if self.mw.class_list.count() > 0: + self.mw.class_list.setCurrentRow(0) + self.on_class_selected(self.mw.class_list.item(0)) + else: + self.mw.disable_annotation_tools() + + self.mw.image_label.update() + + QMessageBox.information( + self.mw, + "Class Deleted", + f"The class '{class_name}' has been deleted.", + ) + self.mw.auto_save() + else: + QMessageBox.information( + self.mw, + "Deletion Cancelled", + "The class deletion was cancelled.", + ) + + def is_class_visible(self, class_name): + items = self.mw.class_list.findItems(class_name, Qt.MatchFlag.MatchExactly) + if items: + return items[0].checkState() == Qt.CheckState.Checked + return False diff --git a/src/digitalsreeni_image_annotator/controllers/dino_controller.py b/src/digitalsreeni_image_annotator/controllers/dino_controller.py new file mode 100644 index 0000000..b21edd8 --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/dino_controller.py @@ -0,0 +1,821 @@ +"""DINO (LLM-assisted detection) coordination controller. + +Extracted from `ImageAnnotator`. Owns: + +- DINO model picker plumbing (preset / custom-path resolution, + on-demand HuggingFace Hub download) +- Single-image and batch detection workflows (DINO produces bboxes → + SAM refines to masks) +- Temp-annotation review state: accept / reject pending DINO results, + navigate batch review across mixed regular-images + multi-dim slices +- The application-wide `DINOReviewEventFilter` that lets Enter / + Escape accept-or-reject pending DINO masks regardless of which + widget has focus + +State (`dino_utils`, `dino_model_loaded`, `dino_custom_model_path`, +`dino_batch_results`) stays on the main window in this phase — same +deferral as prior controllers. Widgets that own DINO configuration +(`dino_phrase_panel`, `dino_class_table`, `dino_model_selector`, +`dino_batch_mode`, `lbl_dino_status`, `btn_detect_*`, `dino_browse_row`, +`lbl_dino_custom`) also stay on the main window. + +The temp-annotation review machinery (Temp-* class handling) lives +here too — it was originally a separate workflow for YOLO predictions +but is now shared with DINO and most easily co-located. +""" + +import os +import traceback + +from PyQt6.QtCore import QEvent, QObject, Qt, QTimer +from PyQt6.QtGui import QColor, QImage +from PyQt6.QtWidgets import ( + QApplication, + QFileDialog, + QLineEdit, + QMessageBox, + QProgressDialog, + QTextEdit, +) + + +class DINOReviewEventFilter(QObject): + """Application-wide event filter that lets Enter / Escape accept or + reject pending DINO temp_annotations regardless of which widget has + focus. Without this, clicking a slice/image entry in a list moves + focus there and Enter is consumed by the list's itemActivated + handler before it can reach ImageLabel.keyPressEvent. + + Suppressed when a modal dialog is active or focus is on a text-input + widget so we don't break dialog default-button behaviour or + in-cell editing. + """ + + def __init__(self, main_window): + super().__init__(main_window) + self.main_window = main_window + + def eventFilter(self, obj, event): + if event.type() != QEvent.Type.KeyPress: + return False + key = event.key() + if key not in (Qt.Key.Key_Return, Qt.Key.Key_Enter, Qt.Key.Key_Escape): + return False + app = QApplication.instance() + if app is None or app.activeModalWidget() is not None: + return False + focused = app.focusWidget() + if isinstance(focused, (QLineEdit, QTextEdit)): + return False + temp = self.main_window.image_label.temp_annotations + if not temp or not any(a.get("source") == "dino" for a in temp): + return False + if key in (Qt.Key.Key_Return, Qt.Key.Key_Enter): + self.main_window.accept_dino_results() + else: + self.main_window.reject_dino_results() + return True + + +class DINOController(QObject): + def __init__(self, main_window): + super().__init__(main_window) + self.mw = main_window + + # --- Model picker plumbing --- + + def _resolve_dino_model_path(self, model_name): + """Return the canonical local path for a preset DINO model, or None if unknown.""" + from ..inference.dino_utils import GDINO_MODEL_PATHS + return GDINO_MODEL_PATHS.get(model_name) + + def _on_dino_model_changed(self, text): + """Selection → ready state. Downloads happen lazily on first Detect.""" + self.mw.dino_browse_row.setVisible(text == "Custom / fine-tuned (browse)") + + if text == "Pick a DINO Model": + self.mw.dino_model_loaded = False + self.mw.lbl_dino_status.setText("No DINO model loaded") + self.mw.btn_detect_single.setEnabled(False) + self.mw.btn_detect_batch.setEnabled(False) + return + + if text == "Custom / fine-tuned (browse)": + if ( + self.mw.dino_custom_model_path + and os.path.exists(self.mw.dino_custom_model_path) + ): + self.mw.dino_model_loaded = True + self.mw.lbl_dino_status.setText( + f"Ready: {os.path.basename(self.mw.dino_custom_model_path)}" + ) + self.mw.btn_detect_single.setEnabled(True) + self.mw.btn_detect_batch.setEnabled(True) + else: + self.mw.dino_model_loaded = False + self.mw.lbl_dino_status.setText("Browse for a custom model folder") + self.mw.btn_detect_single.setEnabled(False) + self.mw.btn_detect_batch.setEnabled(False) + return + + self.mw.dino_model_loaded = True + self.mw.btn_detect_single.setEnabled(True) + self.mw.btn_detect_batch.setEnabled(True) + model_path = self._resolve_dino_model_path(text) + if model_path and os.path.exists(model_path): + self.mw.lbl_dino_status.setText(f"Ready: {text}") + else: + self.mw.lbl_dino_status.setText(f"{text} — will download on first detection") + + def _ensure_dino_model_downloaded(self, model_name): + """If the preset model isn't on disk yet, download it. Returns success.""" + if model_name in ("Pick a DINO Model", "Custom / fine-tuned (browse)"): + return True + model_path = self._resolve_dino_model_path(model_name) + if model_path and os.path.exists(model_path): + return True + + try: + import huggingface_hub # noqa: F401 + except ImportError: + QMessageBox.critical( + self.mw, "Missing Dependency", + f"Cannot download {model_name}: the huggingface_hub package " + "is not installed.\n\nRun:\n pip install huggingface_hub", + ) + return False + + self.mw.lbl_dino_status.setText(f"Downloading {model_name}...") + QApplication.processEvents() + try: + downloaded = self.mw.dino_utils.download_model(model_name) + except Exception as e: + QMessageBox.critical(self.mw, "Download Failed", f"{model_name}:\n{e}") + return False + if not downloaded: + QMessageBox.critical( + self.mw, "Download Failed", + f"Could not download {model_name} from Hugging Face Hub.", + ) + return False + return True + + def browse_dino_model(self): + path = QFileDialog.getExistingDirectory(self.mw, "Select DINO Model Folder") + if path: + self.mw.dino_custom_model_path = path + self.mw.lbl_dino_custom.setText(os.path.basename(path)) + self._on_dino_model_changed(self.mw.dino_model_selector.currentText()) + + def on_dino_class_row_changed(self): + name = self.mw.dino_class_table.selected_class_name() + self.mw.dino_phrase_panel.set_active_class(name) + + def _build_dino_class_configs(self): + """Build class_configs from threshold table + phrase panel.""" + configs = [] + for cfg in self.mw.dino_class_table.get_class_configs(): + phrases = self.mw.dino_phrase_panel.get_phrases_for(cfg["name"]) + configs.append({ + "name": cfg["name"], + "phrases": phrases, + "box_thr": cfg["box_thr"], + "txt_thr": cfg["txt_thr"], + "nms_thr": cfg["nms_thr"], + }) + return configs + + # --- Detection workflows --- + + def run_dino_detection_single(self): + if not self.mw.dino_model_loaded: + QMessageBox.warning(self.mw, "No DINO Model", + "Please pick a DINO model first.") + return + if not self.mw.sam_utils.current_sam_model: + QMessageBox.warning( + self.mw, "No SAM Model", + "DINO produces bounding boxes; SAM is needed to convert them " + "into segmentation masks. Please pick a SAM model first.", + ) + return + if not self.mw.current_image or self.mw.current_image.isNull(): + QMessageBox.warning(self.mw, "No Image", + "Please load an image first.") + return + + model_name = self.mw.dino_model_selector.currentText() + class_configs = self._build_dino_class_configs() + if not class_configs: + QMessageBox.warning(self.mw, "No Classes", + "Please add at least one class with phrases.") + return + + self.mw.btn_detect_single.setEnabled(False) + self.mw.btn_detect_batch.setEnabled(False) + + # Clear any stale temp annotations before starting detection so an + # accept from a previous run doesn't bleed into the results handler. + self.mw.image_label.temp_annotations = [] + + if not self._ensure_dino_model_downloaded(model_name): + self.mw.btn_detect_single.setEnabled(True) + self.mw.btn_detect_batch.setEnabled(True) + return + + self.mw.lbl_dino_status.setText("Detecting...") + QApplication.processEvents() + + print(f"[DINO] detect_single: model={model_name!r} class_configs={class_configs}") + try: + results = self.mw.dino_utils.detect( + self.mw.current_image, class_configs, + model_name=model_name, + custom_model_path=self.mw.dino_custom_model_path, + ) + except Exception as e: + traceback.print_exc() + QMessageBox.critical(self.mw, "DINO Error", str(e)) + self.mw.btn_detect_single.setEnabled(True) + self.mw.btn_detect_batch.setEnabled(True) + self.mw.lbl_dino_status.setText("Detection failed.") + return + + self.mw.btn_detect_single.setEnabled(True) + self.mw.btn_detect_batch.setEnabled(True) + + if results is None: + print("[DINO] detect_single: results=None (model resolution failure)") + self.mw.lbl_dino_status.setText("No detections.") + return + + print(f"[DINO] detect_single: got {len(results)} result(s)") + if results: + for i, r in enumerate(results[:3]): + print(f"[DINO] result[{i}] class={r['class_name']!r} score={r['score']:.3f} bbox={r['bbox']}") + + if not results: + self.mw.lbl_dino_status.setText("No detections found.") + return + + self.mw.lbl_dino_status.setText(f"{len(results)} detection(s). Running SAM...") + QApplication.processEvents() + + bboxes = [r["bbox"] for r in results] + print(f"[SAM] batch call: {len(bboxes)} bbox(es), first 3 = {bboxes[:3]}") + try: + sam_results = self.mw.sam_utils.apply_sam_predictions_batch( + self.mw.current_image, bboxes + ) + except Exception as e: + traceback.print_exc() + QMessageBox.critical(self.mw, "SAM Error", str(e)) + self.mw.lbl_dino_status.setText("SAM segmentation failed.") + return + + if sam_results is None: + print("[SAM] batch returned None (no SAM model loaded)") + QMessageBox.warning(self.mw, "SAM Error", + "Failed to segment detections with SAM.") + self.mw.lbl_dino_status.setText("SAM segmentation failed.") + return + + n_errors = sum(1 for s in sam_results if "error" in s) + n_ok = sum(1 for s in sam_results if "segmentation" in s) + print(f"[SAM] batch returned {len(sam_results)} result(s): {n_ok} ok, {n_errors} error(s)") + + # Honor the batch-mode dropdown for the single-image case too: + # "Auto-accept" means commit straight to annotations without + # showing the temp-review overlay. The dropdown name is "batch" + # historically but it controls both paths. + image_name = self.mw.current_slice or self.mw.image_file_name + auto_accept = ( + self.mw.dino_batch_mode.currentText() == "Auto-accept all detections" + ) + if auto_accept: + print(f"[DINO] detect_single: auto_accept=True, committing {len(results)} result(s)") + try: + self._commit_dino_results(image_name, results, sam_results) + except Exception as e: + print(f"[DINO] _commit_dino_results failed: {e}") + traceback.print_exc() + n_committed = sum(1 for s in sam_results if "error" not in s) + self.mw.image_label.temp_annotations = [] + self.mw.image_label.update() + self.mw.update_annotation_list() + # Refresh slice list so the freshly-annotated slice picks + # up the highlight color; review-mode's accept_dino_results + # already does this, the auto-accept path didn't. + self.mw.update_slice_list_colors() + self.mw.auto_save() + self.mw.lbl_dino_status.setText( + f"Loaded: {model_name} | {n_committed} mask(s) auto-accepted" + ) + print(f"[DINO] auto-accept: committed {n_committed} mask(s) to {image_name}") + return + + # Review mode + temp_annotations = [] + for r, s in zip(results, sam_results): + if "error" in s: + print(f"[SAM] failed for {r['class_name']}: {s['error']}") + continue + temp_annotations.append({ + "segmentation": s["segmentation"], + "category_name": r["class_name"], + "score": r["score"], + "source": "dino", + "temp": True, + }) + + self.mw.image_label.temp_annotations = temp_annotations + QTimer.singleShot(0, self.mw.image_label.setFocus) + self.mw.image_label.update() + self.mw.lbl_dino_status.setText( + f"Loaded: {model_name} | {len(temp_annotations)} mask(s) ready" + ) + print(f"[DINO] detection complete: {len(results)} boxes, {len(temp_annotations)} masks attached to canvas") + + def run_dino_detection_batch(self): + if not self.mw.dino_model_loaded: + QMessageBox.warning(self.mw, "No DINO Model", + "Please pick a DINO model first.") + return + if not self.mw.sam_utils.current_sam_model: + QMessageBox.warning( + self.mw, "No SAM Model", + "DINO produces bounding boxes; SAM is needed to convert them " + "into segmentation masks. Please pick a SAM model first.", + ) + return + if not self.mw.all_images: + QMessageBox.warning(self.mw, "No Images", + "Please load images first.") + return + + model_name = self.mw.dino_model_selector.currentText() + class_configs = self._build_dino_class_configs() + if not class_configs: + QMessageBox.warning(self.mw, "No Classes", + "Please add at least one class with phrases.") + return + + # Prevent stale temp annotations from a prior single-image review from + # confusing the batch results handler or the DINOReviewEventFilter. + self.mw.image_label.temp_annotations = [] + + if not self._ensure_dino_model_downloaded(model_name): + return + + auto_accept = ( + self.mw.dino_batch_mode.currentText() == "Auto-accept all detections" + ) + print(f"[DINO] detect_batch: auto_accept={auto_accept}") + + # Build a flat list of (display_name, qimage) work items covering + # both regular images (loaded from disk) and multi-dim image + # slices (already QImages in memory). Slices live in + # self.mw.image_slices[base_name], indexed by their slice_name + # (e.g. "stack_T1_Z1_C1"). The earlier implementation only + # iterated self.all_images and skipped multi-slice entries with + # a console warning, leaving slice-based projects unable to use + # Detect All. + work_items = self._collect_dino_batch_work_items() + if not work_items: + QMessageBox.information( + self.mw, "Detect All Images", + "No images or slices available to process." + ) + return + total = len(work_items) + + progress = QProgressDialog("Running LLM Detection...", "Cancel", 0, total, self.mw) + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + + for idx, (image_name, qimage) in enumerate(work_items): + if progress.wasCanceled(): + break + progress.setValue(idx) + QApplication.processEvents() + + try: + results = self.mw.dino_utils.detect( + qimage, class_configs, + model_name=model_name, + custom_model_path=self.mw.dino_custom_model_path, + ) + except Exception as e: + print(f" DINO failed for {image_name}: {e}") + continue + + if not results: + continue + + bboxes = [r["bbox"] for r in results] + try: + sam_results = self.mw.sam_utils.apply_sam_predictions_batch( + qimage, bboxes + ) + except Exception as e: + print(f" SAM failed for {image_name}: {e}") + continue + if sam_results is None: + continue + + if auto_accept: + self._commit_dino_results(image_name, results, sam_results) + else: + self._store_dino_batch_results(image_name, results, sam_results) + + progress.setValue(total) + progress.close() + + if auto_accept: + QMessageBox.information( + self.mw, "Batch Detection Complete", + "Detections have been saved to annotations." + ) + self.mw.update_annotation_list() + self.mw.update_slice_list_colors() + self.mw.auto_save() + else: + self._show_dino_batch_review() + + def _collect_dino_batch_work_items(self): + """Return a flat ``[(name, QImage), …]`` list for batch DINO. + + Regular images are loaded from disk via PIL → QImage. Multi-dim + images contribute one entry per slice from ``self.mw.image_slices``; + slices that haven't been materialised yet (the parent image was + never opened in this session) are skipped with a console log. + """ + from PIL import Image as PILImage + items = [] + for img_info in self.mw.all_images: + file_name = img_info["file_name"] + if img_info.get("is_multi_slice", False): + base_name = os.path.splitext(file_name)[0] + slices = self.mw.image_slices.get(base_name, []) + if not slices: + print(f" Skipping multi-slice image '{file_name}': " + "no slices loaded (open the image first to " + "materialise its slices).") + continue + for slice_name, qimage in slices: + items.append((slice_name, qimage)) + else: + image_path = self.mw.image_paths.get(file_name) + if not image_path or not os.path.exists(image_path): + print(f" Skipping '{file_name}': missing image path.") + continue + try: + pil_img = PILImage.open(image_path).convert("RGB") + qimage = QImage( + pil_img.tobytes(), + pil_img.width, + pil_img.height, + pil_img.width * 3, + QImage.Format.Format_RGB888, + ) + items.append((file_name, qimage)) + except Exception as e: + print(f" Skipping '{file_name}': failed to load ({e}).") + print(f"[DINO] batch work items: {len(items)} total") + return items + + def _commit_dino_results(self, image_name, dino_results, sam_results): + """Commit DINO+SAM results to annotations for a single image. + + If image_name is the currently-displayed image, route through + image_label.annotations so the canvas reflects the change and the + next save_current_annotations() doesn't overwrite the additions. + Otherwise write directly to the project-level cache. + """ + current_image = self.mw.current_slice or self.mw.image_file_name + is_current = image_name == current_image + + if is_current: + target = self.mw.image_label.annotations + else: + if image_name not in self.mw.all_annotations: + self.mw.all_annotations[image_name] = {} + target = self.mw.all_annotations[image_name] + + for r, s in zip(dino_results, sam_results): + if "error" in s: + continue + class_name = r["class_name"] + if class_name not in self.mw.class_mapping: + print(f" Skipping DINO result for unknown class '{class_name}'") + continue + existing = target.get(class_name, []) + number = max((a.get("number", 0) for a in existing), default=0) + 1 + ann = { + "segmentation": s["segmentation"], + "category_id": self.mw.class_mapping[class_name], + "category_name": class_name, + "score": r["score"], + "source": "dino", + "number": number, + } + target.setdefault(class_name, []).append(ann) + + if is_current: + self.mw.save_current_annotations() + self.mw.image_label.update() + + def _store_dino_batch_results(self, image_name, dino_results, sam_results): + """Store results for batch review mode.""" + valid = [] + for r, s in zip(dino_results, sam_results): + if "error" not in s: + valid.append({ + "segmentation": s["segmentation"], + "category_name": r["class_name"], + "score": r["score"], + "source": "dino", + "temp": True, + }) + self.mw.dino_batch_results[image_name] = valid + + def _show_dino_batch_review(self): + """Navigate to first image with batch results for review. + + If the next entry refers to an image/slice that's no longer in + the project (e.g. the source was removed between detection and + review), pop the orphan and try the next entry so the user + doesn't get stuck with un-reviewable results. + """ + if not self.mw.dino_batch_results: + QMessageBox.information(self.mw, "Batch Detection", + "No detections found in any image.") + return + while self.mw.dino_batch_results: + first = next(iter(self.mw.dino_batch_results)) + if self._navigate_to_image_or_slice(first): + return + print(f"[DINO] dropping orphan batch result for {first!r} " + "(no matching image or slice in project)") + self.mw.dino_batch_results.pop(first, None) + QMessageBox.warning( + self.mw, "Batch Detection", + "Detections were produced but none of them map to an image " + "or slice still in the project. Results discarded.", + ) + + def _navigate_to_image_or_slice(self, name): + """Switch the UI to a regular image or a slice by name. + + Returns True if a match was found and the switch was issued. + Used by batch-review navigation, which mixes regular image + names and slice names in ``dino_batch_results``. + """ + for i in range(self.mw.image_list.count()): + item = self.mw.image_list.item(i) + if item and item.text() == name: + self.mw.image_list.setCurrentRow(i) + self.mw.switch_image(item) + return True + for base_name, slices in self.mw.image_slices.items(): + if not any(s_name == name for s_name, _ in slices): + continue + for i in range(self.mw.image_list.count()): + item = self.mw.image_list.item(i) + if not item: + continue + file_name = item.text() + if os.path.splitext(file_name)[0] == base_name: + self.mw.image_list.setCurrentRow(i) + self.mw.switch_image(item) + for s_i in range(self.mw.slice_list.count()): + s_item = self.mw.slice_list.item(s_i) + if s_item and s_item.text() == name: + self.mw.slice_list.setCurrentRow(s_i) + self.mw.switch_slice(s_item) + return True + break + return False + return False + + def _refresh_dino_temp_for_current(self): + """Sync ``image_label.temp_annotations`` to whatever the + currently-displayed image/slice has stored in + ``dino_batch_results``. Called from switch_slice / switch_image. + + Why this exists: ``temp_annotations`` is a single field on + ``ImageLabel``, not a per-image cache. Without this sync, masks + from the previously-viewed image bleed onto every slice the + user navigates to. + """ + new_image = self.mw.current_slice or self.mw.image_file_name + pending = self.mw.dino_batch_results.get(new_image, []) if new_image else [] + if pending: + self.mw.image_label.temp_annotations = list(pending) + self.mw.lbl_dino_status.setText( + f"Review: {new_image} ({len(pending)} detection(s))" + ) + QTimer.singleShot(0, self.mw.image_label.setFocus) + else: + if self.mw.image_label.temp_annotations: + print("[DINO] temp annotations cleared on switch " + f"(no pending batch results for {new_image!r})") + self.mw.image_label.temp_annotations = [] + self.mw.image_label.update() + + def accept_dino_results(self): + """Accept current temp_annotations (called from keyPressEvent).""" + if not self.mw.image_label.temp_annotations: + return + image_name = self.mw.current_slice or self.mw.image_file_name + + for ann in self.mw.image_label.temp_annotations: + class_name = ann["category_name"] + if class_name not in self.mw.class_mapping: + print(f" Skipping DINO result for unknown class '{class_name}'") + continue + new_ann = { + "segmentation": ann["segmentation"], + "category_id": self.mw.class_mapping[class_name], + "category_name": class_name, + "score": ann.get("score", 0.0), + "source": "dino", + } + self.mw.image_label.annotations.setdefault(class_name, []).append(new_ann) + self.mw.add_annotation_to_list(new_ann) + + self.mw.image_label.temp_annotations = [] + self.mw.dino_batch_results.pop(image_name, None) + if self.mw.dino_batch_results: + self._show_dino_batch_review() + self.mw.save_current_annotations() + self.mw.update_slice_list_colors() + self.mw.image_label.update() + self.mw.lbl_dino_status.setText("Results accepted.") + print("DINO results accepted.") + + def reject_dino_results(self): + """Discard current temp_annotations.""" + self.mw.image_label.temp_annotations = [] + image_name = self.mw.current_slice or self.mw.image_file_name + self.mw.dino_batch_results.pop(image_name, None) + if self.mw.dino_batch_results: + self._show_dino_batch_review() + self.mw.image_label.update() + self.mw.lbl_dino_status.setText("Results discarded.") + print("DINO results discarded.") + + # --- Temp-class review workflow (shared with YOLO predictions) --- + + def has_visible_temp_classes(self): + for i in range(self.mw.class_list.count()): + item = self.mw.class_list.item(i) + if ( + item.text().startswith("Temp-") + and item.checkState() == Qt.CheckState.Checked + ): + return True + return False + + def add_temp_classes(self, temp_annotations): + for temp_class_name, annotations in temp_annotations.items(): + if temp_class_name not in self.mw.image_label.class_colors: + color = QColor( + Qt.GlobalColor(len(self.mw.image_label.class_colors) % 16 + 7) + ) + self.mw.image_label.class_colors[temp_class_name] = color + self.mw.image_label.annotations[temp_class_name] = annotations + + self.mw.update_class_list() + + def verify_current_class(self): + if ( + self.mw.current_class is None + or self.mw.current_class not in self.mw.class_mapping + ): + if self.mw.class_list.count() > 0: + self.mw.class_list.setCurrentRow(0) + self.mw.on_class_selected(self.mw.class_list.item(0)) + else: + self.mw.current_class = None + self.mw.disable_annotation_tools() + + def accept_visible_temp_classes(self): + visible_temp_classes = [ + item.text() + for item in self.mw.class_list.findItems( + "Temp-*", Qt.MatchFlag.MatchWildcard + ) + if item.checkState() == Qt.CheckState.Checked + ] + + for temp_class_name in visible_temp_classes: + permanent_class_name = temp_class_name[5:] + if permanent_class_name not in self.mw.image_label.annotations: + self.mw.add_class( + permanent_class_name, + self.mw.image_label.class_colors[temp_class_name], + ) + + current_max = max( + [ + ann.get("number", 0) + for ann in self.mw.image_label.annotations.get( + permanent_class_name, [] + ) + ] + + [0] + ) + + for annotation in self.mw.image_label.annotations[temp_class_name]: + current_max += 1 + annotation["category_name"] = permanent_class_name + annotation["number"] = current_max + self.mw.image_label.annotations.setdefault( + permanent_class_name, [] + ).append(annotation) + + del self.mw.image_label.annotations[temp_class_name] + del self.mw.image_label.class_colors[temp_class_name] + + self.mw.update_class_list() + current_name = self.mw.current_slice or self.mw.image_file_name + self.mw.all_annotations[current_name] = self.mw.image_label.annotations + self.mw.update_annotation_list() + self.mw.image_label.update() + self.mw.save_current_annotations() + + self.select_first_primary_class() + self.verify_current_class() + + QMessageBox.information( + self.mw, + "Annotations Accepted", + "Temporary annotations have been accepted and added to the permanent classes.", + ) + + def select_first_primary_class(self): + for i in range(self.mw.class_list.count()): + item = self.mw.class_list.item(i) + if not item.text().startswith("Temp-"): + self.mw.class_list.setCurrentItem(item) + self.mw.on_class_selected(item) + break + + def reject_visible_temp_classes(self): + visible_temp_classes = [ + item.text() + for item in self.mw.class_list.findItems( + "Temp-*", Qt.MatchFlag.MatchWildcard + ) + if item.checkState() == Qt.CheckState.Checked + ] + + for temp_class_name in visible_temp_classes: + if temp_class_name in self.mw.image_label.annotations: + del self.mw.image_label.annotations[temp_class_name] + if temp_class_name in self.mw.image_label.class_colors: + del self.mw.image_label.class_colors[temp_class_name] + + self.mw.update_class_list() + self.mw.image_label.update() + + def check_temp_annotations(self): + temp_classes = [ + class_name + for class_name in self.mw.image_label.annotations.keys() + if class_name.startswith("Temp-") + ] + if temp_classes: + reply = QMessageBox.question( + self.mw, + "Temporary Annotations", + "There are temporary annotations that will be discarded. Do you want to continue?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + for temp_class in temp_classes: + del self.mw.image_label.annotations[temp_class] + del self.mw.image_label.class_colors[temp_class] + self.mw.update_class_list() + self.mw.update_annotation_list() + return True + return False + return True + + def remove_all_temp_annotations(self): + for image_name in list(self.mw.all_annotations.keys()): + for class_name in list(self.mw.all_annotations[image_name].keys()): + if class_name.startswith("Temp-"): + del self.mw.all_annotations[image_name][class_name] + if not self.mw.all_annotations[image_name]: + del self.mw.all_annotations[image_name] + + for class_name in list(self.mw.image_label.class_colors.keys()): + if class_name.startswith("Temp-"): + del self.mw.image_label.class_colors[class_name] + + self.mw.update_class_list() + self.mw.update_annotation_list() + self.mw.image_label.update() diff --git a/src/digitalsreeni_image_annotator/controllers/image_controller.py b/src/digitalsreeni_image_annotator/controllers/image_controller.py new file mode 100644 index 0000000..1d30557 --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/image_controller.py @@ -0,0 +1,845 @@ +"""Image / multi-dimensional slice loading and navigation controller. + +Extracted from `ImageAnnotator` to give image I/O its own home. Owns: + +- Loading from disk (PNG/JPG, TIFF, CZI) +- Multi-dimensional image handling: dimension assignment dialog, + per-axis slicing, slice list population +- Image / slice navigation (switch_image, switch_slice, activate_slice) +- Display and per-image lifecycle (remove_image, delete_selected_image, + redefine_dimensions) + +State (`current_image`, `current_slice`, `slices`, `image_paths`, +`image_slices`, `image_dimensions`, `image_shapes`, `all_images`, +`image_file_name`, etc.) still lives on the main window and is read here +via `self.mw`. A future phase may migrate ownership of selected +attributes to the controller — for now this is pure method relocation. + +The `DimensionDialog` widget lives here too — it is only used by +`process_multidimensional_image`. +""" + +import os + +import numpy as np +from czifile import CziFile +from PyQt6.QtCore import Qt, QObject +from PyQt6.QtGui import QColor, QImage, QPixmap +from PyQt6.QtWidgets import ( + QApplication, + QComboBox, + QDialog, + QFileDialog, + QGridLayout, + QLabel, + QListWidget, + QListWidgetItem, + QMessageBox, + QProgressDialog, + QPushButton, + QVBoxLayout, + QWidget, +) +from tifffile import TiffFile + +from ..core import image_utils + + +class DimensionDialog(QDialog): + def __init__(self, shape, file_name, parent=None, default_dimensions=None): + super().__init__(parent) + self.setWindowTitle("Assign Dimensions") + layout = QVBoxLayout(self) + + file_name_label = QLabel(f"File: {file_name}") + file_name_label.setWordWrap(True) + layout.addWidget(file_name_label) + + dim_widget = QWidget() + dim_layout = QGridLayout(dim_widget) + self.combos = [] + self.shape = shape + dimensions = ["T", "Z", "C", "S", "H", "W"] + for i, dim in enumerate(shape): + dim_layout.addWidget(QLabel(f"Dimension {i} (size {dim}):"), i, 0) + combo = QComboBox() + combo.addItems(dimensions) + if default_dimensions and i < len(default_dimensions): + combo.setCurrentText(default_dimensions[i]) + dim_layout.addWidget(combo, i, 1) + self.combos.append(combo) + layout.addWidget(dim_widget) + + self.button = QPushButton("OK") + self.button.clicked.connect(self.accept) + layout.addWidget(self.button) + + self.setMinimumWidth(300) + + def get_dimensions(self): + return [combo.currentText() for combo in self.combos] + + +class ImageController(QObject): + def __init__(self, main_window): + super().__init__(main_window) + self.mw = main_window + + def update_image_list(self): + self.mw.image_list.clear() + for image_info in self.mw.all_images: + self.mw.image_list.addItem(image_info["file_name"]) + + def setup_slice_list(self): + self.mw.slice_list = QListWidget() + self.mw.slice_list.itemClicked.connect(self.switch_slice) + self.mw.image_list_layout.addWidget(QLabel("Slices:")) + self.mw.image_list_layout.addWidget(self.mw.slice_list) + + def open_images(self): + file_names, _ = QFileDialog.getOpenFileNames( + self.mw, + "Open Images", + "", + "Image Files (*.png *.jpg *.bmp *.tif *.tiff *.czi)", + ) + if file_names: + self.mw.image_list.clear() + self.mw.image_paths.clear() + self.mw.all_images.clear() + self.mw.slice_list.clear() + self.mw.slices.clear() + self.mw.current_stack = None + self.mw.current_slice = None + self.add_images_to_list(file_names) + + def add_images_to_list(self, file_names): + first_added_item = None + for file_name in file_names: + base_name = os.path.basename(file_name) + if base_name not in self.mw.image_paths: + image_info = { + "file_name": base_name, + "height": 0, + "width": 0, + "id": len(self.mw.all_images) + 1, + "is_multi_slice": False, + } + + if file_name.lower().endswith((".tif", ".tiff", ".czi")): + self.load_multi_slice_image(file_name) + base_name_without_ext = os.path.splitext(base_name)[0] + if ( + base_name_without_ext in self.mw.image_slices + and self.mw.image_slices[base_name_without_ext] + ): + first_slice_name, first_slice = self.mw.image_slices[ + base_name_without_ext + ][0] + image_info["height"] = first_slice.height() + image_info["width"] = first_slice.width() + image_info["is_multi_slice"] = True + image_info["dimensions"] = self.mw.image_dimensions.get( + base_name_without_ext, [] + ) + image_info["shape"] = self.mw.image_shapes.get( + base_name_without_ext, [] + ) + else: + image = QImage(file_name) + image_info["height"] = image.height() + image_info["width"] = image.width() + + self.mw.all_images.append(image_info) + item = QListWidgetItem(base_name) + self.mw.image_list.addItem(item) + if first_added_item is None: + first_added_item = item + + self.mw.image_paths[base_name] = file_name + + if first_added_item: + self.mw.image_list.setCurrentItem(first_added_item) + self.switch_image(first_added_item) + + if not self.mw.is_loading_project: + self.mw.auto_save() + + def update_all_images(self, new_image_info): + for info in new_image_info: + if not any( + img["file_name"] == info["file_name"] for img in self.mw.all_images + ): + self.mw.all_images.append(info) + + def switch_slice(self, item): + if item is None: + return + # check_unsaved_changes prompts the user and commits/discards + # all dirty tool handlers; returns False on Cancel. + if not self.mw.image_label.check_unsaved_changes(): + return + + self.mw.save_current_annotations() + self.mw.image_label.clear_temp_sam_prediction() + + slice_name = item.text() + for name, qimage in self.mw.slices: + if name == slice_name: + self.mw.current_image = qimage + self.mw.current_slice = name + self.display_image() + self.mw.load_image_annotations() + self.mw.update_annotation_list() + self.mw.clear_highlighted_annotation() + self.mw.image_label.reset_annotation_state() + self.mw.image_label.clear_current_annotation() + self.mw.update_image_info() + break + + self.mw.image_label.update() + self.mw.update_slice_list_colors() + + self.mw.set_zoom(1.0) + + self.mw._refresh_dino_temp_for_current() + + def switch_image(self, item): + if item is None: + return + if not self.mw.image_label.check_unsaved_changes(): + return + + current_item = self.mw.image_list.currentItem() + + if not self.mw.check_temp_annotations(): + self.mw.image_list.setCurrentItem(current_item) + return + + self.mw.save_current_annotations() + self.mw.image_label.clear_temp_sam_prediction() + self.mw.image_label.exit_editing_mode() + + file_name = item.text() + print(f"\nSwitching to image: {file_name}") + + image_info = next( + (img for img in self.mw.all_images if img["file_name"] == file_name), None + ) + + if image_info: + self.mw.image_file_name = file_name + image_path = self.mw.image_paths.get(file_name) + + if not image_path: + image_path = os.path.join( + self.mw.current_project_dir, "images", file_name + ) + + if image_path and os.path.exists(image_path): + if image_info.get("is_multi_slice", False): + base_name = os.path.splitext(file_name)[0] + if base_name in self.mw.image_slices: + self.mw.slices = self.mw.image_slices[base_name] + if self.mw.slices: + self.mw.current_image = self.mw.slices[0][1] + self.mw.current_slice = self.mw.slices[0][0] + self.update_slice_list() + self.activate_slice(self.mw.current_slice) + else: + self.load_multi_slice_image( + image_path, + image_info.get("dimensions"), + image_info.get("shape"), + ) + else: + self.load_regular_image(image_path) + self.display_image() + self.clear_slice_list() + + self.mw.load_image_annotations() + self.mw.update_annotation_list() + self.mw.clear_highlighted_annotation() + self.mw.image_label.update() + self.mw.image_label.reset_annotation_state() + self.mw.image_label.clear_current_annotation() + self.mw.update_image_info() + + self.mw.adjust_zoom_to_fit() + else: + self.mw.current_image = None + self.mw.image_label.clear() + self.mw.load_image_annotations() + self.mw.update_annotation_list() + self.mw.update_image_info() + + self.mw.image_list.setCurrentItem(item) + self.mw.image_label.update() + self.mw.update_slice_list_colors() + else: + self.mw.current_image = None + self.mw.current_slice = None + self.mw.image_label.clear() + self.mw.update_image_info() + self.clear_slice_list() + + self.mw._refresh_dino_temp_for_current() + + def activate_current_slice(self): + if self.mw.current_slice: + items = self.mw.slice_list.findItems( + self.mw.current_slice, Qt.MatchFlag.MatchExactly + ) + if items: + self.mw.slice_list.setCurrentItem(items[0]) + + self.mw.load_image_annotations() + self.mw.image_label.update() + self.mw.update_annotation_list() + + def load_image(self, image_path): + extension = os.path.splitext(image_path)[1].lower() + if extension in [".tif", ".tiff"]: + self.load_tiff(image_path) + elif extension == ".czi": + self.load_czi(image_path) + else: + self.load_regular_image(image_path) + + def load_tiff( + self, image_path, dimensions=None, shape=None, force_dimension_dialog=False + ): + print(f"Loading TIFF file: {image_path}") + axes_hint = None + with TiffFile(image_path) as tif: + print(f"TIFF tags: {tif.pages[0].tags}") + + try: + metadata = tif.pages[0].tags["ImageDescription"].value + print(f"TIFF metadata: {metadata}") + except KeyError: + print("No ImageDescription metadata found") + + try: + series_axes = tif.series[0].axes if tif.series else None + if series_axes: + axis_map = { + "T": "T", "Z": "Z", "C": "C", "S": "S", + "Y": "H", "X": "W", + } + mapped = [axis_map.get(a) for a in series_axes] + if all(a is not None for a in mapped): + axes_hint = mapped + print(f"TIFF series axes: {series_axes} → dimension hint: {axes_hint}") + else: + unknown = [a for a in series_axes if axis_map.get(a) is None] + print(f"TIFF series axes had unknown labels {unknown}, no hint applied") + except Exception as e: + print(f"Could not read TIFF series axes: {e}") + + if len(tif.pages) > 1: + print(f"Multi-page TIFF detected. Number of pages: {len(tif.pages)}") + image_array = tif.asarray() + else: + print("Single-page TIFF detected.") + image_array = tif.pages[0].asarray() + + print(f"Image array shape: {image_array.shape}") + print(f"Image array dtype: {image_array.dtype}") + print(f"Image min: {image_array.min()}, max: {image_array.max()}") + + if dimensions and shape and not force_dimension_dialog: + print(f"Using stored dimensions: {dimensions}") + print(f"Using stored shape: {shape}") + image_array = image_array.reshape(shape) + else: + print("Processing as new image or forcing dimension dialog.") + dimensions = None + + self.process_multidimensional_image( + image_array, image_path, dimensions, force_dimension_dialog, + axes_hint=axes_hint, + ) + + def load_czi( + self, image_path, dimensions=None, shape=None, force_dimension_dialog=False + ): + print(f"Loading CZI file: {image_path}") + with CziFile(image_path) as czi: + image_array = czi.asarray() + print(f"CZI array shape: {image_array.shape}") + print(f"CZI array dtype: {image_array.dtype}") + print(f"CZI array min: {image_array.min()}, max: {image_array.max()}") + + if dimensions and shape and not force_dimension_dialog: + print(f"Using stored dimensions: {dimensions}") + print(f"Using stored shape: {shape}") + image_array = image_array.reshape(shape) + else: + print("Processing as new image or forcing dimension dialog.") + dimensions = None + + self.process_multidimensional_image( + image_array, image_path, dimensions, force_dimension_dialog + ) + + def load_regular_image(self, image_path): + self.mw.current_image = QImage(image_path) + self.mw.slices = [] + self.mw.slice_list.clear() + self.mw.current_slice = None + + def load_multi_slice_image(self, image_path, dimensions=None, shape=None): + file_name = os.path.basename(image_path) + base_name = os.path.splitext(file_name)[0] + print(f"Loading multi-slice image: {image_path}") + print(f"Base name: {base_name}") + + if dimensions and shape: + print(f"Using stored dimensions: {dimensions}") + print(f"Using stored shape: {shape}") + self.mw.image_dimensions[base_name] = dimensions + self.mw.image_shapes[base_name] = shape + if image_path.lower().endswith((".tif", ".tiff")): + self.load_tiff(image_path, dimensions, shape) + elif image_path.lower().endswith(".czi"): + self.load_czi(image_path, dimensions, shape) + else: + print("No stored dimensions or shape, loading as new image") + if image_path.lower().endswith((".tif", ".tiff")): + self.load_tiff(image_path) + elif image_path.lower().endswith(".czi"): + self.load_czi(image_path) + + print(f"Loaded multi-slice image: {file_name}") + print(f"Dimensions: {self.mw.image_dimensions.get(base_name, 'Not found')}") + print(f"Shape: {self.mw.image_shapes.get(base_name, 'Not found')}") + print(f"Number of slices: {len(self.mw.slices)}") + + if self.mw.slices: + self.mw.current_image = self.mw.slices[0][1] + self.mw.current_slice = self.mw.slices[0][0] + + self.update_slice_list() + self.mw.slice_list.setCurrentRow(0) + self.activate_slice(self.mw.current_slice) + print(f"Activated first slice: {self.mw.current_slice}") + else: + print("No slices were loaded") + self.mw.current_image = None + self.mw.current_slice = None + + self.update_slice_list() + self.mw.image_label.update() + + def process_multidimensional_image( + self, image_array, image_path, dimensions=None, + force_dimension_dialog=False, axes_hint=None, + ): + file_name = os.path.basename(image_path) + base_name = os.path.splitext(file_name)[0] + print(f"Processing file: {file_name}") + print(f"Image array shape: {image_array.shape}") + print(f"Image array dtype: {image_array.dtype}") + + if dimensions is None or force_dimension_dialog: + if image_array.ndim > 2: + # ndim≥5 had a `[-ndim:]` slice bug that produced 2560 wrong + # slices on a 5D TZCYX file — see arc42. + if axes_hint and len(axes_hint) == image_array.ndim: + default_dimensions = list(axes_hint) + print(f"Applying axes hint as default dims: {default_dimensions}") + else: + if axes_hint and len(axes_hint) != image_array.ndim: + print( + f"Ignoring axes hint (length {len(axes_hint)} " + f"vs ndim {image_array.ndim})" + ) + ndim_defaults = { + 3: ["Z", "H", "W"], + 4: ["T", "Z", "H", "W"], + 5: ["T", "Z", "C", "H", "W"], + 6: ["T", "Z", "C", "S", "H", "W"], + } + default_dimensions = ndim_defaults.get( + image_array.ndim, + ["T"] * max(0, image_array.ndim - 2) + ["H", "W"], + ) + + progress = QProgressDialog( + "Assigning dimensions...", "Cancel", 0, 100, self.mw + ) + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + progress.setValue(10) + QApplication.processEvents() + + while True: + dialog = DimensionDialog( + image_array.shape, file_name, self.mw, default_dimensions + ) + progress.setValue(50) + QApplication.processEvents() + if dialog.exec(): + dimensions = dialog.get_dimensions() + print(f"Assigned dimensions: {dimensions}") + if "H" in dimensions and "W" in dimensions: + self.mw.image_dimensions[base_name] = dimensions + break + else: + QMessageBox.warning( + self.mw, + "Invalid Dimensions", + "You must assign both H and W dimensions.", + ) + else: + progress.close() + return + progress.setValue(100) + progress.close() + else: + dimensions = ["H", "W"] + self.mw.image_dimensions[base_name] = dimensions + + self.mw.image_shapes[base_name] = image_array.shape + print(f"Final assigned dimensions: {self.mw.image_dimensions[base_name]}") + print(f"Image shape: {self.mw.image_shapes[base_name]}") + + if self.mw.image_dimensions[base_name]: + self.create_slices( + image_array, self.mw.image_dimensions[base_name], image_path + ) + else: + rgb_image = image_utils.convert_to_8bit_rgb(image_array) + self.mw.current_image = image_utils.array_to_qimage(rgb_image) + self.mw.slices = [] + self.mw.slice_list.clear() + + if self.mw.slices: + self.mw.current_image = self.mw.slices[0][1] + self.mw.current_slice = self.mw.slices[0][0] + self.mw.slice_list.setCurrentRow(0) + self.mw.load_image_annotations() + self.mw.image_label.update() + + self.mw.update_image_info() + + self.update_slice_list() + self.mw.update_annotation_list() + self.mw.image_label.update() + + def create_slices(self, image_array, dimensions, image_path): + base_name = os.path.splitext(os.path.basename(image_path))[0] + slices = [] + self.mw.slice_list.clear() + + print(f"Creating slices for {base_name}") + print(f"Dimensions: {dimensions}") + print(f"Image array shape: {image_array.shape}") + + progress = QProgressDialog("Loading slices...", "Cancel", 0, 100, self.mw) + progress.setWindowModality(Qt.WindowModality.WindowModal) + progress.setMinimumDuration(0) + + if image_array.ndim == 2: + progress.setValue(50) + QApplication.processEvents() + normalized_array = image_utils.normalize_array(image_array) + qimage = image_utils.array_to_qimage(normalized_array) + slice_name = f"{base_name}" + slices.append((slice_name, qimage)) + self.add_slice_to_list(slice_name) + else: + slice_indices = [ + i for i, dim in enumerate(dimensions) if dim not in ["H", "W"] + ] + + total_slices = np.prod([image_array.shape[i] for i in slice_indices]) + for idx, _ in enumerate( + np.ndindex(tuple(image_array.shape[i] for i in slice_indices)) + ): + if progress.wasCanceled(): + break + + full_idx = [slice(None)] * len(dimensions) + for i, val in zip(slice_indices, _): + full_idx[i] = val + + slice_array = image_array[tuple(full_idx)] + rgb_slice = image_utils.convert_to_8bit_rgb(slice_array) + qimage = image_utils.array_to_qimage(rgb_slice) + + slice_name = f"{base_name}_{'_'.join([f'{dimensions[i]}{val+1}' for i, val in zip(slice_indices, _)])}" + slices.append((slice_name, qimage)) + + self.add_slice_to_list(slice_name) + + progress_value = int((idx + 1) / total_slices * 100) + progress.setValue(progress_value) + QApplication.processEvents() + + progress.setValue(100) + + self.mw.image_slices[base_name] = slices + self.mw.slices = slices + + if slices: + self.mw.current_image = slices[0][1] + self.mw.current_slice = slices[0][0] + self.mw.slice_list.setCurrentRow(0) + + self.activate_slice(self.mw.current_slice) + + slice_info = f"Total slices: {len(slices)}" + for dim, size in zip(dimensions, image_array.shape): + if dim not in ["H", "W"]: + slice_info += f", {dim}: {size}" + self.mw.update_image_info(additional_info=slice_info) + else: + print("No slices were created") + + print(f"Created {len(slices)} slices for {base_name}") + return slices + + def add_slice_to_list(self, slice_name): + item = QListWidgetItem(slice_name) + + if self.mw.dark_mode: + item.setBackground(QColor(40, 40, 40)) + if slice_name in self.mw.all_annotations: + item.setForeground(QColor(235, 235, 235)) + item.setBackground(QColor(58, 95, 140)) + else: + item.setForeground(QColor(200, 200, 200)) + else: + item.setBackground(QColor(240, 240, 240)) + if slice_name in self.mw.all_annotations: + item.setForeground(QColor(255, 255, 255)) + item.setBackground(QColor(70, 130, 180)) + else: + item.setForeground(QColor(0, 0, 0)) + + self.mw.slice_list.addItem(item) + + def activate_slice(self, slice_name): + self.mw.current_slice = slice_name + self.mw.image_file_name = slice_name + self.mw.load_image_annotations() + self.mw.update_annotation_list() + + for name, qimage in self.mw.slices: + if name == slice_name: + self.mw.current_image = qimage + self.display_image() + break + + self.mw.image_label.update() + + items = self.mw.slice_list.findItems(slice_name, Qt.MatchFlag.MatchExactly) + if items: + self.mw.slice_list.setCurrentItem(items[0]) + + def update_slice_list(self): + self.mw.slice_list.clear() + for slice_name, _ in self.mw.slices: + item = QListWidgetItem(slice_name) + if slice_name in self.mw.all_annotations: + item.setForeground(QColor(Qt.GlobalColor.green)) + else: + item.setForeground( + QColor(Qt.GlobalColor.black) + if not self.mw.dark_mode + else QColor(Qt.GlobalColor.white) + ) + self.mw.slice_list.addItem(item) + + if self.mw.current_slice: + items = self.mw.slice_list.findItems( + self.mw.current_slice, Qt.MatchFlag.MatchExactly + ) + if items: + self.mw.slice_list.setCurrentItem(items[0]) + + def clear_slice_list(self): + self.mw.slice_list.clear() + self.mw.slices = [] + self.mw.current_slice = None + + def is_multi_dimensional(self, file_name): + return file_name.lower().endswith((".tif", ".tiff", ".czi")) + + def redefine_dimensions(self, file_name): + file_path = self.mw.image_paths.get(file_name) + if not file_path or not file_path.lower().endswith((".tif", ".tiff", ".czi")): + return + + reply = QMessageBox.warning( + self.mw, + "Redefine Dimensions", + "Redefining dimensions will cause all associated annotations to be lost. " + "Do you want to continue?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + + if reply == QMessageBox.StandardButton.Yes: + base_name = os.path.splitext(file_name)[0] + + print(f"Removing annotations for image: {base_name}") + + keys_to_remove = [ + key + for key in self.mw.all_annotations.keys() + if key == base_name + or ( + key.startswith(f"{base_name}_") + and not key.startswith(f"{base_name}_8bit") + ) + ] + + print(f"Keys to remove: {keys_to_remove}") + + for key in keys_to_remove: + del self.mw.all_annotations[key] + + if base_name in self.mw.image_slices: + del self.mw.image_slices[base_name] + + if self.mw.image_file_name == file_name: + self.mw.current_image = None + self.mw.image_label.clear() + + if file_path.lower().endswith((".tif", ".tiff")): + self.load_tiff(file_path, force_dimension_dialog=True) + elif file_path.lower().endswith(".czi"): + self.load_czi(file_path, force_dimension_dialog=True) + + self.update_slice_list() + self.mw.update_annotation_list() + self.mw.image_label.update() + + QMessageBox.information( + self.mw, + "Dimensions Redefined", + "The dimensions have been redefined and the image reloaded. " + "All previous annotations for this image have been removed.", + ) + + def remove_image(self): + current_item = self.mw.image_list.currentItem() + if current_item: + file_name = current_item.text() + + self.mw.image_list.takeItem(self.mw.image_list.row(current_item)) + self.mw.image_paths.pop(file_name, None) + self.mw.all_images = [ + img for img in self.mw.all_images if img["file_name"] != file_name + ] + + self.mw.all_annotations.pop(file_name, None) + + base_name = os.path.splitext(file_name)[0] + if base_name in self.mw.image_slices: + for slice_name, _ in self.mw.image_slices[base_name]: + self.mw.all_annotations.pop(slice_name, None) + del self.mw.image_slices[base_name] + + self.mw.slice_list.clear() + + if self.mw.image_file_name == file_name: + self.mw.current_image = None + self.mw.image_file_name = "" + self.mw.current_slice = None + self.mw.image_label.clear() + self.mw.annotation_list.clear() + + if self.mw.image_list.count() > 0: + next_item = self.mw.image_list.item(0) + self.mw.image_list.setCurrentItem(next_item) + self.switch_image(next_item) + else: + self.mw.current_image = None + self.mw.image_file_name = "" + self.mw.current_slice = None + self.mw.image_label.clear() + self.mw.annotation_list.clear() + self.mw.slice_list.clear() + + self.mw.update_ui() + self.mw.auto_save() + + def delete_selected_image(self): + current_item = self.mw.image_list.currentItem() + if current_item: + file_name = current_item.text() + reply = QMessageBox.question( + self.mw, + "Delete Image", + f"Are you sure you want to delete the image '{file_name}'?\n\n" + "This will remove the image and all its associated annotations.", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + + if reply == QMessageBox.StandardButton.Yes: + self.mw.image_list.takeItem(self.mw.image_list.row(current_item)) + self.mw.image_paths.pop(file_name, None) + self.mw.all_images = [ + img for img in self.mw.all_images if img["file_name"] != file_name + ] + + self.mw.all_annotations.pop(file_name, None) + + base_name = os.path.splitext(file_name)[0] + if base_name in self.mw.image_slices: + for slice_name, _ in self.mw.image_slices[base_name]: + self.mw.all_annotations.pop(slice_name, None) + del self.mw.image_slices[base_name] + + self.mw.slice_list.clear() + + if self.mw.image_file_name == file_name: + self.mw.current_image = None + self.mw.image_file_name = "" + self.mw.current_slice = None + self.mw.image_label.clear() + self.mw.annotation_list.clear() + + if self.mw.image_list.count() > 0: + next_item = self.mw.image_list.item(0) + self.mw.image_list.setCurrentItem(next_item) + self.switch_image(next_item) + else: + self.mw.current_image = None + self.mw.image_file_name = "" + self.mw.current_slice = None + self.mw.image_label.clear() + self.mw.annotation_list.clear() + self.mw.slice_list.clear() + + self.mw.update_ui() + + QMessageBox.information( + self.mw, + "Image Deleted", + f"The image '{file_name}' has been deleted.", + ) + + def display_image(self): + if self.mw.current_image: + if isinstance(self.mw.current_image, QImage): + pixmap = QPixmap.fromImage(self.mw.current_image) + elif isinstance(self.mw.current_image, QPixmap): + pixmap = self.mw.current_image + else: + print(f"Unexpected image type: {type(self.mw.current_image)}") + return + + if not pixmap.isNull(): + self.mw.image_label.setPixmap(pixmap) + self.mw.image_label.adjustSize() + else: + print("Error: Null pixmap") + else: + self.mw.image_label.clear() + print("No current image to display") diff --git a/src/digitalsreeni_image_annotator/controllers/io_controller.py b/src/digitalsreeni_image_annotator/controllers/io_controller.py new file mode 100644 index 0000000..f3de559 --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/io_controller.py @@ -0,0 +1,338 @@ +"""Import / export / save-slices orchestration extracted from `ImageAnnotator`. + +The actual format readers and writers live in `io.import_formats` and +`io.export_formats` and are pure functions parameterised on annotation +state. The wrappers here are the UI glue: file dialogs, state mutation +on the main window, status message boxes, auto-save trigger. + +Functions take the main window as the first argument so call sites +inside `annotator_window.py` delegate trivially. +""" + +import os + +from PyQt6.QtCore import Qt +from PyQt6.QtGui import QColor +from PyQt6.QtWidgets import QFileDialog, QMessageBox + +from ..io.export_formats import ( + export_coco_json, + export_labeled_images, + export_pascal_voc_bbox, + export_pascal_voc_both, + export_semantic_labels, + export_yolo_v4, + export_yolo_v5plus, +) +from ..io.import_formats import import_coco_json, process_import_format + + +def import_annotations(mw): + if not mw.image_label.check_unsaved_changes(): + return + print("Starting import_annotations") + import_format = mw.import_format_selector.currentText() + print(f"Import format: {import_format}") + + if import_format == "COCO JSON": + file_name, _ = QFileDialog.getOpenFileName( + mw, "Import COCO JSON Annotations", "", "JSON Files (*.json)" + ) + if not file_name: + print("No file selected, returning") + return + + print(f"Selected file: {file_name}") + json_dir = os.path.dirname(file_name) + images_dir = os.path.join(json_dir, "images") + imported_annotations, image_info = import_coco_json(file_name, mw.class_mapping) + + elif import_format in ["YOLO (v4 and earlier)", "YOLO (v5+)"]: + yaml_file, _ = QFileDialog.getOpenFileName( + mw, "Select YOLO Dataset YAML", "", "YAML Files (*.yaml *.yml)" + ) + if not yaml_file: + print("No YAML file selected, returning") + return + + print(f"Selected YAML file: {yaml_file}") + try: + imported_annotations, image_info = process_import_format( + import_format, yaml_file, mw.class_mapping + ) + yaml_dir = os.path.dirname(yaml_file) + if import_format == "YOLO (v4 and earlier)": + images_dir = os.path.join(yaml_dir, "train", "images") + else: + images_dir = os.path.join(yaml_dir, "images", "train") + except ValueError as e: + QMessageBox.warning(mw, "Import Error", str(e)) + return + + else: + QMessageBox.warning( + mw, + "Unsupported Format", + f"The selected format '{import_format}' is not implemented for import.", + ) + return + + print( + f"JSON/YOLO directory: {json_dir if import_format == 'COCO JSON' else os.path.dirname(yaml_file)}" + ) + print(f"Images directory: {images_dir}") + print(f"Imported annotations count: {len(imported_annotations)}") + print(f"Image info count: {len(image_info)}") + + images_loaded = 0 + images_not_found = [] + + for info in image_info.values(): + print(f"Processing image: {info['file_name']}") + image_path = os.path.join(images_dir, info["file_name"]) + + if os.path.exists(image_path): + print(f"Image found at: {image_path}") + mw.image_paths[info["file_name"]] = image_path + mw.all_images.append( + { + "file_name": info["file_name"], + "height": info["height"], + "width": info["width"], + "id": info["id"], + "is_multi_slice": False, + } + ) + images_loaded += 1 + else: + print(f"Image not found at: {image_path}") + images_not_found.append(info["file_name"]) + + print(f"Images loaded: {images_loaded}") + print(f"Images not found: {len(images_not_found)}") + + if images_not_found: + message = f"The following {len(images_not_found)} images were not found in the 'images' directory:\n\n" + message += "\n".join(images_not_found[:10]) + if len(images_not_found) > 10: + message += f"\n... and {len(images_not_found) - 10} more." + message += "\n\nDo you want to proceed and ignore annotations for these missing images?" + reply = QMessageBox.question( + mw, + "Missing Images", + message, + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + + if reply == QMessageBox.StandardButton.No: + print("Import cancelled due to missing images") + QMessageBox.information( + mw, + "Import Cancelled", + "Import cancelled. Please ensure all images are in the 'images' directory and try again.", + ) + return + + for image_name, annotations in imported_annotations.items(): + if image_name not in mw.image_paths: + continue + mw.all_annotations[image_name] = {} + for category_name, category_annotations in annotations.items(): + mw.all_annotations[image_name][category_name] = [] + for i, ann in enumerate(category_annotations, start=1): + new_ann = { + "segmentation": ann.get("segmentation"), + "bbox": ann.get("bbox"), + "category_id": ann["category_id"], + "category_name": category_name, + "number": i, + "type": ann.get("type", "polygon"), + } + mw.all_annotations[image_name][category_name].append(new_ann) + + for annotations in mw.all_annotations.values(): + for category_name in annotations.keys(): + if category_name not in mw.class_mapping: + new_id = len(mw.class_mapping) + 1 + mw.class_mapping[category_name] = new_id + mw.image_label.class_colors[category_name] = QColor( + Qt.GlobalColor(new_id % 16 + 7) + ) + + print("Updating UI") + mw.update_class_list() + mw.update_image_list() + mw.update_annotation_list() + + if mw.image_list.count() > 0: + mw.image_list.setCurrentRow(0) + mw.switch_image(mw.image_list.item(0)) + + if mw.class_list.count() > 0: + mw.class_list.setCurrentRow(0) + mw.on_class_selected() + + mw.image_label.update() + + message = ( + f"Annotations have been imported successfully from " + f"{file_name if import_format == 'COCO JSON' else yaml_file}.\n" + ) + message += f"{images_loaded} images were loaded from the 'images' directory.\n" + if images_not_found: + message += f"Annotations for {len(images_not_found)} missing images were ignored." + + print("Import complete, showing message") + QMessageBox.information(mw, "Import Complete", message) + mw.auto_save() + + +def export_annotations(mw): + if not mw.image_label.check_unsaved_changes(): + return + export_format = mw.export_format_selector.currentText() + + supported_formats = [ + "COCO JSON", + "YOLO (v4 and earlier)", + "YOLO (v5+)", + "Labeled Images", + "Semantic Labels", + "Pascal VOC (BBox)", + "Pascal VOC (BBox + Segmentation)", + ] + + if export_format not in supported_formats: + QMessageBox.warning( + mw, + "Unsupported Format", + f"The selected format '{export_format}' is not implemented.", + ) + return + + if export_format == "COCO JSON": + file_name, _ = QFileDialog.getSaveFileName( + mw, "Export COCO JSON Annotations", "", "JSON Files (*.json)" + ) + else: + file_name = QFileDialog.getExistingDirectory( + mw, f"Select Output Directory for {export_format} Export" + ) + + if not file_name: + return + + mw.save_current_annotations() + + if export_format == "COCO JSON": + output_dir = os.path.dirname(file_name) + json_filename = os.path.basename(file_name) + json_file, images_dir = export_coco_json( + mw.all_annotations, + mw.class_mapping, + mw.image_paths, + mw.slices, + mw.image_slices, + output_dir, + json_filename, + ) + message = "Annotations have been exported successfully in COCO JSON format.\n" + message += f"JSON file: {json_file}\nImages directory: {images_dir}" + + elif export_format == "YOLO (v4 and earlier)": + labels_dir, yaml_path = export_yolo_v4( + mw.all_annotations, + mw.class_mapping, + mw.image_paths, + mw.slices, + mw.image_slices, + file_name, + ) + message = "Annotations have been exported successfully in YOLO (v4 and earlier) format.\n" + message += f"Labels: {labels_dir}\nYAML: {yaml_path}" + + elif export_format == "YOLO (v5+)": + output_dir, yaml_path = export_yolo_v5plus( + mw.all_annotations, + mw.class_mapping, + mw.image_paths, + mw.slices, + mw.image_slices, + file_name, + ) + message = "Annotations have been exported successfully in YOLO (v5+) format.\n" + message += f"Output directory: {output_dir}\nYAML: {yaml_path}" + + elif export_format == "Labeled Images": + labeled_images_dir = export_labeled_images( + mw.all_annotations, + mw.class_mapping, + mw.image_paths, + mw.slices, + mw.image_slices, + file_name, + ) + message = ( + f"Labeled images have been exported successfully.\n" + f"Labeled Images: {labeled_images_dir}\n" + ) + message += ( + f"A class summary has been saved in: " + f"{os.path.join(labeled_images_dir, 'class_summary.txt')}" + ) + + elif export_format == "Semantic Labels": + semantic_labels_dir = export_semantic_labels( + mw.all_annotations, + mw.class_mapping, + mw.image_paths, + mw.slices, + mw.image_slices, + file_name, + ) + message = ( + f"Semantic labels have been exported successfully.\n" + f"Semantic Labels: {semantic_labels_dir}\n" + ) + message += ( + f"A class-pixel mapping has been saved in: " + f"{os.path.join(semantic_labels_dir, 'class_pixel_mapping.txt')}" + ) + + elif export_format == "Pascal VOC (BBox)": + voc_dir = export_pascal_voc_bbox( + mw.all_annotations, + mw.class_mapping, + mw.image_paths, + mw.slices, + mw.image_slices, + file_name, + ) + message = "Annotations have been exported successfully in Pascal VOC format (BBox only).\n" + message += f"Pascal VOC Annotations: {voc_dir}" + + elif export_format == "Pascal VOC (BBox + Segmentation)": + voc_dir = export_pascal_voc_both( + mw.all_annotations, + mw.class_mapping, + mw.image_paths, + mw.slices, + mw.image_slices, + file_name, + ) + message = "Annotations have been exported successfully in Pascal VOC format (BBox + Segmentation).\n" + message += f"Pascal VOC Annotations: {voc_dir}" + + QMessageBox.information(mw, "Export Complete", message) + + +def save_slices(mw, directory): + slices_saved = False + for image_file, image_slices in mw.image_slices.items(): + for slice_name, qimage in image_slices: + if slice_name in mw.all_annotations and mw.all_annotations[slice_name]: + file_path = os.path.join(directory, f"{slice_name}.png") + qimage.save(file_path, "PNG") + slices_saved = True + return slices_saved diff --git a/src/digitalsreeni_image_annotator/controllers/project_controller.py b/src/digitalsreeni_image_annotator/controllers/project_controller.py new file mode 100644 index 0000000..70e92c3 --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/project_controller.py @@ -0,0 +1,548 @@ +"""Project lifecycle controller. + +Extracted from `ImageAnnotator` to give project I/O a single home: +creating, opening, saving, auto-saving, and handling missing images for +`.iap` project files. + +State (`is_loading_project`, `backup_project_path`, `current_project_file`, +`current_project_dir`, `project_notes`, etc.) currently still lives on +the main window and is read here via `self.mw`. A future phase may +migrate ownership of those attributes to the controller — for now this +extraction is purely method relocation. +""" + +import json +import os +import shutil +from datetime import datetime + +from PyQt6.QtCore import QObject +from PyQt6.QtGui import QColor +from PyQt6.QtWidgets import QFileDialog, QInputDialog, QMessageBox + +from ..core import image_utils + + +class ProjectController(QObject): + def __init__(self, main_window): + super().__init__(main_window) + self.mw = main_window + + def update_window_title(self): + base_title = "Image Annotator" + if hasattr(self.mw, "current_project_file"): + project_name = os.path.basename(self.mw.current_project_file) + project_name = os.path.splitext(project_name)[0] + self.mw.setWindowTitle(f"{base_title} - {project_name}") + else: + self.mw.setWindowTitle(base_title) + + def new_project(self): + self.mw.remove_all_temp_annotations() + project_file, _ = QFileDialog.getSaveFileName( + self.mw, "Create New Project", "", "Image Annotator Project (*.iap)" + ) + if project_file: + if not project_file.lower().endswith(".iap"): + project_file += ".iap" + + self.mw.current_project_file = project_file + self.mw.current_project_dir = os.path.dirname(project_file) + + images_dir = os.path.join(self.mw.current_project_dir, "images") + os.makedirs(images_dir, exist_ok=True) + + self.mw.clear_all(new_project=True, show_messages=False) + + notes, ok = QInputDialog.getMultiLineText( + self.mw, "Project Notes", "Enter initial project notes:" + ) + self.mw.project_notes = notes if ok else "" + self.mw.project_creation_date = datetime.now().isoformat() + + self.save_project(show_message=False) + + self.mw.show_info( + "New Project", f"New project created at {self.mw.current_project_file}" + ) + self.mw.initialize_yolo_trainer() + self.update_window_title() + + def open_project(self): + print("open_project method called") + self.mw.remove_all_temp_annotations() + project_file, _ = QFileDialog.getOpenFileName( + self.mw, "Open Project", "", "Image Annotator Project (*.iap)" + ) + print(f"Selected project file: {project_file}") + if project_file: + try: + self.backup_project_before_open(project_file) + self.open_specific_project(project_file) + except Exception as e: + self.restore_project_from_backup() + QMessageBox.critical( + self.mw, + "Error", + f"An error occurred while opening the project: {str(e)}\n" + f"The project file has been restored from backup.", + ) + else: + print("No project file selected") + + def backup_project_before_open(self, project_file): + """Create a backup of the project file before opening it.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_dir = os.path.join(os.path.dirname(project_file), ".project_backups") + os.makedirs(backup_dir, exist_ok=True) + + self.mw.backup_project_path = os.path.join( + backup_dir, f"{os.path.basename(project_file)}.{timestamp}.backup" + ) + shutil.copy2(project_file, self.mw.backup_project_path) + + def restore_project_from_backup(self): + """Restore the project file from its backup if available.""" + if self.mw.backup_project_path and os.path.exists(self.mw.backup_project_path): + try: + shutil.copy2(self.mw.backup_project_path, self.mw.current_project_file) + print(f"Project restored from backup: {self.mw.backup_project_path}") + except Exception as e: + print(f"Failed to restore from backup: {str(e)}") + + def open_specific_project(self, project_file): + print(f"Opening specific project: {project_file}") + if os.path.exists(project_file): + try: + self.mw.is_loading_project = True + + with open(project_file, "r") as f: + project_data = json.load(f) + + self.mw.clear_all(show_messages=False) + self.mw.current_project_file = project_file + self.mw.current_project_dir = os.path.dirname(project_file) + + self.mw.project_notes = project_data.get("notes", "") + self.mw.project_creation_date = project_data.get("creation_date", "") + self.mw.last_modified = project_data.get("last_modified", "") + + if self.mw.project_creation_date: + self.mw.project_creation_date = datetime.fromisoformat( + self.mw.project_creation_date + ).strftime("%Y-%m-%d %H:%M:%S") + if self.mw.last_modified: + self.mw.last_modified = datetime.fromisoformat( + self.mw.last_modified + ).strftime("%Y-%m-%d %H:%M:%S") + + self.load_project_data(project_data) + + self.mw.is_loading_project = False + if self.mw.dino_class_table.rowCount() > 0: + self.mw.dino_class_table.selectRow(0) + self.save_project(show_message=False) + + self.mw.initialize_yolo_trainer() + self.update_window_title() + + print(f"Project opened successfully: {project_file}") + QMessageBox.information( + self.mw, + "Project Opened", + f"Project opened successfully: {os.path.basename(project_file)}", + ) + + except Exception as e: + self.mw.is_loading_project = False + raise e + else: + print(f"Project file not found: {project_file}") + QMessageBox.critical( + self.mw, "Error", f"Project file not found: {project_file}" + ) + + def load_project_data(self, project_data): + """Load project data without triggering auto-saves.""" + self.mw.class_mapping.clear() + self.mw.image_label.class_colors.clear() + for class_info in project_data.get("classes", []): + self.mw.add_class(class_info["name"], QColor(class_info["color"])) + + self.mw.all_images = project_data.get("images", []) + self.mw.image_paths = project_data.get("image_paths", {}) + + self.mw.all_annotations.clear() + for image_info in project_data["images"]: + if image_info.get("is_multi_slice", False): + for slice_info in image_info.get("slices", []): + self.mw.all_annotations[slice_info["name"]] = slice_info["annotations"] + else: + self.mw.all_annotations[image_info["file_name"]] = image_info.get( + "annotations", {} + ) + + missing_images = [] + for image_info in project_data["images"]: + image_path = os.path.join( + self.mw.current_project_dir, "images", image_info["file_name"] + ) + + if not os.path.exists(image_path): + missing_images.append(image_info["file_name"]) + continue + + self.mw.image_paths[image_info["file_name"]] = image_path + + if image_info.get("is_multi_slice", False): + dimensions = image_info.get("dimensions", []) + shape = image_info.get("shape", []) + self.mw.load_multi_slice_image(image_path, dimensions, shape) + else: + self.mw.add_images_to_list([image_path]) + + dino_cfg = project_data.get("dino_config", {}) + valid_classes = set(self.mw.class_mapping.keys()) + + phrases = dino_cfg.get("phrases", {}) + if phrases: + kept = {k: v for k, v in phrases.items() if k in valid_classes} + for orphan in phrases.keys() - kept.keys(): + print(f" Skipped saved DINO phrases for unknown class " + f"'{orphan}' — class is not in the current project.") + self.mw.dino_phrase_panel.set_phrases(kept) + + for cls_name, thr in dino_cfg.get("thresholds", {}).items(): + ok = self.mw.dino_class_table.set_thresholds( + cls_name, + thr.get("box", 0.25), + thr.get("txt", 0.25), + thr.get("nms", 0.50), + ) + if not ok: + print(f" Skipped saved DINO thresholds for unknown class " + f"'{cls_name}' — class is not in the current project.") + + self.mw.update_ui() + + if missing_images: + self.handle_missing_images(missing_images) + + if self.mw.image_list.count() > 0: + self.mw.image_list.setCurrentRow(0) + first_item = self.mw.image_list.item(0) + if first_item: + self.mw.switch_image(first_item) + + if self.mw.class_list.count() > 0: + self.mw.class_list.setCurrentRow(0) + self.mw.on_class_selected() + + def handle_missing_images(self, missing_images): + message = "The following images have annotations but were not found in the project directory:\n\n" + message += "\n".join(missing_images[:10]) + if len(missing_images) > 10: + message += f"\n... and {len(missing_images) - 10} more." + message += "\n\nWould you like to locate these images now?" + + reply = QMessageBox.question( + self.mw, + "Missing Images", + message, + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.Yes, + ) + + if reply == QMessageBox.StandardButton.Yes: + self.load_missing_images(missing_images) + else: + self.remove_missing_images(missing_images) + + def remove_missing_images(self, missing_images): + for image_name in missing_images: + self.mw.all_images = [ + img for img in self.mw.all_images if img["file_name"] != image_name + ] + self.mw.image_paths.pop(image_name, None) + self.mw.all_annotations.pop(image_name, None) + + base_name = os.path.splitext(image_name)[0] + if base_name in self.mw.image_slices: + for slice_name, _ in self.mw.image_slices[base_name]: + self.mw.all_annotations.pop(slice_name, None) + del self.mw.image_slices[base_name] + + self.mw.update_ui() + QMessageBox.information( + self.mw, + "Images Removed", + f"{len(missing_images)} missing images and their annotations have been removed from the project.", + ) + + def prompt_load_missing_images(self, missing_images): + message = "The following images have annotations but were not found in the project directory:\n\n" + message += "\n".join(missing_images[:10]) + if len(missing_images) > 10: + message += f"\n... and {len(missing_images) - 10} more." + message += "\n\nWould you like to locate these images now?" + + reply = QMessageBox.question( + self.mw, + "Load Missing Images", + message, + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.Yes, + ) + + if reply == QMessageBox.StandardButton.Yes: + self.load_missing_images(missing_images) + + def load_missing_images(self, missing_images): + files, _ = QFileDialog.getOpenFileNames( + self.mw, + "Select Missing Images", + "", + "Image Files (*.png *.jpg *.bmp *.tif *.tiff *.czi)", + ) + if files: + images_loaded = 0 + for file_path in files: + file_name = os.path.basename(file_path) + if file_name in missing_images: + dst_path = os.path.join( + self.mw.current_project_dir, "images", file_name + ) + shutil.copy2(file_path, dst_path) + self.mw.image_paths[file_name] = dst_path + + if not any( + img["file_name"] == file_name for img in self.mw.all_images + ): + self.mw.all_images.append( + { + "file_name": file_name, + "height": 0, + "width": 0, + "id": len(self.mw.all_images) + 1, + "is_multi_slice": False, + } + ) + images_loaded += 1 + missing_images.remove(file_name) + + self.mw.update_image_list() + if images_loaded > 0: + self.mw.image_list.setCurrentRow(0) + self.mw.switch_image(self.mw.image_list.item(0)) + QMessageBox.information( + self.mw, + "Images Loaded", + f"Successfully copied and loaded {images_loaded} out of {len(files)} selected images.", + ) + + if missing_images: + self.prompt_load_missing_images(missing_images) + + def check_missing_images(self): + missing_images = [ + img["file_name"] + for img in self.mw.all_images + if img["file_name"] not in self.mw.image_paths + or not os.path.exists(self.mw.image_paths[img["file_name"]]) + ] + if missing_images: + self.prompt_load_missing_images(missing_images) + + def close_project(self): + if hasattr(self.mw, "current_project_file"): + reply = QMessageBox.question( + self.mw, + "Close Project", + "Do you want to save the current project before closing?", + QMessageBox.StandardButton.Yes + | QMessageBox.StandardButton.No + | QMessageBox.StandardButton.Cancel, + ) + + if reply == QMessageBox.StandardButton.Yes: + self.mw.remove_all_temp_annotations() + self.save_project(show_message=False) + elif reply == QMessageBox.StandardButton.Cancel: + return + + self.mw.clear_all(new_project=True, show_messages=False) + + if hasattr(self.mw, "current_project_file"): + del self.mw.current_project_file + if hasattr(self.mw, "current_project_dir"): + del self.mw.current_project_dir + + self.update_window_title() + + def save_project(self, show_message=True): + if not hasattr(self.mw, "current_project_file") or not self.mw.current_project_file: + self.mw.current_project_file, _ = QFileDialog.getSaveFileName( + self.mw, "Save Project", "", "Image Annotator Project (*.iap)" + ) + if not self.mw.current_project_file: + return + + self.mw.current_project_dir = os.path.dirname(self.mw.current_project_file) + + images_dir = os.path.join(self.mw.current_project_dir, "images") + os.makedirs(images_dir, exist_ok=True) + + images_to_copy = [] + for file_name, src_path in self.mw.image_paths.items(): + dst_path = os.path.join(images_dir, file_name) + if os.path.abspath(src_path) != os.path.abspath(dst_path): + if not os.path.exists(dst_path): + images_to_copy.append((file_name, src_path, dst_path)) + + if images_to_copy: + reply = QMessageBox.question( + self.mw, + "Image Directory Structure", + f"The project structure requires all images to be in an 'images' subdirectory. " + f"{len(images_to_copy)} images need to be copied to the correct location. " + f"Do you want to copy these images?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.Yes, + ) + + if reply == QMessageBox.StandardButton.Yes: + for file_name, src_path, dst_path in images_to_copy: + try: + shutil.copy2(src_path, dst_path) + self.mw.image_paths[file_name] = dst_path + except Exception as e: + QMessageBox.warning( + self.mw, "Copy Failed", f"Failed to copy {file_name}: {str(e)}" + ) + return + else: + QMessageBox.warning( + self.mw, + "Save Cancelled", + "Project cannot be saved without the correct directory structure.", + ) + return + + images_data = [] + for image_info in self.mw.all_images: + file_name = image_info["file_name"] + image_data = { + "file_name": file_name, + "width": image_info["width"], + "height": image_info["height"], + "is_multi_slice": image_info["is_multi_slice"], + } + + if image_data["is_multi_slice"]: + base_name_without_ext = os.path.splitext(file_name)[0] + image_data["slices"] = [] + for slice_name, _ in self.mw.image_slices.get(base_name_without_ext, []): + slice_data = { + "name": slice_name, + "annotations": image_utils.convert_to_serializable( + self.mw.all_annotations.get(slice_name, {}) + ), + } + image_data["slices"].append(slice_data) + + image_data["dimensions"] = image_utils.convert_to_serializable( + self.mw.image_dimensions.get(base_name_without_ext, []) + ) + image_data["shape"] = image_utils.convert_to_serializable( + self.mw.image_shapes.get(base_name_without_ext, []) + ) + else: + image_data["annotations"] = {} + for class_name, annotations in self.mw.all_annotations.get( + file_name, {} + ).items(): + image_data["annotations"][class_name] = [ + ann.copy() for ann in annotations + ] + + images_data.append(image_data) + + project_data = { + "classes": [ + {"name": name, "color": color.name()} + for name, color in self.mw.image_label.class_colors.items() + ], + "images": images_data, + "image_paths": { + k: v for k, v in self.mw.image_paths.items() if os.path.exists(v) + }, + "notes": getattr(self.mw, "project_notes", ""), + "creation_date": getattr( + self.mw, "project_creation_date", datetime.now().isoformat() + ), + "last_modified": datetime.now().isoformat(), + } + + dino_cfg = { + "phrases": self.mw.dino_phrase_panel.get_all_phrases(), + "thresholds": self.mw.dino_class_table.get_thresholds_dict(), + } + if dino_cfg["phrases"] or dino_cfg["thresholds"]: + project_data["dino_config"] = dino_cfg + + with open(self.mw.current_project_file, "w") as f: + json.dump(image_utils.convert_to_serializable(project_data), f, indent=2) + + if show_message: + self.mw.show_info( + "Project Saved", f"Project saved to {self.mw.current_project_file}" + ) + + self.update_window_title() + + for file_name in self.mw.image_paths.keys(): + self.mw.image_paths[file_name] = os.path.join(images_dir, file_name) + + def save_project_as(self): + new_project_file, _ = QFileDialog.getSaveFileName( + self.mw, "Save Project As", "", "Image Annotator Project (*.iap)" + ) + if new_project_file: + if not new_project_file.lower().endswith(".iap"): + new_project_file += ".iap" + + original_project_file = getattr(self.mw, "current_project_file", None) + + self.mw.current_project_file = new_project_file + self.mw.current_project_dir = os.path.dirname(new_project_file) + + self.save_project(show_message=False) + self.update_window_title() + + QMessageBox.information( + self.mw, "Project Saved As", f"Project saved as:\n{new_project_file}" + ) + + if original_project_file is None: + self.mw.current_project_file = new_project_file + + def auto_save(self): + if self.mw.is_loading_project: + return + + if not hasattr(self.mw, "current_project_file"): + reply = QMessageBox.question( + self.mw, + "No Project", + "You need to save the project before auto-saving. Would you like to save now?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.Yes, + ) + if reply == QMessageBox.StandardButton.Yes: + self.save_project() + else: + return + + if hasattr(self.mw, "current_project_file"): + self.save_project(show_message=False) + print("Project auto-saved.") diff --git a/src/digitalsreeni_image_annotator/controllers/sam_controller.py b/src/digitalsreeni_image_annotator/controllers/sam_controller.py new file mode 100644 index 0000000..096fa4e --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/sam_controller.py @@ -0,0 +1,219 @@ +"""SAM (Segment Anything) coordination controller. + +Extracted from `ImageAnnotator`. Owns the SAM tool lifecycle (box, +points), the debounce timer state machine, ADR-013's in-flight +re-entrancy guard, and the model picker dropdown plumbing. + +State (`sam_utils`, `sam_inference_timer`, `_sam_inference_in_flight`, +`current_sam_model`) stays on the main window in this phase for the +same reason ProjectController / ImageController state stays there: +external callers (image_label.py, clear_all, the sidebar button +enabling logic) read these attributes directly via `main_window.X`. A +future phase may migrate ownership. + +ADR-013 invariants preserved verbatim: +- `_sam_inference_in_flight` flag set BEFORE calling + `sam_utils.apply_sam_*`, cleared in `finally`. +- `InferenceBusyError` (raised by `sam_utils._run_sync` when the worker + thread is already running) is swallowed silently — the next user + click restarts the debounce. +- `change_sam_model` blocks via `_run_sync` event-loop pump; UI stays + responsive. +""" + +import traceback + +from PyQt6.QtCore import Qt, QObject +from PyQt6.QtWidgets import QMessageBox + +from ..inference.sam_utils import InferenceBusyError + + +class SAMController(QObject): + def __init__(self, main_window): + super().__init__(main_window) + self.mw = main_window + + def deactivate_sam_tools(self): + """Turn off SAM box / points and clear any pending SAM state. + + Called before YOLO predictions overlay their own temp results + and when the SAM model is unset, so a stale bbox / point set / + temp prediction can't linger into the next workflow.""" + self.mw.sam_inference_timer.stop() + self.mw.sam_box_button.setChecked(False) + self.mw.sam_points_button.setChecked(False) + + image_label = self.mw.image_label + if image_label.current_tool in ("sam_box", "sam_points"): + image_label.current_tool = None + image_label.sam_box_active = False + image_label.sam_points_active = False + image_label.sam_bbox = None + image_label.drawing_sam_bbox = False + image_label.sam_positive_points = [] + image_label.sam_negative_points = [] + image_label.temp_sam_prediction = None + image_label.setCursor(Qt.CursorShape.ArrowCursor) + + self.mw.update_ui_for_current_tool() + + def schedule_sam_prediction(self): + """Restart the debounce timer; inference fires 1s after last click.""" + self.mw.sam_inference_timer.stop() + self.mw.sam_inference_timer.start(1000) + + def cancel_sam_debounce(self): + """Stop the SAM debounce timer so a queued inference doesn't + fire. Does NOT abort an in-flight inference; that case is + handled by the _sam_inference_in_flight guard (ADR-013). + Triggered by Escape in ImageLabel while sam_points is active.""" + self.mw.sam_inference_timer.stop() + + def apply_sam_prediction(self): + # Re-entry guard (ADR-013): the event-loop pump inside _run_sync + # can deliver this timer fire before the first call returns. + # Bail and rely on the user clicking again (which restarts the + # debounce) to issue a fresh inference with the up-to-date + # point set. + if self.mw._sam_inference_in_flight: + return + self.mw._sam_inference_in_flight = True + try: + try: + if self.mw.image_label.current_tool == "sam_box": + if self.mw.image_label.sam_bbox is None: + print("SAM bbox is None") + return + x1, y1, x2, y2 = self.mw.image_label.sam_bbox + bbox = [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] + prediction = self.mw.sam_utils.apply_sam_prediction( + self.mw.current_image, bbox + ) + self.mw.image_label.sam_bbox = None + elif self.mw.image_label.current_tool == "sam_points": + pos_points = self.mw.image_label.sam_positive_points + neg_points = self.mw.image_label.sam_negative_points + print( + f"[SAM-POINTS] Predicting with {len(pos_points)} positive points: {pos_points} " + f"and {len(neg_points)} negative points: {neg_points}" + ) + if not pos_points: + print("No positive points for SAM-points") + return + prediction = self.mw.sam_utils.apply_sam_points( + self.mw.current_image, + pos_points, + neg_points, + ) + else: + return + except InferenceBusyError: + # Re-entry safety net from sam_utils itself. The + # call-site flag above should catch this first, but if + # a different caller drives inference concurrently we + # skip — the user keeps interacting; their next click + # will restart the debounce. + return + except Exception as exc: + traceback.print_exc() + QMessageBox.critical( + self.mw, + "SAM Error", + f"SAM inference failed:\n\n{exc}\n\n" + "See the log for details.", + ) + return + + if prediction: + temp_annotation = { + "segmentation": prediction["segmentation"], + "category_id": self.mw.class_mapping[self.mw.current_class], + "category_name": self.mw.current_class, + "score": prediction["score"], + } + self.mw.image_label.temp_sam_prediction = temp_annotation + self.mw.image_label.update() + elif prediction is None: + QMessageBox.information( + self.mw, + "SAM", + "No mask matches the given constraints. " + "Try adjusting the box or point positions." + ) + else: + print("Failed to generate prediction") + + if self.mw.image_label.current_tool == "sam_box": + self.mw.image_label.sam_bbox = None + self.mw.image_label.update() + finally: + self.mw._sam_inference_in_flight = False + + def accept_sam_prediction(self): + if self.mw.image_label.temp_sam_prediction: + new_annotation = self.mw.image_label.temp_sam_prediction + self.mw.image_label.annotations.setdefault( + new_annotation["category_name"], [] + ).append(new_annotation) + self.mw.add_annotation_to_list(new_annotation) + self.mw.save_current_annotations() + self.mw.update_slice_list_colors() + self.mw.image_label.temp_sam_prediction = None + self.mw.image_label.sam_positive_points = [] + self.mw.image_label.sam_negative_points = [] + self.mw.image_label.update() + print("SAM prediction accepted, points cleared, and added to annotations.") + + def toggle_sam_box(self): + if self.mw.sam_box_button.isChecked(): + self.mw.sam_points_button.setChecked(False) + self.mw.image_label.current_tool = "sam_box" + self.mw.image_label.sam_box_active = True + self.mw.image_label.sam_points_active = False + self.mw.image_label.setCursor(Qt.CursorShape.CrossCursor) + else: + self.mw.image_label.current_tool = None + self.mw.image_label.sam_box_active = False + self.mw.image_label.setCursor(Qt.CursorShape.ArrowCursor) + self.mw.update_ui_for_current_tool() + + def toggle_sam_points(self): + if self.mw.sam_points_button.isChecked(): + self.mw.sam_box_button.setChecked(False) + self.mw.image_label.current_tool = "sam_points" + self.mw.image_label.sam_points_active = True + self.mw.image_label.sam_box_active = False + self.mw.image_label.setCursor(Qt.CursorShape.CrossCursor) + self.mw.image_label.sam_positive_points = [] + self.mw.image_label.sam_negative_points = [] + else: + self.mw.sam_inference_timer.stop() + self.mw.image_label.current_tool = None + self.mw.image_label.sam_points_active = False + self.mw.image_label.setCursor(Qt.CursorShape.ArrowCursor) + self.mw.image_label.sam_positive_points = [] + self.mw.image_label.sam_negative_points = [] + self.mw.update_ui_for_current_tool() + + def change_sam_model(self, model_name): + try: + self.mw.sam_utils.change_sam_model(model_name) + except Exception as e: + QMessageBox.critical( + self.mw, + "SAM Model Error", + f"Failed to load SAM model '{model_name}':\n\n{str(e)}\n\n" + "Check that the model weights are downloadable and that torch " + "is correctly installed for your platform / GPU." + ) + self.mw.sam_model_selector.setCurrentIndex(0) + return + + self.mw.current_sam_model = self.mw.sam_utils.current_sam_model + + if model_name != "Pick a SAM Model": + print(f"Changed SAM model to: {model_name}") + else: + self.deactivate_sam_tools() + print("SAM model unset") diff --git a/src/digitalsreeni_image_annotator/controllers/yolo_controller.py b/src/digitalsreeni_image_annotator/controllers/yolo_controller.py new file mode 100644 index 0000000..802c096 --- /dev/null +++ b/src/digitalsreeni_image_annotator/controllers/yolo_controller.py @@ -0,0 +1,539 @@ +"""YOLO training / prediction coordination controller. + +Extracted from `ImageAnnotator`. Owns: + +- The YOLO menu (Training submenu + Prediction Settings submenu) +- Pre-trained model loading and dataset preparation +- Training: dialog wiring, the `TrainingThread` worker, progress + callback chain, finish handler +- Prediction: model loading via `LoadPredictionModelDialog`, the + confidence-threshold dialog, single-image and multi-image prediction +- Result post-processing (`process_yolo_results`) that converts YOLO + output into temp annotations for the user to review + +State (`yolo_trainer`, `training_thread`, `training_dialog`) stays on +the main window — the menu actions and signal connections are +addressed from elsewhere as `main_window.X`, and `training_dialog` is +referenced via `hasattr(self, "training_dialog")` to lazily initialize. +""" + +import cv2 +import numpy as np +from PyQt6.QtCore import QObject, QThread, pyqtSignal +from PyQt6.QtGui import QAction +from PyQt6.QtWidgets import ( + QDialog, + QDialogButtonBox, + QDoubleSpinBox, + QInputDialog, + QLabel, + QLineEdit, + QListWidget, + QMessageBox, + QPushButton, + QVBoxLayout, +) + +from ..dialogs.yolo_trainer import ( + LoadPredictionModelDialog, + TrainingInfoDialog, + YOLOTrainer, +) + + +class TrainingThread(QThread): + progress_update = pyqtSignal(str) + finished = pyqtSignal(object) + + def __init__(self, yolo_trainer, epochs, imgsz): + super().__init__() + self.yolo_trainer = yolo_trainer + self.epochs = epochs + self.imgsz = imgsz + + def run(self): + try: + results = self.yolo_trainer.train_model( + epochs=self.epochs, imgsz=self.imgsz + ) + self.finished.emit(results) + except Exception as e: + self.finished.emit(str(e)) + + +class YOLOController(QObject): + def __init__(self, main_window): + super().__init__(main_window) + self.mw = main_window + + def setup_yolo_menu(self): + yolo_menu = self.mw.menuBar().addMenu("&YOLO (beta)") + + training_submenu = yolo_menu.addMenu("Training") + + load_pretrained_action = QAction("Load Pre-trained Model", self.mw) + load_pretrained_action.triggered.connect(self.load_yolo_model) + training_submenu.addAction(load_pretrained_action) + + prepare_data_action = QAction("Prepare YOLO Dataset", self.mw) + prepare_data_action.triggered.connect(self.prepare_yolo_dataset) + training_submenu.addAction(prepare_data_action) + + load_yaml_action = QAction("Load Dataset YAML", self.mw) + load_yaml_action.triggered.connect(self.load_yolo_yaml) + training_submenu.addAction(load_yaml_action) + + train_action = QAction("Train Model", self.mw) + train_action.triggered.connect(self.show_train_dialog) + training_submenu.addAction(train_action) + + save_model_action = QAction("Save Model", self.mw) + save_model_action.triggered.connect(self.save_yolo_model) + training_submenu.addAction(save_model_action) + + prediction_submenu = yolo_menu.addMenu("Prediction Settings") + + load_model_action = QAction("Load Model", self.mw) + load_model_action.triggered.connect(self.load_prediction_model) + prediction_submenu.addAction(load_model_action) + + set_threshold_action = QAction("Set Confidence Threshold", self.mw) + set_threshold_action.triggered.connect(self.set_confidence_threshold) + prediction_submenu.addAction(set_threshold_action) + + def initialize_yolo_trainer(self): + if hasattr(self.mw, "current_project_dir"): + self.mw.yolo_trainer = YOLOTrainer(self.mw.current_project_dir, self.mw) + else: + QMessageBox.warning( + self.mw, "No Project", "Please open or create a project first." + ) + + def load_yolo_model(self): + if not hasattr(self.mw, "current_project_dir"): + QMessageBox.warning( + self.mw, "No Project", "Please open or create a project first." + ) + return + + if not self.mw.yolo_trainer: + self.initialize_yolo_trainer() + + if self.mw.yolo_trainer.load_model(): + QMessageBox.information( + self.mw, "Model Loaded", "YOLO model loaded successfully." + ) + else: + QMessageBox.warning( + self.mw, "Load Cancelled", "Model loading was cancelled." + ) + + def prepare_yolo_dataset(self): + if not hasattr(self.mw, "current_project_file"): + QMessageBox.warning( + self.mw, "No Project", "Please open or create a project first." + ) + return + + if not self.mw.yolo_trainer: + self.initialize_yolo_trainer() + + try: + yaml_path = self.mw.yolo_trainer.prepare_dataset() + QMessageBox.information( + self.mw, + "Dataset Prepared", + f"YOLO dataset prepared successfully. YAML file: {yaml_path}", + ) + except Exception as e: + QMessageBox.critical( + self.mw, + "Error", + f"An error occurred while preparing the dataset: {str(e)}", + ) + + def load_yolo_yaml(self): + if not hasattr(self.mw, "current_project_file"): + QMessageBox.warning( + self.mw, "No Project", "Please open or create a project first." + ) + return + + if not self.mw.yolo_trainer: + self.initialize_yolo_trainer() + + try: + if self.mw.yolo_trainer.load_yaml(): + QMessageBox.information( + self.mw, "YAML Loaded", "Dataset YAML loaded successfully." + ) + else: + QMessageBox.warning( + self.mw, "Load Cancelled", "YAML loading was cancelled." + ) + except Exception as e: + QMessageBox.critical( + self.mw, + "Error", + f"An error occurred while loading the YAML file: {str(e)}", + ) + + def save_yolo_model(self): + if not hasattr(self.mw, "current_project_file"): + QMessageBox.warning( + self.mw, "No Project", "Please open or create a project first." + ) + return + + if not self.mw.yolo_trainer or not self.mw.yolo_trainer.model: + QMessageBox.warning( + self.mw, "No Model", "Please train or load a YOLO model first." + ) + return + + try: + if self.mw.yolo_trainer.save_model(): + QMessageBox.information( + self.mw, "Model Saved", "YOLO model saved successfully." + ) + else: + QMessageBox.warning( + self.mw, "Save Cancelled", "Model saving was cancelled." + ) + except Exception as e: + QMessageBox.critical( + self.mw, "Error", f"An error occurred while saving the model: {str(e)}" + ) + + def load_prediction_model(self): + if not hasattr(self.mw, "current_project_file"): + QMessageBox.warning( + self.mw, "No Project", "Please open or create a project first." + ) + return + + if not self.mw.yolo_trainer: + self.initialize_yolo_trainer() + + dialog = LoadPredictionModelDialog(self.mw) + if dialog.exec() == QDialog.DialogCode.Accepted: + model_path = dialog.model_path + yaml_path = dialog.yaml_path + if model_path and yaml_path: + try: + result, message = self.mw.yolo_trainer.load_prediction_model( + model_path, yaml_path + ) + if result: + QMessageBox.information( + self.mw, + "Model Loaded", + "YOLO model and YAML file loaded successfully for prediction.", + ) + if message: + QMessageBox.warning( + self.mw, "Class Mismatch Warning", message + ) + else: + QMessageBox.critical( + self.mw, + "Error Loading Model", + f"Could not load the model or YAML file: {message}", + ) + except Exception as e: + QMessageBox.critical( + self.mw, "Error", f"An error occurred: {str(e)}" + ) + else: + QMessageBox.warning( + self.mw, + "Files Required", + "Both model and YAML files are required for prediction.", + ) + + def show_train_dialog(self): + if not self.mw.yolo_trainer: + QMessageBox.warning( + self.mw, "No Project", "Please open or create a project first." + ) + return + if not self.mw.yolo_trainer.model: + QMessageBox.warning( + self.mw, "No Model", "Please load a pre-trained model first." + ) + return + if not self.mw.yolo_trainer.yaml_path: + QMessageBox.warning( + self.mw, "No Dataset", "Please prepare or load a dataset YAML first." + ) + return + + dialog = QDialog(self.mw) + dialog.setWindowTitle("Train YOLO Model") + layout = QVBoxLayout() + + epochs_label = QLabel("Number of Epochs:") + epochs_input = QLineEdit("100") + layout.addWidget(epochs_label) + layout.addWidget(epochs_input) + + imgsz_label = QLabel("Image Size:") + imgsz_input = QLineEdit("640") + layout.addWidget(imgsz_label) + layout.addWidget(imgsz_input) + + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) + button_box.accepted.connect(dialog.accept) + button_box.rejected.connect(dialog.reject) + layout.addWidget(button_box) + + dialog.setLayout(layout) + + if dialog.exec() == QDialog.DialogCode.Accepted: + epochs = int(epochs_input.text()) + imgsz = int(imgsz_input.text()) + self.start_training(epochs, imgsz) + + def start_training(self, epochs, imgsz): + if not hasattr(self.mw, "training_dialog"): + self.mw.training_dialog = TrainingInfoDialog(self.mw) + self.mw.training_dialog.show() + + self.mw.yolo_trainer.progress_signal.connect( + self.mw.training_dialog.update_info + ) + self.mw.yolo_trainer.set_progress_callback(self.mw.training_dialog.update_info) + self.mw.training_dialog.stop_signal.connect( + self.mw.yolo_trainer.stop_training_signal + ) + + self.mw.training_thread = TrainingThread(self.mw.yolo_trainer, epochs, imgsz) + self.mw.training_thread.finished.connect(self.training_finished) + self.mw.training_thread.start() + + def training_finished(self, results): + self.mw.training_dialog.stop_button.setEnabled(True) + self.mw.training_dialog.stop_button.setText("Stop Training") + self.mw.yolo_trainer.progress_signal.disconnect( + self.mw.training_dialog.update_info + ) + self.mw.training_dialog.stop_signal.disconnect( + self.mw.yolo_trainer.stop_training_signal + ) + + if isinstance(results, str): + QMessageBox.critical( + self.mw, + "Training Error", + f"An error occurred during training: {results}", + ) + else: + QMessageBox.information( + self.mw, + "Training Complete", + "YOLO model training completed successfully.", + ) + + def set_confidence_threshold(self): + if not hasattr(self.mw, "current_project_file"): + QMessageBox.warning( + self.mw, "No Project", "Please open or create a project first." + ) + return + + if not self.mw.yolo_trainer: + self.initialize_yolo_trainer() + + current_threshold = self.mw.yolo_trainer.conf_threshold + new_threshold, ok = QInputDialog.getDouble( + self.mw, + "Set Confidence Threshold", + "Enter confidence threshold (0-1):", + current_threshold, + 0, + 1, + 2, + ) + if ok: + self.mw.yolo_trainer.set_conf_threshold(new_threshold) + QMessageBox.information( + self.mw, + "Threshold Updated", + f"Confidence threshold set to {new_threshold}", + ) + + def show_predict_dialog(self): + if not self.mw.yolo_trainer or not self.mw.yolo_trainer.model: + QMessageBox.warning(self.mw, "No Model", "Please load a YOLO model first.") + return + + dialog = QDialog(self.mw) + dialog.setWindowTitle("Predict with YOLO Model") + layout = QVBoxLayout() + + image_list = QListWidget() + for image_name in self.mw.image_paths.keys(): + image_list.addItem(image_name) + layout.addWidget(QLabel("Select images for prediction:")) + layout.addWidget(image_list) + + conf_label = QLabel("Confidence Threshold:") + conf_input = QDoubleSpinBox() + conf_input.setRange(0, 1) + conf_input.setSingleStep(0.01) + conf_input.setValue(self.mw.yolo_trainer.conf_threshold) + layout.addWidget(conf_label) + layout.addWidget(conf_input) + + button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Cancel) + predict_button = QPushButton("Predict") + button_box.addButton(predict_button, QDialogButtonBox.ButtonRole.AcceptRole) + button_box.accepted.connect(dialog.accept) + button_box.rejected.connect(dialog.reject) + layout.addWidget(button_box) + + dialog.setLayout(layout) + + if dialog.exec() == QDialog.DialogCode.Accepted: + selected_images = [item.text() for item in image_list.selectedItems()] + conf = conf_input.value() + self.mw.yolo_trainer.set_conf_threshold(conf) + self.run_predictions(selected_images) + + def run_predictions(self, selected_images): + for image_name in selected_images: + image_path = self.mw.image_paths[image_name] + results = self.mw.yolo_trainer.predict(image_path) + self.process_yolo_results(results, image_name) + + def predict_single_image(self, file_name): + if self.mw.is_multi_dimensional(file_name): + return + + if not self.mw.yolo_trainer or not self.mw.yolo_trainer.model: + QMessageBox.warning( + self.mw, + "No Model", + "Please load a YOLO model first from the YOLO > Prediction Settings > Load Model menu.", + ) + return + + self.mw.deactivate_sam_tools() + + image_path = self.mw.image_paths[file_name] + try: + results = self.mw.yolo_trainer.predict(image_path) + self.process_yolo_results(results, file_name) + except Exception as e: + QMessageBox.warning( + self.mw, + "Prediction Error", + f"An error occurred during prediction: {str(e)}\n\n" + "This might be due to a mismatch between the model and the YAML file classes. " + "Please check that the YAML file corresponds to the loaded model.", + ) + + def process_yolo_results(self, results, image_name): + image_path = self.mw.image_paths[image_name] + image = cv2.imread(image_path) + if image is None: + QMessageBox.warning(self.mw, "Error", f"Failed to load image: {image_name}") + return + original_height, original_width = image.shape[:2] + + temp_annotations = {} + + try: + results, input_size, original_size = results + input_height, input_width = input_size + orig_height, orig_width = original_size + + scale_x = original_width / orig_width + scale_y = original_height / orig_height + + for result in results: + boxes = result.boxes + masks = result.masks + + if masks is None: + print(f"No masks found for {image_name}") + continue + + for mask, box in zip(masks, boxes): + try: + class_id = int(box.cls) + class_name = self.mw.yolo_trainer.class_names[class_id] + score = float(box.conf) + + mask_array = mask.data.cpu().numpy()[0] + mask_array = cv2.resize(mask_array, (orig_width, orig_height)) + contours, _ = cv2.findContours( + (mask_array > 0.5).astype(np.uint8), + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE, + ) + + if contours: + epsilon = 0.005 * cv2.arcLength(contours[0], True) + approx = cv2.approxPolyDP(contours[0], epsilon, True) + polygon = approx.flatten().tolist() + + scaled_polygon = [] + for i in range(0, len(polygon), 2): + x = polygon[i] * scale_x + y = polygon[i + 1] * scale_y + scaled_polygon.extend([x, y]) + + temp_class_name = f"Temp-{class_name}" + if temp_class_name not in temp_annotations: + temp_annotations[temp_class_name] = [] + + temp_annotation = { + "segmentation": scaled_polygon, + "category_name": temp_class_name, + "score": score, + "temp": True, + } + temp_annotations[temp_class_name].append(temp_annotation) + except IndexError: + QMessageBox.warning( + self.mw, + "Class Mismatch", + "There is a mismatch between the model and the YAML file classes. " + "Please check that the YAML file corresponds to the loaded model.", + ) + return + + except Exception as e: + QMessageBox.warning( + self.mw, + "Prediction Error", + f"An error occurred during prediction: {str(e)}\n\n" + "This might be due to a mismatch between the model and the YAML file classes. " + "Please check that the YAML file corresponds to the loaded model.", + ) + return + + self.mw.add_temp_classes(temp_annotations) + self.mw.update_class_list() + self.mw.image_label.update() + + if temp_annotations: + total_predictions = sum(len(anns) for anns in temp_annotations.values()) + QMessageBox.information( + self.mw, + "Review Predictions", + f"Found {total_predictions} predictions for {len(temp_annotations)} classes.\n" + "Use class visibility checkboxes to review.\n" + "Press Enter to accept or Esc to reject visible predictions.", + ) + else: + QMessageBox.information( + self.mw, + "No Predictions", + "No predictions were found for this image.", + ) + + self.mw.deactivate_sam_tools() diff --git a/src/digitalsreeni_image_annotator/core/__init__.py b/src/digitalsreeni_image_annotator/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/digitalsreeni_image_annotator/annotation_utils.py b/src/digitalsreeni_image_annotator/core/annotation_utils.py similarity index 100% rename from src/digitalsreeni_image_annotator/annotation_utils.py rename to src/digitalsreeni_image_annotator/core/annotation_utils.py diff --git a/src/digitalsreeni_image_annotator/constants.py b/src/digitalsreeni_image_annotator/core/constants.py similarity index 100% rename from src/digitalsreeni_image_annotator/constants.py rename to src/digitalsreeni_image_annotator/core/constants.py diff --git a/src/digitalsreeni_image_annotator/core/image_utils.py b/src/digitalsreeni_image_annotator/core/image_utils.py new file mode 100644 index 0000000..f3e2a98 --- /dev/null +++ b/src/digitalsreeni_image_annotator/core/image_utils.py @@ -0,0 +1,74 @@ +"""Pure image / array helpers extracted from `ImageAnnotator`. + +These are deliberately free of any Qt main-window dependency so they can +be unit-tested in isolation and reused by controllers added in later +refactor phases. +""" + +import numpy as np +from PyQt6.QtGui import QImage + + +def convert_to_serializable(obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, list): + return [convert_to_serializable(item) for item in obj] + if isinstance(obj, dict): + return {key: convert_to_serializable(value) for key, value in obj.items()} + return obj + + +def normalize_array(array): + array_float = array.astype(np.float32) + + if array.dtype == np.uint16: + array_normalized = (array_float - array.min()) / (array.max() - array.min()) + elif array.dtype == np.uint8: + p_low, p_high = np.percentile(array_float, (0, 100)) + array_normalized = np.clip(array_float, p_low, p_high) + array_normalized = (array_normalized - p_low) / (p_high - p_low) + else: + array_normalized = (array_float - array.min()) / (array.max() - array.min()) + + gamma = 1.0 + array_normalized = np.power(array_normalized, gamma) + + return (array_normalized * 255).astype(np.uint8) + + +def adjust_contrast(image, low_percentile=1, high_percentile=99): + if image.dtype != np.uint8: + p_low, p_high = np.percentile(image, (low_percentile, high_percentile)) + image_adjusted = np.clip(image, p_low, p_high) + image_adjusted = (image_adjusted - p_low) / (p_high - p_low) + return (image_adjusted * 255).astype(np.uint8) + return image + + +def convert_to_8bit_rgb(image_array): + if image_array.ndim == 2: + image_8bit = normalize_array(image_array) + return np.stack((image_8bit,) * 3, axis=-1) + if image_array.ndim == 3: + if image_array.shape[2] == 3: + return normalize_array(image_array) + if image_array.shape[2] > 3: + rgb_array = image_array[:, :, :3] + return normalize_array(rgb_array) + raise ValueError(f"Unsupported image shape: {image_array.shape}") + + +def array_to_qimage(array): + if array.ndim == 2: + height, width = array.shape + return QImage(array.data, width, height, width, QImage.Format.Format_Grayscale8) + if array.ndim == 3 and array.shape[2] == 3: + height, width, _ = array.shape + bytes_per_line = 3 * width + return QImage(array.data, width, height, bytes_per_line, QImage.Format.Format_RGB888) + raise ValueError(f"Unsupported array shape {array.shape} for conversion to QImage") diff --git a/src/digitalsreeni_image_annotator/dialogs/__init__.py b/src/digitalsreeni_image_annotator/dialogs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/digitalsreeni_image_annotator/annotation_statistics.py b/src/digitalsreeni_image_annotator/dialogs/annotation_statistics.py similarity index 100% rename from src/digitalsreeni_image_annotator/annotation_statistics.py rename to src/digitalsreeni_image_annotator/dialogs/annotation_statistics.py diff --git a/src/digitalsreeni_image_annotator/coco_json_combiner.py b/src/digitalsreeni_image_annotator/dialogs/coco_json_combiner.py similarity index 100% rename from src/digitalsreeni_image_annotator/coco_json_combiner.py rename to src/digitalsreeni_image_annotator/dialogs/coco_json_combiner.py diff --git a/src/digitalsreeni_image_annotator/dataset_splitter.py b/src/digitalsreeni_image_annotator/dialogs/dataset_splitter.py similarity index 91% rename from src/digitalsreeni_image_annotator/dataset_splitter.py rename to src/digitalsreeni_image_annotator/dialogs/dataset_splitter.py index 2a1d849..df2169f 100644 --- a/src/digitalsreeni_image_annotator/dataset_splitter.py +++ b/src/digitalsreeni_image_annotator/dialogs/dataset_splitter.py @@ -130,10 +130,13 @@ def split_dataset(self): QMessageBox.warning(self, "Error", "Percentages must add up to 100%.") return - if self.images_only_radio.isChecked(): - self.split_images_only() - else: - self.split_images_and_annotations() + try: + if self.images_only_radio.isChecked(): + self.split_images_only() + else: + self.split_images_and_annotations() + except Exception as e: + QMessageBox.critical(self, "Error", f"Dataset split failed:\n{e}") def split_images_only(self): image_files = [f for f in os.listdir(self.input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))] @@ -161,6 +164,25 @@ def split_images_and_annotations(self): coco_data = json.load(f) image_files = [img['file_name'] for img in coco_data['images']] + + # The JSON lists filenames; nothing guarantees they exist in the + # chosen input directory. A partial split would silently produce a + # broken dataset, so refuse to start if anything is missing. + missing = [f for f in image_files + if not os.path.exists(os.path.join(self.input_directory, f))] + if missing: + preview = "\n".join(missing[:10]) + if len(missing) > 10: + preview += f"\n... and {len(missing) - 10} more" + QMessageBox.warning( + self, "Images Not Found", + f"{len(missing)} of {len(image_files)} image(s) listed in the " + f"COCO JSON were not found in the selected input directory:\n\n" + f"{preview}\n\n" + "Please select the directory that contains these images." + ) + return + random.shuffle(image_files) train_split = int(len(image_files) * self.train_percent.value() / 100) diff --git a/src/digitalsreeni_image_annotator/dicom_converter.py b/src/digitalsreeni_image_annotator/dialogs/dicom_converter.py similarity index 100% rename from src/digitalsreeni_image_annotator/dicom_converter.py rename to src/digitalsreeni_image_annotator/dialogs/dicom_converter.py diff --git a/src/digitalsreeni_image_annotator/dino_merge_dialog.py b/src/digitalsreeni_image_annotator/dialogs/dino_merge_dialog.py similarity index 100% rename from src/digitalsreeni_image_annotator/dino_merge_dialog.py rename to src/digitalsreeni_image_annotator/dialogs/dino_merge_dialog.py diff --git a/src/digitalsreeni_image_annotator/dino_phrase_editor.py b/src/digitalsreeni_image_annotator/dialogs/dino_phrase_editor.py similarity index 91% rename from src/digitalsreeni_image_annotator/dino_phrase_editor.py rename to src/digitalsreeni_image_annotator/dialogs/dino_phrase_editor.py index 73e931d..107fd09 100644 --- a/src/digitalsreeni_image_annotator/dino_phrase_editor.py +++ b/src/digitalsreeni_image_annotator/dialogs/dino_phrase_editor.py @@ -60,14 +60,20 @@ def __init__(self, parent=None): self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows) self.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) self.verticalHeader().setVisible(False) + # Rows track their content height so cell text isn't clipped + # when the UI font zoom enlarges the compact panel font. + self.verticalHeader().setSectionResizeMode( + QHeaderView.ResizeMode.ResizeToContents) self.setMaximumHeight(160) # No hardcoded background colors — pick them up from the active # stylesheet so the table integrates with both light and dark # mode. The earlier "background: #e0e0e0" produced a bright bar # across the top of the panel in dark mode. + # No font-size either: the compact size is set (and scaled with + # ui_font_pt) by the appended overrides in ui/theme.py. self.setStyleSheet( - "QTableWidget { font-size: 11px; }" - "QHeaderView::section { font-size: 11px; font-weight: bold; padding: 2px; }" + "QHeaderView::section { font-weight: bold; " + " padding: 2px; background-color: palette(mid); color: palette(text); }" ) def _make_spin(self, value=0.25): @@ -77,7 +83,6 @@ def _make_spin(self, value=0.25): sp.setDecimals(2) sp.setValue(value) sp.setFrame(True) - sp.setStyleSheet("font-size: 11px;") return sp def add_class(self, name: str) -> bool: @@ -92,7 +97,6 @@ def add_class(self, name: str) -> bool: self.setCellWidget(row, _COL_BOX, self._make_spin(DEFAULT_BOX_THR)) self.setCellWidget(row, _COL_TXT, self._make_spin(DEFAULT_TXT_THR)) self.setCellWidget(row, _COL_NMS, self._make_spin(DEFAULT_NMS_THR)) - self.setRowHeight(row, 26) return True def remove_class(self, name: str) -> bool: @@ -172,34 +176,36 @@ def __init__(self, parent=None): layout.setContentsMargins(0, 4, 0, 0) layout.setSpacing(3) + # Compact font sizes for this panel come from the appended + # overrides in ui/theme.py so they scale with the UI font zoom. + # Note: *every* QLabel/QListWidget/QPushButton in this panel is + # compact by design — theme.py targets them by type. A new + # label added here will get the compact size, not body size. self.lbl_title = QLabel("Phrases for: ---") - self.lbl_title.setStyleSheet( - "font-size: 11px; font-weight: bold; color: #333;") + self.lbl_title.setStyleSheet("font-weight: bold;") layout.addWidget(self.lbl_title) hint = QLabel( "DINO uses all phrases below for this class.\n" "First phrase (class name) cannot be removed.") hint.setWordWrap(True) - hint.setStyleSheet("font-size: 10px; color: #777; font-style: italic;") + hint.setObjectName("dino_phrase_hint") # theme.py font-size rule + hint.setStyleSheet("font-style: italic;") layout.addWidget(hint) self.phrase_list = QListWidget() self.phrase_list.setMaximumHeight(90) - self.phrase_list.setStyleSheet("font-size: 11px;") self.phrase_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) self.phrase_list.customContextMenuRequested.connect(self._show_phrase_context_menu) layout.addWidget(self.phrase_list) btn_row = QHBoxLayout() self.btn_add_phrase = QPushButton("Add Phrase") - self.btn_add_phrase.setStyleSheet( - "QPushButton{font-size:11px;padding:3px 6px;}") + self.btn_add_phrase.setStyleSheet("QPushButton{padding:3px 6px;}") self.btn_add_phrase.clicked.connect(self._add_phrase) self.btn_rem_phrase = QPushButton("Remove Selected") - self.btn_rem_phrase.setStyleSheet( - "QPushButton{font-size:11px;padding:3px 6px;}") + self.btn_rem_phrase.setStyleSheet("QPushButton{padding:3px 6px;}") self.btn_rem_phrase.clicked.connect(self._remove_phrase) btn_row.addWidget(self.btn_add_phrase) diff --git a/src/digitalsreeni_image_annotator/help_window.py b/src/digitalsreeni_image_annotator/dialogs/help_window.py similarity index 86% rename from src/digitalsreeni_image_annotator/help_window.py rename to src/digitalsreeni_image_annotator/dialogs/help_window.py index 24a0b01..e5c5f01 100644 --- a/src/digitalsreeni_image_annotator/help_window.py +++ b/src/digitalsreeni_image_annotator/dialogs/help_window.py @@ -1,7 +1,7 @@ from PyQt6.QtWidgets import QDialog, QVBoxLayout, QTextBrowser from PyQt6.QtCore import Qt -from .soft_dark_stylesheet import soft_dark_stylesheet -from .default_stylesheet import default_stylesheet +from ..ui.soft_dark_stylesheet import soft_dark_stylesheet +from ..ui.default_stylesheet import default_stylesheet class HelpWindow(QDialog): def __init__(self, dark_mode=False, font_size=10): @@ -15,11 +15,8 @@ def __init__(self, dark_mode=False, font_size=10): layout.addWidget(self.text_browser) self.setLayout(layout) - if dark_mode: - self.setStyleSheet(soft_dark_stylesheet) - else: - self.setStyleSheet(default_stylesheet) - + self._base_stylesheet = soft_dark_stylesheet if dark_mode else default_stylesheet + self.font_size = font_size self.apply_font_size() self.load_help_content() @@ -30,7 +27,11 @@ def show_centered(self, parent): self.show() def apply_font_size(self): - self.setStyleSheet(f"QWidget {{ font-size: {self.font_size}pt; }}") + # Append the font rule to the theme stylesheet — replacing the + # whole sheet here used to wipe the dark/light theme. + self.setStyleSheet( + f"{self._base_stylesheet}\nQWidget {{ font-size: {self.font_size}pt; }}" + ) font = self.text_browser.font() font.setPointSize(self.font_size) self.text_browser.setFont(font) @@ -87,19 +88,18 @@ def load_help_content(self):

Annotation Process

  1. Select a Class: Choose the class you want to annotate from the class list.
  2. -
  3. Choose a Tool: Select either the Polygon Tool, Rectangle Tool, or SAM-Assisted tool.
  4. +
  5. Choose a Tool: Select the Polygon Tool, Rectangle Tool, or one of the SAM-assisted tools (SAM-box / SAM-points).
  6. Create Annotation:
    • For Polygon Tool: Click around the object to define its boundary. Press Enter or click "Finish Polygon" when done.
    • For Rectangle Tool: Click and drag to create a bounding box.
    • -
    • For SAM-Assisted tool: +
    • For SAM-assisted tools (SAM-box / SAM-points):
      1. Select a SAM model from the "Pick a SAM Model" dropdown. It's recommended to use smaller models like SAM2 tiny or SAM2 small for better performance.
      2. Note: When you select a model for the first time, the application needs to download it. This process may take a few seconds to a minute, depending on your internet connection speed. Subsequent uses of the same model will be faster as it will already be cached locally, in your working directory.
      3. -
      4. Click the "SAM-Assisted" button to activate the tool.
      5. -
      6. Draw a rectangle around objects of interest to allow SAM2 to automatically detect objects.
      7. -
      8. SAM2 will provide various outputs with different scores, and only the top-scoring region will be displayed.
      9. -
      10. If the desired result isn't achieved on the first try, draw again.
      11. +
      12. Click the "SAM-box" button and draw a rectangle around an object of interest, or click the "SAM-points" button and left-click points inside the object (right-click adds negative points to exclude regions).
      13. +
      14. SAM2 will display the top-scoring mask as a temporary prediction. Press Enter to accept it or Esc to discard it.
      15. +
      16. If the desired result isn't achieved on the first try, draw the box again or adjust the points.
      17. For low-quality images where SAM2 may not auto-detect objects, manual tools may be necessary.
    • @@ -145,6 +145,9 @@ def load_help_content(self):
    • Ctrl + Shift + S: Open Annotation Statistics
    • F1: Open this help window
    • Ctrl + Wheel: Zoom in/out
    • +
    • Ctrl + Shift + = (or Ctrl + +): Increase application font size
    • +
    • Ctrl + Shift + - (or Ctrl + -): Decrease application font size
    • +
    • Ctrl + Shift + 0: Reset application font size
    • Esc: Cancel current annotation, exit edit mode, or exit SAM-assisted annotation
    • Enter: Finish current annotation, exit edit mode, or accept SAM-generated mask
    • Up/Down Arrow Keys: Navigate through slices in multi-dimensional images
    • diff --git a/src/digitalsreeni_image_annotator/image_augmenter.py b/src/digitalsreeni_image_annotator/dialogs/image_augmenter.py similarity index 100% rename from src/digitalsreeni_image_annotator/image_augmenter.py rename to src/digitalsreeni_image_annotator/dialogs/image_augmenter.py diff --git a/src/digitalsreeni_image_annotator/image_patcher.py b/src/digitalsreeni_image_annotator/dialogs/image_patcher.py similarity index 100% rename from src/digitalsreeni_image_annotator/image_patcher.py rename to src/digitalsreeni_image_annotator/dialogs/image_patcher.py diff --git a/src/digitalsreeni_image_annotator/project_details.py b/src/digitalsreeni_image_annotator/dialogs/project_details.py similarity index 100% rename from src/digitalsreeni_image_annotator/project_details.py rename to src/digitalsreeni_image_annotator/dialogs/project_details.py diff --git a/src/digitalsreeni_image_annotator/project_search.py b/src/digitalsreeni_image_annotator/dialogs/project_search.py similarity index 100% rename from src/digitalsreeni_image_annotator/project_search.py rename to src/digitalsreeni_image_annotator/dialogs/project_search.py diff --git a/src/digitalsreeni_image_annotator/slice_registration.py b/src/digitalsreeni_image_annotator/dialogs/slice_registration.py similarity index 100% rename from src/digitalsreeni_image_annotator/slice_registration.py rename to src/digitalsreeni_image_annotator/dialogs/slice_registration.py diff --git a/src/digitalsreeni_image_annotator/snake_game.py b/src/digitalsreeni_image_annotator/dialogs/snake_game.py similarity index 100% rename from src/digitalsreeni_image_annotator/snake_game.py rename to src/digitalsreeni_image_annotator/dialogs/snake_game.py diff --git a/src/digitalsreeni_image_annotator/stack_interpolator.py b/src/digitalsreeni_image_annotator/dialogs/stack_interpolator.py similarity index 100% rename from src/digitalsreeni_image_annotator/stack_interpolator.py rename to src/digitalsreeni_image_annotator/dialogs/stack_interpolator.py diff --git a/src/digitalsreeni_image_annotator/stack_to_slices.py b/src/digitalsreeni_image_annotator/dialogs/stack_to_slices.py similarity index 100% rename from src/digitalsreeni_image_annotator/stack_to_slices.py rename to src/digitalsreeni_image_annotator/dialogs/stack_to_slices.py diff --git a/src/digitalsreeni_image_annotator/yolo_trainer.py b/src/digitalsreeni_image_annotator/dialogs/yolo_trainer.py similarity index 99% rename from src/digitalsreeni_image_annotator/yolo_trainer.py rename to src/digitalsreeni_image_annotator/dialogs/yolo_trainer.py index 2232908..07536fe 100644 --- a/src/digitalsreeni_image_annotator/yolo_trainer.py +++ b/src/digitalsreeni_image_annotator/dialogs/yolo_trainer.py @@ -5,7 +5,7 @@ import yaml import numpy as np from pathlib import Path -from .export_formats import export_yolo_v5plus +from ..io.export_formats import export_yolo_v5plus from collections import deque diff --git a/src/digitalsreeni_image_annotator/inference/__init__.py b/src/digitalsreeni_image_annotator/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/digitalsreeni_image_annotator/dino_utils.py b/src/digitalsreeni_image_annotator/inference/dino_utils.py similarity index 99% rename from src/digitalsreeni_image_annotator/dino_utils.py rename to src/digitalsreeni_image_annotator/inference/dino_utils.py index b6663f3..f1cea6a 100644 --- a/src/digitalsreeni_image_annotator/dino_utils.py +++ b/src/digitalsreeni_image_annotator/inference/dino_utils.py @@ -25,7 +25,7 @@ from PyQt6.QtGui import QImage from .sam_utils import _qimage_to_numpy, _run_sync -from .utils import models_base_dir +from ..utils import models_base_dir GDINO_MODEL_NAMES = [ diff --git a/src/digitalsreeni_image_annotator/sam_utils.py b/src/digitalsreeni_image_annotator/inference/sam_utils.py similarity index 99% rename from src/digitalsreeni_image_annotator/sam_utils.py rename to src/digitalsreeni_image_annotator/inference/sam_utils.py index 0ef90dc..badd6df 100644 --- a/src/digitalsreeni_image_annotator/sam_utils.py +++ b/src/digitalsreeni_image_annotator/inference/sam_utils.py @@ -37,7 +37,7 @@ from PyQt6.QtCore import QEventLoop, QObject, QThread, pyqtSignal from PyQt6.QtGui import QImage -from .utils import models_base_dir +from ..utils import models_base_dir MODEL_NAMES = [ diff --git a/src/digitalsreeni_image_annotator/io/__init__.py b/src/digitalsreeni_image_annotator/io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/digitalsreeni_image_annotator/export_formats.py b/src/digitalsreeni_image_annotator/io/export_formats.py similarity index 99% rename from src/digitalsreeni_image_annotator/export_formats.py rename to src/digitalsreeni_image_annotator/io/export_formats.py index fd98c11..ae181aa 100644 --- a/src/digitalsreeni_image_annotator/export_formats.py +++ b/src/digitalsreeni_image_annotator/io/export_formats.py @@ -1,6 +1,6 @@ import json from PyQt6.QtGui import QImage -from .utils import calculate_area, calculate_bbox +from ..utils import calculate_area, calculate_bbox import yaml import os import shutil diff --git a/src/digitalsreeni_image_annotator/import_formats.py b/src/digitalsreeni_image_annotator/io/import_formats.py similarity index 100% rename from src/digitalsreeni_image_annotator/import_formats.py rename to src/digitalsreeni_image_annotator/io/import_formats.py diff --git a/src/digitalsreeni_image_annotator/main.py b/src/digitalsreeni_image_annotator/main.py index 83b187f..e38606d 100644 --- a/src/digitalsreeni_image_annotator/main.py +++ b/src/digitalsreeni_image_annotator/main.py @@ -9,6 +9,23 @@ import sys import os + +# ── Windows DLL load-order workaround (torch → Qt, not Qt → torch) +# +# On Windows + Python 3.14, importing torch *after* PyQt has loaded +# its native platform DLLs (qwindows.dll via QtCore/Gui/Widgets) +# triggers WinError 1114 when torch's c10.dll initialises. This +# was historically blamed on PyQt5 (ADR-011) and thought fixed in +# PyQt6 (ADR-014). Real-world testing with torch 2.11.0 + PyQt6 +# 6.10.2 shows the conflict still surfaces. The workaround is +# cheap and harmless: import torch eagerly before QApplication is +# created so torch's DLLs claim their slot first. +# See ADR-017. +try: + import torch # noqa: F401 +except ImportError: + pass # torch may not be installed; lazy fallback in sam_utils/dino_utils + from PyQt6.QtWidgets import QApplication from .annotator_window import ImageAnnotator diff --git a/src/digitalsreeni_image_annotator/ui/__init__.py b/src/digitalsreeni_image_annotator/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/digitalsreeni_image_annotator/default_stylesheet.py b/src/digitalsreeni_image_annotator/ui/default_stylesheet.py similarity index 100% rename from src/digitalsreeni_image_annotator/default_stylesheet.py rename to src/digitalsreeni_image_annotator/ui/default_stylesheet.py diff --git a/src/digitalsreeni_image_annotator/ui/menu_bar.py b/src/digitalsreeni_image_annotator/ui/menu_bar.py new file mode 100644 index 0000000..f3752c1 --- /dev/null +++ b/src/digitalsreeni_image_annotator/ui/menu_bar.py @@ -0,0 +1,155 @@ +"""Build the application menu bar. + +Moved verbatim from ImageAnnotator.create_menu_bar (Phase 8). Every +action's `triggered` connects to a method on `window` (the +ImageAnnotator instance) — many of those are thin delegates to +controllers, but the menu doesn't need to know that. +""" + +from PyQt6.QtGui import QAction, QKeySequence + +from . import theme + + +def build_menu_bar(window): + menu_bar = window.menuBar() + + # Project Menu + project_menu = menu_bar.addMenu("&Project") + + new_project_action = QAction("&New Project", window) + new_project_action.setShortcut(QKeySequence.StandardKey.New) + new_project_action.triggered.connect(window.new_project) + project_menu.addAction(new_project_action) + + open_project_action = QAction("&Open Project", window) + open_project_action.setShortcut(QKeySequence.StandardKey.Open) + open_project_action.triggered.connect(window.open_project) + project_menu.addAction(open_project_action) + + save_project_action = QAction("&Save Project", window) + save_project_action.setShortcut(QKeySequence.StandardKey.Save) + save_project_action.triggered.connect(window.save_project) + project_menu.addAction(save_project_action) + + save_project_as_action = QAction("Save Project &As...", window) + save_project_as_action.setShortcut(QKeySequence("Ctrl+Shift+S")) + save_project_as_action.triggered.connect(window.save_project_as) + project_menu.addAction(save_project_as_action) + + close_project_action = QAction("&Close Project", window) + close_project_action.setShortcut(QKeySequence("Ctrl+W")) + close_project_action.triggered.connect(window.close_project) + project_menu.addAction(close_project_action) + + project_details_action = QAction("Project &Details", window) + project_details_action.setShortcut(QKeySequence("Ctrl+I")) + project_details_action.triggered.connect(window.show_project_details) + project_menu.addAction(project_details_action) + + search_projects_action = QAction("&Search Projects", window) + search_projects_action.setShortcut(QKeySequence("Ctrl+F")) + search_projects_action.triggered.connect(window.show_project_search) + project_menu.addAction(search_projects_action) + + # Settings Menu + settings_menu = menu_bar.addMenu("&Settings") + + font_size_menu = settings_menu.addMenu("&Font Size") + window._font_preset_actions = {} + for size in ["Small", "Medium", "Large", "XL", "XXL"]: + action = QAction(size, window) + action.setCheckable(True) + action.triggered.connect(lambda checked, s=size: window.change_font_size(s)) + font_size_menu.addAction(action) + window._font_preset_actions[size] = action + # Show the persisted size as checked from the first frame (no + # preset is checked when ui_font_pt sits between preset values). + theme.sync_font_menu(window) + + font_size_menu.addSeparator() + + # Continuous UI zoom for low-vision users — steps ui_font_pt ±1pt + # within 8-24. Secondary Ctrl++ / Ctrl+- sequences cover keypads + # and layouts where Ctrl+Shift+= is awkward. + increase_font_action = QAction("&Increase Font Size", window) + increase_font_action.setShortcuts( + [QKeySequence("Ctrl+Shift+="), QKeySequence("Ctrl++")] + ) + increase_font_action.triggered.connect(lambda: window.step_font_size(1)) + font_size_menu.addAction(increase_font_action) + + decrease_font_action = QAction("&Decrease Font Size", window) + decrease_font_action.setShortcuts( + [QKeySequence("Ctrl+Shift+-"), QKeySequence("Ctrl+-")] + ) + decrease_font_action.triggered.connect(lambda: window.step_font_size(-1)) + font_size_menu.addAction(decrease_font_action) + + reset_font_action = QAction("&Reset Font Size", window) + reset_font_action.setShortcut(QKeySequence("Ctrl+Shift+0")) + reset_font_action.triggered.connect(window.reset_font_size) + font_size_menu.addAction(reset_font_action) + + toggle_dark_mode_action = QAction("Toggle &Dark Mode", window) + toggle_dark_mode_action.setShortcut(QKeySequence("Ctrl+D")) + toggle_dark_mode_action.triggered.connect(window.toggle_dark_mode) + settings_menu.addAction(toggle_dark_mode_action) + + # Tools Menu + tools_menu = menu_bar.addMenu("&Tools") + + annotation_stats_action = QAction("Annotation Statistics", window) + annotation_stats_action.triggered.connect(window.show_annotation_statistics) + annotation_stats_action.setShortcut(QKeySequence("Ctrl+Alt+S")) + tools_menu.addAction(annotation_stats_action) + + coco_json_combiner_action = QAction("COCO JSON Combiner", window) + coco_json_combiner_action.triggered.connect(window.show_coco_json_combiner) + tools_menu.addAction(coco_json_combiner_action) + + dataset_splitter_action = QAction("Dataset Splitter", window) + dataset_splitter_action.triggered.connect(window.open_dataset_splitter) + tools_menu.addAction(dataset_splitter_action) + + dino_merge_action = QAction("Merge COCO for Training", window) + dino_merge_action.triggered.connect(window.show_dino_merge_dialog) + tools_menu.addAction(dino_merge_action) + + stack_to_slices_action = QAction("Stack to Slices", window) + stack_to_slices_action.triggered.connect(window.show_stack_to_slices) + tools_menu.addAction(stack_to_slices_action) + + image_patcher_action = QAction("Image Patcher", window) + image_patcher_action.triggered.connect(window.show_image_patcher) + tools_menu.addAction(image_patcher_action) + + image_augmenter_action = QAction("Image Augmenter", window) + image_augmenter_action.triggered.connect(window.show_image_augmenter) + tools_menu.addAction(image_augmenter_action) + + slice_registration_action = QAction("Slice Registration", window) + slice_registration_action.triggered.connect(window.show_slice_registration) + tools_menu.addAction(slice_registration_action) + + stack_interpolator_action = QAction("Stack Interpolator", window) + stack_interpolator_action.triggered.connect(window.show_stack_interpolator) + tools_menu.addAction(stack_interpolator_action) + + dicom_converter_action = QAction("DICOM Converter", window) + dicom_converter_action.triggered.connect(window.show_dicom_converter) + tools_menu.addAction(dicom_converter_action) + + tools_menu.addSeparator() + + unload_models_action = QAction("Unload AI Models (Free GPU Memory)", window) + unload_models_action.triggered.connect(window.unload_ai_models) + tools_menu.addAction(unload_models_action) + + # Help Menu + help_menu = menu_bar.addMenu("&Help") + + help_action = QAction("&Show Help", window) + help_action.setShortcut(QKeySequence.StandardKey.HelpContents) + help_action.triggered.connect(window.show_help) + help_menu.addAction(help_action) diff --git a/src/digitalsreeni_image_annotator/ui/shortcuts.py b/src/digitalsreeni_image_annotator/ui/shortcuts.py new file mode 100644 index 0000000..bdd2409 --- /dev/null +++ b/src/digitalsreeni_image_annotator/ui/shortcuts.py @@ -0,0 +1,34 @@ +"""Global shortcuts and application-wide event filters for ImageAnnotator. + +Both pieces were inline init blocks in ImageAnnotator.__init__ before +Phase 8; factored out here for symmetry with the other ui/ builders +and so the orchestrator stays focused on wiring. +""" + +from PyQt6.QtCore import Qt +from PyQt6.QtGui import QKeySequence, QShortcut +from PyQt6.QtWidgets import QApplication + +from ..controllers.dino_controller import DINOReviewEventFilter + + +def install_shortcuts(window): + """Register global keyboard shortcuts. Currently just F2 → Snake + game. Registered as a QShortcut with ApplicationShortcut context + so it fires regardless of which widget has focus — putting it in + keyPressEvent didn't work because QTableWidget (DINO threshold + table) and other focusable children consume F2 before it bubbles + up to the main window.""" + window._snake_shortcut = QShortcut(QKeySequence("F2"), window) + window._snake_shortcut.setContext(Qt.ShortcutContext.ApplicationShortcut) + window._snake_shortcut.activated.connect(window.launch_snake_game) + + +def install_event_filters(window): + """Install application-wide event filters. Currently just the DINO + review filter — Enter/Escape for DINO temp_annotations need to + work even when focus is on slice_list / image_list / a button, + none of which forward the key to ImageLabel.keyPressEvent. See + ADR-015.""" + window._dino_review_filter = DINOReviewEventFilter(window) + QApplication.instance().installEventFilter(window._dino_review_filter) diff --git a/src/digitalsreeni_image_annotator/ui/sidebar.py b/src/digitalsreeni_image_annotator/ui/sidebar.py new file mode 100644 index 0000000..487e48e --- /dev/null +++ b/src/digitalsreeni_image_annotator/ui/sidebar.py @@ -0,0 +1,333 @@ +"""Build the left sidebar, central image area, and right image list. + +Moved verbatim from ImageAnnotator (Phase 8). Each builder takes +`window` (the ImageAnnotator instance), attaches widgets as +`window.X = ...` for the references read by other modules, and +connects signals to `window.` (the delegate methods on +ImageAnnotator which forward to controllers). +""" + +from PyQt6.QtCore import Qt +from PyQt6.QtWidgets import ( + QAbstractItemView, + QButtonGroup, + QComboBox, + QHBoxLayout, + QLabel, + QListWidget, + QPushButton, + QScrollArea, + QSlider, + QVBoxLayout, + QWidget, +) + +from ..dialogs.dino_phrase_editor import ClassThresholdTable, PhraseEditorPanel + + +def _section_header(text): + label = QLabel(text) + label.setProperty("class", "section-header") + label.setAlignment(Qt.AlignmentFlag.AlignLeft) + return label + + +def build_sidebar(window): + window.sidebar = QWidget() + window.sidebar_layout = QVBoxLayout(window.sidebar) + window.layout.addWidget(window.sidebar, 1) + + # Import functionality + window.import_button = QPushButton("Import Annotations with Images") + window.import_button.clicked.connect(window.import_annotations) + window.sidebar_layout.addWidget(window.import_button) + + window.import_format_selector = QComboBox() + window.import_format_selector.addItem("COCO JSON") + window.import_format_selector.addItem("YOLO (v4 and earlier)") + window.import_format_selector.addItem("YOLO (v5+)") + window.sidebar_layout.addWidget(window.import_format_selector) + + # Add spacing + window.sidebar_layout.addSpacing(20) + + window.add_images_button = QPushButton("Add New Images") + window.add_images_button.clicked.connect(window.add_images) + window.sidebar_layout.addWidget(window.add_images_button) + + window.add_class_button = QPushButton("Add Classes") + window.add_class_button.clicked.connect(lambda: window.add_class()) + window.sidebar_layout.addWidget(window.add_class_button) + + # Class list (without the "Classes" header) + window.class_list = QListWidget() + window.class_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + window.class_list.customContextMenuRequested.connect(window.show_class_context_menu) + window.class_list.itemClicked.connect(window.on_class_selected) + # itemChanged fires when a class's checkbox is toggled; routes to + # visibility toggling on the class controller. Previously wired in + # ImageAnnotator.__init__ post-setup_ui; moved here to live next + # to the widget construction. + window.class_list.itemChanged.connect(window.toggle_class_visibility) + window.sidebar_layout.addWidget(window.class_list) + + # Annotation section + window.sidebar_layout.addWidget(_section_header("Annotation")) + annotation_widget = QWidget() + annotation_layout = QVBoxLayout(annotation_widget) + + # Manual tools subsection + manual_widget = QWidget() + manual_layout = QVBoxLayout(manual_widget) + + button_layout_top = QHBoxLayout() + window.polygon_button = QPushButton("Polygon") + window.polygon_button.setCheckable(True) + window.rectangle_button = QPushButton("Rectangle") + window.rectangle_button.setCheckable(True) + button_layout_top.addWidget(window.polygon_button) + button_layout_top.addWidget(window.rectangle_button) + + button_layout_bottom = QHBoxLayout() + window.paint_brush_button = QPushButton("Paint Brush") + window.paint_brush_button.setCheckable(True) + window.eraser_button = QPushButton("Eraser") + window.eraser_button.setCheckable(True) + button_layout_bottom.addWidget(window.paint_brush_button) + button_layout_bottom.addWidget(window.eraser_button) + + manual_layout.addLayout(button_layout_top) + manual_layout.addLayout(button_layout_bottom) + + annotation_layout.addWidget(manual_widget) + + # SAM-Assisted tools subsection + sam_widget = QWidget() + sam_layout = QVBoxLayout(sam_widget) + + sam_buttons_layout = QHBoxLayout() + + window.sam_box_button = QPushButton("SAM-box") + window.sam_box_button.setCheckable(True) + window.sam_box_button.clicked.connect(window.toggle_sam_box) + + window.sam_points_button = QPushButton("SAM-points") + window.sam_points_button.setCheckable(True) + window.sam_points_button.clicked.connect(window.toggle_sam_points) + + sam_buttons_layout.addWidget(window.sam_box_button) + sam_buttons_layout.addWidget(window.sam_points_button) + sam_layout.addLayout(sam_buttons_layout) + + # SAM model selector + window.sam_model_selector = QComboBox() + window.sam_model_selector.addItem("Pick a SAM Model") + window.sam_model_selector.addItems(list(window.sam_utils.sam_models.keys())) + window.sam_model_selector.currentTextChanged.connect(window.change_sam_model) + sam_layout.addWidget(window.sam_model_selector) + + annotation_layout.addWidget(sam_widget) + + # --- LLM-Assisted Detection (DINO) subsection --- + dino_widget = QWidget() + dino_layout = QVBoxLayout(dino_widget) + + window.dino_model_selector = QComboBox() + window.dino_model_selector.addItem("Pick a DINO Model") + window.dino_model_selector.addItem("grounding-dino-base") + window.dino_model_selector.addItem("grounding-dino-tiny") + window.dino_model_selector.addItem("Custom / fine-tuned (browse)") + window.dino_model_selector.currentTextChanged.connect(window._on_dino_model_changed) + dino_layout.addWidget(window.dino_model_selector) + + # Custom model browse row (hidden by default) + window.dino_browse_row = QWidget() + dino_browse_layout = QHBoxLayout(window.dino_browse_row) + dino_browse_layout.setContentsMargins(0, 0, 0, 0) + window.lbl_dino_custom = QLabel("No path set") + window.lbl_dino_custom.setWordWrap(True) + # Use palette(text) so the colour follows the active stylesheet + # (light or dark) — hardcoded #555 used to render unreadable on + # dark mode. See "No Hardcoded Colors Rule" in CLAUDE.md. + # No font-size here — the caption follows the global ui_font_pt rule. + window.lbl_dino_custom.setStyleSheet("color:palette(text);") + btn_dino_browse = QPushButton("Browse") + # No fixed width — a 60px cap clipped the caption at large UI font + # sizes (low-vision zoom); sizeHint tracks the active font. + btn_dino_browse.clicked.connect(window.browse_dino_model) + dino_browse_layout.addWidget(window.lbl_dino_custom, 1) + dino_browse_layout.addWidget(btn_dino_browse) + window.dino_browse_row.setVisible(False) + dino_layout.addWidget(window.dino_browse_row) + + window.lbl_dino_status = QLabel("No DINO model loaded") + window.lbl_dino_status.setWordWrap(True) + # No hardcoded background — let the active stylesheet (light or + # dark) provide it via QLabel rules. Hardcoded #f5f5f5 used to + # punch a bright rectangle into the dark sidebar. + window.lbl_dino_status.setStyleSheet( + "padding:4px;border-radius:3px;" + "border:1px solid palette(mid);" + ) + dino_layout.addWidget(window.lbl_dino_status) + + # Threshold table + window.dino_class_table = ClassThresholdTable() + window.dino_class_table.itemSelectionChanged.connect( + window.on_dino_class_row_changed + ) + dino_layout.addWidget(window.dino_class_table) + + # Phrase editor + window.dino_phrase_panel = PhraseEditorPanel() + dino_layout.addWidget(window.dino_phrase_panel) + + # Detect buttons + det_btn_layout = QHBoxLayout() + window.btn_detect_single = QPushButton("Detect Current Image") + window.btn_detect_single.clicked.connect(window.run_dino_detection_single) + window.btn_detect_single.setEnabled(False) + det_btn_layout.addWidget(window.btn_detect_single) + + window.btn_detect_batch = QPushButton("Detect All Images") + window.btn_detect_batch.clicked.connect(window.run_dino_detection_batch) + window.btn_detect_batch.setEnabled(False) + det_btn_layout.addWidget(window.btn_detect_batch) + dino_layout.addLayout(det_btn_layout) + + # Batch mode + window.dino_batch_mode = QComboBox() + window.dino_batch_mode.addItem("Review before accepting") + window.dino_batch_mode.addItem("Auto-accept all detections") + dino_layout.addWidget(window.dino_batch_mode) + + annotation_layout.addWidget(dino_widget) + # --- END DINO section --- + + # Tool group — must include all checkable tool buttons so + # update_ui_for_current_tool / enable_tools / disable_tools can + # iterate. + window.tool_group = QButtonGroup(window) + window.tool_group.setExclusive(False) + window.tool_group.addButton(window.polygon_button) + window.tool_group.addButton(window.rectangle_button) + window.tool_group.addButton(window.paint_brush_button) + window.tool_group.addButton(window.eraser_button) + window.tool_group.addButton(window.sam_box_button) + window.tool_group.addButton(window.sam_points_button) + + window.polygon_button.clicked.connect(window.toggle_tool) + window.rectangle_button.clicked.connect(window.toggle_tool) + window.paint_brush_button.clicked.connect(window.toggle_tool) + window.eraser_button.clicked.connect(window.toggle_tool) + + # Annotations list subsection + annotation_layout.addWidget(QLabel("Annotations")) + window.annotation_list = QListWidget() + window.annotation_list.setSelectionMode( + QAbstractItemView.SelectionMode.ExtendedSelection + ) + window.annotation_list.itemSelectionChanged.connect( + window.update_highlighted_annotations + ) + annotation_layout.addWidget(window.annotation_list) + + # Sort buttons + sort_button_layout = QHBoxLayout() + window.sort_by_class_button = QPushButton("Sort by Class") + window.sort_by_class_button.clicked.connect(window.sort_annotations_by_class) + sort_button_layout.addWidget(window.sort_by_class_button) + + window.sort_by_area_button = QPushButton("Sort by Area") + window.sort_by_area_button.clicked.connect(window.sort_annotations_by_area) + sort_button_layout.addWidget(window.sort_by_area_button) + + annotation_layout.addLayout(sort_button_layout) + + # Delete / Merge / Change Class buttons + window.delete_button = QPushButton("Delete") + window.delete_button.clicked.connect(window.delete_selected_annotations) + window.merge_button = QPushButton("Merge") + window.merge_button.clicked.connect(window.merge_annotations) + window.change_class_button = QPushButton("Change Class") + window.change_class_button.clicked.connect(window.change_annotation_class) + + button_layout = QHBoxLayout() + button_layout.addWidget(window.delete_button) + button_layout.addWidget(window.merge_button) + button_layout.addWidget(window.change_class_button) + annotation_layout.addLayout(button_layout) + + # Export format selector + window.export_format_selector = QComboBox() + window.export_format_selector.addItem("COCO JSON") + window.export_format_selector.addItem("YOLO (v4 and earlier)") + window.export_format_selector.addItem("YOLO (v5+)") + window.export_format_selector.addItem("Labeled Images") + window.export_format_selector.addItem("Semantic Labels") + window.export_format_selector.addItem("Pascal VOC (BBox)") + window.export_format_selector.addItem("Pascal VOC (BBox + Segmentation)") + + annotation_layout.addWidget(QLabel("Export Format:")) + annotation_layout.addWidget(window.export_format_selector) + + window.export_button = QPushButton("Export Annotations") + window.export_button.clicked.connect(window.export_annotations) + annotation_layout.addWidget(window.export_button) + + window.sidebar_layout.addWidget(annotation_widget) + + +def build_image_area(window): + window.image_widget = QWidget() + window.image_layout = QVBoxLayout(window.image_widget) + window.layout.addWidget(window.image_widget, 3) + + window.scroll_area = QScrollArea() + window.scroll_area.setWidgetResizable(True) + window.scroll_area.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAsNeeded + ) + window.scroll_area.setVerticalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAsNeeded + ) + + # Use the already initialized image_label + window.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + window.scroll_area.setWidget(window.image_label) + + window.image_layout.addWidget(window.scroll_area) + + window.zoom_slider = QSlider(Qt.Orientation.Horizontal) + window.zoom_slider.setMinimum(10) + window.zoom_slider.setMaximum(500) + window.zoom_slider.setValue(100) + window.zoom_slider.setTickPosition(QSlider.TickPosition.TicksBelow) + window.zoom_slider.setTickInterval(50) + window.zoom_slider.valueChanged.connect(window.zoom_image) + window.image_layout.addWidget(window.zoom_slider) + + window.image_info_label = QLabel() + window.image_layout.addWidget(window.image_info_label) + + +def build_image_list(window): + window.image_list_widget = QWidget() + window.image_list_layout = QVBoxLayout(window.image_list_widget) + window.layout.addWidget(window.image_list_widget, 1) + + window.image_list_label = QLabel("Images:") + window.image_list_layout.addWidget(window.image_list_label) + + window.image_list = QListWidget() + window.image_list.itemClicked.connect(window.switch_image) + window.image_list.currentRowChanged.connect( + lambda row: window.switch_image(window.image_list.currentItem()) + ) + window.image_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + window.image_list.customContextMenuRequested.connect(window.show_image_context_menu) + window.image_list_layout.addWidget(window.image_list) + + window.clear_all_button = QPushButton("Clear All Images and Annotations") + window.clear_all_button.clicked.connect(window.clear_all) + window.image_list_layout.addWidget(window.clear_all_button) diff --git a/src/digitalsreeni_image_annotator/soft_dark_stylesheet.py b/src/digitalsreeni_image_annotator/ui/soft_dark_stylesheet.py similarity index 100% rename from src/digitalsreeni_image_annotator/soft_dark_stylesheet.py rename to src/digitalsreeni_image_annotator/ui/soft_dark_stylesheet.py diff --git a/src/digitalsreeni_image_annotator/ui/theme.py b/src/digitalsreeni_image_annotator/ui/theme.py new file mode 100644 index 0000000..3751b14 --- /dev/null +++ b/src/digitalsreeni_image_annotator/ui/theme.py @@ -0,0 +1,122 @@ +"""Theme + font size application, extracted from `ImageAnnotator`. + +The functions here take the main window as their first argument; they +read state directly off it (`dark_mode`, `ui_font_pt`, etc.) and write +to its widgets. Kept as plain functions rather than a controller class +because they are stateless and the call sites are sparse. + +`mw.ui_font_pt` (int, 8-24, default 10) is the single source of truth +for UI text size. The Settings → Font Size presets jump to fixed +values; Ctrl+Shift+= / Ctrl+Shift+- step it ±1pt. Every change goes +through `set_font_pt`, which clamps, re-applies the theme, persists +via QSettings and syncs the preset menu checkmarks. +""" + +from PyQt6.QtGui import QFont +from PyQt6.QtWidgets import QWidget + +from ..app_settings import FONT_PT_DEFAULT, clamp_font_pt, save_ui_prefs +from .default_stylesheet import default_stylesheet +from .soft_dark_stylesheet import soft_dark_stylesheet + +# Legacy px values from the static stylesheets at the 10pt default. +# The overrides scale these by ui_font_pt / 10 and stay in px so the +# default renders pixel-identical to the pre-zoom stylesheets. +_HEADER_PX_AT_DEFAULT = 14 # QLabel.section-header font-size +_INDICATOR_PX_AT_DEFAULT = 14 # checkbox / radio indicator width+height +_RADIO_RADIUS_PX_AT_DEFAULT = 8 # radio indicator border-radius +# DINO sidebar panel uses deliberately compact text (smaller than body). +# The widgets carry no inline font-size — these rules own it so the +# compact look is preserved but still scales with ui_font_pt. +_DINO_COMPACT_PX_AT_DEFAULT = 11 # threshold table, phrase list, buttons +_DINO_HINT_PX_AT_DEFAULT = 10 # italic hint under "Phrases for:" + + +def apply_theme_and_font(mw): + font_size = mw.ui_font_pt + style = soft_dark_stylesheet if mw.dark_mode else default_stylesheet + # Appended rules win over the static sheet (same specificity, + # later in cascade) — this is how the hardcoded px sizes in the + # stylesheets are made to follow ui_font_pt without templating + # the static strings. + scale = font_size / FONT_PT_DEFAULT + header_px = round(_HEADER_PX_AT_DEFAULT * scale) + indicator_px = round(_INDICATOR_PX_AT_DEFAULT * scale) + radio_radius_px = round(_RADIO_RADIUS_PX_AT_DEFAULT * scale) + dino_px = round(_DINO_COMPACT_PX_AT_DEFAULT * scale) + dino_hint_px = round(_DINO_HINT_PX_AT_DEFAULT * scale) + combined_style = ( + f"{style}\n" + f"QWidget {{ font-size: {font_size}pt; }}\n" + f"QLabel.section-header {{ font-size: {header_px}px; }}\n" + f"QCheckBox::indicator, QRadioButton::indicator {{" + f" width: {indicator_px}px; height: {indicator_px}px; }}\n" + f"QRadioButton::indicator {{ border-radius: {radio_radius_px}px; }}\n" + f"ClassThresholdTable, ClassThresholdTable QDoubleSpinBox," + f" ClassThresholdTable QHeaderView::section," + f" PhraseEditorPanel QLabel, PhraseEditorPanel QListWidget," + f" PhraseEditorPanel QPushButton {{ font-size: {dino_px}px; }}\n" + f"QLabel#dino_phrase_hint {{ font-size: {dino_hint_px}px; }}" + ) + mw.setStyleSheet(combined_style) + + for widget in mw.findChildren(QWidget): + font = widget.font() + font.setPointSize(font_size) + widget.setFont(font) + + mw.image_label.setFont(QFont("Arial", font_size)) + mw.image_label.set_ui_scale(font_size / FONT_PT_DEFAULT) + mw.update() + + +def set_font_pt(mw, pt): + mw.ui_font_pt = clamp_font_pt(pt) + apply_theme_and_font(mw) + save_ui_prefs(mw.ui_font_pt, mw.dark_mode) + sync_font_menu(mw) + + +def step_font_pt(mw, delta): + set_font_pt(mw, mw.ui_font_pt + delta) + + +def reset_font_pt(mw): + set_font_pt(mw, FONT_PT_DEFAULT) + + +def change_font_size(mw, size): + """Preset entry point — `size` is a name from `mw.font_sizes`.""" + set_font_pt(mw, mw.font_sizes[size]) + + +def sync_font_menu(mw): + """Check the preset action matching ui_font_pt, uncheck the rest. + + No preset is checked when the user stepped to an in-between size. + """ + actions = getattr(mw, "_font_preset_actions", None) + if not actions: + return + for name, action in actions.items(): + action.setChecked(mw.font_sizes[name] == mw.ui_font_pt) + + +def toggle_dark_mode(mw): + mw.dark_mode = not mw.dark_mode + apply_theme_and_font(mw) + save_ui_prefs(mw.ui_font_pt, mw.dark_mode) + mw.update_slice_list_colors() + mw.update_class_list() + mw.update_annotation_list() + mw.repaint() + + +def apply_stylesheet(mw): + mw.setStyleSheet(soft_dark_stylesheet if mw.dark_mode else default_stylesheet) + + +def update_ui_colors(mw): + mw.update_annotation_list_colors() + mw.update_slice_list_colors() + mw.image_label.update() diff --git a/src/digitalsreeni_image_annotator/widgets/__init__.py b/src/digitalsreeni_image_annotator/widgets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/digitalsreeni_image_annotator/widgets/canvas_context.py b/src/digitalsreeni_image_annotator/widgets/canvas_context.py new file mode 100644 index 0000000..cfc440d --- /dev/null +++ b/src/digitalsreeni_image_annotator/widgets/canvas_context.py @@ -0,0 +1,48 @@ +""" +CanvasContext — narrow read-only view of main-window state used by +ImageLabel during rendering and event handling. + +The orchestrator (ImageAnnotator) constructs one CanvasContext and +passes it to ImageLabel via set_context(). ImageLabel reads state +through the accessors here; writes go out as Qt signals connected to +controllers in ImageAnnotator.__init__. + +All accessors are methods (not attributes) so the source of truth +stays on ImageAnnotator and future refactors that move state to a +controller can re-route the accessor without changing ImageLabel. +""" + + +class CanvasContext: + def __init__(self, main_window): + self._mw = main_window + + def paint_brush_size(self) -> int: + return self._mw.paint_brush_size + + def eraser_size(self) -> int: + return self._mw.eraser_size + + def current_class(self): + return self._mw.current_class + + def class_id(self, name: str) -> int: + return self._mw.class_mapping[name] + + def class_mapping(self) -> dict: + return self._mw.class_mapping + + def is_class_visible(self, name: str) -> bool: + return self._mw.class_controller.is_class_visible(name) + + def current_image_key(self): + return self._mw.current_slice or self._mw.image_file_name + + def all_annotations(self) -> dict: + return self._mw.all_annotations + + def scroll_area(self): + return self._mw.scroll_area + + def dialog_parent(self): + return self._mw diff --git a/src/digitalsreeni_image_annotator/image_label.py b/src/digitalsreeni_image_annotator/widgets/image_label.py similarity index 58% rename from src/digitalsreeni_image_annotator/image_label.py rename to src/digitalsreeni_image_annotator/widgets/image_label.py index e22accf..ab89d8e 100644 --- a/src/digitalsreeni_image_annotator/image_label.py +++ b/src/digitalsreeni_image_annotator/widgets/image_label.py @@ -11,10 +11,8 @@ import os import warnings -import cv2 -import numpy as np from PIL import Image -from PyQt6.QtCore import QPoint, QPointF, QRectF, QSize, Qt +from PyQt6.QtCore import QPoint, QPointF, QRectF, QSize, Qt, pyqtSignal from PyQt6.QtGui import ( QBrush, QColor, @@ -25,11 +23,12 @@ QPainter, QPen, QPixmap, - QPolygon, QPolygonF, QWheelEvent, ) -from PyQt6.QtWidgets import QApplication, QLabel, QMessageBox +from PyQt6.QtWidgets import QLabel, QMessageBox + +from .tools import EraserTool, PaintBrushTool, PolygonTool, RectangleTool warnings.filterwarnings("ignore", category=UserWarning) @@ -39,6 +38,36 @@ class ImageLabel(QLabel): A custom QLabel for displaying images and handling annotations. """ + # Annotation lifecycle + annotationCommitted = pyqtSignal(dict) # paint / accept-temp per-annotation add + annotationsBatchSaved = pyqtSignal() # batch finalizer: save + slice-color refresh + annotationsReplaced = pyqtSignal(str, dict) # eraser path: (image_key, per-class dict) + annotationListUpdateRequested = pyqtSignal() # editing-mode exit refresh + annotationSelected = pyqtSignal(object) # double-click selection + deleteSelectionRequested = pyqtSignal() + finishPolygonRequested = pyqtSignal() + finishRectangleRequested = pyqtSignal() + + # Class + classRequested = pyqtSignal(str) # accept-temp path needs a new class + + # SAM + samPredictionRequested = pyqtSignal() # debounced (mouse press) + samPredictionApplyRequested = pyqtSignal() # post-debounce (mouse release) + samPredictionAccepted = pyqtSignal() # Enter on temp prediction + samPointsCleared = pyqtSignal() # Escape during sam_points: stop timer + + # Tool / UI state + enableToolsRequested = pyqtSignal() + disableToolsRequested = pyqtSignal() + resetToolButtonsRequested = pyqtSignal() + toolSizeChanged = pyqtSignal(str, int) # ("paint" | "eraser", new_size) + + # Navigation / info + zoomInRequested = pyqtSignal() + zoomOutRequested = pyqtSignal() + imageInfoChanged = pyqtSignal() + def __init__(self, parent=None): super().__init__(parent) self.annotations = {} @@ -46,6 +75,11 @@ def __init__(self, parent=None): self.temp_point = None self.current_tool = None self.zoom_factor = 1.0 + # Low-vision UI zoom: ui_font_pt / 10 (legacy default), set by + # theme.apply_theme_and_font. Multiplies overlay sizes (label + # fonts, marker radii, pen widths) — orthogonal to zoom_factor, + # which keeps them constant-size on screen across image zoom. + self.ui_scale = 1.0 self.class_colors = {} self.class_visibility = {} self.start_point = None @@ -56,7 +90,7 @@ def __init__(self, parent=None): self.original_pixmap = None self.scaled_pixmap = None self.pan_start_pos = None - self.main_window = None + self._ctx = None self.offset_x = 0 self.offset_y = 0 self.drawing_polygon = False @@ -70,8 +104,6 @@ def __init__(self, parent=None): self.image_path = None self.dark_mode = False - self.paint_mask = None - self.eraser_mask = None self.temp_paint_mask = None self.is_painting = False self.temp_eraser_mask = None @@ -79,7 +111,6 @@ def __init__(self, parent=None): self.cursor_pos = None # SAM - self.sam_magic_wand_active = False self.sam_bbox = None self.drawing_sam_bbox = False self.temp_sam_prediction = None @@ -91,8 +122,46 @@ def __init__(self, parent=None): self.sam_positive_points = [] self.sam_negative_points = [] - def set_main_window(self, main_window): - self.main_window = main_window + # Per-tool handlers (Phase 7). Each owns its event-handling + # behaviour; state fields used by controllers (current_rectangle, + # current_annotation, temp_paint_mask, …) stay on the widget. + self._tools = { + "polygon": PolygonTool(self), + "rectangle": RectangleTool(self), + "paint_brush": PaintBrushTool(self), + "eraser": EraserTool(self), + } + + def set_context(self, ctx): + self._ctx = ctx + + def set_ui_scale(self, scale): + self.ui_scale = scale + self.update() + + def _pen_w(self, base): + """Overlay pen width: ui-scaled, zoom-compensated (constant on screen).""" + return base * self.ui_scale / self.zoom_factor + + def _overlay_font(self, base=12): + """Overlay label font: ui-scaled, zoom-compensated (constant on screen).""" + return QFont("Arial", max(1, int(base * self.ui_scale / self.zoom_factor))) + + @property + def active_tool_handler(self): + return self._tools.get(self.current_tool) + + def set_active_tool(self, tool_name): + """Called by ImageAnnotator when the user switches tools. Gives + the previous handler a chance to clean up (default no-op + preserves the existing 'drop temp state silently' behaviour; + explicit commit/discard goes through Enter/Escape or the + check_unsaved_changes dialog).""" + prev = self.active_tool_handler + new = self._tools.get(tool_name) + if prev is not None and prev is not new: + prev.deactivate() + self.current_tool = tool_name def set_dark_mode(self, is_dark): self.dark_mode = is_dark @@ -122,8 +191,7 @@ def detect_bit_depth(self): else: self.bit_depth = img.bits - if self.main_window: - self.main_window.update_image_info() + self.imageInfoChanged.emit() def update_scaled_pixmap(self): if self.original_pixmap and not self.original_pixmap.isNull(): @@ -163,143 +231,9 @@ def resizeEvent(self, event): super().resizeEvent(event) self.update_offset() - def start_painting(self, pos): - if self.temp_paint_mask is None: - self.temp_paint_mask = np.zeros( - (self.original_pixmap.height(), self.original_pixmap.width()), - dtype=np.uint8, - ) - self.is_painting = True - self.continue_painting(pos) - - def continue_painting(self, pos): - if not self.is_painting: - return - brush_size = self.main_window.paint_brush_size - cv2.circle( - self.temp_paint_mask, (int(pos[0]), int(pos[1])), brush_size, 255, -1 - ) - self.update() - - def finish_painting(self): - if not self.is_painting: - return - self.is_painting = False - # Don't commit the annotation yet, just keep the temp_paint_mask - - def commit_paint_annotation(self): - if self.temp_paint_mask is not None and self.main_window.current_class: - class_name = self.main_window.current_class - contours, _ = cv2.findContours( - self.temp_paint_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE - ) - for contour in contours: - if cv2.contourArea(contour) > 10: # Minimum area threshold - segmentation = contour.flatten().tolist() - new_annotation = { - "segmentation": segmentation, - "category_id": self.main_window.class_mapping[class_name], - "category_name": class_name, - } - self.annotations.setdefault(class_name, []).append(new_annotation) - self.main_window.add_annotation_to_list(new_annotation) - self.temp_paint_mask = None - self.main_window.save_current_annotations() - self.main_window.update_slice_list_colors() - self.update() - - def discard_paint_annotation(self): - self.temp_paint_mask = None - self.update() - - def start_erasing(self, pos): - if self.temp_eraser_mask is None: - self.temp_eraser_mask = np.zeros( - (self.original_pixmap.height(), self.original_pixmap.width()), - dtype=np.uint8, - ) - self.is_erasing = True - self.continue_erasing(pos) - - def continue_erasing(self, pos): - if not self.is_erasing: - return - eraser_size = self.main_window.eraser_size - cv2.circle( - self.temp_eraser_mask, (int(pos[0]), int(pos[1])), eraser_size, 255, -1 - ) - self.update() - - def finish_erasing(self): - if not self.is_erasing: - return - self.is_erasing = False - # Don't commit the eraser changes yet, just keep the temp_eraser_mask - - def commit_eraser_changes(self): - if self.temp_eraser_mask is not None: - eraser_mask = self.temp_eraser_mask.astype(bool) - current_name = ( - self.main_window.current_slice or self.main_window.image_file_name - ) - annotations_changed = False - - for class_name, annotations in self.annotations.items(): - updated_annotations = [] - max_number = max([ann.get("number", 0) for ann in annotations] + [0]) - for annotation in annotations: - if "segmentation" in annotation: - points = ( - np.array(annotation["segmentation"]) - .reshape(-1, 2) - .astype(int) - ) - mask = np.zeros_like(self.temp_eraser_mask) - cv2.fillPoly(mask, [points], 255) - mask = mask.astype(bool) - mask[eraser_mask] = False - contours, _ = cv2.findContours( - mask.astype(np.uint8), - cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_SIMPLE, - ) - for i, contour in enumerate(contours): - if cv2.contourArea(contour) > 10: # Minimum area threshold - new_segmentation = contour.flatten().tolist() - new_annotation = annotation.copy() - new_annotation["segmentation"] = new_segmentation - if i == 0: - new_annotation["number"] = annotation.get( - "number", max_number + 1 - ) - else: - max_number += 1 - new_annotation["number"] = max_number - updated_annotations.append(new_annotation) - if len(contours) > 1: - annotations_changed = True - else: - updated_annotations.append(annotation) - self.annotations[class_name] = updated_annotations - - self.temp_eraser_mask = None - - # Update the all_annotations dictionary in the main window - self.main_window.all_annotations[current_name] = self.annotations - - # Call update_annotation_list directly - self.main_window.update_annotation_list() - - self.main_window.save_current_annotations() - self.main_window.update_slice_list_colors() - self.update() - - # print(f"Eraser changes committed. Annotations changed: {annotations_changed}") - # print(f"Current annotations: {self.annotations}") - - def discard_eraser_changes(self): - self.temp_eraser_mask = None - self.update() + # Paint, eraser, polygon, and rectangle behaviour lives in + # widgets/tools/*; this widget dispatches events to the active + # handler (see set_active_tool / active_tool_handler). def paintEvent(self, event): super().paintEvent(event) @@ -310,37 +244,38 @@ def paintEvent(self, event): painter.drawPixmap( int(self.offset_x), int(self.offset_y), self.scaled_pixmap ) - # Draw annotations + # Draw committed annotations self.draw_annotations(painter) - # Draw other elements + # Polygon edit mode is modal; runs orthogonal to tool selection if self.editing_polygon: self.draw_editing_polygon(painter) - if self.drawing_rectangle and self.current_rectangle: - self.draw_current_rectangle(painter) - if self.sam_magic_wand_active and self.sam_bbox: - self.draw_sam_bbox(painter) - # --- Draw for SAM-box mode --- + # SAM overlays (cross-cutting; not part of the tool handlers) if self.sam_box_active and self.sam_bbox: self.draw_sam_bbox(painter) - # --- Draw for SAM-points mode --- if self.sam_points_active: painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) + # Radii intentionally NOT zoom-compensated — the dots + # grow with image zoom (pre-existing behaviour). + dot_r = 4 * self.ui_scale for pt in self.sam_positive_points: - painter.setPen(QPen(Qt.GlobalColor.green, 6 / self.zoom_factor, Qt.PenStyle.SolidLine)) + painter.setPen(QPen(Qt.GlobalColor.green, self._pen_w(6), Qt.PenStyle.SolidLine)) painter.setBrush(QBrush(Qt.GlobalColor.green)) - painter.drawEllipse(QPointF(pt[0], pt[1]), 4, 4) + painter.drawEllipse(QPointF(pt[0], pt[1]), dot_r, dot_r) for pt in self.sam_negative_points: - painter.setPen(QPen(Qt.GlobalColor.red, 6 / self.zoom_factor, Qt.PenStyle.SolidLine)) + painter.setPen(QPen(Qt.GlobalColor.red, self._pen_w(6), Qt.PenStyle.SolidLine)) painter.setBrush(QBrush(Qt.GlobalColor.red)) - painter.drawEllipse(QPointF(pt[0], pt[1]), 4, 4) + painter.drawEllipse(QPointF(pt[0], pt[1]), dot_r, dot_r) painter.restore() - # Draw temporary paint mask - if self.temp_paint_mask is not None: - self.draw_temp_paint_mask(painter) - if self.temp_eraser_mask is not None: - self.draw_temp_eraser_mask(painter) + # In-progress overlays from every tool that has state to + # render (paint mask, eraser mask, polygon-in-progress, + # rectangle preview). Pre-Phase-7 these drew whenever + # their state field was populated regardless of the active + # tool; iterating all handlers preserves that — switching + # tools mid-stroke does not hide an unsaved mark. + for handler in self._tools.values(): + handler.paint_overlay(painter) self.draw_tool_size_indicator(painter) if self.temp_annotations: self.draw_temp_annotations(painter) @@ -353,7 +288,7 @@ def draw_temp_annotations(self, painter): for annotation in self.temp_annotations: color = QColor(255, 165, 0, 128) # Semi-transparent orange - painter.setPen(QPen(color, 2 / self.zoom_factor, Qt.PenStyle.DashLine)) + painter.setPen(QPen(color, self._pen_w(2), Qt.PenStyle.DashLine)) painter.setBrush(QBrush(color)) # Prefer segmentation polygon over bbox when both are present @@ -373,7 +308,7 @@ def draw_temp_annotations(self, painter): painter.drawRect(QRectF(x, y, w, h)) # Draw label and score - painter.setFont(QFont("Arial", int(12 / self.zoom_factor))) + painter.setFont(self._overlay_font()) label = f"{annotation['category_name']} {annotation['score']:.2f}" if points is not None: centroid = self.calculate_centroid(points) @@ -390,8 +325,8 @@ def accept_temp_annotations(self): class_name = annotation["category_name"] # Check if the class exists, if not, add it - if class_name not in self.main_window.class_mapping: - self.main_window.add_class(class_name) + if class_name not in self._ctx.class_mapping(): + self.classRequested.emit(class_name) if class_name not in self.annotations: self.annotations[class_name] = [] @@ -401,57 +336,16 @@ def accept_temp_annotations(self): "score" ] # Remove the score as it's not needed in the final annotation self.annotations[class_name].append(annotation) - self.main_window.add_annotation_to_list(annotation) + self.annotationCommitted.emit(annotation) self.temp_annotations.clear() - self.main_window.save_current_annotations() - self.main_window.update_slice_list_colors() + self.annotationsBatchSaved.emit() self.update() def discard_temp_annotations(self): self.temp_annotations.clear() self.update() - def draw_temp_paint_mask(self, painter): - if self.temp_paint_mask is not None: - painter.save() - painter.translate(self.offset_x, self.offset_y) - painter.scale(self.zoom_factor, self.zoom_factor) - - mask_image = QImage( - self.temp_paint_mask.data, - self.temp_paint_mask.shape[1], - self.temp_paint_mask.shape[0], - self.temp_paint_mask.shape[1], - QImage.Format.Format_Grayscale8, - ) - mask_pixmap = QPixmap.fromImage(mask_image) - painter.setOpacity(0.5) - painter.drawPixmap(0, 0, mask_pixmap) - painter.setOpacity(1.0) - - painter.restore() - - def draw_temp_eraser_mask(self, painter): - if self.temp_eraser_mask is not None: - painter.save() - painter.translate(self.offset_x, self.offset_y) - painter.scale(self.zoom_factor, self.zoom_factor) - - mask_image = QImage( - self.temp_eraser_mask.data, - self.temp_eraser_mask.shape[1], - self.temp_eraser_mask.shape[0], - self.temp_eraser_mask.shape[1], - QImage.Format.Format_Grayscale8, - ) - mask_pixmap = QPixmap.fromImage(mask_image) - painter.setOpacity(0.5) - painter.drawPixmap(0, 0, mask_pixmap) - painter.setOpacity(1.0) - - painter.restore() - def draw_tool_size_indicator(self, painter): if self.current_tool in ["paint_brush", "eraser"] and hasattr( self, "cursor_pos" @@ -461,10 +355,10 @@ def draw_tool_size_indicator(self, painter): painter.scale(self.zoom_factor, self.zoom_factor) if self.current_tool == "paint_brush": - size = self.main_window.paint_brush_size + size = self._ctx.paint_brush_size() color = QColor(255, 0, 0, 128) # Semi-transparent red else: # eraser - size = self.main_window.eraser_size + size = self._ctx.eraser_size() color = QColor(0, 0, 255, 128) # Semi-transparent blue # Draw filled circle with lower opacity @@ -477,7 +371,7 @@ def draw_tool_size_indicator(self, painter): # Draw circle outline with full opacity painter.setOpacity(1.0) - painter.setPen(QPen(color.darker(150), 1 / self.zoom_factor, Qt.PenStyle.SolidLine)) + painter.setPen(QPen(color.darker(150), self._pen_w(1), Qt.PenStyle.SolidLine)) painter.setBrush(Qt.BrushStyle.NoBrush) painter.drawEllipse( QPointF(self.cursor_pos[0], self.cursor_pos[1]), size, size @@ -487,7 +381,9 @@ def draw_tool_size_indicator(self, painter): # Reset the transform to ensure text is drawn at screen coordinates painter.resetTransform() font = QFont() - font.setPointSize(10) + # Screen-space text (transform was reset above): scale with + # the UI font setting only, not with image zoom. + font.setPointSize(max(1, int(10 * self.ui_scale))) painter.setFont(font) painter.setPen(QPen(Qt.GlobalColor.black)) # Use black color for better visibility @@ -508,47 +404,11 @@ def draw_tool_size_indicator(self, painter): painter.restore() - def draw_paint_mask(self, painter): - if self.paint_mask is not None: - mask_image = QImage( - self.paint_mask.data, - self.paint_mask.shape[1], - self.paint_mask.shape[0], - self.paint_mask.shape[1], - QImage.Format.Format_Grayscale8, - ) - mask_pixmap = QPixmap.fromImage(mask_image) - painter.setOpacity(0.5) - painter.drawPixmap( - self.offset_x, - self.offset_y, - mask_pixmap.scaled(self.scaled_pixmap.size()), - ) - painter.setOpacity(1.0) - - def draw_eraser_mask(self, painter): - if self.eraser_mask is not None: - mask_image = QImage( - self.eraser_mask.data, - self.eraser_mask.shape[1], - self.eraser_mask.shape[0], - self.eraser_mask.shape[1], - QImage.Format.Format_Grayscale8, - ) - mask_pixmap = QPixmap.fromImage(mask_image) - painter.setOpacity(0.5) - painter.drawPixmap( - self.offset_x, - self.offset_y, - mask_pixmap.scaled(self.scaled_pixmap.size()), - ) - painter.setOpacity(1.0) - def draw_sam_bbox(self, painter): painter.save() painter.translate(self.offset_x, self.offset_y) painter.scale(self.zoom_factor, self.zoom_factor) - painter.setPen(QPen(Qt.GlobalColor.red, 2 / self.zoom_factor, Qt.PenStyle.SolidLine)) + painter.setPen(QPen(Qt.GlobalColor.red, self._pen_w(2), Qt.PenStyle.SolidLine)) x1, y1, x2, y2 = self.sam_bbox painter.drawRect(QRectF(min(x1, x2), min(y1, y2), abs(x2 - x1), abs(y2 - y1))) painter.restore() @@ -558,26 +418,26 @@ def clear_temp_sam_prediction(self): self.update() def check_unsaved_changes(self): - if self.temp_paint_mask is not None or self.temp_eraser_mask is not None: - reply = QMessageBox.question( - self.main_window, - "Unsaved Changes", - "You have unsaved changes. Do you want to save them?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No | QMessageBox.StandardButton.Cancel, - ) - if reply == QMessageBox.StandardButton.Yes: - if self.temp_paint_mask is not None: - self.commit_paint_annotation() - if self.temp_eraser_mask is not None: - self.commit_eraser_changes() - return True - elif reply == QMessageBox.StandardButton.No: - self.discard_paint_annotation() - self.discard_eraser_changes() - return True - else: # Cancel - return False - return True # No unsaved changes + dirty = [t for t in self._tools.values() if t.has_unsaved_state()] + if not dirty: + return True + reply = QMessageBox.question( + self._ctx.dialog_parent(), + "Unsaved Changes", + "You have unsaved changes. Do you want to save them?", + QMessageBox.StandardButton.Yes + | QMessageBox.StandardButton.No + | QMessageBox.StandardButton.Cancel, + ) + if reply == QMessageBox.StandardButton.Yes: + for t in dirty: + t.commit() + return True + if reply == QMessageBox.StandardButton.No: + for t in dirty: + t.discard() + return True + return False # Cancel def clear(self): super().clear() @@ -611,7 +471,7 @@ def draw_annotations(self, painter): painter.scale(self.zoom_factor, self.zoom_factor) for class_name, class_annotations in self.annotations.items(): - if not self.main_window.is_class_visible(class_name): + if not self._ctx.is_class_visible(class_name): continue color = self.class_colors.get(class_name, QColor(Qt.GlobalColor.white)) @@ -626,7 +486,7 @@ def draw_annotations(self, painter): fill_color.setAlphaF(self.fill_opacity) text_color = Qt.GlobalColor.white if self.dark_mode else Qt.GlobalColor.black - painter.setPen(QPen(border_color, 2 / self.zoom_factor, Qt.PenStyle.SolidLine)) + painter.setPen(QPen(border_color, self._pen_w(2), Qt.PenStyle.SolidLine)) painter.setBrush(QBrush(fill_color)) if "segmentation" in annotation: @@ -652,11 +512,9 @@ def draw_annotations(self, painter): if points: centroid = self.calculate_centroid(points) if centroid: - painter.setFont( - QFont("Arial", int(12 / self.zoom_factor)) - ) + painter.setFont(self._overlay_font()) painter.setPen( - QPen(text_color, 2 / self.zoom_factor, Qt.PenStyle.SolidLine) + QPen(text_color, self._pen_w(2), Qt.PenStyle.SolidLine) ) painter.drawText( centroid, @@ -666,28 +524,18 @@ def draw_annotations(self, painter): elif "bbox" in annotation: x, y, width, height = annotation["bbox"] painter.drawRect(QRectF(x, y, width, height)) - painter.setPen(QPen(text_color, 2 / self.zoom_factor, Qt.PenStyle.SolidLine)) + painter.setPen(QPen(text_color, self._pen_w(2), Qt.PenStyle.SolidLine)) painter.drawText( QPointF(x, y), f"{class_name} {annotation.get('number', '')}" ) - if self.current_annotation: - painter.setPen(QPen(Qt.GlobalColor.red, 2 / self.zoom_factor, Qt.PenStyle.SolidLine)) - points = [QPointF(float(x), float(y)) for x, y in self.current_annotation] - if len(points) > 1: - painter.drawPolyline(QPolygonF(points)) - for point in points: - painter.drawEllipse(point, 5 / self.zoom_factor, 5 / self.zoom_factor) - if self.temp_point: - painter.drawLine( - points[-1], - QPointF(float(self.temp_point[0]), float(self.temp_point[1])), - ) + # Polygon-in-progress is rendered by PolygonTool.paint_overlay + # (paintEvent calls active_tool_handler.paint_overlay). # Draw temporary SAM prediction if self.temp_sam_prediction: temp_color = QColor(255, 165, 0, 128) # Semi-transparent orange - painter.setPen(QPen(temp_color, 2 / self.zoom_factor, Qt.PenStyle.DashLine)) + painter.setPen(QPen(temp_color, self._pen_w(2), Qt.PenStyle.DashLine)) painter.setBrush(QBrush(temp_color)) segmentation = self.temp_sam_prediction["segmentation"] @@ -699,37 +547,13 @@ def draw_annotations(self, painter): painter.drawPolygon(QPolygonF(points)) centroid = self.calculate_centroid(points) if centroid: - painter.setFont(QFont("Arial", int(12 / self.zoom_factor))) + painter.setFont(self._overlay_font()) painter.drawText( centroid, f"SAM: {self.temp_sam_prediction['score']:.2f}" ) painter.restore() - def draw_current_rectangle(self, painter): - """Draw the current rectangle being created.""" - if not self.current_rectangle: - return - - painter.save() - painter.translate(self.offset_x, self.offset_y) - painter.scale(self.zoom_factor, self.zoom_factor) - - x1, y1, x2, y2 = self.current_rectangle - color = self.class_colors.get(self.main_window.current_class, QColor(Qt.GlobalColor.red)) - painter.setPen(QPen(color, 2 / self.zoom_factor, Qt.PenStyle.SolidLine)) - painter.drawRect(QRectF(float(x1), float(y1), float(x2 - x1), float(y2 - y1))) - - painter.restore() - - def get_rectangle_from_points(self): - """Get rectangle coordinates from start and end points.""" - if not self.start_point or not self.end_point: - return None - x1, y1 = self.start_point - x2, y2 = self.end_point - return [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] - def draw_editing_polygon(self, painter): """Draw the polygon being edited.""" painter.save() @@ -749,7 +573,7 @@ def draw_editing_polygon(self, painter): fill_color = QColor(color) fill_color.setAlphaF(self.fill_opacity) - painter.setPen(QPen(color, 2 / self.zoom_factor, Qt.PenStyle.SolidLine)) + painter.setPen(QPen(color, self._pen_w(2), Qt.PenStyle.SolidLine)) painter.setBrush(QBrush(fill_color)) painter.drawPolygon(QPolygonF(points)) # Changed QPolygon to QPolygonF - Sreeni @@ -758,7 +582,8 @@ def draw_editing_polygon(self, painter): painter.setBrush(QColor(255, 0, 0)) else: painter.setBrush(QColor(0, 255, 0)) - painter.drawEllipse(point, 5 / self.zoom_factor, 5 / self.zoom_factor) + r = 5 * self.ui_scale / self.zoom_factor + painter.drawEllipse(point, r, r) painter.restore() @@ -789,16 +614,17 @@ def wheelEvent(self, event: QWheelEvent): img_x = (cursor_widget_pos.x() - self.offset_x) / self.zoom_factor img_y = (cursor_widget_pos.y() - self.offset_y) / self.zoom_factor - scrollbar_h = self.main_window.scroll_area.horizontalScrollBar() - scrollbar_v = self.main_window.scroll_area.verticalScrollBar() + scroll_area = self._ctx.scroll_area() + scrollbar_h = scroll_area.horizontalScrollBar() + scrollbar_v = scroll_area.verticalScrollBar() old_scroll_h = scrollbar_h.value() old_scroll_v = scrollbar_v.value() delta = event.angleDelta().y() if delta > 0: - self.main_window.zoom_in() + self.zoomInRequested.emit() else: - self.main_window.zoom_out() + self.zoomOutRequested.emit() # Compute the post-zoom offset analytically from the # viewport size and the new scaled-pixmap size. Reading @@ -807,7 +633,7 @@ def wheelEvent(self, event: QWheelEvent): # widget hasn't shrunk yet when update_offset ran. self.width() # is stale → offset_x is wrong → cursor drifts. The viewport # width is always current. - viewport = self.main_window.scroll_area.viewport() + viewport = scroll_area.viewport() new_scaled_w = self.scaled_pixmap.width() new_scaled_h = self.scaled_pixmap.height() new_offset_x = max(0, (viewport.width() - new_scaled_w) / 2) @@ -835,41 +661,31 @@ def mousePressEvent(self, event: QMouseEvent): return pos = self.get_image_coordinates(event.position()) + + # SAM points has priority over the rest (it accepts both + # mouse buttons and short-circuits the tool dispatch). if self.current_tool == "sam_points" and self.sam_points_active: if event.button() == Qt.MouseButton.LeftButton: self.sam_positive_points.append(pos) self.update() - self.main_window.schedule_sam_prediction() + self.samPredictionRequested.emit() return elif event.button() == Qt.MouseButton.RightButton: self.sam_negative_points.append(pos) self.update() - self.main_window.schedule_sam_prediction() + self.samPredictionRequested.emit() return if event.button() == Qt.MouseButton.LeftButton: if self.current_tool == "sam_box" and self.sam_box_active: self.sam_bbox = [pos[0], pos[1], pos[0], pos[1]] self.drawing_sam_bbox = True - elif self.sam_magic_wand_active: - self.sam_bbox = [pos[0], pos[1], pos[0], pos[1]] - self.drawing_sam_bbox = True elif self.editing_polygon: self.handle_editing_click(pos, event) - elif self.current_tool == "polygon": - if not self.drawing_polygon: - self.drawing_polygon = True - self.current_annotation = [] - self.current_annotation.append(pos) - elif self.current_tool == "rectangle": - self.start_point = pos - self.end_point = pos - self.drawing_rectangle = True - self.current_rectangle = None - elif self.current_tool == "paint_brush": - self.start_painting(pos) - elif self.current_tool == "eraser": - self.start_erasing(pos) + else: + handler = self.active_tool_handler + if handler is not None: + handler.on_mouse_press(event, pos) self.update() def mouseMoveEvent(self, event: QMouseEvent): @@ -880,8 +696,9 @@ def mouseMoveEvent(self, event: QMouseEvent): if self.pan_start_pos: cur = event.globalPosition() delta = cur - self.pan_start_pos - scrollbar_h = self.main_window.scroll_area.horizontalScrollBar() - scrollbar_v = self.main_window.scroll_area.verticalScrollBar() + scroll_area = self._ctx.scroll_area() + scrollbar_h = scroll_area.horizontalScrollBar() + scrollbar_v = scroll_area.verticalScrollBar() scrollbar_h.setValue(scrollbar_h.value() - int(delta.x())) scrollbar_v.setValue(scrollbar_v.value() - int(delta.y())) self.pan_start_pos = cur @@ -897,24 +714,12 @@ def mouseMoveEvent(self, event: QMouseEvent): ): self.sam_bbox[2] = pos[0] self.sam_bbox[3] = pos[1] - elif ( - self.sam_magic_wand_active - and self.drawing_sam_bbox - and self.sam_bbox is not None - ): - self.sam_bbox[2] = pos[0] - self.sam_bbox[3] = pos[1] elif self.editing_polygon: self.handle_editing_move(pos) - elif self.current_tool == "polygon" and self.current_annotation: - self.temp_point = pos - elif self.current_tool == "rectangle" and self.drawing_rectangle: - self.end_point = pos - self.current_rectangle = self.get_rectangle_from_points() - elif self.current_tool == "paint_brush" and event.buttons() == Qt.MouseButton.LeftButton: - self.continue_painting(pos) - elif self.current_tool == "eraser" and event.buttons() == Qt.MouseButton.LeftButton: - self.continue_erasing(pos) + else: + handler = self.active_tool_handler + if handler is not None: + handler.on_mouse_move(event, pos) self.update() def mouseReleaseEvent(self, event: QMouseEvent): @@ -935,26 +740,13 @@ def mouseReleaseEvent(self, event: QMouseEvent): self.sam_bbox[2] = pos[0] self.sam_bbox[3] = pos[1] self.drawing_sam_bbox = False - self.main_window.apply_sam_prediction() - elif ( - self.sam_magic_wand_active - and self.drawing_sam_bbox - and self.sam_bbox is not None - ): - self.sam_bbox[2] = pos[0] - self.sam_bbox[3] = pos[1] - self.drawing_sam_bbox = False - self.main_window.apply_sam_prediction() + self.samPredictionApplyRequested.emit() elif self.editing_polygon: self.editing_point_index = None - elif self.current_tool == "rectangle" and self.drawing_rectangle: - self.drawing_rectangle = False - if self.current_rectangle: - self.main_window.finish_rectangle() - elif self.current_tool == "paint_brush": - self.finish_painting() - elif self.current_tool == "eraser": - self.finish_erasing() + else: + handler = self.active_tool_handler + if handler is not None: + handler.on_mouse_release(event, pos) self.update() def mouseDoubleClickEvent(self, event): @@ -962,13 +754,18 @@ def mouseDoubleClickEvent(self, event): return pos = self.get_image_coordinates(event.position()) if event.button() == Qt.MouseButton.LeftButton: - if self.drawing_polygon and len(self.current_annotation) > 2: - self.finish_polygon() - else: + # Polygon handler can consume the double-click to finish + # the polygon. If it doesn't (no in-progress polygon), fall + # through to polygon-edit mode. + handler = self.active_tool_handler + consumed = False + if handler is not None: + consumed = handler.on_double_click(event, pos) + if not consumed: self.clear_current_annotation() annotation = self.start_polygon_edit(pos) if annotation: - self.main_window.select_annotation_in_list(annotation) + self.annotationSelected.emit(annotation) self.update() def get_image_coordinates(self, pos): @@ -981,100 +778,75 @@ def get_image_coordinates(self, pos): def keyPressEvent(self, event: QKeyEvent): if event.key() == Qt.Key.Key_Return or event.key() == Qt.Key.Key_Enter: # DINO temp_annotations are accepted via the application-wide - # _DINOReviewEventFilter (see ADR-015) so Enter works regardless + # DINOReviewEventFilter (see ADR-015) so Enter works regardless # of focus. The branch below only catches non-DINO temp state # (legacy YOLO model-prediction review path). if self.temp_annotations: self.accept_temp_annotations() elif self.temp_sam_prediction: - self.main_window.accept_sam_prediction() + self.samPredictionAccepted.emit() elif self.editing_polygon: self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None - self.main_window.enable_tools() - self.main_window.update_annotation_list() - elif self.current_tool == "polygon" and self.drawing_polygon: - self.finish_polygon() - elif self.current_tool == "paint_brush": - self.commit_paint_annotation() - elif self.current_tool == "eraser": - self.commit_eraser_changes() + self.enableToolsRequested.emit() + self.annotationListUpdateRequested.emit() else: - self.finish_current_annotation() + handler = self.active_tool_handler + if handler is not None: + handler.on_enter() elif event.key() == Qt.Key.Key_Escape: if self.sam_points_active: - self.main_window.sam_inference_timer.stop() + self.samPointsCleared.emit() self.sam_positive_points = [] self.sam_negative_points = [] self.clear_temp_sam_prediction() self.update() # DINO temp_annotations are rejected via the application-wide - # _DINOReviewEventFilter (see ADR-015). Branch below catches + # DINOReviewEventFilter (see ADR-015). Branch below catches # non-DINO temp state only. elif self.temp_annotations: self.discard_temp_annotations() - elif self.sam_magic_wand_active: + elif self.sam_box_active: self.sam_bbox = None self.clear_temp_sam_prediction() elif self.editing_polygon: self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None - self.main_window.enable_tools() - elif self.current_tool == "paint_brush": - self.discard_paint_annotation() - elif self.current_tool == "eraser": - self.discard_eraser_changes() + self.enableToolsRequested.emit() else: - self.cancel_current_annotation() + handler = self.active_tool_handler + if handler is not None: + handler.on_escape() elif event.key() == Qt.Key.Key_Delete: if self.editing_polygon: - self.main_window.delete_selected_annotations() + self.deleteSelectionRequested.emit() self.editing_polygon = None self.editing_point_index = None self.hover_point_index = None - self.main_window.enable_tools() + self.enableToolsRequested.emit() self.update() elif event.key() == Qt.Key.Key_Minus: if self.current_tool == "paint_brush": - self.main_window.paint_brush_size = max( - 1, self.main_window.paint_brush_size - 1 - ) - print(f"Paint brush size: {self.main_window.paint_brush_size}") + new_size = max(1, self._ctx.paint_brush_size() - 1) + self.toolSizeChanged.emit("paint", new_size) + print(f"Paint brush size: {new_size}") elif self.current_tool == "eraser": - self.main_window.eraser_size = max(1, self.main_window.eraser_size - 1) - print(f"Eraser size: {self.main_window.eraser_size}") - elif event.key() == Qt.Key.Key_Equal: + new_size = max(1, self._ctx.eraser_size() - 1) + self.toolSizeChanged.emit("eraser", new_size) + print(f"Eraser size: {new_size}") + elif event.key() in (Qt.Key.Key_Equal, Qt.Key.Key_Plus): if self.current_tool == "paint_brush": - self.main_window.paint_brush_size += 1 - print(f"Paint brush size: {self.main_window.paint_brush_size}") + new_size = self._ctx.paint_brush_size() + 1 + self.toolSizeChanged.emit("paint", new_size) + print(f"Paint brush size: {new_size}") elif self.current_tool == "eraser": - self.main_window.eraser_size += 1 - print(f"Eraser size: {self.main_window.eraser_size}") + new_size = self._ctx.eraser_size() + 1 + self.toolSizeChanged.emit("eraser", new_size) + print(f"Eraser size: {new_size}") self.update() - def cancel_current_annotation(self): - """Cancel the current annotation being created.""" - if self.current_tool == "polygon" and self.current_annotation: - self.current_annotation = [] - self.temp_point = None - self.drawing_polygon = False - self.update() - - def finish_current_annotation(self): - """Finish the current annotation being created.""" - if self.current_tool == "polygon" and len(self.current_annotation) > 2: - if self.main_window: - self.main_window.finish_polygon() - - def finish_polygon(self): - """Finish the current polygon annotation.""" - if self.drawing_polygon and len(self.current_annotation) > 2: - self.drawing_polygon = False - if self.main_window: - self.main_window.finish_polygon() - def start_polygon_edit(self, pos): for class_name, annotations in self.annotations.items(): for annotation in annotations: @@ -1089,8 +861,8 @@ def start_polygon_edit(self, pos): if self.point_in_polygon(pos, points): self.editing_polygon = annotation self.current_tool = None - self.main_window.disable_tools() - self.main_window.reset_tool_buttons() + self.disableToolsRequested.emit() + self.resetToolButtonsRequested.emit() return annotation return None @@ -1104,7 +876,7 @@ def handle_editing_click(self, pos, event): ) ] for i, point in enumerate(points): - if self.distance(pos, point) < 10 / self.zoom_factor: + if self.distance(pos, point) < 10 * self.ui_scale / self.zoom_factor: if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: # Delete point del self.editing_polygon["segmentation"][i * 2 : i * 2 + 2] @@ -1133,7 +905,7 @@ def handle_editing_move(self, pos): ] self.hover_point_index = None for i, point in enumerate(points): - if self.distance(pos, point) < 10 / self.zoom_factor: + if self.distance(pos, point) < 10 * self.ui_scale / self.zoom_factor: self.hover_point_index = i break if self.editing_point_index is not None: diff --git a/src/digitalsreeni_image_annotator/widgets/tools/__init__.py b/src/digitalsreeni_image_annotator/widgets/tools/__init__.py new file mode 100644 index 0000000..ef27150 --- /dev/null +++ b/src/digitalsreeni_image_annotator/widgets/tools/__init__.py @@ -0,0 +1,13 @@ +from .base import ToolHandler +from .eraser_tool import EraserTool +from .paint_tool import PaintBrushTool +from .polygon_tool import PolygonTool +from .rectangle_tool import RectangleTool + +__all__ = [ + "ToolHandler", + "EraserTool", + "PaintBrushTool", + "PolygonTool", + "RectangleTool", +] diff --git a/src/digitalsreeni_image_annotator/widgets/tools/base.py b/src/digitalsreeni_image_annotator/widgets/tools/base.py new file mode 100644 index 0000000..fe1ec4f --- /dev/null +++ b/src/digitalsreeni_image_annotator/widgets/tools/base.py @@ -0,0 +1,71 @@ +""" +Base class for per-tool mouse / key event handlers in ImageLabel. + +Each handler owns its tool-specific temp state. ImageLabel keeps a +dispatcher that routes events to the active handler. Handlers emit +back through the ImageLabel's Phase 6 signals (see ADR-018) — they +never call into the orchestrator directly. + +Plain Python objects, not QObjects: no need for their own signals, +no parent-child memory model to worry about, and unit tests can +instantiate them without a Qt event loop. +""" + + +class ToolHandler: + def __init__(self, label): + # Back-reference to the ImageLabel. Used to: + # - emit signals (self.label.annotationCommitted.emit(...), …) + # - read state via the CanvasContext (self.label._ctx.X()) + # - write to ImageLabel.annotations (paint/eraser commit paths) + # - trigger a repaint (self.label.update()) + self.label = label + + # --- Mouse hooks. Each returns True if the event was consumed. --- + + def on_mouse_press(self, event, img_pt) -> bool: + return False + + def on_mouse_move(self, event, img_pt) -> bool: + return False + + def on_mouse_release(self, event, img_pt) -> bool: + return False + + def on_double_click(self, event, img_pt) -> bool: + return False + + # --- Key hooks. ImageLabel routes Enter/Escape here only after the + # higher-priority modal branches (DINO temp, sam_points, sam_box, + # editing polygon) have had their turn. --- + + def on_enter(self) -> bool: + return False + + def on_escape(self) -> bool: + return False + + # --- Painter overlay drawn after committed annotations but before + # the size indicator. Tools render their in-progress state here. --- + + def paint_overlay(self, painter) -> None: + return + + # --- Lifecycle. Called when the user switches away from this tool. + # Default is no-op (matches the existing "drop state silently" + # behaviour); commit/discard must be explicit via Enter / Escape. --- + + def deactivate(self) -> None: + return + + # --- Unsaved-state reporting. ImageLabel.check_unsaved_changes + # iterates handlers to decide whether to prompt the user. --- + + def has_unsaved_state(self) -> bool: + return False + + def commit(self) -> None: + return + + def discard(self) -> None: + return diff --git a/src/digitalsreeni_image_annotator/widgets/tools/eraser_tool.py b/src/digitalsreeni_image_annotator/widgets/tools/eraser_tool.py new file mode 100644 index 0000000..a6ddadd --- /dev/null +++ b/src/digitalsreeni_image_annotator/widgets/tools/eraser_tool.py @@ -0,0 +1,155 @@ +"""EraserTool — circular strokes mask out existing polygons; Enter commits.""" + +import cv2 +import numpy as np +from PyQt6.QtCore import Qt +from PyQt6.QtGui import QImage, QPixmap + +from .base import ToolHandler + + +class EraserTool(ToolHandler): + """Mutates ImageLabel's `temp_eraser_mask` and `is_erasing`. The + commit path (OpenCV polygon clipping) is moved byte-for-byte + from the pre-Phase-7 ImageLabel.commit_eraser_changes — do not + refactor here.""" + + def on_mouse_press(self, event, img_pt) -> bool: + if event.button() != Qt.MouseButton.LeftButton: + return False + self._start(img_pt) + return True + + def on_mouse_move(self, event, img_pt) -> bool: + if event.buttons() != Qt.MouseButton.LeftButton: + return False + if not self.label.is_erasing: + return False + self._continue(img_pt) + return True + + def on_mouse_release(self, event, img_pt) -> bool: + if event.button() != Qt.MouseButton.LeftButton: + return False + if not self.label.is_erasing: + return False + self.label.is_erasing = False + # Don't commit the eraser changes yet; Enter or image-switch + # finalises. + return True + + def on_enter(self) -> bool: + if self.label.temp_eraser_mask is None: + return False + self.commit() + return True + + def on_escape(self) -> bool: + if self.label.temp_eraser_mask is None: + return False + self.discard() + return True + + def paint_overlay(self, painter) -> None: + mask = self.label.temp_eraser_mask + if mask is None: + return + painter.save() + painter.translate(self.label.offset_x, self.label.offset_y) + painter.scale(self.label.zoom_factor, self.label.zoom_factor) + + mask_copy = mask.copy() + mask_image = QImage( + mask_copy.data, + mask_copy.shape[1], + mask_copy.shape[0], + mask_copy.shape[1], + QImage.Format.Format_Grayscale8, + ) + mask_pixmap = QPixmap.fromImage(mask_image) + painter.setOpacity(0.5) + painter.drawPixmap(0, 0, mask_pixmap) + painter.setOpacity(1.0) + painter.restore() + + def has_unsaved_state(self) -> bool: + return self.label.temp_eraser_mask is not None + + def commit(self) -> None: + if self.label.temp_eraser_mask is None: + return + eraser_mask = self.label.temp_eraser_mask.astype(bool) + current_name = self.label._ctx.current_image_key() + + for class_name, annotations in self.label.annotations.items(): + updated_annotations = [] + max_number = max([ann.get("number", 0) for ann in annotations] + [0]) + for annotation in annotations: + if "segmentation" in annotation: + points = ( + np.array(annotation["segmentation"]) + .reshape(-1, 2) + .astype(int) + ) + mask = np.zeros_like(self.label.temp_eraser_mask) + cv2.fillPoly(mask, [points], 255) + mask = mask.astype(bool) + mask[eraser_mask] = False + contours, _ = cv2.findContours( + mask.astype(np.uint8), + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE, + ) + for i, contour in enumerate(contours): + if cv2.contourArea(contour) > 10: # Minimum area threshold + new_segmentation = contour.flatten().tolist() + new_annotation = annotation.copy() + new_annotation["segmentation"] = new_segmentation + if i == 0: + new_annotation["number"] = annotation.get( + "number", max_number + 1 + ) + else: + max_number += 1 + new_annotation["number"] = max_number + updated_annotations.append(new_annotation) + else: + updated_annotations.append(annotation) + self.label.annotations[class_name] = updated_annotations + + self.label.temp_eraser_mask = None + # AnnotationController.replace_annotations writes into + # all_annotations and triggers save + slice-color refresh. + self.label.annotationsReplaced.emit(current_name, self.label.annotations) + self.label.update() + + def discard(self) -> None: + self.label.temp_eraser_mask = None + self.label.update() + + # --- internals --- + + def _start(self, pos): + if self.label.temp_eraser_mask is None: + self.label.temp_eraser_mask = np.zeros( + ( + self.label.original_pixmap.height(), + self.label.original_pixmap.width(), + ), + dtype=np.uint8, + ) + self.label.is_erasing = True + self._continue(pos) + + def _continue(self, pos): + if not self.label.is_erasing: + return + eraser_size = self.label._ctx.eraser_size() + cv2.circle( + self.label.temp_eraser_mask, + (int(pos[0]), int(pos[1])), + eraser_size, + 255, + -1, + ) + self.label.update() diff --git a/src/digitalsreeni_image_annotator/widgets/tools/paint_tool.py b/src/digitalsreeni_image_annotator/widgets/tools/paint_tool.py new file mode 100644 index 0000000..6a43ec6 --- /dev/null +++ b/src/digitalsreeni_image_annotator/widgets/tools/paint_tool.py @@ -0,0 +1,129 @@ +"""PaintBrushTool — circular brush strokes into a temp mask; Enter commits.""" + +import cv2 +import numpy as np +from PyQt6.QtCore import Qt +from PyQt6.QtGui import QImage, QPixmap + +from .base import ToolHandler + + +class PaintBrushTool(ToolHandler): + """Mutates ImageLabel's `temp_paint_mask` and `is_painting` so + other code paths (notably `check_unsaved_changes` callers and + paint-mask rendering) see the same state they did pre-Phase-7.""" + + def on_mouse_press(self, event, img_pt) -> bool: + if event.button() != Qt.MouseButton.LeftButton: + return False + self._start(img_pt) + return True + + def on_mouse_move(self, event, img_pt) -> bool: + if event.buttons() != Qt.MouseButton.LeftButton: + return False + if not self.label.is_painting: + return False + self._continue(img_pt) + return True + + def on_mouse_release(self, event, img_pt) -> bool: + if event.button() != Qt.MouseButton.LeftButton: + return False + if not self.label.is_painting: + return False + self.label.is_painting = False + # Don't commit the annotation yet; Enter / image-switch dialog + # finalises. + return True + + def on_enter(self) -> bool: + if self.label.temp_paint_mask is None: + return False + self.commit() + return True + + def on_escape(self) -> bool: + if self.label.temp_paint_mask is None: + return False + self.discard() + return True + + def paint_overlay(self, painter) -> None: + mask = self.label.temp_paint_mask + if mask is None: + return + painter.save() + painter.translate(self.label.offset_x, self.label.offset_y) + painter.scale(self.label.zoom_factor, self.label.zoom_factor) + + mask_copy = mask.copy() + mask_image = QImage( + mask_copy.data, + mask_copy.shape[1], + mask_copy.shape[0], + mask_copy.shape[1], + QImage.Format.Format_Grayscale8, + ) + mask_pixmap = QPixmap.fromImage(mask_image) + painter.setOpacity(0.5) + painter.drawPixmap(0, 0, mask_pixmap) + painter.setOpacity(1.0) + painter.restore() + + def has_unsaved_state(self) -> bool: + return self.label.temp_paint_mask is not None + + def commit(self) -> None: + if self.label.temp_paint_mask is None or not self.label._ctx.current_class(): + return + class_name = self.label._ctx.current_class() + contours, _ = cv2.findContours( + self.label.temp_paint_mask, + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE, + ) + for contour in contours: + if cv2.contourArea(contour) > 10: # Minimum area threshold + segmentation = contour.flatten().tolist() + new_annotation = { + "segmentation": segmentation, + "category_id": self.label._ctx.class_id(class_name), + "category_name": class_name, + } + self.label.annotations.setdefault(class_name, []).append(new_annotation) + self.label.annotationCommitted.emit(new_annotation) + self.label.temp_paint_mask = None + self.label.annotationsBatchSaved.emit() + self.label.update() + + def discard(self) -> None: + self.label.temp_paint_mask = None + self.label.update() + + # --- internals --- + + def _start(self, pos): + if self.label.temp_paint_mask is None: + self.label.temp_paint_mask = np.zeros( + ( + self.label.original_pixmap.height(), + self.label.original_pixmap.width(), + ), + dtype=np.uint8, + ) + self.label.is_painting = True + self._continue(pos) + + def _continue(self, pos): + if not self.label.is_painting: + return + brush_size = self.label._ctx.paint_brush_size() + cv2.circle( + self.label.temp_paint_mask, + (int(pos[0]), int(pos[1])), + brush_size, + 255, + -1, + ) + self.label.update() diff --git a/src/digitalsreeni_image_annotator/widgets/tools/polygon_tool.py b/src/digitalsreeni_image_annotator/widgets/tools/polygon_tool.py new file mode 100644 index 0000000..1aa91de --- /dev/null +++ b/src/digitalsreeni_image_annotator/widgets/tools/polygon_tool.py @@ -0,0 +1,89 @@ +"""PolygonTool — click to add vertices, double-click / Enter to finish.""" + +from PyQt6.QtCore import QPointF, Qt +from PyQt6.QtGui import QPen, QPolygonF + +from .base import ToolHandler + + +class PolygonTool(ToolHandler): + """Mutates ImageLabel state fields (`current_annotation`, + `temp_point`, `drawing_polygon`) directly — those fields are + still read by AnnotationController.finish_polygon and stay on + the widget for now.""" + + def on_mouse_press(self, event, img_pt) -> bool: + if event.button() != Qt.MouseButton.LeftButton: + return False + if not self.label.drawing_polygon: + self.label.drawing_polygon = True + self.label.current_annotation = [] + self.label.current_annotation.append(img_pt) + return True + + def on_mouse_move(self, event, img_pt) -> bool: + if not self.label.current_annotation: + return False + self.label.temp_point = img_pt + return True + + def on_double_click(self, event, img_pt) -> bool: + if event.button() != Qt.MouseButton.LeftButton: + return False + if self.label.drawing_polygon and len(self.label.current_annotation) > 2: + self.label.drawing_polygon = False + self.label.finishPolygonRequested.emit() + return True + return False + + def on_enter(self) -> bool: + if self.label.drawing_polygon and len(self.label.current_annotation) > 2: + self.label.drawing_polygon = False + self.label.finishPolygonRequested.emit() + return True + return False + + def on_escape(self) -> bool: + if self.label.current_annotation: + self.discard() + return True + return False + + def paint_overlay(self, painter) -> None: + if not self.label.current_annotation: + return + painter.save() + painter.translate(self.label.offset_x, self.label.offset_y) + painter.scale(self.label.zoom_factor, self.label.zoom_factor) + + zf = self.label.zoom_factor + painter.setPen(QPen(Qt.GlobalColor.red, 2 / zf, Qt.PenStyle.SolidLine)) + points = [QPointF(float(x), float(y)) for x, y in self.label.current_annotation] + if len(points) > 1: + painter.drawPolyline(QPolygonF(points)) + for point in points: + painter.drawEllipse(point, 5 / zf, 5 / zf) + if self.label.temp_point: + painter.drawLine( + points[-1], + QPointF(float(self.label.temp_point[0]), float(self.label.temp_point[1])), + ) + painter.restore() + + def has_unsaved_state(self) -> bool: + # Only report unsaved if the polygon is actually finishable + # (3+ points). 1- or 2-point polygons can't be saved; they're + # silently dropped on tool switch via discard(), matching the + # pre-Phase-7 behaviour where commit_paint_annotation / + # commit_eraser_changes were the only "save?" prompts. + return self.label.drawing_polygon and len(self.label.current_annotation) > 2 + + def commit(self) -> None: + if self.has_unsaved_state(): + self.label.drawing_polygon = False + self.label.finishPolygonRequested.emit() + + def discard(self) -> None: + self.label.current_annotation = [] + self.label.temp_point = None + self.label.drawing_polygon = False diff --git a/src/digitalsreeni_image_annotator/widgets/tools/rectangle_tool.py b/src/digitalsreeni_image_annotator/widgets/tools/rectangle_tool.py new file mode 100644 index 0000000..9e60ed5 --- /dev/null +++ b/src/digitalsreeni_image_annotator/widgets/tools/rectangle_tool.py @@ -0,0 +1,90 @@ +"""RectangleTool — drag a bbox; release commits via finishRectangleRequested.""" + +from PyQt6.QtCore import QRectF, Qt +from PyQt6.QtGui import QColor, QPen + +from .base import ToolHandler + + +class RectangleTool(ToolHandler): + """Mutates ImageLabel state fields (`start_point`, `end_point`, + `current_rectangle`, `drawing_rectangle`) directly. Those fields + stay on the widget because AnnotationController.finish_rectangle + reads `mw.image_label.current_rectangle`. Moving them onto the + tool would require a parallel controller refactor; out of scope + for Phase 7.""" + + def on_mouse_press(self, event, img_pt) -> bool: + if event.button() != Qt.MouseButton.LeftButton: + return False + self.label.start_point = img_pt + self.label.end_point = img_pt + self.label.drawing_rectangle = True + self.label.current_rectangle = None + return True + + def on_mouse_move(self, event, img_pt) -> bool: + if not self.label.drawing_rectangle: + return False + self.label.end_point = img_pt + self.label.current_rectangle = self._rect_from_points() + return True + + def on_mouse_release(self, event, img_pt) -> bool: + if event.button() != Qt.MouseButton.LeftButton: + return False + if not self.label.drawing_rectangle: + return False + self.label.drawing_rectangle = False + if self.label.current_rectangle: + self.label.finishRectangleRequested.emit() + return True + + def _rect_from_points(self): + s = self.label.start_point + e = self.label.end_point + if not s or not e: + return None + x1, y1 = s + x2, y2 = e + return [min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)] + + def paint_overlay(self, painter) -> None: + if not (self.label.drawing_rectangle and self.label.current_rectangle): + return + painter.save() + painter.translate(self.label.offset_x, self.label.offset_y) + painter.scale(self.label.zoom_factor, self.label.zoom_factor) + + x1, y1, x2, y2 = self.label.current_rectangle + color = self.label.class_colors.get( + self.label._ctx.current_class(), QColor(Qt.GlobalColor.red) + ) + painter.setPen( + QPen(color, 2 / self.label.zoom_factor, Qt.PenStyle.SolidLine) + ) + painter.drawRect( + QRectF(float(x1), float(y1), float(x2 - x1), float(y2 - y1)) + ) + painter.restore() + + def has_unsaved_state(self) -> bool: + # A rectangle commits automatically on mouse release (emits + # finishRectangleRequested) — there's no "draft" rectangle the + # user might want to save later. The only way to have lingering + # state is mid-drag; in that case discard() clears it on tool + # switch / image switch via check_unsaved_changes. + return self.label.drawing_rectangle + + def discard(self) -> None: + self.label.start_point = None + self.label.end_point = None + self.label.current_rectangle = None + self.label.drawing_rectangle = False + + def commit(self) -> None: + # Mid-drag rectangle isn't finishable (no mouse-release signal + # was emitted yet). Treat "Yes save" as discard for consistency + # with the dialog's intent — the user clicked Yes meaning "keep + # what I drew" but there's nothing complete to keep. + self.discard() diff --git a/tests/integration/test_export_formats.py b/tests/integration/test_export_formats.py index 654e5aa..5cc3975 100644 --- a/tests/integration/test_export_formats.py +++ b/tests/integration/test_export_formats.py @@ -11,7 +11,7 @@ import shutil from pathlib import Path from PyQt6.QtGui import QImage -from src.digitalsreeni_image_annotator.export_formats import ( +from src.digitalsreeni_image_annotator.io.export_formats import ( export_coco_json, export_yolo_v5plus, export_pascal_voc_bbox, diff --git a/tests/integration/test_smoke.py b/tests/integration/test_smoke.py new file mode 100644 index 0000000..5d103c3 --- /dev/null +++ b/tests/integration/test_smoke.py @@ -0,0 +1,123 @@ +""" +Smoke tests: verify the package's public API surface and every internal +module can be imported. + +These tests are the safety net for the modular refactoring. They catch +the most common refactor regressions (renames, missing re-exports, broken +intra-package imports) without needing a real image, a SAM model, or a +running Qt event loop. + +If a module gets moved into a subpackage, the corresponding line below +must be updated. That is the test's whole point. +""" + +import importlib + +import pytest + + +def test_public_api_exports(): + """The five names documented in __init__.py must remain importable.""" + import digitalsreeni_image_annotator as pkg + + assert pkg.__version__ == "0.9.0" + assert hasattr(pkg, "ImageAnnotator") + assert hasattr(pkg, "ImageLabel") + assert hasattr(pkg, "SAMUtils") + assert hasattr(pkg, "calculate_area") + assert hasattr(pkg, "calculate_bbox") + + +INTERNAL_MODULES = [ + # Core + # NOTE: 'main' is deliberately omitted — it eagerly imports torch before + # QApplication is created (ADR-017). In a pytest-qt process Qt is already + # loaded; importing torch afterward triggers WinError 1114. main is the + # entry point and is validated by the CLI smoke tests instead. + "digitalsreeni_image_annotator.annotator_window", + "digitalsreeni_image_annotator.utils", + # Widgets + "digitalsreeni_image_annotator.widgets.image_label", + # Inference + "digitalsreeni_image_annotator.inference.sam_utils", + "digitalsreeni_image_annotator.inference.dino_utils", + # I/O + "digitalsreeni_image_annotator.io.export_formats", + "digitalsreeni_image_annotator.io.import_formats", + # Core helpers + "digitalsreeni_image_annotator.core.constants", + "digitalsreeni_image_annotator.core.annotation_utils", + # UI + "digitalsreeni_image_annotator.ui.default_stylesheet", + "digitalsreeni_image_annotator.ui.soft_dark_stylesheet", + # Dialogs + "digitalsreeni_image_annotator.dialogs.annotation_statistics", + "digitalsreeni_image_annotator.dialogs.coco_json_combiner", + "digitalsreeni_image_annotator.dialogs.dataset_splitter", + "digitalsreeni_image_annotator.dialogs.dicom_converter", + "digitalsreeni_image_annotator.dialogs.dino_merge_dialog", + "digitalsreeni_image_annotator.dialogs.dino_phrase_editor", + "digitalsreeni_image_annotator.dialogs.help_window", + "digitalsreeni_image_annotator.dialogs.image_augmenter", + "digitalsreeni_image_annotator.dialogs.image_patcher", + "digitalsreeni_image_annotator.dialogs.project_details", + "digitalsreeni_image_annotator.dialogs.project_search", + "digitalsreeni_image_annotator.dialogs.slice_registration", + "digitalsreeni_image_annotator.dialogs.snake_game", + "digitalsreeni_image_annotator.dialogs.stack_interpolator", + "digitalsreeni_image_annotator.dialogs.stack_to_slices", + "digitalsreeni_image_annotator.dialogs.yolo_trainer", +] + + +@pytest.mark.parametrize("module_name", INTERNAL_MODULES) +def test_internal_module_imports(module_name): + """Every internal module must import without raising.""" + importlib.import_module(module_name) + + +def test_annotator_window_inline_imports_are_resolvable(): + """Parse annotator_window.py AST, verify every bare relative import + (from .module) resolves to a file still in the package root. + + This catches stale inline imports inside function bodies that are + invisible to test_internal_module_imports because Python defers + execution until the function is called. Phase 1 moved 25 modules + into subpackages; four inline imports in annotator_window.py were + missed and only surfaced at runtime (e.g. from .dino_utils import + GDINO_MODEL_PATHS which needed to be .inference.dino_utils). + """ + import ast + import pathlib + + # Package root — modules that stayed at root live here + pkg_dir = ( + pathlib.Path(__file__).parents[2] + / "src" + / "digitalsreeni_image_annotator" + ) + source = (pkg_dir / "annotator_window.py").read_text(encoding="utf-8") + tree = ast.parse(source) + + bad = [] + for node in ast.walk(tree): + if not isinstance(node, ast.ImportFrom): + continue + # Only bare relative imports at level 1 (from .module) + if node.level != 1 or not node.module: + continue + module = node.module + # Proper subpackage imports are fine (e.g. .dialogs.foo) + dots = module.split(".") + if dots[0] in ("controllers", "dialogs", "inference", "io", "ui", "widgets", "core"): + continue + # Root-level modules that stayed behind: utils, annotator_window, main + root_py = pkg_dir / f"{module}.py" + root_pkg = pkg_dir / module / "__init__.py" + if not root_py.exists() and not root_pkg.exists(): + bad.append((node.lineno, f"from .{module} import ...")) + + assert not bad, ( + f"Stale inline imports in annotator_window.py at lines: {bad}. " + f"The module was likely moved into a subpackage; update the import path." + ) diff --git a/tests/ui/test_font_zoom.py b/tests/ui/test_font_zoom.py new file mode 100644 index 0000000..e1d60dd --- /dev/null +++ b/tests/ui/test_font_zoom.py @@ -0,0 +1,151 @@ +"""Tests for the continuous UI font zoom (low-vision mode). + +Uses a minimal QMainWindow stub carrying exactly the state +theme.set_font_pt touches, instead of constructing the full +ImageAnnotator (which would dominate suite runtime). save_ui_prefs is +patched out so tests never write the real per-user settings. +""" + +import pytest +from PyQt6.QtGui import QAction +from PyQt6.QtWidgets import QMainWindow + +from digitalsreeni_image_annotator.app_settings import ( + FONT_PT_DEFAULT, + FONT_PT_MAX, + FONT_PT_MIN, +) +from digitalsreeni_image_annotator.ui import theme +from digitalsreeni_image_annotator.widgets.image_label import ImageLabel + + +class _StubWindow(QMainWindow): + """Just enough ImageAnnotator surface for theme.set_font_pt.""" + + def __init__(self): + super().__init__() + self.font_sizes = {"Small": 8, "Medium": 10, "Large": 12, "XL": 14, "XXL": 16} + self.ui_font_pt = FONT_PT_DEFAULT + self.dark_mode = True + self.image_label = ImageLabel() + self._font_preset_actions = {} + for name in self.font_sizes: + action = QAction(name, self) + action.setCheckable(True) + self._font_preset_actions[name] = action + + +@pytest.fixture +def window(qt_application, monkeypatch): + saved = [] + monkeypatch.setattr( + theme, "save_ui_prefs", lambda pt, dark, settings=None: saved.append((pt, dark)) + ) + w = _StubWindow() + w._saved_prefs = saved + yield w + w.image_label.deleteLater() + w.deleteLater() + + +def test_step_up_increments_and_scales_canvas(window): + theme.step_font_pt(window, 1) + assert window.ui_font_pt == FONT_PT_DEFAULT + 1 + assert window.image_label.ui_scale == pytest.approx( + (FONT_PT_DEFAULT + 1) / FONT_PT_DEFAULT + ) + + +def test_step_clamps_at_bounds(window): + theme.set_font_pt(window, FONT_PT_MAX) + theme.step_font_pt(window, 1) + assert window.ui_font_pt == FONT_PT_MAX + + theme.set_font_pt(window, FONT_PT_MIN) + theme.step_font_pt(window, -1) + assert window.ui_font_pt == FONT_PT_MIN + + +def test_reset_returns_to_default(window): + theme.set_font_pt(window, 20) + theme.reset_font_pt(window) + assert window.ui_font_pt == FONT_PT_DEFAULT + assert window.image_label.ui_scale == pytest.approx(1.0) + + +def test_preset_entry_point_sets_value(window): + theme.change_font_size(window, "XXL") + assert window.ui_font_pt == 16 + + +def test_preset_checkmark_follows_value(window): + theme.change_font_size(window, "Large") + assert window._font_preset_actions["Large"].isChecked() + assert not window._font_preset_actions["Medium"].isChecked() + + # Stepping to an in-between size unchecks every preset. + theme.step_font_pt(window, 1) # 13pt — between Large and XL + assert not any(a.isChecked() for a in window._font_preset_actions.values()) + + +def test_every_change_is_persisted(window): + theme.set_font_pt(window, 12) + theme.step_font_pt(window, 1) + assert window._saved_prefs == [(12, True), (13, True)] + + +def test_default_scale_renders_identical_to_legacy(window): + """At the default 10pt, ui_scale is exactly 1.0 — overlay rendering + must be pixel-identical to the pre-feature code paths.""" + theme.set_font_pt(window, FONT_PT_DEFAULT) + assert window.image_label.ui_scale == 1.0 + assert window.image_label._pen_w(2) == pytest.approx(2.0) + + +def test_stylesheet_contains_scaled_overrides(window): + theme.set_font_pt(window, 20) + sheet = window.styleSheet() + assert "QWidget { font-size: 20pt; }" in sheet + # Header/indicator overrides scale the legacy px values by 2x at 20pt. + assert "QLabel.section-header { font-size: 28px; }" in sheet + assert "width: 28px; height: 28px;" in sheet + assert "QRadioButton::indicator { border-radius: 16px; }" in sheet + # Compact DINO panel scales too (11px / 10px legacy at default). + assert "PhraseEditorPanel QPushButton { font-size: 22px; }" in sheet + assert "QLabel#dino_phrase_hint { font-size: 20px; }" in sheet + + +def test_dino_panel_resolves_compact_scaled_font(window, qt_application): + """The compact DINO-panel rules must win at the *rendered-font* + level — i.e. the QSS px rules beat the findChildren setFont(pt) + loop in apply_theme_and_font. String presence in the stylesheet + (tested above) is not enough; this guards the precedence.""" + from PyQt6.QtWidgets import QLabel + + from digitalsreeni_image_annotator.dialogs.dino_phrase_editor import ( + PhraseEditorPanel, + ) + + panel = PhraseEditorPanel() + window.setCentralWidget(panel) + theme.set_font_pt(window, 20) # scale 2.0 -> compact 22px, hint 20px + qt_application.processEvents() + + panel.btn_add_phrase.ensurePolished() + assert panel.btn_add_phrase.font().pixelSize() == 22 + hint = panel.findChild(QLabel, "dino_phrase_hint") + hint.ensurePolished() + assert hint.font().pixelSize() == 20 + + +def test_default_stylesheet_overrides_match_legacy_px(window): + """At the 10pt default the appended overrides must reproduce the + static stylesheets' values exactly (14px header, 14px indicators, + 8px radio radius) — the zoom feature must be invisible until used.""" + theme.set_font_pt(window, FONT_PT_DEFAULT) + sheet = window.styleSheet() + assert "QLabel.section-header { font-size: 14px; }" in sheet + assert "width: 14px; height: 14px;" in sheet + assert "QRadioButton::indicator { border-radius: 8px; }" in sheet + assert "PhraseEditorPanel QPushButton { font-size: 11px; }" in sheet + assert "QLabel#dino_phrase_hint { font-size: 10px; }" in sheet diff --git a/tests/unit/test_app_settings.py b/tests/unit/test_app_settings.py new file mode 100644 index 0000000..c0f4e3b --- /dev/null +++ b/tests/unit/test_app_settings.py @@ -0,0 +1,61 @@ +"""Unit tests for app_settings (UI preference persistence). + +QSettings is exercised against an INI file in tmp_path so the tests +never touch the real per-user registry/config. +""" + +import pytest +from PyQt6.QtCore import QSettings + +from digitalsreeni_image_annotator.app_settings import ( + FONT_PT_DEFAULT, + FONT_PT_MAX, + FONT_PT_MIN, + clamp_font_pt, + load_ui_prefs, + save_ui_prefs, +) + + +class TestClampFontPt: + def test_in_range_passes_through(self): + assert clamp_font_pt(12) == 12 + + def test_below_min_clamps(self): + assert clamp_font_pt(3) == FONT_PT_MIN + + def test_above_max_clamps(self): + assert clamp_font_pt(99) == FONT_PT_MAX + + def test_numeric_string_is_coerced(self): + # QSettings INI backend round-trips ints as strings. + assert clamp_font_pt("14") == 14 + + def test_garbage_falls_back_to_default(self): + assert clamp_font_pt("huge") == FONT_PT_DEFAULT + + def test_none_falls_back_to_default(self): + assert clamp_font_pt(None) == FONT_PT_DEFAULT + + +@pytest.fixture +def ini_settings(tmp_path): + return QSettings(str(tmp_path / "prefs.ini"), QSettings.Format.IniFormat) + + +class TestUiPrefsRoundtrip: + def test_defaults_from_empty_settings(self, ini_settings): + assert load_ui_prefs(ini_settings) == (FONT_PT_DEFAULT, True) + + def test_roundtrip(self, ini_settings): + save_ui_prefs(18, False, ini_settings) + ini_settings.sync() + assert load_ui_prefs(ini_settings) == (18, False) + + def test_save_clamps_out_of_range(self, ini_settings): + save_ui_prefs(100, True, ini_settings) + assert load_ui_prefs(ini_settings) == (FONT_PT_MAX, True) + + def test_load_clamps_corrupt_value(self, ini_settings): + ini_settings.setValue("ui/font_pt", "not-a-number") + assert load_ui_prefs(ini_settings)[0] == FONT_PT_DEFAULT diff --git a/tests/unit/test_conversions.py b/tests/unit/test_conversions.py index e10ba87..ae85c32 100644 --- a/tests/unit/test_conversions.py +++ b/tests/unit/test_conversions.py @@ -5,20 +5,14 @@ """ import pytest -import sys -import os -import importlib.util from PyQt6.QtCore import QPoint, QSize from PyQt6.QtGui import QPixmap -# Import image_label module directly by file path to avoid torch dependency issues -image_label_path = os.path.join(os.path.dirname(__file__), '..', '..', 'src', 'digitalsreeni_image_annotator', 'image_label.py') -spec = importlib.util.spec_from_file_location("image_label", image_label_path) -image_label = importlib.util.module_from_spec(spec) -sys.modules['digitalsreeni_image_annotator.image_label'] = image_label -spec.loader.exec_module(image_label) - -ImageLabel = image_label.ImageLabel +# Phase 7 introduced widgets/tools/* as a subpackage, so image_label.py +# now uses relative imports and can no longer be loaded via spec_from_file_location. +# The widgets/__init__.py is empty and doesn't pull in torch, so a normal import +# is safe here. +from src.digitalsreeni_image_annotator.widgets.image_label import ImageLabel @pytest.fixture diff --git a/tools/check_pyqt6_torch_coexistence.py b/tools/check_pyqt6_torch_coexistence.py index c630eeb..122eb0e 100644 --- a/tools/check_pyqt6_torch_coexistence.py +++ b/tools/check_pyqt6_torch_coexistence.py @@ -4,33 +4,32 @@ Why this exists --------------- The historical ADR-011 documented that on Windows + Python 3.14, -importing PyQt5 first and then loading PyTorch triggers -``WinError 1114`` (DLL load-order conflict between Qt's and Torch's -native deps). That motivated the now-deleted subprocess isolation -layer (sam_worker.py, dino_worker.py, check_worker_isolation.py). - -Migrating to PyQt6 *should* eliminate the conflict — Qt6 reshuffled -its DLL packaging — but that is a hypothesis. This script is the -mechanical check. Run it before deleting any worker code. - -The crucial bit: ``import PyQt6.QtCore`` alone does NOT load Qt's -native platform plugin (qwindows.dll on Windows, libqxcb on Linux). -The plugin is loaded lazily by ``QApplication.__init__``. That's -where the WinError 1114 actually triggers. So this script -constructs a ``QApplication`` after importing both PyQt6 and torch -to exercise the real interaction. +importing PyQt first and then loading PyTorch triggers +``WinError 1114`` (DLL load-order conflict). It was thought that +migrating to PyQt6 (ADR-014) eliminated the conflict, but +real-world testing with torch 2.11.0 + PyQt6 6.10.2 shows the +conflict still surfaces when Qt DLLs are loaded BEFORE torch. +The workaround is simple and confirmed: import torch eagerly +before QApplication is created so torch claims its DLL slot first. +See ADR-017. + +The crucial bit: plain ``import PyQt6`` does NOT load Qt's native +platform plugin (qwindows.dll on Windows, libqxcb on Linux). The +plugin is loaded lazily by ``QApplication.__init__``. So this +script tests BOTH orders to document the real failure mode and +confirm safe order. Usage ----- python tools/check_pyqt6_torch_coexistence.py -Run it especially on Windows + Python 3.14. Exit code 0 means the -combination loads cleanly *and* QApplication constructs without -crashing; exit code 1 means at least one stage failed. +Exit code 0 means torch-first works (production order). +Exit code 1 means torch-first also fails → return to subprocess. """ from __future__ import annotations +import multiprocessing import platform import sys import traceback @@ -67,37 +66,127 @@ def _construct_qapplication(): return app +def _check_torch_then_qt() -> bool: + """ + Production import order: torch first, then QApplication. + This is what main.py does (see ADR-017). + """ + ok = True + ok &= _try("(torch-first) torch", lambda: __import__("torch")) + ok &= _try("(torch-first) torchvision", lambda: __import__("torchvision")) + ok &= _try("(torch-first) transformers", lambda: __import__("transformers")) + ok &= _try("(torch-first) ultralytics", lambda: __import__("ultralytics")) + ok &= _try( + "(torch-first) QApplication construct (loads Qt platform plugin)", + _construct_qapplication, + ) + return ok + + +def _check_qt_then_torch() -> bool: + """ + The import order that ADR-014 thought was fixed. On some torch + versions this still fails (WinError 1114). We check it so we + can warn if the 'safe' environment regressed. + """ + ok = True + ok &= _try("(qt-first) PyQt6.QtCore", lambda: __import__("PyQt6.QtCore", fromlist=["QtCore"])) + ok &= _try("(qt-first) PyQt6.QtWidgets", lambda: __import__("PyQt6.QtWidgets", fromlist=["QtWidgets"])) + ok &= _try("(qt-first) PyQt6.QtGui", lambda: __import__("PyQt6.QtGui", fromlist=["QtGui"])) + # Force platform plugin load BEFORE torch — this is where the + # failure appeared. + ok &= _try( + "(qt-first) QApplication construct (loads Qt platform plugin)", + _construct_qapplication, + ) + if not ok: + print("[qt-first] Qt failed to load — can't test torch-after-Qt.") + return False + ok &= _try("(qt-first) torch", lambda: __import__("torch")) + ok &= _try("(qt-first) torchvision", lambda: __import__("torchvision")) + ok &= _try("(qt-first) ultralytics", lambda: __import__("ultralytics")) + return ok + + def main() -> int: print(f"Python: {sys.version}") print(f"Platform: {platform.platform()}") print(f"Machine: {platform.machine()}") print("-" * 60) - # Order matters: PyQt first, then Torch, then Transformers. - # This is the exact order the running app loads them in - # (annotator_window imports PyQt at startup; torch is pulled - # in by ultralytics/transformers when the user picks a model). + safe_ok = _check_torch_then_qt() + print("-" * 60) + + # We run qt-first check in a FRESH process because the preceding + # torch-first test may have already loaded DLLs that would mask + # the issue. + print("\nChecking Qt-first order in a fresh subprocess...") + import subprocess as sp + + worker_src = """ +import ast, sys, traceback + + +def _try(label, fn): + print(f"[{label}] running ...", flush=True) + try: + result = fn() + except BaseException: + print(f"[{label}] FAILED:") + traceback.print_exc() + return False + print(f"[{label}] OK", flush=True) + return True + + +def _construct_qapplication(): + import os + os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") + from PyQt6.QtWidgets import QApplication + app = QApplication.instance() or QApplication(sys.argv) + return app + + +def _check_qt_then_torch(): ok = True - ok &= _try("PyQt6.QtCore", lambda: __import__("PyQt6.QtCore", fromlist=["QtCore"])) - ok &= _try("PyQt6.QtWidgets", lambda: __import__("PyQt6.QtWidgets", fromlist=["QtWidgets"])) - ok &= _try("PyQt6.QtGui", lambda: __import__("PyQt6.QtGui", fromlist=["QtGui"])) - ok &= _try("torch", lambda: __import__("torch")) - ok &= _try("torchvision", lambda: __import__("torchvision")) - ok &= _try("transformers", lambda: __import__("transformers")) - ok &= _try("ultralytics", lambda: __import__("ultralytics")) - # THIS is the real test — load the Qt platform plugin AFTER torch - # is in the address space. Pure import_module above does not load - # the platform plugin, so a green result without this step would - # be a false positive. - ok &= _try("QApplication construct (loads Qt platform plugin)", _construct_qapplication) + ok &= _try("(qt-first) PyQt6.QtCore", lambda: __import__("PyQt6.QtCore", fromlist=["QtCore"])) + ok &= _try("(qt-first) PyQt6.QtWidgets", lambda: __import__("PyQt6.QtWidgets", fromlist=["QtWidgets"])) + ok &= _try("(qt-first) PyQt6.QtGui", lambda: __import__("PyQt6.QtGui", fromlist=["QtGui"])) + ok &= _try("(qt-first) QApplication", _construct_qapplication) + if not ok: + print("[qt-first] Qt failed — can't test torch-after-Qt.") + return False + ok &= _try("(qt-first) torch", lambda: __import__("torch")) + ok &= _try("(qt-first) torchvision", lambda: __import__("torchvision")) + ok &= _try("(qt-first) ultralytics", lambda: __import__("ultralytics")) + return ok + +ok = _check_qt_then_torch() +print("OK" if ok else "FAIL") +""" + proc = sp.run( + [sys.executable, "-c", worker_src], + capture_output=True, + text=True, + timeout=120, + ) + print(proc.stdout, end="") + if proc.stderr: + print(proc.stderr, end="", file=sys.stderr) + qt_first_ok = proc.stdout.strip().endswith("OK") and proc.returncode == 0 print("-" * 60) - if ok: - print("RESULT: PyQt6 + Torch coexist cleanly, QApplication constructs.") - print(" Subprocess removal unblocked.") - return 0 - print("RESULT: at least one stage failed. Investigate before merging.") - return 1 + if not safe_ok: + print("RESULT: torch-first order FAILED.") + print(" Return to subprocess isolation (ADR-011).") + return 1 + if not qt_first_ok: + print("RESULT: torch-first OK. Qt-first FAILED (known with some versions).") + print(" Keep main.py eager torch import (ADR-017).") + else: + print("RESULT: both orders clean. Qt packaging has fixed the conflict.") + print(" Consider removing main.py eager torch import if confirmed stable.") + return 0 if __name__ == "__main__":