Skip to content

zerunniu/WaSeCom

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WaSeCom: Wasserstein Distributionally Robust Semantic Communication

Introduction

This is a clean implementation of the WaSeCom (Wasserstein Distributionally Robust Wireless Semantic Communication) algorithm. WaSeCom is a bilevel optimization framework for distributionally robust semantic communication, designed to enhance the robustness of wireless semantic communication systems against semantic noise and channel noise.

This version removes traditional DAE training methods, with WaSeCom as the sole training approach, resulting in a clearer and more concise codebase.

Core Features

  • Bilevel Optimization Framework: Simultaneously handles uncertainty at both semantic and channel layers
  • Wasserstein Distributional Robustness: Provides theoretically guaranteed robustness bounds
  • Model Agnostic: Compatible with existing architectures like ViT/Transformer
  • End-to-End Training: Supports joint optimization from semantic encoder to channel decoder
  • Adversarial Training: Optional FGSM adversarial sample augmentation

Project Structure

wasecom_baseline/
├── models/
│   ├── wasecom_loss.py          # WaSeCom loss function implementation
│   ├── lit_wasecom.py           # WaSeCom Lightning training module
│   ├── dae_vit.py               # Dual-branch DAEViT model
│   ├── encoder.py               # ViT encoder
│   ├── decoder.py               # ViT decoder
│   └── transmitter.py           # Channel encoder/decoder
├── config/
│   ├── conf.py                  # Configuration parameters
│   ├── model_cfg.yaml           # Model configuration
│   └── data_cfg.yaml            # Data configuration
├── data/                        # Data loading modules
├── utils/                       # Utility functions
├── train.py                     # Training script
└── README.md                    # This document

Algorithm Principles

Bilevel Optimization Formulation

Inner Problem (Semantic Robustness):

min_{θ,φ} sup_{Q∈B_p(P̂,ρ)} E_Q[ℓ_s(x, g_φ(d_ω(h·c_ψ(f_θ(x)) + w)))]

Outer Problem (Channel Robustness):

min_{ψ,ω} sup_{R∈B_p(Q̂,σ)} E_R[ℓ_c(f_θ(x), d_ω(z))]

Where:

  • f_θ: Semantic encoder
  • g_φ: Semantic decoder
  • c_ψ: Channel encoder
  • d_ω: Channel decoder
  • B_p(·,·): Wasserstein ball
  • ρ, σ: Wasserstein radii

Quick Start

Install Dependencies

# Create environment with conda
conda env create -f environment.yml
conda activate wasecom

# Or install with pip
pip install torch lightning torchmetrics torchinfo termcolor pyyaml python-box

Basic Training

python train.py --model-name dae_vit_tiny --dataset cifar10 --rho-semantic 0.1 --sigma-channel 0.1 --num-epochs 5 --batch-size 64

Adversarial Training Enhancement

python train.py --model-name dae_vit_small --dataset cifar10 --adversarial-training --fgsm --fgsm-epsilon 0.1 --rho-semantic 0.15 --sigma-channel 0.12 --num-epochs 5

Multi-GPU Training

python train.py --model-name dae_vit_base --dataset cifar10 --devices 3 --accelerator gpu --rho-semantic 0.1 --sigma-channel 0.1 --num-epochs 5

Parameter Description

WaSeCom Core Parameters

Parameter Default Description
--rho-semantic 0.1 Semantic layer Wasserstein radius (ρ)
--sigma-channel 0.1 Channel layer Wasserstein radius (σ)
--lambda-reg 1.0 Regularization weight
--lambda-semantic 1.0 Semantic loss weight
--lambda-channel 0.1 Channel loss weight

Optimization Parameters

Parameter Default Description
--bilevel-steps 5 Number of bilevel optimization iterations
--semantic-steps 3 Number of inner semantic optimization steps
--channel-steps 2 Number of outer channel optimization steps
--dual-lr-ratio 0.1 Dual variable learning rate ratio
--lr 0.0005 Base learning rate
--weight-decay 0.005 Weight decay

