This is a clean implementation of the WaSeCom (Wasserstein Distributionally Robust Wireless Semantic Communication) algorithm. WaSeCom is a bilevel optimization framework for distributionally robust semantic communication, designed to enhance the robustness of wireless semantic communication systems against semantic noise and channel noise.
This version removes traditional DAE training methods, with WaSeCom as the sole training approach, resulting in a clearer and more concise codebase.
- ✅ Bilevel Optimization Framework: Simultaneously handles uncertainty at both semantic and channel layers
- ✅ Wasserstein Distributional Robustness: Provides theoretically guaranteed robustness bounds
- ✅ Model Agnostic: Compatible with existing architectures like ViT/Transformer
- ✅ End-to-End Training: Supports joint optimization from semantic encoder to channel decoder
- ✅ Adversarial Training: Optional FGSM adversarial sample augmentation
wasecom_baseline/
├── models/
│ ├── wasecom_loss.py # WaSeCom loss function implementation
│ ├── lit_wasecom.py # WaSeCom Lightning training module
│ ├── dae_vit.py # Dual-branch DAEViT model
│ ├── encoder.py # ViT encoder
│ ├── decoder.py # ViT decoder
│ └── transmitter.py # Channel encoder/decoder
├── config/
│ ├── conf.py # Configuration parameters
│ ├── model_cfg.yaml # Model configuration
│ └── data_cfg.yaml # Data configuration
├── data/ # Data loading modules
├── utils/ # Utility functions
├── train.py # Training script
└── README.md # This document
Inner Problem (Semantic Robustness):
min_{θ,φ} sup_{Q∈B_p(P̂,ρ)} E_Q[ℓ_s(x, g_φ(d_ω(h·c_ψ(f_θ(x)) + w)))]
Outer Problem (Channel Robustness):
min_{ψ,ω} sup_{R∈B_p(Q̂,σ)} E_R[ℓ_c(f_θ(x), d_ω(z))]
Where:
f_θ: Semantic encoderg_φ: Semantic decoderc_ψ: Channel encoderd_ω: Channel decoderB_p(·,·): Wasserstein ballρ, σ: Wasserstein radii
# Create environment with conda
conda env create -f environment.yml
conda activate wasecom
# Or install with pip
pip install torch lightning torchmetrics torchinfo termcolor pyyaml python-boxpython train.py --model-name dae_vit_tiny --dataset cifar10 --rho-semantic 0.1 --sigma-channel 0.1 --num-epochs 5 --batch-size 64python train.py --model-name dae_vit_small --dataset cifar10 --adversarial-training --fgsm --fgsm-epsilon 0.1 --rho-semantic 0.15 --sigma-channel 0.12 --num-epochs 5python train.py --model-name dae_vit_base --dataset cifar10 --devices 3 --accelerator gpu --rho-semantic 0.1 --sigma-channel 0.1 --num-epochs 5| Parameter | Default | Description |
|---|---|---|
--rho-semantic |
0.1 | Semantic layer Wasserstein radius (ρ) |
--sigma-channel |
0.1 | Channel layer Wasserstein radius (σ) |
--lambda-reg |
1.0 | Regularization weight |
--lambda-semantic |
1.0 | Semantic loss weight |
--lambda-channel |
0.1 | Channel loss weight |
| Parameter | Default | Description |
|---|---|---|
--bilevel-steps |
5 | Number of bilevel optimization iterations |
--semantic-steps |
3 | Number of inner semantic optimization steps |
--channel-steps |
2 | Number of outer channel optimization steps |
--dual-lr-ratio |
0.1 | Dual variable learning rate ratio |
--lr |
0.0005 | Base learning rate |
--weight-decay |
0.005 | Weight decay |
| Parameter | Default | Description |
|---|---|---|
--model-name |
Required | Model name (dae_vit_tiny/small/base/large/huge) |
--dataset |
Required | Dataset (cifar10/cifar100/mnist/imagenette) |
--batch-size |
64 | Batch size |
--num-epochs |
100 | Number of training epochs |
--noise-factor |
0.25 | Noise factor |
| Parameter | Default | Description |
|---|---|---|
--channel |
rayleigh | Channel model (awgn/rayleigh/rician/none) |
--snr |
25.0 | Training SNR (dB) |
| Parameter | Default | Description |
|---|---|---|
--adversarial-training |
False | Enable adversarial training |
--fgsm |
False | Enable FGSM attack |
--fgsm-epsilon |
0.1 | FGSM perturbation strength |
The WaSeCom framework automatically logs the following metrics:
-
Reconstruction Quality:
- PSNR (Peak Signal-to-Noise Ratio)
- SSIM (Structural Similarity Index)
- MSE Loss
-
Robustness Metrics:
- Semantic regularization loss
- Channel regularization loss
- Performance retention across different SNRs
-
Training Monitoring:
- Semantic encoder loss
- Channel encoder loss
- Dual variable evolution
Compared to traditional DAE training, WaSeCom provides:
- ✅ Better Semantic Robustness: More stable against input perturbations and distribution drift
- ✅ Stronger Channel Adaptability: More consistent performance across different SNR and channel conditions
- ✅ Theoretical Guarantees: Provides theoretical bounds for distributional robustness
For most scenarios, the following initial parameters are recommended:
--rho-semantic 0.1 \
--sigma-channel 0.1 \
--lambda-reg 1.0 \
--lambda-semantic 1.0 \
--lambda-channel 0.1 \
--bilevel-steps 5 \
--semantic-steps 3 \
--channel-steps 2-
Wasserstein Radius Tuning:
rho-semantictoo large: Over-regularization, performance degradationrho-semantictoo small: Insufficient robustness- Recommended range: [0.001, 0.5]
sigma-channelfollows similar principles
-
Loss Weight Tuning:
lambda-channelis typically one order of magnitude smaller thanlambda-semanticlambda-regcontrols overall regularization strength
-
Optimization Steps Tuning:
- Increasing
bilevel-stepsimproves convergence quality but increases computational cost - Usually
semantic-steps > channel-stepsworks better
- Increasing
- Reduce
--batch-size - Reduce
--bilevel-steps - Use a smaller model (
dae_vit_tiny)
- Lower learning rate
- Reduce Wasserstein radii
- Enable gradient clipping (disabled by default)
- Increase
--bilevel-steps - Adjust
--dual-lr-ratio - Check data preprocessing
# Start training
python train.py --logger-backend tensorboard [other parameters]
# View logs
tensorboard --logdir logs --bind_allpython train.py --logger-backend csv [other parameters]
# Logs saved in logs/version_*/metrics.csvInherit from WaSeComLoss class in models/wasecom_loss.py:
class CustomWaSeComLoss(WaSeComLoss):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Add custom parameters
def semantic_regularizer(self, semantic_features, **kwargs):
# Custom semantic regularization
pass- Ensure the model supports
return_features=Trueparameter - Return a dictionary containing required intermediate features
- Adjust parameter grouping in
LitWaSeCom
If you use this implementation, please cite the original paper:
@article{wasecom2025,
title={Distributionally Robust Wireless Semantic Communication with Large AI Models},
journal={IEEE Journal on Selected Areas in Communications},
year={2025}
}This project follows the MIT License.
For questions or suggestions, please submit an Issue or Pull Request.