From c7bccafee1ca26a9f8db8ad17716b025bdd2b23c Mon Sep 17 00:00:00 2001 From: Keltin Grimes <35310549+keltin13@users.noreply.github.com> Date: Mon, 13 Apr 2026 11:51:26 -0400 Subject: [PATCH] Add missing .float() call to min_k++ --- src/evals/metrics/mia/min_k_plus_plus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/evals/metrics/mia/min_k_plus_plus.py b/src/evals/metrics/mia/min_k_plus_plus.py index aa58214a9..94f99a2e0 100644 --- a/src/evals/metrics/mia/min_k_plus_plus.py +++ b/src/evals/metrics/mia/min_k_plus_plus.py @@ -32,7 +32,7 @@ def compute_score(self, sample_stats): sigma = torch.clamp(sigma, min=1e-6) scores = ( target_prob.float().cpu().numpy() - mu.float().cpu().numpy() - ) / torch.sqrt(sigma).cpu().numpy() + ) / torch.sqrt(sigma).float().cpu().numpy() # Take bottom k% as the attack score num_k = max(1, int(len(scores) * self.k))