From 54fdc400379da5c2692d629c3d9c0e6b66c9429f Mon Sep 17 00:00:00 2001 From: Cameron Mattson Date: Mon, 8 Jun 2026 13:21:29 -0600 Subject: [PATCH 1/8] This branch is for the cross zamirski wasserstein GAN GP models From 2d0814f3c112cdb4c3eac817f6b07ad05469345e Mon Sep 17 00:00:00 2001 From: Cameron Mattson Date: Mon, 8 Jun 2026 13:35:03 -0600 Subject: [PATCH 2/8] Add unconditional WGAN-GP training stack --- callbacks/batch_logging.py | 26 ++- .../WassersteinGeneratorCrossZamirskiLoss.py | 29 ++-- losses/WassersteinGradientPenaltyLoss.py | 30 ++-- models/unconditional_critic.py | 29 ++++ train.ipynb | 78 ++++++--- train.py | 83 ++++++--- trainers/WGANGPTrainer.py | 160 ++++++++++++++++++ trainers/utils/__init__.py | 6 + trainers/utils/wgan_gp.py | 151 +++++++++++++++++ 9 files changed, 516 insertions(+), 76 deletions(-) create mode 100644 models/unconditional_critic.py create mode 100644 trainers/WGANGPTrainer.py create mode 100644 trainers/utils/__init__.py create mode 100644 trainers/utils/wgan_gp.py diff --git a/callbacks/batch_logging.py b/callbacks/batch_logging.py index 24758fb..b711f51 100644 --- a/callbacks/batch_logging.py +++ b/callbacks/batch_logging.py @@ -40,13 +40,23 @@ def on_batch_end(self, hook_data: dict[str, Any]) -> None: self.global_batch_step += 1 return - batch_loss_components = hook_data.get("batch_loss_components", {}) - batch_loss_name = hook_data.get("batch_loss_name", "unknown_loss") - for loss_name, loss_value in batch_loss_components.items(): - mlflow.log_metric( - f"batch/train/{batch_loss_name}/{loss_name}", - loss_value, - step=self.global_batch_step, - ) + batch_loss_groups = hook_data.get("batch_loss_groups") + if batch_loss_groups is not None: + for group_name, batch_loss_components in batch_loss_groups.items(): + for loss_name, loss_value in batch_loss_components.items(): + mlflow.log_metric( + f"batch/train/{group_name}/{loss_name}", + loss_value, + step=self.global_batch_step, + ) + else: + batch_loss_components = hook_data.get("batch_loss_components", {}) + batch_loss_name = hook_data.get("batch_loss_name", "unknown_loss") + for loss_name, loss_value in batch_loss_components.items(): + mlflow.log_metric( + f"batch/train/{batch_loss_name}/{loss_name}", + loss_value, + step=self.global_batch_step, + ) self.global_batch_step += 1 diff --git a/losses/WassersteinGeneratorCrossZamirskiLoss.py b/losses/WassersteinGeneratorCrossZamirskiLoss.py index 2b7036a..9265228 100644 --- a/losses/WassersteinGeneratorCrossZamirskiLoss.py +++ b/losses/WassersteinGeneratorCrossZamirskiLoss.py @@ -1,13 +1,19 @@ import torch from torch import nn +from trainers.utils.wgan_gp import compute_generator_components + class WassersteinGeneratorCrossZamirskiLoss(nn.Module): """Generator loss combining L1 reconstruction and Wasserstein term.""" loss_name = "wasserstein_generator" # Stable MLflow namespace for this loss family. - def __init__(self, reconstruction_importance: float = 100.0) -> None: + def __init__( + self, + reconstruction_importance: float = 100.0, + use_adversarial_decay: bool = True, + ) -> None: """Configure weighting for the reconstruction component. Args: @@ -16,6 +22,7 @@ def __init__(self, reconstruction_importance: float = 100.0) -> None: super().__init__() self.reconstruction_importance = reconstruction_importance + self.use_adversarial_decay = use_adversarial_decay def forward( self, @@ -43,19 +50,15 @@ def forward( ValueError: If critic output batch size does not match predictions batch size. """ - if generated_predictions.shape != targets.shape: - raise ValueError("generated_predictions and targets must have the same shape.") - - batch_size = generated_predictions.size(0) - if fake_classification_outputs.size(0) != batch_size: - raise ValueError( - "fake_classification_outputs batch size must match generated_predictions." - ) - - reconstruction_loss = torch.nn.functional.l1_loss( - generated_predictions, targets, reduction="mean" + components = compute_generator_components( + fake_classification_outputs=fake_classification_outputs, + generated_predictions=generated_predictions, + targets=targets, + epoch=epoch, + use_adversarial_decay=self.use_adversarial_decay, ) - adversarial_term = torch.mean(fake_classification_outputs) / (epoch + 1) + reconstruction_loss = components["reconstruction_term"] + adversarial_term = components["adversarial_term"] total = self.reconstruction_importance * reconstruction_loss - adversarial_term return { "total": total, diff --git a/losses/WassersteinGradientPenaltyLoss.py b/losses/WassersteinGradientPenaltyLoss.py index 6c9371f..7541ca0 100644 --- a/losses/WassersteinGradientPenaltyLoss.py +++ b/losses/WassersteinGradientPenaltyLoss.py @@ -1,6 +1,8 @@ import torch from torch import nn +from trainers.utils.wgan_gp import compute_wgan_gp_terms + class WassersteinGradientPenaltyLoss(nn.Module): """WGAN-GP loss wrapper with trainer-compatible call signature.""" @@ -19,17 +21,17 @@ def __init__(self, gradient_penalty_importance: float = 10.0) -> None: def forward( self, - gradients: torch.Tensor, - real_classification_outputs: torch.Tensor, - fake_classification_outputs: torch.Tensor, + critic: nn.Module, + real_samples: torch.Tensor, + fake_samples: torch.Tensor, **kwargs, ) -> dict[str, torch.Tensor]: """Compute critic loss with Wasserstein distance and gradient penalty. Args: - gradients: Gradients of critic outputs w.r.t. interpolated inputs. - real_classification_outputs: Critic outputs for real samples. - fake_classification_outputs: Critic outputs for generated samples. + critic: Critic network used to score real/fake samples. + real_samples: Real target samples. + fake_samples: Generated samples. **kwargs: Additional unused loss arguments. Returns: @@ -41,17 +43,13 @@ def forward( ValueError: If fake critic output batch size does not match gradients batch size. """ - batch_size = gradients.size(0) - if real_classification_outputs.size(0) != batch_size: - raise ValueError("real_classification_outputs batch size must match gradients.") - if fake_classification_outputs.size(0) != batch_size: - raise ValueError("fake_classification_outputs batch size must match gradients.") - - gradients = gradients.view(batch_size, -1) - gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() - wasserstein = torch.mean(fake_classification_outputs) - torch.mean( - real_classification_outputs + components = compute_wgan_gp_terms( + critic=critic, + real_samples=real_samples, + fake_samples=fake_samples, ) + gradient_penalty = components["gradient_penalty_unweighted"] + wasserstein = components["wasserstein_term"] total = wasserstein + gradient_penalty * self.gradient_penalty_importance return { "total": total, diff --git a/models/unconditional_critic.py b/models/unconditional_critic.py new file mode 100644 index 0000000..776949e --- /dev/null +++ b/models/unconditional_critic.py @@ -0,0 +1,29 @@ +import torch +from torch import nn + + +class UnconditionalCritic(nn.Module): + """Small convolutional critic for single-channel image scoring.""" + + def __init__(self, in_channels: int = 1, base_channels: int = 64) -> None: + super().__init__() + self.layers = nn.Sequential( + nn.Conv2d(in_channels, base_channels, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_channels, base_channels * 2, kernel_size=4, stride=2, padding=1), + nn.InstanceNorm2d(base_channels * 2, affine=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d( + base_channels * 2, + base_channels * 4, + kernel_size=4, + stride=2, + padding=1, + ), + nn.InstanceNorm2d(base_channels * 4, affine=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(base_channels * 4, 1, kernel_size=4, stride=1, padding=1), + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.layers(inputs) diff --git a/train.ipynb b/train.ipynb index 61fb833..b817a95 100644 --- a/train.ipynb +++ b/train.ipynb @@ -17,6 +17,7 @@ "import numpy as np\n", "import optuna\n", "import torch\n", + "from models.unconditional_critic import UnconditionalCritic\n", "from models.convnext_unet.unext import ConvNeXtUNet\n", "\n", "from callbacks.CallbackPipeline import CallbackPipeline\n", @@ -29,14 +30,17 @@ ")\n", "from datasets.dataset_00.utils.ImagePostProcessor import ImagePostProcessor\n", "from datasets.dataset_00.utils.ImagePreProcessor import ImagePreProcessor\n", - "from losses.L1Loss import L1Loss\n", + "from losses.WassersteinGeneratorCrossZamirskiLoss import (\n", + " WassersteinGeneratorCrossZamirskiLoss,\n", + ")\n", + "from losses.WassersteinGradientPenaltyLoss import WassersteinGradientPenaltyLoss\n", "from metrics.L1 import L1\n", "from metrics.L2 import L2\n", "from metrics.PearsonCorrelation import PearsonCorrelation\n", "from metrics.PSNR import PSNR\n", "from metrics.SSIM import SSIM\n", "from splitters.HashSplitter import HashSplitter\n", - "from trainers.UNetTrainer import UNetTrainer" + "from trainers.WGANGPTrainer import WGANGPTrainer" ] }, { @@ -131,36 +135,57 @@ " hash_splitter: Any,\n", " dataset: Any,\n", " callbacks_args: dict[str, Any],\n", - " model_factory: Callable[[], torch.nn.Module],\n", + " generator_factory: Callable[[], torch.nn.Module],\n", + " discriminator_factory: Callable[[], torch.nn.Module],\n", " **trainer_kwargs,\n", " ):\n", " self.trainer = trainer\n", " self.hash_splitter = hash_splitter\n", " self.dataset = dataset\n", " self.callbacks_args = callbacks_args\n", - " self.model_factory = model_factory\n", + " self.generator_factory = generator_factory\n", + " self.discriminator_factory = discriminator_factory\n", " self.trainer_kwargs = trainer_kwargs\n", "\n", " def __call__(self, trial: optuna.trial.Trial):\n", - " # Let Optuna choose a mini-batch size and learning rate for this trial.\n", + " # Let Optuna choose core optimization and loss weights for this trial.\n", " batch_size = trial.suggest_int(\"batch_size\", 1, 8)\n", " lr = trial.suggest_float(\"lr\", 1e-5, 1e-3, log=True)\n", + " gradient_penalty_importance = trial.suggest_float(\n", + " \"gradient_penalty_importance\", 1.0, 20.0\n", + " )\n", + " reconstruction_importance = trial.suggest_float(\n", + " \"reconstruction_importance\", 10.0, 200.0\n", + " )\n", "\n", " # Rebuild train/val loaders at the chosen batch size while keeping deterministic splits.\n", " train_dataloader, val_dataloader, _ = self.hash_splitter(batch_size=batch_size)\n", " self.trainer_kwargs[\"train_dataloader\"] = train_dataloader\n", " self.trainer_kwargs[\"val_dataloader\"] = val_dataloader\n", "\n", - " model = self.model_factory()\n", - " self.trainer_kwargs[\"model\"] = model\n", + " generator = self.generator_factory()\n", + " discriminator = self.discriminator_factory()\n", + " self.trainer_kwargs[\"generator\"] = generator\n", + " self.trainer_kwargs[\"discriminator\"] = discriminator\n", "\n", - " optimizer_params = {\n", - " \"params\": model.parameters(),\n", + " generator_optimizer_params = {\n", + " \"params\": generator.parameters(),\n", + " \"lr\": lr,\n", + " \"betas\": (0.5, 0.999),\n", + " }\n", + " discriminator_optimizer_params = {\n", + " \"params\": discriminator.parameters(),\n", " \"lr\": lr,\n", " \"betas\": (0.5, 0.999),\n", " }\n", "\n", - " loss_trainer = L1Loss()\n", + " generator_loss = WassersteinGeneratorCrossZamirskiLoss(\n", + " reconstruction_importance=reconstruction_importance,\n", + " use_adversarial_decay=True,\n", + " )\n", + " discriminator_loss = WassersteinGradientPenaltyLoss(\n", + " gradient_penalty_importance=gradient_penalty_importance\n", + " )\n", " loss_callbacks = L1(device=device)\n", " metrics = [\n", " L2(device=device),\n", @@ -171,21 +196,30 @@ "\n", " # Use a nested MLflow run so each Optuna trial has its own metrics/artifacts.\n", " with mlflow.start_run(nested=True, run_name=f\"trial_{trial.number}\"):\n", - " optimizer = torch.optim.Adam(**optimizer_params)\n", - " self.trainer_kwargs[\"model_optimizer\"] = optimizer\n", + " generator_optimizer = torch.optim.Adam(**generator_optimizer_params)\n", + " discriminator_optimizer = torch.optim.Adam(**discriminator_optimizer_params)\n", + " self.trainer_kwargs[\"generator_optimizer\"] = generator_optimizer\n", + " self.trainer_kwargs[\"discriminator_optimizer\"] = discriminator_optimizer\n", "\n", - " opt_params = optimizer.param_groups[0].copy()\n", + " opt_params = generator_optimizer.param_groups[0].copy()\n", " del opt_params[\"params\"]\n", " mlflow.log_params({f\"optimizer_{k}\": v for k, v in opt_params.items()})\n", " mlflow.log_param(\"batch_size\", batch_size)\n", - " mlflow.set_tag(\"optimizer_class\", optimizer.__class__.__name__.lower())\n", + " mlflow.log_param(\"gradient_penalty_importance\", gradient_penalty_importance)\n", + " mlflow.log_param(\"reconstruction_importance\", reconstruction_importance)\n", + " mlflow.log_param(\"use_adversarial_decay\", True)\n", + " mlflow.set_tag(\"optimizer_class\", generator_optimizer.__class__.__name__.lower())\n", "\n", " self.trainer_kwargs[\"callbacks\"] = CallbackPipeline(\n", " **self.callbacks_args | {\"metrics\": metrics, \"loss\": loss_callbacks}\n", " )\n", "\n", " trainer_obj = self.trainer(\n", - " **self.trainer_kwargs | {\"model_loss\": loss_trainer}\n", + " **self.trainer_kwargs\n", + " | {\n", + " \"generator_loss\": generator_loss,\n", + " \"discriminator_loss\": discriminator_loss,\n", + " }\n", " )\n", " trainer_obj.train()\n", "\n", @@ -220,11 +254,14 @@ "mlflow.log_param(\"crop_size\", args.crop_size)\n", "\n", "description = \"\"\"\n", - "Optimization of a DAPI-to-Gold image-to-image translation model with:\n", + "Optimization of an unconditional WGAN-GP DAPI-to-Gold image-to-image translation model with:\n", "- ConvNeXtUNet Generator\n", + "- Unconditional convolutional discriminator\n", "- Single 2D crop input and single 2D crop target\n", "- Cache-backed filtered nucleus crops generated from the configured data directory\n", - "- L1 optimization objective with L2, PSNR, SSIM, and Pearson correlation metric logging\n", + "- Generator objective: reconstruction-weighted L1 plus epoch-decayed adversarial term\n", + "- Discriminator objective: Wasserstein loss with gradient penalty\n", + "- Validation uses L1, L2, PSNR, SSIM, and Pearson correlation metric logging\n", "\"\"\"\n", "mlflow.set_tag(\"mlflow.note.content\", description)\n", "\n", @@ -322,16 +359,19 @@ "outputs": [], "source": [ "optimization_manager = OptimizationManager(\n", - " trainer=UNetTrainer,\n", + " trainer=WGANGPTrainer,\n", " hash_splitter=hash_splitter,\n", " dataset=crop_image_dataset,\n", " callbacks_args=callbacks_args,\n", - " model_factory=lambda: ConvNeXtUNet(\n", + " generator_factory=lambda: ConvNeXtUNet(\n", " in_channels=1,\n", " out_channels=1,\n", " decoder_up_block=\"convt\",\n", " ),\n", + " discriminator_factory=lambda: UnconditionalCritic(in_channels=1),\n", " epochs=args.epochs,\n", + " image_postprocessor=image_postprocessor,\n", + " device=device,\n", " max_train_batches=max_train_batches,\n", ")\n", "\n", diff --git a/train.py b/train.py index 2796464..a6aaf85 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ import numpy as np import optuna import torch +from models.unconditional_critic import UnconditionalCritic from models.convnext_unet.unext import ConvNeXtUNet from callbacks.CallbackPipeline import CallbackPipeline @@ -21,14 +22,17 @@ ) from datasets.dataset_00.utils.ImagePostProcessor import ImagePostProcessor from datasets.dataset_00.utils.ImagePreProcessor import ImagePreProcessor -from losses.L1Loss import L1Loss +from losses.WassersteinGeneratorCrossZamirskiLoss import ( + WassersteinGeneratorCrossZamirskiLoss, +) +from losses.WassersteinGradientPenaltyLoss import WassersteinGradientPenaltyLoss from metrics.L1 import L1 from metrics.L2 import L2 from metrics.PearsonCorrelation import PearsonCorrelation from metrics.PSNR import PSNR from metrics.SSIM import SSIM from splitters.HashSplitter import HashSplitter -from trainers.UNetTrainer import UNetTrainer +from trainers.WGANGPTrainer import WGANGPTrainer @dataclass(frozen=True) class DatasetConfig: @@ -123,7 +127,8 @@ def __init__( hash_splitter: Any, dataset: Any, callbacks_args: dict[str, Any], - model_factory: Callable[[], torch.nn.Module], + generator_factory: Callable[[], torch.nn.Module], + discriminator_factory: Callable[[], torch.nn.Module], **trainer_kwargs, ): """Store dependencies for Optuna-driven training trials. @@ -133,7 +138,8 @@ def __init__( hash_splitter: Callable that returns train/val/test dataloaders. dataset: Dataset associated with the optimization run. callbacks_args: Static callback arguments reused across trials. - model_factory: Callable that creates a new model instance per trial. + generator_factory: Callable that creates a new generator instance per trial. + discriminator_factory: Callable that creates a new discriminator instance per trial. **trainer_kwargs: Shared trainer keyword arguments. """ @@ -141,7 +147,8 @@ def __init__( self.hash_splitter = hash_splitter self.dataset = dataset self.callbacks_args = callbacks_args - self.model_factory = model_factory + self.generator_factory = generator_factory + self.discriminator_factory = discriminator_factory self.trainer_kwargs = trainer_kwargs def __call__(self, trial: optuna.trial.Trial): @@ -154,25 +161,44 @@ def __call__(self, trial: optuna.trial.Trial): Best validation loss reported by the trainer. """ - # Let Optuna choose a mini-batch size and learning rate for this trial. + # Let Optuna choose core optimization and loss weights for this trial. batch_size = trial.suggest_int("batch_size", 1, 8) lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True) + gradient_penalty_importance = trial.suggest_float( + "gradient_penalty_importance", 1.0, 20.0 + ) + reconstruction_importance = trial.suggest_float( + "reconstruction_importance", 10.0, 200.0 + ) # Rebuild train/val loaders at the chosen batch size while keeping deterministic splits. train_dataloader, val_dataloader, _ = self.hash_splitter(batch_size=batch_size) self.trainer_kwargs["train_dataloader"] = train_dataloader self.trainer_kwargs["val_dataloader"] = val_dataloader - model = self.model_factory() - self.trainer_kwargs["model"] = model + generator = self.generator_factory() + discriminator = self.discriminator_factory() + self.trainer_kwargs["generator"] = generator + self.trainer_kwargs["discriminator"] = discriminator - optimizer_params = { - "params": model.parameters(), + generator_optimizer_params = { + "params": generator.parameters(), + "lr": lr, + "betas": (0.5, 0.999), + } + discriminator_optimizer_params = { + "params": discriminator.parameters(), "lr": lr, "betas": (0.5, 0.999), } - loss_trainer = L1Loss() + generator_loss = WassersteinGeneratorCrossZamirskiLoss( + reconstruction_importance=reconstruction_importance, + use_adversarial_decay=True, + ) + discriminator_loss = WassersteinGradientPenaltyLoss( + gradient_penalty_importance=gradient_penalty_importance + ) loss_callbacks = L1(device=device) metrics = [ L2(device=device), @@ -183,21 +209,32 @@ def __call__(self, trial: optuna.trial.Trial): # Use a nested MLflow run so each Optuna trial has its own metrics/artifacts. with mlflow.start_run(nested=True, run_name=f"trial_{trial.number}"): - optimizer = torch.optim.Adam(**optimizer_params) - self.trainer_kwargs["model_optimizer"] = optimizer + generator_optimizer = torch.optim.Adam(**generator_optimizer_params) + discriminator_optimizer = torch.optim.Adam(**discriminator_optimizer_params) + self.trainer_kwargs["generator_optimizer"] = generator_optimizer + self.trainer_kwargs["discriminator_optimizer"] = discriminator_optimizer - opt_params = optimizer.param_groups[0].copy() + opt_params = generator_optimizer.param_groups[0].copy() del opt_params["params"] mlflow.log_params({f"optimizer_{k}": v for k, v in opt_params.items()}) mlflow.log_param("batch_size", batch_size) - mlflow.set_tag("optimizer_class", optimizer.__class__.__name__.lower()) + mlflow.log_param("gradient_penalty_importance", gradient_penalty_importance) + mlflow.log_param("reconstruction_importance", reconstruction_importance) + mlflow.log_param("use_adversarial_decay", True) + mlflow.set_tag( + "optimizer_class", generator_optimizer.__class__.__name__.lower() + ) self.trainer_kwargs["callbacks"] = CallbackPipeline( **self.callbacks_args | {"metrics": metrics, "loss": loss_callbacks} ) trainer_obj = self.trainer( - **self.trainer_kwargs | {"model_loss": loss_trainer} + **self.trainer_kwargs + | { + "generator_loss": generator_loss, + "discriminator_loss": discriminator_loss, + } ) trainer_obj.train() @@ -226,11 +263,14 @@ def __call__(self, trial: optuna.trial.Trial): mlflow.log_param("crop_size", args.crop_size) description = """ -Optimization of a DAPI-to-Gold image-to-image translation model with: +Optimization of an unconditional WGAN-GP DAPI-to-Gold image-to-image translation model with: - ConvNeXtUNet Generator +- Unconditional convolutional discriminator - Single 2D crop input and single 2D crop target - Cache-backed filtered nucleus crops generated from the configured data directory -- L1 optimization objective with L2, PSNR, SSIM, and Pearson correlation metric logging +- Generator objective: reconstruction-weighted L1 plus epoch-decayed adversarial term +- Discriminator objective: Wasserstein loss with gradient penalty +- Validation uses L1, L2, PSNR, SSIM, and Pearson correlation metric logging """ mlflow.set_tag("mlflow.note.content", description) @@ -323,16 +363,19 @@ def __call__(self, trial: optuna.trial.Trial): } optimization_manager = OptimizationManager( - trainer=UNetTrainer, + trainer=WGANGPTrainer, hash_splitter=hash_splitter, dataset=crop_image_dataset, callbacks_args=callbacks_args, - model_factory=lambda: ConvNeXtUNet( + generator_factory=lambda: ConvNeXtUNet( in_channels=1, out_channels=1, decoder_up_block="convt", ), + discriminator_factory=lambda: UnconditionalCritic(in_channels=1), epochs=args.epochs, + image_postprocessor=image_postprocessor, + device=device, max_train_batches=max_train_batches, ) diff --git a/trainers/WGANGPTrainer.py b/trainers/WGANGPTrainer.py new file mode 100644 index 0000000..76ab7de --- /dev/null +++ b/trainers/WGANGPTrainer.py @@ -0,0 +1,160 @@ +from typing import Any, Union + +import torch +from torch.utils.data import DataLoader + + +class WGANGPTrainer: + """Train a generator-discriminator pair with unconditional WGAN-GP.""" + + def __init__( + self, + generator: torch.nn.Module, + discriminator: torch.nn.Module, + generator_optimizer: torch.optim.Optimizer, + discriminator_optimizer: torch.optim.Optimizer, + generator_loss: torch.nn.Module, + discriminator_loss: torch.nn.Module, + train_dataloader: Union[torch.utils.data.Dataset, DataLoader], + val_dataloader: Union[torch.utils.data.Dataset, DataLoader], + callbacks: Any, + image_postprocessor: Any = lambda x: x, + epochs: int = 10, + device: Union[str, torch.device] = "cuda", + max_train_batches: int | None = None, + ) -> None: + self.generator = generator + self.discriminator = discriminator + self.generator_optimizer = generator_optimizer + self.discriminator_optimizer = discriminator_optimizer + self.generator_loss = generator_loss + self.discriminator_loss = discriminator_loss + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.callbacks = callbacks + self.image_postprocessor = image_postprocessor + self.epochs = epochs + self.device = ( + device if isinstance(device, torch.device) else torch.device(device) + ) + self.max_train_batches = max_train_batches + + @property + def best_loss_value(self): + return self.callbacks.best_loss_value + + @staticmethod + def _detach_components( + batch_loss_components: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, float]]: + if not isinstance(batch_loss_components, dict): + raise TypeError("Loss modules must return dict[str, torch.Tensor].") + if "total" not in batch_loss_components: + raise ValueError("Loss module output must include a 'total' key.") + + loss = batch_loss_components["total"] + if not torch.is_tensor(loss) or loss.ndim != 0: + raise ValueError("Loss module output 'total' must be a scalar torch.Tensor.") + + detached_loss_components: dict[str, float] = {} + for name, value in batch_loss_components.items(): + if not torch.is_tensor(value) or value.ndim != 0: + raise ValueError(f"Loss module output '{name}' must be a scalar torch.Tensor.") + detached_loss_components[name] = value.detach().item() + + return loss, detached_loss_components + + def train(self) -> None: + train_data = {"continue_training": True, "device": self.device} + + self.generator = self.generator.to(self.device) + self.discriminator = self.discriminator.to(self.device) + + for epoch in range(self.epochs): + train_data["epoch"] = epoch + train_data["callback_hook"] = "on_epoch_start" + self.callbacks(**train_data) + + self.generator.train() + self.discriminator.train() + + for batch, batch_data in enumerate(self.train_dataloader): + train_data["callback_hook"] = "on_batch_start" + train_data["batch"] = batch + train_data["batch_data"] = batch_data + self.callbacks(**train_data) + + inputs = batch_data["input"].to(self.device) + targets = batch_data["target"].to(self.device) + + fake_targets_for_discriminator = self.image_postprocessor( + self.generator(inputs) + ) + discriminator_outputs = self.discriminator_loss( + critic=self.discriminator, + real_samples=targets, + fake_samples=fake_targets_for_discriminator.detach(), + ) + discriminator_loss, discriminator_components = self._detach_components( + discriminator_outputs + ) + + self.discriminator_optimizer.zero_grad() + discriminator_loss.backward() + self.discriminator_optimizer.step() + + generated_predictions = self.image_postprocessor(self.generator(inputs)) + fake_classification_outputs = self.discriminator(generated_predictions) + generator_outputs = self.generator_loss( + fake_classification_outputs=fake_classification_outputs, + generated_predictions=generated_predictions, + targets=targets, + epoch=epoch, + loss_mask=batch_data.get("loss_mask"), + ) + generator_loss, generator_components = self._detach_components( + generator_outputs + ) + + self.generator_optimizer.zero_grad() + generator_loss.backward() + self.generator_optimizer.step() + + train_data["generated_predictions"] = generated_predictions + train_data["model_update_loss"] = generator_loss + train_data["batch_loss_components"] = generator_components + train_data["batch_loss_name"] = getattr( + self.generator_loss, + "loss_name", + self.generator_loss.__class__.__name__, + ) + train_data["batch_loss_groups"] = { + "generator": generator_components, + "discriminator": discriminator_components, + } + train_data["model"] = self.generator + train_data["callback_hook"] = "on_batch_end" + + self.callbacks(**train_data) + + if not train_data["continue_training"]: + break + + if ( + self.max_train_batches is not None + and (batch + 1) >= self.max_train_batches + ): + break + + train_data["callback_hook"] = "on_epoch_end" + train_data["continue_training"] = self.callbacks( + train_dataloader=self.train_dataloader, + val_dataloader=self.val_dataloader, + **train_data, + ) + + if not train_data["continue_training"]: + break + + def __call__(self) -> None: + self.train() diff --git a/trainers/utils/__init__.py b/trainers/utils/__init__.py new file mode 100644 index 0000000..a50daea --- /dev/null +++ b/trainers/utils/__init__.py @@ -0,0 +1,6 @@ +from .wgan_gp import ( + compute_generator_components, + compute_wgan_gp_components, + compute_wgan_gp_primitives, + compute_wgan_gp_terms, +) diff --git a/trainers/utils/wgan_gp.py b/trainers/utils/wgan_gp.py new file mode 100644 index 0000000..0bbc3f6 --- /dev/null +++ b/trainers/utils/wgan_gp.py @@ -0,0 +1,151 @@ +from collections.abc import Callable + +import torch + + +def compute_wgan_gp_primitives( + critic: Callable[[torch.Tensor], torch.Tensor], + real_samples: torch.Tensor, + fake_samples: torch.Tensor, +) -> dict[str, torch.Tensor]: + """Compute interpolation and critic outputs for WGAN-GP. + + Args: + critic: Critic/discriminator callable that maps samples to scores. + real_samples: Batch of real samples. + fake_samples: Batch of generated samples. + + Returns: + Dictionary containing: + - gradients: d(critic(interpolated))/d(interpolated) per sample. + - real_classification_outputs: Critic outputs for real samples. + - fake_classification_outputs: Critic outputs for fake samples. + + Raises: + ValueError: If real and fake batch sizes differ. + ValueError: If critic outputs are empty. + """ + + if real_samples.size(0) != fake_samples.size(0): + raise ValueError("real_samples and fake_samples must have matching batch size.") + + batch_size = real_samples.size(0) + alpha_shape = (batch_size,) + (1,) * (real_samples.dim() - 1) + alpha = torch.rand(alpha_shape, device=real_samples.device, dtype=real_samples.dtype) + + interpolated_samples = ( + alpha * real_samples + (1.0 - alpha) * fake_samples + ).requires_grad_(True) + + interpolated_outputs = critic(interpolated_samples) + real_classification_outputs = critic(real_samples) + fake_classification_outputs = critic(fake_samples) + + if interpolated_outputs.numel() == 0: + raise ValueError("critic output for interpolated samples must not be empty.") + + grad_outputs = torch.ones_like(interpolated_outputs) + gradients = torch.autograd.grad( + outputs=interpolated_outputs, + inputs=interpolated_samples, + grad_outputs=grad_outputs, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + return { + "gradients": gradients, + "real_classification_outputs": real_classification_outputs, + "fake_classification_outputs": fake_classification_outputs, + } + + +def compute_wgan_gp_components( + gradients: torch.Tensor, + real_classification_outputs: torch.Tensor, + fake_classification_outputs: torch.Tensor, +) -> dict[str, torch.Tensor]: + """Compute unweighted WGAN-GP components from precomputed primitives. + + Args: + gradients: d(critic(interpolated))/d(interpolated) per sample. + real_classification_outputs: Critic outputs for real samples. + fake_classification_outputs: Critic outputs for fake samples. + + Returns: + Dictionary containing: + - wasserstein_term: mean(fake) - mean(real). + - gradient_penalty_unweighted: mean((||grad||_2 - 1)^2). + + Raises: + ValueError: If critic output batch sizes do not match gradients batch size. + """ + + batch_size = gradients.size(0) + if real_classification_outputs.size(0) != batch_size: + raise ValueError("real_classification_outputs batch size must match gradients.") + if fake_classification_outputs.size(0) != batch_size: + raise ValueError("fake_classification_outputs batch size must match gradients.") + + flat_gradients = gradients.view(batch_size, -1) + gradient_penalty_unweighted = ((flat_gradients.norm(2, dim=1) - 1) ** 2).mean() + wasserstein_term = torch.mean(fake_classification_outputs) - torch.mean( + real_classification_outputs + ) + + return { + "wasserstein_term": wasserstein_term, + "gradient_penalty_unweighted": gradient_penalty_unweighted, + } + + +def compute_wgan_gp_terms( + critic: Callable[[torch.Tensor], torch.Tensor], + real_samples: torch.Tensor, + fake_samples: torch.Tensor, +) -> dict[str, torch.Tensor]: + """Compute WGAN-GP primitives and derived unweighted components. + + This convenience wrapper preserves a one-call API and combines + ``compute_wgan_gp_primitives`` and ``compute_wgan_gp_components``. + """ + + primitives = compute_wgan_gp_primitives( + critic=critic, + real_samples=real_samples, + fake_samples=fake_samples, + ) + components = compute_wgan_gp_components(**primitives) + return primitives | components + + +def compute_generator_components( + fake_classification_outputs: torch.Tensor, + generated_predictions: torch.Tensor, + targets: torch.Tensor, + epoch: int = 0, + use_adversarial_decay: bool = True, +) -> dict[str, torch.Tensor]: + """Compute unconditional WGAN generator terms with optional epoch decay.""" + + if generated_predictions.shape != targets.shape: + raise ValueError("generated_predictions and targets must have the same shape.") + + batch_size = generated_predictions.size(0) + if fake_classification_outputs.size(0) != batch_size: + raise ValueError( + "fake_classification_outputs batch size must match generated_predictions." + ) + + reconstruction_term = torch.nn.functional.l1_loss( + generated_predictions, targets, reduction="mean" + ) + adversarial_term = torch.mean(fake_classification_outputs) + if use_adversarial_decay: + adversarial_term = adversarial_term / (epoch + 1) + + return { + "reconstruction_term": reconstruction_term, + "adversarial_term": adversarial_term, + } From 8463f2737a7c29f0c24c886719d89e13c3c290a4 Mon Sep 17 00:00:00 2001 From: Cameron Mattson Date: Mon, 8 Jun 2026 13:35:03 -0600 Subject: [PATCH 3/8] Remove obsolete L1 trainer path --- losses/L1Loss.py | 33 -------- losses/__init__.py | 1 - trainers/UNetTrainer.py | 179 ---------------------------------------- 3 files changed, 213 deletions(-) delete mode 100644 losses/L1Loss.py delete mode 100644 trainers/UNetTrainer.py diff --git a/losses/L1Loss.py b/losses/L1Loss.py deleted file mode 100644 index 6b5dd63..0000000 --- a/losses/L1Loss.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -from torch import nn - - -class L1Loss(nn.Module): - """Training loss wrapper with trainer-compatible call signature.""" - - loss_name = "l1" # Stable MLflow metric namespace for this loss family. - - def forward( - self, - generated_predictions: torch.Tensor, - targets: torch.Tensor, - **kwargs, - ) -> dict[str, torch.Tensor]: - """Compute mean L1 training loss for one batch. - - Args: - generated_predictions: Model predictions. - targets: Ground-truth targets with matching shape. - **kwargs: Additional unused loss arguments. - - Returns: - Dictionary with scalar mean absolute error under ``total``. - - Raises: - ValueError: If prediction and target shapes differ. - """ - - if generated_predictions.shape != targets.shape: - raise ValueError("The generated predictions and targets must be the same shape.") - total = torch.nn.functional.l1_loss(generated_predictions, targets, reduction="mean") - return {"total": total} diff --git a/losses/__init__.py b/losses/__init__.py index 389b832..297d26a 100644 --- a/losses/__init__.py +++ b/losses/__init__.py @@ -1,3 +1,2 @@ -from .L1Loss import L1Loss from .WassersteinGeneratorCrossZamirskiLoss import WassersteinGeneratorCrossZamirskiLoss from .WassersteinGradientPenaltyLoss import WassersteinGradientPenaltyLoss diff --git a/trainers/UNetTrainer.py b/trainers/UNetTrainer.py deleted file mode 100644 index f1906bd..0000000 --- a/trainers/UNetTrainer.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import Any, Union - -import torch -from torch import nn -from torch.utils.data import DataLoader - - -class UNetTrainer: - """ - Orchestrates training and evaluation of image-to-image translation models. - """ - - def __init__( - self, - model: torch.nn.Module, - model_optimizer: torch.optim.Optimizer, - model_loss: nn.Module, - train_dataloader: Union[torch.utils.data.Dataset, DataLoader], - val_dataloader: Union[torch.utils.data.Dataset, DataLoader], - callbacks: Any, - image_postprocessor: Any = lambda x: x, - epochs: int = 10, - device: Union[str, torch.device] = "cuda", - use_amp: bool = True, - max_train_batches: int | None = None, - ) -> None: - """Initialize trainer state and optional AMP scaler. - - Args: - model: Trainable image-to-image model. - model_optimizer: Optimizer used for parameter updates. - model_loss: Loss module used for backpropagation. Must return - ``dict[str, torch.Tensor]`` with required scalar key - ``"total"`` and optional additional scalar components. - train_dataloader: Training dataloader. - val_dataloader: Validation dataloader used by callbacks. - callbacks: Callback dispatcher used for hooks and logging. - image_postprocessor: Postprocessor applied to model outputs. - epochs: Maximum number of training epochs. - device: Target device for model and tensors. - use_amp: Whether to use automatic mixed precision. - max_train_batches: Optional cap on train batches per epoch. - """ - - self.model = model - self.model_optimizer = model_optimizer - self.model_loss = model_loss - self.train_dataloader = train_dataloader - self.val_dataloader = val_dataloader - self.callbacks = callbacks - self.image_postprocessor = image_postprocessor - self.epochs = epochs - self.device = ( - device if isinstance(device, torch.device) else torch.device(device) - ) - self.use_amp = use_amp # Automatic Mixed Precision (AMP) - self.max_train_batches = max_train_batches - # Stable loss identifier used to namespace batch metrics in MLflow. - self.loss_name = getattr(self.model_loss, "loss_name", self.model_loss.__class__.__name__) - - if self.use_amp: - if self.device.type == "cuda": - self.scaler = torch.amp.GradScaler("cuda") - else: - self.scaler = torch.amp.GradScaler("cpu") - else: - self.scaler = None - - @property - def best_loss_value(self): - """Expose best validation loss tracked by callbacks.""" - - return self.callbacks.best_loss_value - - def train(self) -> None: - """Run the training loop with callback hooks and optional early stopping. - - The trainer enforces a strict loss output contract: - ``model_loss(...) -> dict[str, torch.Tensor]`` with required scalar key - ``"total"`` used for optimization. - - At each batch end, callback hook data includes: - - ``model_update_loss``: Scalar tensor used for backward pass. - - ``batch_loss_name``: Stable loss identifier for metric naming. - - ``batch_loss_components``: Detached scalar loss components as - ``dict[str, float]`` for logging. - """ - - train_data = {} - train_data["continue_training"] = True - train_data["device"] = self.device - - self.model = self.model.to(self.device) - - for epoch in range(self.epochs): - train_data["epoch"] = epoch - train_data["callback_hook"] = "on_epoch_start" - self.callbacks(**train_data) - - self.model.train() - - for batch, batch_data in enumerate(self.train_dataloader): - train_data["callback_hook"] = "on_batch_start" - train_data["batch"] = batch - train_data["batch_data"] = batch_data - self.callbacks(**train_data) - - inputs = batch_data["input"].to(self.device) - targets = batch_data["target"].to(self.device) - - with torch.amp.autocast( - enabled=self.use_amp, device_type=self.device.type - ): - generated_predictions = self.image_postprocessor(self.model(inputs)) - batch_loss_components = self.model_loss( - targets=targets, - generated_predictions=generated_predictions, - loss_mask=batch_data.get("loss_mask"), - ) - - if not isinstance(batch_loss_components, dict): - raise TypeError("model_loss must return dict[str, torch.Tensor].") - if "total" not in batch_loss_components: - raise ValueError("model_loss output must include a 'total' key.") - - loss = batch_loss_components["total"] - if not torch.is_tensor(loss) or loss.ndim != 0: - raise ValueError("model_loss['total'] must be a scalar torch.Tensor.") - - detached_loss_components: dict[str, float] = {} - for name, value in batch_loss_components.items(): - if not torch.is_tensor(value) or value.ndim != 0: - raise ValueError( - f"model_loss['{name}'] must be a scalar torch.Tensor." - ) - detached_loss_components[name] = value.detach().item() - - train_data["generated_predictions"] = generated_predictions - train_data["model_update_loss"] = loss - train_data["batch_loss_components"] = detached_loss_components - train_data["batch_loss_name"] = self.loss_name - - self.model_optimizer.zero_grad() - if self.use_amp and self.scaler is not None: - self.scaler.scale(loss).backward() - self.scaler.step(self.model_optimizer) - self.scaler.update() - else: - loss.backward() - self.model_optimizer.step() - - train_data["model"] = self.model - train_data["callback_hook"] = "on_batch_end" - - self.callbacks(**train_data) - - if not train_data["continue_training"]: - break - - if ( - self.max_train_batches is not None - and (batch + 1) >= self.max_train_batches - ): - break - - train_data["callback_hook"] = "on_epoch_end" - train_data["continue_training"] = self.callbacks( - train_dataloader=self.train_dataloader, - val_dataloader=self.val_dataloader, - **train_data, - ) - - if not train_data["continue_training"]: - break - - def __call__(self) -> None: - """Execute training when the trainer is called like a function.""" - - self.train() From 550f48d331bd804e3b906e2a64beb604e7342518 Mon Sep 17 00:00:00 2001 From: Cameron Mattson Date: Mon, 8 Jun 2026 14:17:30 -0600 Subject: [PATCH 4/8] Reduce Optuna max batch size --- callbacks/evaluation.py | 1 - train.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/callbacks/evaluation.py b/callbacks/evaluation.py index e5ce780..0d0fb79 100644 --- a/callbacks/evaluation.py +++ b/callbacks/evaluation.py @@ -66,7 +66,6 @@ def _evaluate_split( """ model.eval() - with torch.no_grad(): for batch_idx, samples in enumerate(dataloader): generated_predictions = model(samples["input"]) diff --git a/train.py b/train.py index a6aaf85..ca804c0 100644 --- a/train.py +++ b/train.py @@ -162,7 +162,7 @@ def __call__(self, trial: optuna.trial.Trial): """ # Let Optuna choose core optimization and loss weights for this trial. - batch_size = trial.suggest_int("batch_size", 1, 8) + batch_size = trial.suggest_int("batch_size", 1, 5) lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True) gradient_penalty_importance = trial.suggest_float( "gradient_penalty_importance", 1.0, 20.0 From 53d29e5156faabaedf98ad7b6969492aaf82fcad Mon Sep 17 00:00:00 2001 From: Cameron Mattson Date: Mon, 8 Jun 2026 14:20:36 -0600 Subject: [PATCH 5/8] Updated to mention batch size purpose when splitting --- train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train.py b/train.py index ca804c0..bd8c975 100644 --- a/train.py +++ b/train.py @@ -329,6 +329,7 @@ def __call__(self, trial: optuna.trial.Trial): val_frac=0.125, ) +# Use a fixed batch size here only to iterate splits while choosing preview images. train_dataloader, val_dataloader, _ = hash_splitter(batch_size=16) train_crop_dataset_idxs = SampleImages( datastruct=train_dataloader, image_fraction=1 / 512 From 18feba25676f63059278e8b4ebd570cf9c13d28c Mon Sep 17 00:00:00 2001 From: Cameron Mattson Date: Mon, 8 Jun 2026 14:56:46 -0600 Subject: [PATCH 6/8] Removed old amp references --- datasets/dataset_00/utils/ImagePostProcessor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datasets/dataset_00/utils/ImagePostProcessor.py b/datasets/dataset_00/utils/ImagePostProcessor.py index d552ceb..9751804 100644 --- a/datasets/dataset_00/utils/ImagePostProcessor.py +++ b/datasets/dataset_00/utils/ImagePostProcessor.py @@ -3,7 +3,6 @@ import numpy as np import torch import torch.nn.functional as F -from torch.amp import autocast class ImagePostProcessor: From 8b43851a05791af6e2f61735da219353d11d6885 Mon Sep 17 00:00:00 2001 From: Cameron Mattson Date: Mon, 8 Jun 2026 15:02:51 -0600 Subject: [PATCH 7/8] Changed batch size to accomodate memory constraint --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index bd8c975..c45796e 100644 --- a/train.py +++ b/train.py @@ -162,7 +162,7 @@ def __call__(self, trial: optuna.trial.Trial): """ # Let Optuna choose core optimization and loss weights for this trial. - batch_size = trial.suggest_int("batch_size", 1, 5) + batch_size = trial.suggest_int("batch_size", 1, 2) lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True) gradient_penalty_importance = trial.suggest_float( "gradient_penalty_importance", 1.0, 20.0 From e75f34344e294ece4381d9cef935eb58feb39f2d Mon Sep 17 00:00:00 2001 From: Cameron Mattson Date: Mon, 8 Jun 2026 15:10:01 -0600 Subject: [PATCH 8/8] Avoid generator autograd during critic update Wrap the discriminator-step generator forward in torch.no_grad() and remove the now-redundant detach on the fake samples passed to the critic loss. This preserves the two-step WGAN-GP training behavior while avoiding construction of an unnecessary generator autograd graph during the critic update, reducing memory and compute overhead. --- trainers/WGANGPTrainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trainers/WGANGPTrainer.py b/trainers/WGANGPTrainer.py index 76ab7de..a2f9a52 100644 --- a/trainers/WGANGPTrainer.py +++ b/trainers/WGANGPTrainer.py @@ -87,13 +87,14 @@ def train(self) -> None: inputs = batch_data["input"].to(self.device) targets = batch_data["target"].to(self.device) - fake_targets_for_discriminator = self.image_postprocessor( - self.generator(inputs) - ) + with torch.no_grad(): + fake_targets_for_discriminator = self.image_postprocessor( + self.generator(inputs) + ) discriminator_outputs = self.discriminator_loss( critic=self.discriminator, real_samples=targets, - fake_samples=fake_targets_for_discriminator.detach(), + fake_samples=fake_targets_for_discriminator, ) discriminator_loss, discriminator_components = self._detach_components( discriminator_outputs