Skip to content

arnavgarg233/2.5CNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

2.5 CNN: Leveraging 2D CNNs to Pretrain 3D Models in Low-Data Regimes for COVID-19 Diagnosis

Python 3.10+ License: MIT Status: published Journal DOI

Official code for the 2.5 CNN pipeline: a two-stage approach that first trains a 2D CNN on individual CT slices, then reuses the learned feature extractor inside a lightweight 3D classifier built on stacked slice embeddings. The design combines slice-level representation learning with the volumetric context essential for medical image interpretation, while directly addressing the 3D label scarcity that plagues clinical CT datasets.

Headline results

MosMed dataset (1130 chest CT scans, 5-class COVID-19 severity, multi-class with severe imbalance):

Metric Value
Weighted accuracy 94.73 %
Unweighted accuracy 95.35 %
Comparison Surpasses both purely 2D and purely 3D pipelines trained on the same data

The model also remains robust under the additional class-imbalance stress of fine-grained severity stratification — see the published paper for per-task and per-severity breakdowns.

Publication

Title 2.5 CNN: Leveraging 2D CNNs to Pretrain 3D Models in Low-Data Regimes for COVID-19 Diagnosis
Journal Electronics (MDPI), 2025, 14 (13), 2571
DOI 10.3390/electronics14132571
Article landing page MDPI
PDF Direct PDF
Article license Open access — CC BY 4.0
Code license MITLICENSE
Authors Arnav Garg · Aksh Garg · Dominique Duncan (corresponding)

BibTeX:

@Article{electronics14132571,
  AUTHOR  = {Garg, Arnav and Garg, Aksh and Duncan, Dominique},
  TITLE   = {2.5 CNN: Leveraging 2D CNNs to Pretrain 3D Models in Low-Data Regimes for COVID-19 Diagnosis},
  JOURNAL = {Electronics},
  VOLUME  = {14},
  YEAR    = {2025},
  NUMBER  = {13},
  ARTICLE-NUMBER = {2571},
  URL     = {https://www.mdpi.com/2079-9292/14/13/2571},
  ISSN    = {2079-9292},
  DOI     = {10.3390/electronics14132571}
}

Plain citation: Garg, A.; Garg, A.; Duncan, D. 2.5 CNN: Leveraging 2D CNNs to Pretrain 3D Models in Low-Data Regimes for COVID-19 Diagnosis. Electronics 2025, 14 (13), 2571. https://doi.org/10.3390/electronics14132571

Highlights

  • Two-stage 2.5D pipeline — 2D CNN feature extractor on CT slices, then a lightweight 3D classifier on stacked slice embeddings.
  • Addresses 3D label scarcity by expanding effective training-set size at the slice level before the volumetric stage.
  • Outperforms both pure-2D and pure-3D baselines on the same MosMed split (94.73 % weighted / 95.35 % unweighted accuracy).
  • Robust to severity-level class imbalance across the 5 MosMed CT-0 → CT-4 categories.
  • Theoretical framing connects the 2.5D approach to multi-instance learning and analyses its computational savings versus naïve 3D training in low-data regimes.

Method overview

        ┌─────────────────────────────────────────────────┐
        │  Stage 1 — 2D CNN on individual CT slices       │
        │  • Expands the effective training set           │
        │  • Learns slice-level pneumonia / GGO patterns  │
        └────────────────────┬────────────────────────────┘
                             │  feature extractor (transferred)
                             ▼
        ┌─────────────────────────────────────────────────┐
        │  Stage 2 — Stack slice embeddings → 3D head     │
        │  • Lightweight volumetric classifier            │
        │  • Captures inter-slice volumetric context      │
        └────────────────────┬────────────────────────────┘
                             ▼
                 5-class severity prediction (CT-0 ... CT-4)

The 2.5D framing preserves the data-efficiency of slice-level 2D learning while recovering the volumetric context that pure-2D methods discard. The paper situates this design within the multi-instance learning literature and provides a complexity comparison versus naïve full-3D training under fixed labelled-volume budgets.

Dataset

Source MosMed — chest CT volumes acquired during the early COVID-19 pandemic
Volumes 1130 CT scans
Task Multi-class severity classification (CT-0 normal, CT-1 mild, CT-2 moderate, CT-3 severe, CT-4 critical)
Class balance Heavily imbalanced — addressed via weighted sampling and weighted accuracy reporting
Slice format Per-volume axial slices, served as .npy arrays for the 2D stage

Architectures

Model family Where it lives Role
2D CNN (slice-level) Models/models2d.py Stage 1 — slice classifier and feature extractor
3D CNN (volumetric) Models/models3d.py Pure-3D baseline for ablation
2.5D hybrid Models/models_half.py, Models/half_using_3d.py Stage 2 — pretrained-2D backbone + 3D head
Generic backbones Models/models.py Shared model utilities, ResNet variants

Repository structure

2.5CNN/
├── Configs/                  # YAML training configs (per task / split)
│   ├── config.yaml           # Default 5-class severity
│   ├── 1v3.yaml              # CT-1 vs CT-3 binary
│   ├── first_two.yaml        # CT-0 vs CT-1 binary (early severity)
│   ├── first_three.yaml      # 3-class (CT-0/1/2) coarse severity
│   ├── resnet_no_dropout_multiclass.yaml
│   └── custom.yaml           # Override starting point
├── Data/
│   ├── dataloader.py         # 2D / 3D dataset and split logic
│   ├── transformations.py    # Medical-image augmentations (intensity, geometry)
│   └── weighted_sampler.py   # Class-weighted sampling for imbalanced batches
├── Models/
│   ├── models2d.py           # Stage-1 2D CNNs (ResNet variants, etc.)
│   ├── models3d.py           # Pure-3D baselines
│   ├── models_half.py        # 2.5D hybrid (paper headline)
│   ├── half_using_3d.py      # Alternate 2.5D head
│   ├── models.py             # Shared backbones / utilities
│   └── models_half.ipynb     # Notebook walk-through of the 2.5D head
├── Scripts/
│   ├── store_data_in_files.py      # Volumes → slice .npy arrays
│   ├── store_slices.py             # Slice extraction + caching
│   ├── train_val_save.py           # Train / val split utilities
│   ├── compare_2d_and_3d_data.py   # Sanity check: 2D vs 3D inputs
│   ├── vis_scans.py / vis_slices.py
│   ├── slider_analysis.py          # Threshold / decision-boundary sweep
│   └── "Severity Assessment With Preloaded Files.py"
├── Utils/
│   ├── utils.py              # General helpers
│   ├── clock.py / clock_2d.py  # Timing instrumentation
│   └── syscheck.py           # CUDA / MPS / system sanity check
├── multiplexer/              # GPU job scheduler for batch experiments
│   ├── scheduler.py
│   ├── config_generator.py
│   ├── launch_gpu.sh
│   ├── default.yaml
│   ├── job_queue.txt / enqueue.txt / job_id.txt
│   └── launch_configs/       # Per-job YAML overrides
├── src/
│   ├── launch.py             # Main training entry point
│   └── evaluate.py           # Evaluation entry point
├── LICENSE                   # MIT
└── README.md

Installation

Tested with Python 3.10+, PyTorch ≥ 2.0, MONAI ≥ 1.2. GPU recommended for the 3D stage; the 2D stage runs on CPU/MPS at modest speed.

git clone https://github.com/arnavgarg233/2.5CNN.git
cd 2.5CNN

# Create env (conda or venv shown — uv works equivalently):
conda create -n 25cnn python=3.10 -y
conda activate 25cnn

# Install PyTorch matching your CUDA / MPS / CPU build, then:
pip install monai wandb pyyaml pillow numpy

Optional — Weights & Biases logging:

wandb login

Quick start

Train (default 5-class severity)

python src/launch.py --config Configs/config.yaml

Override hyperparameters from CLI

python src/launch.py --config Configs/custom.yaml \
    --batch_size 32 --learning_rate 1e-3

Train on a binary subtask (CT-1 vs CT-3)

python src/launch.py --config Configs/1v3.yaml

Evaluate a trained checkpoint

python src/evaluate.py --config Configs/config.yaml --checkpoint <path-to-.pt>

Batch experiments via the GPU scheduler

cd multiplexer
python config_generator.py        # generate sweep configs from default.yaml
python scheduler.py               # consume the job queue, dispatch to GPUs

Data preparation

MosMed volumes are converted into per-slice .npy arrays, organized by class:

data/
├── train/
│   ├── class_0/                # CT-0 (normal)
│   ├── class_1/                # CT-1 (mild)
│   ├── class_2/                # CT-2 (moderate)
│   ├── class_3/                # CT-3 (severe)
│   └── class_4/                # CT-4 (critical)
└── val/
    ├── class_0/
    ├── class_1/
    ├── class_2/
    ├── class_3/
    └── class_4/

Use the helper scripts under Scripts/ to convert raw NIfTI / DICOM volumes into the expected layout:

python Scripts/store_data_in_files.py --src /path/to/MosMed --dst data/
python Scripts/store_slices.py        --src data/ --dst data_slices/

Training details (paper-aligned)

Aspect Setting
Optimizer Adam
Learning rate configurable per YAML (typical 1e-4 – 1e-3, cosine / step decay)
Loss Weighted cross-entropy (class-frequency weights)
Class imbalance Data/weighted_sampler.py — class-frequency-weighted sampling per batch
Augmentation MONAI / custom intensity + geometric transforms in Data/transformations.py
Stage 1 input 2D axial CT slices
Stage 2 input Stacked slice embeddings from frozen 2D backbone
Validation strategy Held-out subject-level split (no patient-leakage between train and val)
Logging Weights & Biases (optional)

Exact hyperparameters per task are checked in under Configs/. The multiplexer/ directory contains the experiment-sweep scaffolding used to produce the paper's ablation tables.

Evaluation

python src/evaluate.py --config Configs/config.yaml \
    --checkpoint outputs/checkpoints/best.pt

Reports weighted / unweighted accuracy, per-class precision / recall / F1, confusion matrices, and (optionally) ROC / PR curves. See the paper for the comparison tables across 2D-only, 3D-only, and 2.5D variants and across the binary / 3-class / 5-class subtasks.

Computational footprint

The 2.5D pipeline is intentionally lighter than full 3D training in the same data regime:

  • Stage 1 runs on commodity GPUs (or Apple Silicon MPS / CPU at slower speed) since 2D slices are inexpensive.
  • Stage 2 is a small 3D head over precomputed embeddings — orders of magnitude cheaper than training a full 3D CNN from scratch.
  • This is a practical advantage when the labelled-volume budget is small, which is precisely the clinical-CT regime motivating the design.

License

This repository is released under the MIT License — see LICENSE. The article itself is open-access under CC BY 4.0 (see the MDPI page for the full legal text).

Copyright © 2025 Arnav Garg, Aksh Garg, Dominique Duncan.

Contact

Arnav Garg — first author — arnavgarg888@gmail.com For questions related to the published article, please contact the corresponding author (Dominique Duncan) listed on the MDPI page.

About

2.5 CNN: leverage 2D CNNs to pretrain 3D models for low-data COVID-19 CT diagnosis. Official code for Electronics (MDPI) 2025, 14(13):2571.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors