diff --git a/src/evals/metrics/mia/min_k_plus_plus.py b/src/evals/metrics/mia/min_k_plus_plus.py index aa58214a9..94f99a2e0 100644 --- a/src/evals/metrics/mia/min_k_plus_plus.py +++ b/src/evals/metrics/mia/min_k_plus_plus.py @@ -32,7 +32,7 @@ def compute_score(self, sample_stats): sigma = torch.clamp(sigma, min=1e-6) scores = ( target_prob.float().cpu().numpy() - mu.float().cpu().numpy() - ) / torch.sqrt(sigma).cpu().numpy() + ) / torch.sqrt(sigma).float().cpu().numpy() # Take bottom k% as the attack score num_k = max(1, int(len(scores) * self.k))