-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattack.py
More file actions
35 lines (27 loc) · 913 Bytes
/
attack.py
File metadata and controls
35 lines (27 loc) · 913 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn.functional as F
def fgsm_attack(model, images, epsilon):
"""
FGSM for image reconstruction models (e.g., DAEViT).
Args:
model: reconstruction model
images: (B, C, H, W) input
epsilon: perturbation strength
Returns:
perturbed images
"""
images = images.clone().detach().to(images.device)
images.requires_grad = True
outputs = model(images)
loss = F.mse_loss(outputs, images)
model.zero_grad()
loss.backward()
perturbed = images + epsilon * images.grad.sign()
perturbed = torch.clamp(perturbed, 0, 1)
return perturbed
def add_noise_with_snr(x: torch.Tensor, snr_db: float) -> torch.Tensor:
signal_power = x.pow(2).mean()
snr = 10 ** (snr_db / 10)
noise_power = signal_power / snr
noise = torch.randn_like(x) * torch.sqrt(noise_power)
return torch.clamp(x + noise, 0, 1)