Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions callbacks/batch_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ def on_batch_end(self, hook_data: dict[str, Any]) -> None:
self.global_batch_step += 1
return

batch_loss_components = hook_data.get("batch_loss_components", {})
batch_loss_name = hook_data.get("batch_loss_name", "unknown_loss")
for loss_name, loss_value in batch_loss_components.items():
mlflow.log_metric(
f"batch/train/{batch_loss_name}/{loss_name}",
loss_value,
step=self.global_batch_step,
)
batch_loss_groups = hook_data.get("batch_loss_groups")
if batch_loss_groups is not None:
for group_name, batch_loss_components in batch_loss_groups.items():
for loss_name, loss_value in batch_loss_components.items():
mlflow.log_metric(
f"batch/train/{group_name}/{loss_name}",
loss_value,
step=self.global_batch_step,
)
else:
batch_loss_components = hook_data.get("batch_loss_components", {})
batch_loss_name = hook_data.get("batch_loss_name", "unknown_loss")
for loss_name, loss_value in batch_loss_components.items():
mlflow.log_metric(
f"batch/train/{batch_loss_name}/{loss_name}",
loss_value,
step=self.global_batch_step,
)

self.global_batch_step += 1
1 change: 0 additions & 1 deletion callbacks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def _evaluate_split(
"""

model.eval()

with torch.no_grad():
for batch_idx, samples in enumerate(dataloader):
generated_predictions = model(samples["input"])
Expand Down
1 change: 0 additions & 1 deletion datasets/dataset_00/utils/ImagePostProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch.amp import autocast


class ImagePostProcessor:
Expand Down
33 changes: 0 additions & 33 deletions losses/L1Loss.py

This file was deleted.

29 changes: 16 additions & 13 deletions losses/WassersteinGeneratorCrossZamirskiLoss.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import torch
from torch import nn

from trainers.utils.wgan_gp import compute_generator_components


class WassersteinGeneratorCrossZamirskiLoss(nn.Module):
"""Generator loss combining L1 reconstruction and Wasserstein term."""

loss_name = "wasserstein_generator" # Stable MLflow namespace for this loss family.

def __init__(self, reconstruction_importance: float = 100.0) -> None:
def __init__(
self,
reconstruction_importance: float = 100.0,
use_adversarial_decay: bool = True,
) -> None:
"""Configure weighting for the reconstruction component.

Args:
Expand All @@ -16,6 +22,7 @@ def __init__(self, reconstruction_importance: float = 100.0) -> None:

super().__init__()
self.reconstruction_importance = reconstruction_importance
self.use_adversarial_decay = use_adversarial_decay

def forward(
self,
Expand Down Expand Up @@ -43,19 +50,15 @@ def forward(
ValueError: If critic output batch size does not match predictions batch size.
"""

if generated_predictions.shape != targets.shape:
raise ValueError("generated_predictions and targets must have the same shape.")

batch_size = generated_predictions.size(0)
if fake_classification_outputs.size(0) != batch_size:
raise ValueError(
"fake_classification_outputs batch size must match generated_predictions."
)

reconstruction_loss = torch.nn.functional.l1_loss(
generated_predictions, targets, reduction="mean"
components = compute_generator_components(
fake_classification_outputs=fake_classification_outputs,
generated_predictions=generated_predictions,
targets=targets,
epoch=epoch,
use_adversarial_decay=self.use_adversarial_decay,
)
adversarial_term = torch.mean(fake_classification_outputs) / (epoch + 1)
reconstruction_loss = components["reconstruction_term"]
adversarial_term = components["adversarial_term"]
total = self.reconstruction_importance * reconstruction_loss - adversarial_term
return {
"total": total,
Expand Down
30 changes: 14 additions & 16 deletions losses/WassersteinGradientPenaltyLoss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch import nn

from trainers.utils.wgan_gp import compute_wgan_gp_terms


class WassersteinGradientPenaltyLoss(nn.Module):
"""WGAN-GP loss wrapper with trainer-compatible call signature."""
Expand All @@ -19,17 +21,17 @@ def __init__(self, gradient_penalty_importance: float = 10.0) -> None:

def forward(
self,
gradients: torch.Tensor,
real_classification_outputs: torch.Tensor,
fake_classification_outputs: torch.Tensor,
critic: nn.Module,
real_samples: torch.Tensor,
fake_samples: torch.Tensor,
**kwargs,
) -> dict[str, torch.Tensor]:
"""Compute critic loss with Wasserstein distance and gradient penalty.

Args:
gradients: Gradients of critic outputs w.r.t. interpolated inputs.
real_classification_outputs: Critic outputs for real samples.
fake_classification_outputs: Critic outputs for generated samples.
critic: Critic network used to score real/fake samples.
real_samples: Real target samples.
fake_samples: Generated samples.
**kwargs: Additional unused loss arguments.

Returns:
Expand All @@ -41,17 +43,13 @@ def forward(
ValueError: If fake critic output batch size does not match gradients batch size.
"""

batch_size = gradients.size(0)
if real_classification_outputs.size(0) != batch_size:
raise ValueError("real_classification_outputs batch size must match gradients.")
if fake_classification_outputs.size(0) != batch_size:
raise ValueError("fake_classification_outputs batch size must match gradients.")

gradients = gradients.view(batch_size, -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
wasserstein = torch.mean(fake_classification_outputs) - torch.mean(
real_classification_outputs
components = compute_wgan_gp_terms(
critic=critic,
real_samples=real_samples,
fake_samples=fake_samples,
)
gradient_penalty = components["gradient_penalty_unweighted"]
wasserstein = components["wasserstein_term"]
total = wasserstein + gradient_penalty * self.gradient_penalty_importance
return {
"total": total,
Expand Down
1 change: 0 additions & 1 deletion losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .L1Loss import L1Loss
from .WassersteinGeneratorCrossZamirskiLoss import WassersteinGeneratorCrossZamirskiLoss
from .WassersteinGradientPenaltyLoss import WassersteinGradientPenaltyLoss
29 changes: 29 additions & 0 deletions models/unconditional_critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from torch import nn


class UnconditionalCritic(nn.Module):
"""Small convolutional critic for single-channel image scoring."""

def __init__(self, in_channels: int = 1, base_channels: int = 64) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_channels, base_channels * 2, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(base_channels * 2, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(
base_channels * 2,
base_channels * 4,
kernel_size=4,
stride=2,
padding=1,
),
nn.InstanceNorm2d(base_channels * 4, affine=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_channels * 4, 1, kernel_size=4, stride=1, padding=1),
)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.layers(inputs)
78 changes: 59 additions & 19 deletions train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"import numpy as np\n",
"import optuna\n",
"import torch\n",
"from models.unconditional_critic import UnconditionalCritic\n",
"from models.convnext_unet.unext import ConvNeXtUNet\n",
"\n",
"from callbacks.CallbackPipeline import CallbackPipeline\n",
Expand All @@ -29,14 +30,17 @@
")\n",
"from datasets.dataset_00.utils.ImagePostProcessor import ImagePostProcessor\n",
"from datasets.dataset_00.utils.ImagePreProcessor import ImagePreProcessor\n",
"from losses.L1Loss import L1Loss\n",
"from losses.WassersteinGeneratorCrossZamirskiLoss import (\n",
" WassersteinGeneratorCrossZamirskiLoss,\n",
")\n",
"from losses.WassersteinGradientPenaltyLoss import WassersteinGradientPenaltyLoss\n",
"from metrics.L1 import L1\n",
"from metrics.L2 import L2\n",
"from metrics.PearsonCorrelation import PearsonCorrelation\n",
"from metrics.PSNR import PSNR\n",
"from metrics.SSIM import SSIM\n",
"from splitters.HashSplitter import HashSplitter\n",
"from trainers.UNetTrainer import UNetTrainer"
"from trainers.WGANGPTrainer import WGANGPTrainer"
]
},
{
Expand Down Expand Up @@ -131,36 +135,57 @@
" hash_splitter: Any,\n",
" dataset: Any,\n",
" callbacks_args: dict[str, Any],\n",
" model_factory: Callable[[], torch.nn.Module],\n",
" generator_factory: Callable[[], torch.nn.Module],\n",
" discriminator_factory: Callable[[], torch.nn.Module],\n",
" **trainer_kwargs,\n",
" ):\n",
" self.trainer = trainer\n",
" self.hash_splitter = hash_splitter\n",
" self.dataset = dataset\n",
" self.callbacks_args = callbacks_args\n",
" self.model_factory = model_factory\n",
" self.generator_factory = generator_factory\n",
" self.discriminator_factory = discriminator_factory\n",
" self.trainer_kwargs = trainer_kwargs\n",
"\n",
" def __call__(self, trial: optuna.trial.Trial):\n",
" # Let Optuna choose a mini-batch size and learning rate for this trial.\n",
" # Let Optuna choose core optimization and loss weights for this trial.\n",
" batch_size = trial.suggest_int(\"batch_size\", 1, 8)\n",
" lr = trial.suggest_float(\"lr\", 1e-5, 1e-3, log=True)\n",
" gradient_penalty_importance = trial.suggest_float(\n",
" \"gradient_penalty_importance\", 1.0, 20.0\n",
" )\n",
" reconstruction_importance = trial.suggest_float(\n",
" \"reconstruction_importance\", 10.0, 200.0\n",
" )\n",
"\n",
" # Rebuild train/val loaders at the chosen batch size while keeping deterministic splits.\n",
" train_dataloader, val_dataloader, _ = self.hash_splitter(batch_size=batch_size)\n",
" self.trainer_kwargs[\"train_dataloader\"] = train_dataloader\n",
" self.trainer_kwargs[\"val_dataloader\"] = val_dataloader\n",
"\n",
" model = self.model_factory()\n",
" self.trainer_kwargs[\"model\"] = model\n",
" generator = self.generator_factory()\n",
" discriminator = self.discriminator_factory()\n",
" self.trainer_kwargs[\"generator\"] = generator\n",
" self.trainer_kwargs[\"discriminator\"] = discriminator\n",
"\n",
" optimizer_params = {\n",
" \"params\": model.parameters(),\n",
" generator_optimizer_params = {\n",
" \"params\": generator.parameters(),\n",
" \"lr\": lr,\n",
" \"betas\": (0.5, 0.999),\n",
" }\n",
" discriminator_optimizer_params = {\n",
" \"params\": discriminator.parameters(),\n",
" \"lr\": lr,\n",
" \"betas\": (0.5, 0.999),\n",
" }\n",
"\n",
" loss_trainer = L1Loss()\n",
" generator_loss = WassersteinGeneratorCrossZamirskiLoss(\n",
" reconstruction_importance=reconstruction_importance,\n",
" use_adversarial_decay=True,\n",
" )\n",
" discriminator_loss = WassersteinGradientPenaltyLoss(\n",
" gradient_penalty_importance=gradient_penalty_importance\n",
" )\n",
" loss_callbacks = L1(device=device)\n",
" metrics = [\n",
" L2(device=device),\n",
Expand All @@ -171,21 +196,30 @@
"\n",
" # Use a nested MLflow run so each Optuna trial has its own metrics/artifacts.\n",
" with mlflow.start_run(nested=True, run_name=f\"trial_{trial.number}\"):\n",
" optimizer = torch.optim.Adam(**optimizer_params)\n",
" self.trainer_kwargs[\"model_optimizer\"] = optimizer\n",
" generator_optimizer = torch.optim.Adam(**generator_optimizer_params)\n",
" discriminator_optimizer = torch.optim.Adam(**discriminator_optimizer_params)\n",
" self.trainer_kwargs[\"generator_optimizer\"] = generator_optimizer\n",
" self.trainer_kwargs[\"discriminator_optimizer\"] = discriminator_optimizer\n",
"\n",
" opt_params = optimizer.param_groups[0].copy()\n",
" opt_params = generator_optimizer.param_groups[0].copy()\n",
" del opt_params[\"params\"]\n",
" mlflow.log_params({f\"optimizer_{k}\": v for k, v in opt_params.items()})\n",
" mlflow.log_param(\"batch_size\", batch_size)\n",
" mlflow.set_tag(\"optimizer_class\", optimizer.__class__.__name__.lower())\n",
" mlflow.log_param(\"gradient_penalty_importance\", gradient_penalty_importance)\n",
" mlflow.log_param(\"reconstruction_importance\", reconstruction_importance)\n",
" mlflow.log_param(\"use_adversarial_decay\", True)\n",
" mlflow.set_tag(\"optimizer_class\", generator_optimizer.__class__.__name__.lower())\n",
"\n",
" self.trainer_kwargs[\"callbacks\"] = CallbackPipeline(\n",
" **self.callbacks_args | {\"metrics\": metrics, \"loss\": loss_callbacks}\n",
" )\n",
"\n",
" trainer_obj = self.trainer(\n",
" **self.trainer_kwargs | {\"model_loss\": loss_trainer}\n",
" **self.trainer_kwargs\n",
" | {\n",
" \"generator_loss\": generator_loss,\n",
" \"discriminator_loss\": discriminator_loss,\n",
" }\n",
" )\n",
" trainer_obj.train()\n",
"\n",
Expand Down Expand Up @@ -220,11 +254,14 @@
"mlflow.log_param(\"crop_size\", args.crop_size)\n",
"\n",
"description = \"\"\"\n",
"Optimization of a DAPI-to-Gold image-to-image translation model with:\n",
"Optimization of an unconditional WGAN-GP DAPI-to-Gold image-to-image translation model with:\n",
"- ConvNeXtUNet Generator\n",
"- Unconditional convolutional discriminator\n",
"- Single 2D crop input and single 2D crop target\n",
"- Cache-backed filtered nucleus crops generated from the configured data directory\n",
"- L1 optimization objective with L2, PSNR, SSIM, and Pearson correlation metric logging\n",
"- Generator objective: reconstruction-weighted L1 plus epoch-decayed adversarial term\n",
"- Discriminator objective: Wasserstein loss with gradient penalty\n",
"- Validation uses L1, L2, PSNR, SSIM, and Pearson correlation metric logging\n",
"\"\"\"\n",
"mlflow.set_tag(\"mlflow.note.content\", description)\n",
"\n",
Expand Down Expand Up @@ -322,16 +359,19 @@
"outputs": [],
"source": [
"optimization_manager = OptimizationManager(\n",
" trainer=UNetTrainer,\n",
" trainer=WGANGPTrainer,\n",
" hash_splitter=hash_splitter,\n",
" dataset=crop_image_dataset,\n",
" callbacks_args=callbacks_args,\n",
" model_factory=lambda: ConvNeXtUNet(\n",
" generator_factory=lambda: ConvNeXtUNet(\n",
" in_channels=1,\n",
" out_channels=1,\n",
" decoder_up_block=\"convt\",\n",
" ),\n",
" discriminator_factory=lambda: UnconditionalCritic(in_channels=1),\n",
" epochs=args.epochs,\n",
" image_postprocessor=image_postprocessor,\n",
" device=device,\n",
" max_train_batches=max_train_batches,\n",
")\n",
"\n",
Expand Down
Loading