Data and Model Parameters

Parameter Default Description
--model-name Required Model name (dae_vit_tiny/small/base/large/huge)
--dataset Required Dataset (cifar10/cifar100/mnist/imagenette)
--batch-size 64 Batch size
--num-epochs 100 Number of training epochs
--noise-factor 0.25 Noise factor

Channel Model Parameters

Parameter Default Description
--channel rayleigh Channel model (awgn/rayleigh/rician/none)
--snr 25.0 Training SNR (dB)

Adversarial Training Parameters

Parameter Default Description
--adversarial-training False Enable adversarial training
--fgsm False Enable FGSM attack
--fgsm-epsilon 0.1 FGSM perturbation strength

Experimental Results

Evaluation Metrics

The WaSeCom framework automatically logs the following metrics:

  1. Reconstruction Quality:

    • PSNR (Peak Signal-to-Noise Ratio)
    • SSIM (Structural Similarity Index)
    • MSE Loss
  2. Robustness Metrics:

    • Semantic regularization loss
    • Channel regularization loss
    • Performance retention across different SNRs
  3. Training Monitoring:

    • Semantic encoder loss
    • Channel encoder loss
    • Dual variable evolution

Performance Comparison

Compared to traditional DAE training, WaSeCom provides:

  • Better Semantic Robustness: More stable against input perturbations and distribution drift
  • Stronger Channel Adaptability: More consistent performance across different SNR and channel conditions
  • Theoretical Guarantees: Provides theoretical bounds for distributional robustness

Hyperparameter Tuning Recommendations

Initial Settings

For most scenarios, the following initial parameters are recommended:

--rho-semantic 0.1 \
--sigma-channel 0.1 \
--lambda-reg 1.0 \
--lambda-semantic 1.0 \
--lambda-channel 0.1 \
--bilevel-steps 5 \
--semantic-steps 3 \
--channel-steps 2

Tuning Strategy

  1. Wasserstein Radius Tuning:

    • rho-semantic too large: Over-regularization, performance degradation
    • rho-semantic too small: Insufficient robustness
    • Recommended range: [0.001, 0.5]
    • sigma-channel follows similar principles
  2. Loss Weight Tuning:

    • lambda-channel is typically one order of magnitude smaller than lambda-semantic
    • lambda-reg controls overall regularization strength
  3. Optimization Steps Tuning:

    • Increasing bilevel-steps improves convergence quality but increases computational cost
    • Usually semantic-steps > channel-steps works better

Troubleshooting

Out of Memory

  • Reduce --batch-size
  • Reduce --bilevel-steps
  • Use a smaller model (dae_vit_tiny)

Training Instability

  • Lower learning rate
  • Reduce Wasserstein radii
  • Enable gradient clipping (disabled by default)

Slow Convergence

  • Increase --bilevel-steps
  • Adjust --dual-lr-ratio
  • Check data preprocessing

Visualization and Logging

TensorBoard

# Start training
python train.py --logger-backend tensorboard [other parameters]

# View logs
tensorboard --logdir logs --bind_all

CSV Logging

python train.py --logger-backend csv [other parameters]

# Logs saved in logs/version_*/metrics.csv

Extensions and Customization

Adding New Loss Functions

Inherit from WaSeComLoss class in models/wasecom_loss.py:

class CustomWaSeComLoss(WaSeComLoss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Add custom parameters
        
    def semantic_regularizer(self, semantic_features, **kwargs):
        # Custom semantic regularization
        pass

Supporting New Model Architectures

  1. Ensure the model supports return_features=True parameter
  2. Return a dictionary containing required intermediate features
  3. Adjust parameter grouping in LitWaSeCom

Citation

If you use this implementation, please cite the original paper:

@article{wasecom2025,
  title={Distributionally Robust Wireless Semantic Communication with Large AI Models},
  journal={IEEE Journal on Selected Areas in Communications},
  year={2025}
}

License

This project follows the MIT License.

Contact

For questions or suggestions, please submit an Issue or Pull Request.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages