A customized fork of milesial/Pytorch-UNet for plant semantic segmentation, featuring an integrated Gradio Web UI for training control, real-time monitoring, inference testing, and dataset exploration — all from the browser.
- 🌿 Plant Segmentation — Pre-configured for 2-class (background + plant) segmentation
- 🖥️ Gradio Web Interface — Train, predict, and explore data via a browser-based dashboard, no CLI needed
- ⚡ AMP & GPU Optimized — Mixed precision training with CUDA acceleration out of the box
- 📊 Real-time Monitoring — Live training logs streamed to the web UI
- 📈 Training Visualization — Auto-generates loss and Dice score curves after training
- 🎯 Interactive Inference — Upload an image, select a checkpoint, and visualize segmentation results instantly
Pytorch-UNet/
├── app.py # Gradio Web UI entry point
├── train.py # Training script (CLI + subprocess backend)
├── predict.py # CLI prediction script
├── evaluate.py # Model evaluation (Dice score)
├── plot_training.py # Training metrics visualization (Loss / Dice curves)
├── start.bat # One-click launcher for Windows
├── hubconf.py # PyTorch Hub integration
├── Dockerfile # Docker support
├── requirements.txt # Python dependencies
├── unet/
│ ├── unet_model.py # U-Net model definition
│ └── unet_parts.py # Building blocks (DoubleConv, Down, Up, etc.)
├── utils/
│ ├── data_loading.py # Dataset classes (BasicDataset, CarvanaDataset)
│ ├── dice_score.py # Dice loss & metric
│ └── utils.py # Helper utilities
├── data/
│ ├── imgs/ # Input images
│ └── masks/ # Ground truth masks
├── doc/ # Documentation & training curve plots
└── checkpoints/ # Saved model weights (.pth)
git clone https://github.com/momo325/Pytorch-UNet.git
cd Pytorch-UNet
# Create virtual environment
python -m venv .venv
# Activate (Windows)
.venv\Scripts\activate
# Activate (Linux/macOS)
source .venv/bin/activate
# Install dependencies (PyTorch must be installed separately)
pip install -r requirements.txtNote: Install PyTorch with CUDA support before running
pip install -r requirements.txt.
Place your dataset in the following structure:
data/
├── imgs/ # RGB input images
│ ├── 001.png
│ ├── 002.png
│ └── ...
└── masks/ # Corresponding segmentation masks
├── 001.png
├── 002.png
└── ...
- Each image and its mask must share the same filename (stem)
- Masks should use distinct pixel values for each class (e.g.,
[0,0,0]for background,[120,200,20]for plant)
Option A: One-click (Windows)
Double-click start.bat in the project root. It automatically activates the virtual environment and opens the browser.
Option B: Command line
python app.pyOpen http://127.0.0.1:7860 in your browser.
The Gradio interface provides three tabs:
| Tab | Function |
|---|---|
| ⚙️ Training | Configure hyperparameters, start/stop training, view real-time console logs |
| 🖼️ Inference | Upload an image, select a checkpoint, adjust threshold, visualize segmentation output |
| 📂 Dataset Explorer | Randomly sample image-mask pairs to verify data alignment |
Web training workflow:
- Set hyperparameters (Epochs, Batch Size, Learning Rate, Scale, AMP toggle) in the Training tab
- Click "Start Training" — the right panel shows live training logs with progress bars, loss values, and Dice scores
- After training, model weights are automatically saved to
checkpoints/ - Switch to the Inference tab, select the newly generated checkpoint, and test on your images
python train.py --epochs 5 --batch-size 4 --learning-rate 1e-5 --scale 0.125 --amp| Argument | Default | Description |
|---|---|---|
--epochs |
5 | Number of training epochs |
--batch-size |
1 | Batch size |
--learning-rate |
1e-5 | Learning rate |
--scale |
0.5 | Image downscale factor (reduce if running out of VRAM) |
--validation |
10.0 | Validation split percentage (0-100) |
--amp |
off | Enable mixed precision (recommended) |
--classes |
2 | Number of output classes |
--bilinear |
off | Use bilinear upsampling instead of transposed conv |
--load |
— | Resume from a .pth checkpoint |
After training, training_log.json is saved to the project root and curve plots are generated in doc/.
Plots are auto-generated after training. To regenerate manually:
python plot_training.pyOutput:
doc/training_loss_curve.png— Training loss vs. stepdoc/validation_dice_curve.png— Validation Dice score vs. step
# Single image
python predict.py -i image.jpg -o output.jpg
# Multiple images with visualization
python predict.py -i image1.jpg image2.jpg --viz --no-savedocker build -t unet-plant .
docker run --rm --gpus all -p 7860:7860 unet-plant python app.pyCompared to the original milesial/Pytorch-UNet:
| Change | Details |
|---|---|
| Gradio Web UI | Added app.py with training control, inference testing, and dataset exploration |
| Training visualization | Added plot_training.py, auto-generates Loss / Dice curves after training |
| Windows one-click launch | Added start.bat for double-click startup |
| Plant segmentation defaults | n_classes=2, hardcoded mask values [0,0,0] / [120,200,20] in data_loading.py |
| DataLoader optimization | Reduced num_workers to 2 for Windows compatibility with high-res images |
| W&B disabled by default | wandb.init(mode='disabled') to avoid login prompts during local development |
- Original U-Net paper: Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation
- Upstream repository: milesial/Pytorch-UNet
This project is distributed under the GPL-3.0 License. See LICENSE for details.