Skip to content

momo325/Pytorch-UNet

 
 

Repository files navigation

U-Net Plant Segmentation

中文版 README

A customized fork of milesial/Pytorch-UNet for plant semantic segmentation, featuring an integrated Gradio Web UI for training control, real-time monitoring, inference testing, and dataset exploration — all from the browser.

✨ Highlights

  • 🌿 Plant Segmentation — Pre-configured for 2-class (background + plant) segmentation
  • 🖥️ Gradio Web Interface — Train, predict, and explore data via a browser-based dashboard, no CLI needed
  • AMP & GPU Optimized — Mixed precision training with CUDA acceleration out of the box
  • 📊 Real-time Monitoring — Live training logs streamed to the web UI
  • 📈 Training Visualization — Auto-generates loss and Dice score curves after training
  • 🎯 Interactive Inference — Upload an image, select a checkpoint, and visualize segmentation results instantly

📁 Project Structure

Pytorch-UNet/
├── app.py                 # Gradio Web UI entry point
├── train.py               # Training script (CLI + subprocess backend)
├── predict.py             # CLI prediction script
├── evaluate.py            # Model evaluation (Dice score)
├── plot_training.py       # Training metrics visualization (Loss / Dice curves)
├── start.bat              # One-click launcher for Windows
├── hubconf.py             # PyTorch Hub integration
├── Dockerfile             # Docker support
├── requirements.txt       # Python dependencies
├── unet/
│   ├── unet_model.py      # U-Net model definition
│   └── unet_parts.py      # Building blocks (DoubleConv, Down, Up, etc.)
├── utils/
│   ├── data_loading.py    # Dataset classes (BasicDataset, CarvanaDataset)
│   ├── dice_score.py      # Dice loss & metric
│   └── utils.py           # Helper utilities
├── data/
│   ├── imgs/              # Input images
│   └── masks/             # Ground truth masks
├── doc/                   # Documentation & training curve plots
└── checkpoints/           # Saved model weights (.pth)

🚀 Quick Start

1. Clone & Setup

git clone https://github.com/momo325/Pytorch-UNet.git
cd Pytorch-UNet

# Create virtual environment
python -m venv .venv

# Activate (Windows)
.venv\Scripts\activate

# Activate (Linux/macOS)
source .venv/bin/activate

# Install dependencies (PyTorch must be installed separately)
pip install -r requirements.txt

Note: Install PyTorch with CUDA support before running pip install -r requirements.txt.

2. Prepare Data

Place your dataset in the following structure:

data/
├── imgs/        # RGB input images
│   ├── 001.png
│   ├── 002.png
│   └── ...
└── masks/       # Corresponding segmentation masks
    ├── 001.png
    ├── 002.png
    └── ...
  • Each image and its mask must share the same filename (stem)
  • Masks should use distinct pixel values for each class (e.g., [0,0,0] for background, [120,200,20] for plant)

3. Launch the Web UI

Option A: One-click (Windows)

Double-click start.bat in the project root. It automatically activates the virtual environment and opens the browser.

Option B: Command line

python app.py

Open http://127.0.0.1:7860 in your browser.

📖 Usage

Web UI (Recommended)

The Gradio interface provides three tabs:

Tab Function
⚙️ Training Configure hyperparameters, start/stop training, view real-time console logs
🖼️ Inference Upload an image, select a checkpoint, adjust threshold, visualize segmentation output
📂 Dataset Explorer Randomly sample image-mask pairs to verify data alignment

Web training workflow:

  1. Set hyperparameters (Epochs, Batch Size, Learning Rate, Scale, AMP toggle) in the Training tab
  2. Click "Start Training" — the right panel shows live training logs with progress bars, loss values, and Dice scores
  3. After training, model weights are automatically saved to checkpoints/
  4. Switch to the Inference tab, select the newly generated checkpoint, and test on your images

CLI Training

python train.py --epochs 5 --batch-size 4 --learning-rate 1e-5 --scale 0.125 --amp
Argument Default Description
--epochs 5 Number of training epochs
--batch-size 1 Batch size
--learning-rate 1e-5 Learning rate
--scale 0.5 Image downscale factor (reduce if running out of VRAM)
--validation 10.0 Validation split percentage (0-100)
--amp off Enable mixed precision (recommended)
--classes 2 Number of output classes
--bilinear off Use bilinear upsampling instead of transposed conv
--load Resume from a .pth checkpoint

After training, training_log.json is saved to the project root and curve plots are generated in doc/.

Training Metrics Visualization

Plots are auto-generated after training. To regenerate manually:

python plot_training.py

Output:

  • doc/training_loss_curve.png — Training loss vs. step
  • doc/validation_dice_curve.png — Validation Dice score vs. step

CLI Prediction

# Single image
python predict.py -i image.jpg -o output.jpg

# Multiple images with visualization
python predict.py -i image1.jpg image2.jpg --viz --no-save

Docker

docker build -t unet-plant .
docker run --rm --gpus all -p 7860:7860 unet-plant python app.py

🔧 Key Modifications from Upstream

Compared to the original milesial/Pytorch-UNet:

Change Details
Gradio Web UI Added app.py with training control, inference testing, and dataset exploration
Training visualization Added plot_training.py, auto-generates Loss / Dice curves after training
Windows one-click launch Added start.bat for double-click startup
Plant segmentation defaults n_classes=2, hardcoded mask values [0,0,0] / [120,200,20] in data_loading.py
DataLoader optimization Reduced num_workers to 2 for Windows compatibility with high-res images
W&B disabled by default wandb.init(mode='disabled') to avoid login prompts during local development

📄 References

📜 License

This project is distributed under the GPL-3.0 License. See LICENSE for details.

About

PyTorch implementation of the U-Net for image semantic segmentation with high quality images

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages

  • Python 95.1%
  • Batchfile 3.1%
  • Shell 1.3%
  • Dockerfile 0.5%