From 724e5842da0bf134ca803ad1853483e7ab4feed4 Mon Sep 17 00:00:00 2001 From: Shraddha Piparia Date: Mon, 25 May 2026 07:26:12 -0700 Subject: [PATCH 1/2] Add phase diagnostics and reorganize attribution analyses --- README.md | 9 +- configs/config_phase1.yaml | 36 +- configs/config_phase2.yaml | 15 +- .../09_phase2_block_attribution.py | 249 +++- .../09_phase2_block_attribution_updated.py | 1320 +++++++++++++++++ ...10_phase1_snp_attribution_within_blocks.py | 0 .../attribution/10_rank_shift_scatter.py | 324 ++++ .../01_analyze_VAEtraining_logs.py | 107 ++ .../02_subject_cluster_analysis.py | 0 .../04_cluster_stability_analysis.py | 0 .../05_attention_confounder_analysis.py | 0 .../06_phase1_phase2_block_comparison.py | 0 .../08_clinical_pc_embedding_alignment.py | 2 +- .../03_leave_hla_out_analysis.py | 0 .../{ => sensitivity}/07_17q21_validation.py | 0 .../sensitivity/11_top_snp_ld_check.py | 359 +++++ .../01_block_embedding_phenotype_analysis.py | 8 +- scripts/core/VAE_phase1.py | 897 +++++++++-- scripts/core/attention_phase2.py | 1029 ++++++++++++- scripts/core/plots_updated.py | 14 +- tests/test_phase2_diagnostics.py | 439 ++++++ tests/test_smoke_outputs.py | 74 +- tests/test_vae_diagnostics.py | 151 ++ 23 files changed, 4731 insertions(+), 302 deletions(-) rename scripts/analysis/{ => attribution}/09_phase2_block_attribution.py (81%) create mode 100644 scripts/analysis/attribution/09_phase2_block_attribution_updated.py rename scripts/analysis/{ => attribution}/10_phase1_snp_attribution_within_blocks.py (100%) create mode 100644 scripts/analysis/attribution/10_rank_shift_scatter.py create mode 100644 scripts/analysis/diagnostics/01_analyze_VAEtraining_logs.py rename scripts/analysis/{ => diagnostics}/02_subject_cluster_analysis.py (100%) rename scripts/analysis/{ => diagnostics}/04_cluster_stability_analysis.py (100%) rename scripts/analysis/{ => diagnostics}/05_attention_confounder_analysis.py (100%) rename scripts/analysis/{ => diagnostics}/06_phase1_phase2_block_comparison.py (100%) rename scripts/analysis/{ => diagnostics}/08_clinical_pc_embedding_alignment.py (99%) rename scripts/analysis/{ => sensitivity}/03_leave_hla_out_analysis.py (100%) rename scripts/analysis/{ => sensitivity}/07_17q21_validation.py (100%) create mode 100644 scripts/analysis/sensitivity/11_top_snp_ld_check.py rename scripts/analysis/{ => validation}/01_block_embedding_phenotype_analysis.py (99%) create mode 100644 tests/test_phase2_diagnostics.py create mode 100644 tests/test_vae_diagnostics.py diff --git a/README.md b/README.md index 917585d..5bd2fcb 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ losing local genomic context. This project takes a hierarchical approach: 1. **Phase 1 — Per-block β-VAE.** Each LD block is encoded independently into a low-dimensional latent vector that captures local haplotype structure. -2. **Phase 2 — Cross-block Transformer.** A Transformer aggregates all block embeddings +2. **Phase 2 — Cross-block Attention.** Phase 2 aggregates all block embeddings into a single subject-level representation. Learned attention weights identify which blocks are most informative for organizing genetic variation across individuals. @@ -152,6 +152,12 @@ biologically coherent without supervised training. --- +## Upcoming Work + +This repository now includes block-level and SNP-level attribution analyses. The next step is to map high-attribution blocks and variants to candidate genes, then annotate those genes using public eQTL, pQTL, disease-association, pathway, druggability, and cell-type expression evidence. + +--- + ## Repository Structure ``` @@ -176,7 +182,6 @@ pytest.ini WORKFLOW.md Step-by-step execution guide with CLI examples run_pipeline.sh Single entry point — runs full pipeline or --dry-run input check test_run.sh Smoke test — Phase 1 → Phase 2 on synthetic data (no restricted data needed) -CLAUDE.md AI assistance constraints and workflow summary ``` See [WORKFLOW.md](WORKFLOW.md) for full CLI instructions, expected inputs/outputs per diff --git a/configs/config_phase1.yaml b/configs/config_phase1.yaml index ced1795..27ea00d 100644 --- a/configs/config_phase1.yaml +++ b/configs/config_phase1.yaml @@ -1,7 +1,7 @@ data: raw_dir: "data/region_blocks" block_def: "data/block_plan/manifest.tsv" - output_dir: "results/output_regions" + output_dir: "results/output_regions_ord_weighted" runtime: device: "cpu" vae: @@ -12,28 +12,36 @@ vae: - [150, 299, 12] - [300, 799, 16] - [800, 1000000000, 16] - dropout: 0.3 - lr: 0.001 + dropout: 0.1 + lr: 0.002 batch_size: 64 epochs: 200 - beta_max: 0.5 + beta_max: 1 # tuned: best for both ORD and MSE after per-element-mean normalization beta_warmup: 50 - patience: 20 + patience: 5 grad_clip: 1.0 - val_frac: 0.2 + val_frac: 0.20 + test_frac: 0.10 seed: 42 cat_weight_clip: 10.0 -loss_functions: ["MSE","ORD","MSE_STD","CAT","BCE"] -# loss_functions: ["ORD"] + # VAE diagnostics / behavior-preserving defaults + free_bits: 0.05 + ld_corr_repeats: 1 + ld_corr_max_snps: 200 + ord_weighted: true # apply per-sample class weights to ordinal_loss + ord_weight_clip: 10.0 # same convention as cat_weight_clip + standardize_embeddings: true +# loss_functions: ["MSE","ORD","MSE_STD","CAT","BCE"] +loss_functions: ["ORD"] tuning: enabled: true - loss: "ORD" # loss function to tune on - metric: "bal_acc_va" # or "best_val_loss", "ld_corr_va" - blocks: ["region_17q21_core_sb1", "region_11q13_FCER1A","region_2q12_IL1RL1_cluster_sb3","region_5q21_PDE4D_sb55", "region_6p21_HLA_classII_sb1","region_5q31_type2_cytokine_sb9", "region_1q31_TNFSF_cluster_sb4", "control_OCA2_sb10"] # list of block_ids to tune on, empty means all + loss: "ORD" + metric: "concordance_gain_va" + blocks: ["region_17q21_core_sb1", "region_11q13_FCER1A","region_2q12_IL1RL1_cluster_sb3","region_5q21_PDE4D_sb55", "region_6p21_HLA_classII_sb1","region_5q31_type2_cytokine_sb9", "region_1q31_TNFSF_cluster_sb4", "control_OCA2_sb10"] params: - dropout: [0.1, 0.3, 0.5] - lr: [0.0005, 0.001, 0.002] - beta_max: [0.3, 0.5] + dropout: [0.1, 0.3] + lr: [0.001, 0.002] + beta_max: [0.5, 1.0, 2.0, 4.0] representative: blocks: ["region_17q21_core_sb1", "region_11q13_FCER1A","region_2q12_IL1RL1_cluster_sb3","region_5q21_PDE4D_sb55", "region_6p21_HLA_classII_sb1","region_5q31_type2_cytokine_sb9", "region_1q31_TNFSF_cluster_sb4", "control_OCA2_sb10"] # list of block_ids to tune on, empty means all metric: "bal_acc_va" # metric for top/bottom selection diff --git a/configs/config_phase2.yaml b/configs/config_phase2.yaml index 54949eb..b6fb515 100644 --- a/configs/config_phase2.yaml +++ b/configs/config_phase2.yaml @@ -1,16 +1,17 @@ phase1_dir: results/output_regions output_dir: results/output_regions2 attention: - d_model: 64 + d_model: 128 n_heads: 4 n_layers: 2 - d_ff: 128 - dropout: 0.1 + d_ff: 512 + n_pool_tokens: 4 + dropout: 0.30 + epochs: 150 + patience: 20 lr: 0.0005 - weight_decay: 0.0001 + weight_decay: 0.001 batch_size: 64 - epochs: 300 - patience: 30 grad_clip: 1.0 seed: 42 extract_self_attn: true @@ -23,4 +24,4 @@ clustering: umap_n_neighbors: 15 umap_min_dist: 0.1 umap_seed: 42 -loss_functions: [ORD] \ No newline at end of file +loss_functions: [ORD_W_Scaled] \ No newline at end of file diff --git a/scripts/analysis/09_phase2_block_attribution.py b/scripts/analysis/attribution/09_phase2_block_attribution.py similarity index 81% rename from scripts/analysis/09_phase2_block_attribution.py rename to scripts/analysis/attribution/09_phase2_block_attribution.py index 972ad19..c81ee44 100644 --- a/scripts/analysis/09_phase2_block_attribution.py +++ b/scripts/analysis/attribution/09_phase2_block_attribution.py @@ -60,6 +60,7 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) +import yaml import numpy as np import pandas as pd @@ -89,10 +90,12 @@ from scripts.core.attention_phase2 import AttentionAggregator # ── default paths (relative to repo root) ───────────────────────────────────── -P1_EMBEDDINGS = "results/output_regions/ORD/embeddings/all_blocks.npy" -P2_CHECKPOINT = "results/output_regions2/ORD/models/attention_aggregator.pt" -P2_EMBEDDINGS = "results/output_regions2/ORD/embeddings/individual_embeddings.npy" -P2_EMB_CSV = "results/output_regions2/ORD/embeddings/individual_embeddings.csv" +P1_EMBEDDINGS = "results/output_regions/ORD_W_Scaled/embeddings/all_blocks.npy" +P1_LATENT_DIMS = "results/output_regions/ORD_W_Scaled/embeddings/all_blocks_latent_dims.npy" +P2_CONFIG = "configs/config_phase2.yaml" +P2_CHECKPOINT = "results/output_regions2/ORD_W_Scaled/models/attention_aggregator.pt" +P2_EMBEDDINGS = "results/output_regions2/ORD_W_Scaled/embeddings/individual_embeddings.npy" +P2_EMB_CSV = "results/output_regions2/ORD_W_Scaled/embeddings/individual_embeddings.csv" BLOCK_ORDER = "results/output_regions/block_order.csv" SUBJECTS_CSV = "results/output_regions/subjects.csv" PHENO_FILE = "metadata/COS_TRIO_pheno_1165.csv" @@ -112,19 +115,57 @@ FULL_PC2_CSV = "results/analysis/phase2_block_attribution/phase2_PC2_leave_one_block_out.csv" +# ── config and block-dim helpers ────────────────────────────────────────────── +def load_phase2_config(config_path: str) -> dict: + """Return the [attention] section of config_phase2.yaml, or {} if not found.""" + p = Path(config_path) + if not p.exists(): + warnings.warn(f"Phase 2 config not found: {p} — using hardcoded defaults") + return {} + with open(p) as f: + cfg = yaml.safe_load(f) + return cfg.get("attention", {}) + + +def load_block_dims(latent_dims_path: str, block_meta: pd.DataFrame): + """Return per-block latent dim list, or None for legacy shared projection. + + Tries all_blocks_latent_dims.npy first (authoritative actual dims after + clamping), then falls back to block_order.csv['latent_dim']. + """ + p = Path(latent_dims_path) + if p.exists(): + dims = np.load(p).astype(int).tolist() + counts = dict(pd.Series(dims).value_counts().sort_index()) + print(f"[model] using BlockProjector with per-block dims: {counts}") + return dims + if "latent_dim" in block_meta.columns: + dims = block_meta["latent_dim"].astype(int).tolist() + counts = dict(pd.Series(dims).value_counts().sort_index()) + print(f"[model] using BlockProjector (from block_order.csv) with per-block dims: {counts}") + return dims + print("[model] using legacy shared input projection") + return None + + # ── model loader ────────────────────────────────────────────────────────────── def load_model(ckpt_path: str, n_blocks: int, d_in: int, d_model: int = 64, n_heads: int = 4, n_layers: int = 2, d_ff: int = 128, - dropout: float = 0.0) -> AttentionAggregator: - """ - Load the AttentionAggregator from checkpoint. - dropout=0.0 at inference time (no randomness). + dropout: float = 0.0, + block_dims: list = None) -> AttentionAggregator: + """Load AttentionAggregator from checkpoint. + + Pass block_dims to instantiate with BlockProjector (heterogeneous per-block + latent dims, checkpoint keys input_proj.projectors.*). Omit for legacy + shared projection (checkpoint keys input_proj.0.weight). + dropout=0.0 at inference — no randomness during attribution. """ model = AttentionAggregator( n_blocks=n_blocks, d_in=d_in, d_model=d_model, n_heads=n_heads, n_layers=n_layers, d_ff=d_ff, dropout=dropout, + block_dims=block_dims, ) state = torch.load(ckpt_path, map_location="cpu", weights_only=True) model.load_state_dict(state) @@ -249,7 +290,7 @@ def run_lobo( For each block b in include_block_indices (default: all blocks): - Replace x[:, b, :] with the cross-subject population mean for block b. - Run encode → masked subject embeddings. - - Project onto PC1/PC2, compute per-subject delta vs baseline. + - Project onto PCA axes, compute per-subject delta vs baseline. Parameters ---------- @@ -261,9 +302,13 @@ def run_lobo( ------- results_PC1 : list of dicts results_PC2 : list of dicts + results_PC3 : list of dicts + results_PC4 : list of dicts + results_PC5 : list of dicts results_ige : list of dicts or None """ B = x_np.shape[1] + n_pc = base_scores.shape[1] # number of PCA components available assert len(block_names) == B, f"block_names length {len(block_names)} != B={B}" indices_to_run = list(range(B)) if include_block_indices is None else list(include_block_indices) @@ -273,36 +318,36 @@ def run_lobo( do_ige = (ridge_ige is not None and pheno_mask is not None and pheno_y is not None) base_ige_pred = ridge_predict(ridge_ige, base_emb[pheno_mask]) if do_ige else None - results_PC1, results_PC2, results_ige = [], [], [] + results_PC1, results_PC2, results_PC3, results_PC4, results_PC5, results_ige = \ + [], [], [], [], [], [] print(f"\n[lobo] running {len(indices_to_run)}/{B} blocks " - f"(mean masking, batch_size={batch_size}) ...") + f"(mean masking, batch_size={batch_size}, n_pc={n_pc}) ...") for loop_idx, b in enumerate(indices_to_run): x_masked = x_np.copy() x_masked[:, b, :] = block_means[b] # replace block b with population mean - masked_emb = encode_batched(model, x_masked, batch_size) # (N, 64) - masked_scores = project_embeddings(masked_emb, scaler, pca) # (N, 2) - - delta_PC1 = masked_scores[:, 0] - base_scores[:, 0] - delta_PC2 = masked_scores[:, 1] - base_scores[:, 1] - - results_PC1.append(dict( - block_index=b, - block_id=block_names[b], - target="emb_PC1", - mean_abs_delta_score=float(np.abs(delta_PC1).mean()), - mean_signed_delta_score=float(delta_PC1.mean()), - n_subjects=len(delta_PC1), - )) - results_PC2.append(dict( - block_index=b, - block_id=block_names[b], - target="emb_PC2", - mean_abs_delta_score=float(np.abs(delta_PC2).mean()), - mean_signed_delta_score=float(delta_PC2.mean()), - n_subjects=len(delta_PC2), - )) + masked_emb = encode_batched(model, x_masked, batch_size) # (N, d_model) + masked_scores = project_embeddings(masked_emb, scaler, pca) # (N, n_pc) + + for pc_idx, (results_list, label) in enumerate([ + (results_PC1, "emb_PC1"), + (results_PC2, "emb_PC2"), + (results_PC3, "emb_PC3"), + (results_PC4, "emb_PC4"), + (results_PC5, "emb_PC5"), + ]): + if pc_idx >= n_pc: + continue + delta = masked_scores[:, pc_idx] - base_scores[:, pc_idx] + results_list.append(dict( + block_index=b, + block_id=block_names[b], + target=label, + mean_abs_delta_score=float(np.abs(delta).mean()), + mean_signed_delta_score=float(delta.mean()), + n_subjects=len(delta), + )) if do_ige: masked_ige_pred = ridge_predict(ridge_ige, masked_emb[pheno_mask]) @@ -320,7 +365,8 @@ def run_lobo( n_run = len(indices_to_run) print(f" block {loop_idx+1:3d}/{n_run} done") - return results_PC1, results_PC2, results_ige if do_ige else None + return results_PC1, results_PC2, results_PC3, results_PC4, results_PC5, \ + (results_ige if do_ige else None) # ── output helpers ──────────────────────────────────────────────────────────── @@ -410,6 +456,9 @@ def write_readme(out_dir: Path, method_note: str): |------|--------| | phase2_PC1_leave_one_block_out.csv | Phase 2 embedding PC1 | | phase2_PC2_leave_one_block_out.csv | Phase 2 embedding PC2 | +| phase2_PC3_leave_one_block_out.csv | Phase 2 embedding PC3 | +| phase2_PC4_leave_one_block_out.csv | Phase 2 embedding PC4 | +| phase2_PC5_leave_one_block_out.csv | Phase 2 embedding PC5 | | phase2_log10Ige_leave_one_block_out.csv | Ridge-predicted log10IgE (if computed) | ## What this is NOT @@ -600,6 +649,15 @@ def run_noHLA_mode(args): out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) + # ── 0. Load Phase 2 config (hyperparameters) ────────────────────────────── + print(f"[config] Phase 2 config: {args.phase2_config}") + attn_cfg = load_phase2_config(args.phase2_config) + d_model = attn_cfg.get("d_model", 64) + n_heads = attn_cfg.get("n_heads", 4) + n_layers = attn_cfg.get("n_layers", 2) + d_ff = attn_cfg.get("d_ff", 128) + print(f" d_model={d_model}, n_heads={n_heads}, n_layers={n_layers}, d_ff={d_ff}") + # ── 1. Load Phase 1 input embeddings ────────────────────────────────────── print(f"\n[load] Phase 1 ORD embeddings: {args.p1_embeddings}") x_np = np.load(args.p1_embeddings) @@ -608,7 +666,7 @@ def run_noHLA_mode(args): if B != 174: raise ValueError(f"Expected B=174 blocks, got B={B}. Check {args.p1_embeddings}.") - # ── 2. Load block order ─────────────────────────────────────────────────── + # ── 2. Load block order and per-block latent dims ───────────────────────── print(f"[load] block_order: {args.block_order}") block_meta = pd.read_csv(args.block_order) if len(block_meta) != B: @@ -617,6 +675,8 @@ def run_noHLA_mode(args): ) block_names = block_meta["block_id"].tolist() + block_dims = load_block_dims(args.p1_latent_dims, block_meta) + # ── 3. Identify HLA vs non-HLA block indices ────────────────────────────── hla_indices = [i for i, name in enumerate(block_names) if HLA_PATTERN in name] non_hla_indices = [i for i in range(B) if i not in set(hla_indices)] @@ -664,17 +724,19 @@ def run_noHLA_mode(args): # ── 7. Load saved no-HLA embeddings ────────────────────────────────────── print(f"[load] saved no-HLA Phase 2 embeddings: {args.noHLA_embeddings}") base_emb_saved = np.load(args.noHLA_embeddings) - if base_emb_saved.shape != (N, 64): + if base_emb_saved.shape != (N, d_model): raise ValueError( - f"Expected shape ({N}, 64), got {base_emb_saved.shape}." + f"Expected shape ({N}, {d_model}), got {base_emb_saved.shape}." ) # ── 8. Load no-HLA model ────────────────────────────────────────────────── print(f"[model] loading no-HLA checkpoint: {args.noHLA_checkpoint}") - model = load_model(args.noHLA_checkpoint, n_blocks=B, d_in=d_in) + model = load_model(args.noHLA_checkpoint, n_blocks=B, d_in=d_in, + d_model=d_model, n_heads=n_heads, n_layers=n_layers, + d_ff=d_ff, dropout=0.0, block_dims=block_dims) n_params = sum(p.numel() for p in model.parameters()) print(f" AttentionAggregator (no-HLA): n_blocks={B}, d_in={d_in}, " - f"d_model=64 — {n_params:,} params") + f"d_model={d_model}, n_heads={n_heads}, n_layers={n_layers} — {n_params:,} params") # ── 9. Verify baseline (zeroed input → no-HLA checkpoint vs saved embeddings) print("[check] verifying no-HLA baseline forward pass ...") @@ -683,13 +745,13 @@ def run_noHLA_mode(args): # ── 10. Fit PCA on saved no-HLA embeddings (defines reference axes) ─────── print("[PCA] fitting PCA on saved no-HLA embeddings ...") - scaler, pca, base_scores = fit_pca(base_emb_saved, n_components=2) + scaler, pca, base_scores = fit_pca(base_emb_saved, n_components=5) var_exp = pca.explained_variance_ratio_ - print(f" no-HLA PC1={var_exp[0]*100:.1f}% " - f"PC2={var_exp[1]*100:.1f}% of embedding variance") + print(" no-HLA variance explained: " + + " ".join(f"PC{i+1}={v*100:.1f}%" for i, v in enumerate(var_exp))) # ── 11. LOBO on non-HLA blocks (mean masking computed from zeroed input) ── - res_PC1, res_PC2, _ = run_lobo( + res_PC1, res_PC2, res_PC3, res_PC4, res_PC5, _ = run_lobo( model=model, x_np=x_np_zeroed, # HLA blocks permanently zeroed base_emb=base_emb_saved, @@ -704,19 +766,26 @@ def run_noHLA_mode(args): # ── 12. Build and save result DataFrames ────────────────────────────────── df_noHLA_PC1 = finalise_df(res_PC1, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") df_noHLA_PC2 = finalise_df(res_PC2, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") - - out_PC1 = out_dir / "phase2_noHLA_PC1_leave_one_block_out.csv" - out_PC2 = out_dir / "phase2_noHLA_PC2_leave_one_block_out.csv" - df_noHLA_PC1.to_csv(out_PC1, index=False) - df_noHLA_PC2.to_csv(out_PC2, index=False) - print(f"\n[output] {out_PC1.name} ({len(df_noHLA_PC1)} blocks)") - print(f"[output] {out_PC2.name} ({len(df_noHLA_PC2)} blocks)") + df_noHLA_PC3 = finalise_df(res_PC3, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") + df_noHLA_PC4 = finalise_df(res_PC4, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") + df_noHLA_PC5 = finalise_df(res_PC5, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") + + for pc_label, df_noHLA in [ + ("PC1", df_noHLA_PC1), ("PC2", df_noHLA_PC2), + ("PC3", df_noHLA_PC3), ("PC4", df_noHLA_PC4), ("PC5", df_noHLA_PC5), + ]: + out_path = out_dir / f"phase2_noHLA_{pc_label}_leave_one_block_out.csv" + df_noHLA.to_csv(out_path, index=False) + print(f"[output] {out_path.name} ({len(df_noHLA)} blocks)") # ── 13. Plots ───────────────────────────────────────────────────────────── - plot_top_blocks(df_noHLA_PC1, "no-HLA Phase 2 emb PC1", - out_dir / "top20_noHLA_PC1_attribution.png", top_n=args.top_n) - plot_top_blocks(df_noHLA_PC2, "no-HLA Phase 2 emb PC2", - out_dir / "top20_noHLA_PC2_attribution.png", top_n=args.top_n) + for pc_label, df_noHLA in [ + ("PC1", df_noHLA_PC1), ("PC2", df_noHLA_PC2), + ("PC3", df_noHLA_PC3), ("PC4", df_noHLA_PC4), ("PC5", df_noHLA_PC5), + ]: + plot_top_blocks(df_noHLA, f"no-HLA Phase 2 emb {pc_label}", + out_dir / f"top20_noHLA_{pc_label}_attribution.png", + top_n=args.top_n) # ── 14. Comparison table ────────────────────────────────────────────────── comp_path = out_dir / "phase2_full_vs_noHLA_rank_comparison.csv" @@ -751,6 +820,12 @@ def parse_args(): # shared paths p.add_argument("--p1-embeddings", default=P1_EMBEDDINGS, help="Phase 1 ORD all_blocks.npy (N, B, d_in)") + p.add_argument("--p1-latent-dims", default=P1_LATENT_DIMS, + help="all_blocks_latent_dims.npy — actual per-block latent dims " + "(used to build BlockProjector; falls back to block_order.csv)") + p.add_argument("--phase2-config", default=P2_CONFIG, + help="Phase 2 YAML config — reads attention hyperparameters " + "(d_model, n_heads, n_layers, d_ff)") p.add_argument("--block-order", default=BLOCK_ORDER, help="block_order.csv") p.add_argument("--subjects", default=SUBJECTS_CSV, @@ -797,6 +872,15 @@ def main(): out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) + # ── 0. Load Phase 2 config (hyperparameters) ────────────────────────────── + print(f"[config] Phase 2 config: {args.phase2_config}") + attn_cfg = load_phase2_config(args.phase2_config) + d_model = attn_cfg.get("d_model", 64) + n_heads = attn_cfg.get("n_heads", 4) + n_layers = attn_cfg.get("n_layers", 2) + d_ff = attn_cfg.get("d_ff", 128) + print(f" d_model={d_model}, n_heads={n_heads}, n_layers={n_layers}, d_ff={d_ff}") + # ── 1. Load Phase 1 input embeddings ────────────────────────────────────── print(f"\n[load] Phase 1 ORD embeddings: {args.p1_embeddings}") x_np = np.load(args.p1_embeddings) @@ -805,7 +889,7 @@ def main(): if B != 174: raise ValueError(f"Expected B=174 blocks, got B={B}. Check {args.p1_embeddings}.") - # ── 2. Load block order ─────────────────────────────────────────────────── + # ── 2. Load block order and per-block latent dims ───────────────────────── print(f"[load] block_order: {args.block_order}") block_meta = pd.read_csv(args.block_order) if len(block_meta) != B: @@ -815,6 +899,8 @@ def main(): block_names = block_meta["block_id"].tolist() print(f" {B} blocks confirmed. Sample: {block_names[:3]}") + block_dims = load_block_dims(args.p1_latent_dims, block_meta) + # ── 3. Load subject IIDs ────────────────────────────────────────────────── print(f"[load] subjects: {args.subjects}") subjects_csv = pd.read_csv(args.subjects) @@ -843,17 +929,19 @@ def main(): # ── 4. Load saved baseline embeddings ───────────────────────────────────── print(f"[load] saved Phase 2 embeddings: {args.p2_embeddings}") base_emb_saved = np.load(args.p2_embeddings) - if base_emb_saved.shape != (N, 64): + if base_emb_saved.shape != (N, d_model): raise ValueError( - f"Expected shape ({N}, 64), got {base_emb_saved.shape}." + f"Expected shape ({N}, {d_model}), got {base_emb_saved.shape}." ) # ── 5. Load model ───────────────────────────────────────────────────────── print(f"[model] loading checkpoint: {args.checkpoint}") - model = load_model(args.checkpoint, n_blocks=B, d_in=d_in) + model = load_model(args.checkpoint, n_blocks=B, d_in=d_in, + d_model=d_model, n_heads=n_heads, n_layers=n_layers, + d_ff=d_ff, dropout=0.0, block_dims=block_dims) n_params = sum(p.numel() for p in model.parameters()) - print(f" AttentionAggregator: n_blocks={B}, d_in={d_in}, d_model=64, " - f"n_heads=4, n_layers=2 — {n_params:,} params") + print(f" AttentionAggregator: n_blocks={B}, d_in={d_in}, d_model={d_model}, " + f"n_heads={n_heads}, n_layers={n_layers} — {n_params:,} params") # ── 6. Verify baseline (model vs saved) ─────────────────────────────────── print("[check] verifying baseline forward pass ...") @@ -862,9 +950,10 @@ def main(): # ── 7. Fit PCA on saved embeddings (defines the reference axes) ─────────── print("[PCA] fitting PCA on saved embeddings ...") - scaler, pca, base_scores = fit_pca(base_emb_saved, n_components=2) + scaler, pca, base_scores = fit_pca(base_emb_saved, n_components=5) var_exp = pca.explained_variance_ratio_ - print(f" PC1={var_exp[0]*100:.1f}% PC2={var_exp[1]*100:.1f}% of embedding variance") + print(" variance explained: " + + " ".join(f"PC{i+1}={v*100:.1f}%" for i, v in enumerate(var_exp))) # ── 8. Phenotype loading + ridge ────────────────────────────────────────── do_pheno = not args.skip_phenotype @@ -883,7 +972,7 @@ def main(): do_pheno = False # ── 9. Leave-one-block-out ──────────────────────────────────────────────── - res_PC1, res_PC2, res_ige = run_lobo( + res_PC1, res_PC2, res_PC3, res_PC4, res_PC5, res_ige = run_lobo( model=model, x_np=x_np, base_emb=base_emb_saved, @@ -904,11 +993,20 @@ def main(): # ── 10. Build and save result DataFrames ────────────────────────────────── df_PC1 = finalise_df(res_PC1, block_meta, method="LOBO_checkpoint_mean_mask") df_PC2 = finalise_df(res_PC2, block_meta, method="LOBO_checkpoint_mean_mask") + df_PC3 = finalise_df(res_PC3, block_meta, method="LOBO_checkpoint_mean_mask") + df_PC4 = finalise_df(res_PC4, block_meta, method="LOBO_checkpoint_mean_mask") + df_PC5 = finalise_df(res_PC5, block_meta, method="LOBO_checkpoint_mean_mask") df_PC1.to_csv(out_dir / "phase2_PC1_leave_one_block_out.csv", index=False) df_PC2.to_csv(out_dir / "phase2_PC2_leave_one_block_out.csv", index=False) + df_PC3.to_csv(out_dir / "phase2_PC3_leave_one_block_out.csv", index=False) + df_PC4.to_csv(out_dir / "phase2_PC4_leave_one_block_out.csv", index=False) + df_PC5.to_csv(out_dir / "phase2_PC5_leave_one_block_out.csv", index=False) print(f"\n[output] phase2_PC1_leave_one_block_out.csv ({len(df_PC1)} blocks)") print(f"[output] phase2_PC2_leave_one_block_out.csv ({len(df_PC2)} blocks)") + print(f"[output] phase2_PC3_leave_one_block_out.csv ({len(df_PC3)} blocks)") + print(f"[output] phase2_PC4_leave_one_block_out.csv ({len(df_PC4)} blocks)") + print(f"[output] phase2_PC5_leave_one_block_out.csv ({len(df_PC5)} blocks)") if do_pheno and res_ige is not None: df_ige = finalise_df(res_ige, block_meta, method="LOBO_checkpoint_mean_mask") @@ -920,6 +1018,12 @@ def main(): out_dir / "top20_PC1_attribution.png", top_n=args.top_n) plot_top_blocks(df_PC2, "Phase 2 emb PC2", out_dir / "top20_PC2_attribution.png", top_n=args.top_n) + plot_top_blocks(df_PC3, "Phase 2 emb PC3", + out_dir / "top20_PC3_attribution.png", top_n=args.top_n) + plot_top_blocks(df_PC4, "Phase 2 emb PC4", + out_dir / "top20_PC4_attribution.png", top_n=args.top_n) + plot_top_blocks(df_PC5, "Phase 2 emb PC5", + out_dir / "top20_PC5_attribution.png", top_n=args.top_n) if do_pheno and res_ige is not None: plot_top_blocks(df_ige, "log10IgE ridge-pred", out_dir / "top20_log10Ige_attribution.png", top_n=args.top_n) @@ -933,15 +1037,16 @@ def main(): def flag_focus(bid): return "YES" if any(p in str(bid) for p in focus) else "-" - print("\n══ Top 10 blocks — PC1 ════════════════════════════════") - top10_PC1 = df_PC1.head(10)[["rank", "block_id", "mean_abs_delta_score", "mean_signed_delta_score"]].copy() - top10_PC1["focus"] = top10_PC1["block_id"].apply(flag_focus) - print(top10_PC1.to_string(index=False)) - - print("\n══ Top 10 blocks — PC2 ════════════════════════════════") - top10_PC2 = df_PC2.head(10)[["rank", "block_id", "mean_abs_delta_score", "mean_signed_delta_score"]].copy() - top10_PC2["focus"] = top10_PC2["block_id"].apply(flag_focus) - print(top10_PC2.to_string(index=False)) + for pc_label, df_pc in [ + ("PC1", df_PC1), ("PC2", df_PC2), + ("PC3", df_PC3), ("PC4", df_PC4), ("PC5", df_PC5), + ]: + print(f"\n══ Top 10 blocks — {pc_label} ════════════════════════════════") + top10 = df_pc.head(10)[ + ["rank", "block_id", "mean_abs_delta_score", "mean_signed_delta_score"] + ].copy() + top10["focus"] = top10["block_id"].apply(flag_focus) + print(top10.to_string(index=False)) if do_pheno and res_ige is not None: print("\n══ Top 10 blocks — log10IgE (ridge) ══════════════════") diff --git a/scripts/analysis/attribution/09_phase2_block_attribution_updated.py b/scripts/analysis/attribution/09_phase2_block_attribution_updated.py new file mode 100644 index 0000000..0b9f55e --- /dev/null +++ b/scripts/analysis/attribution/09_phase2_block_attribution_updated.py @@ -0,0 +1,1320 @@ +#!/usr/bin/env python3 +""" +09_phase2_block_attribution.py + +Post-hoc Phase 2 block-level attribution via leave-one-block-out (LOBO) masking. + +Requires torch (install via the 'vae' conda environment or any env with PyTorch). +Run as: + KMP_DUPLICATE_LIB_OK=TRUE python scripts/analysis/09_phase2_block_attribution.py + +What this does +-------------- +1. Loads the saved Phase 2 AttentionAggregator checkpoint. +2. Passes the Phase 1 block embeddings (N=997, B=174, d_in=16) through the frozen + model to confirm baseline subject embeddings. +3. Fits PCA on the *saved* individual_embeddings.npy to define PC1/PC2 axes exactly + as produced by the original run. +4. For each of the 174 blocks, masks that block by replacing its values with the + cross-subject population mean for that block (mean masking, not zero masking, + to minimise distributional shift to the model). +5. Runs a fresh forward pass and projects new embeddings onto the original PC1/PC2 + axes. Computes mean absolute delta and mean signed delta per block. +6. Optionally attributes log10Ige: merges phenotype, fits a RidgeCV on the baseline + embedding, and computes delta in ridge-predicted score under each block mask. + +Attribution targets +------------------- +- Phase 2 embedding PC1 +- Phase 2 embedding PC2 +- log10Ige ridge-prediction score (if subject ID alignment is confirmed) + +Outputs (results/analysis/phase2_block_attribution/) +-------- +phase2_PC1_leave_one_block_out.csv +phase2_PC2_leave_one_block_out.csv +phase2_log10Ige_leave_one_block_out.csv (only if phenotype attribution is valid) +top20_PC1_attribution.png +top20_PC2_attribution.png +top20_log10Ige_attribution.png (only if valid) +README_phase2_block_attribution.md + +Limitations +----------- +- Masking affects downstream transformer contextualization in the original model. + A "perfect" attribution would re-run the model from Phase 1 genotype inputs + with block b completely ablated. This script goes one level up: it masks the + Phase 1 latent embedding for block b before feeding into the Phase 2 transformer. + Residual information about block b may persist in positional embeddings and in + the transformer's learned cross-block interactions. +- Attribution reflects embedding geometry change, not causal importance. +- Do not interpret high attribution as evidence that a block causes the phenotype. +""" + +import os +import sys +import warnings +import argparse +from pathlib import Path + +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=RuntimeWarning) + +import yaml +import numpy as np +import pandas as pd + +try: + import torch + import torch.nn as nn + import torch.nn.functional as F +except ImportError: + sys.exit( + "ERROR: torch not found.\n" + "Run in the vae conda environment:\n" + " KMP_DUPLICATE_LIB_OK=TRUE /path/to/vae/bin/python " + "scripts/analysis/09_phase2_block_attribution.py" + ) + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +from sklearn.decomposition import PCA +from sklearn.linear_model import RidgeCV +from sklearn.preprocessing import StandardScaler + +# ── add repo root to sys.path so we can import AttentionAggregator ──────────── +def find_repo_root(start: Path) -> Path: + """Walk upward until repo root is found.""" + for p in [start] + list(start.parents): + if (p / "scripts" / "core" / "attention_phase2.py").exists(): + return p + raise RuntimeError( + "Could not find repo root containing scripts/core/attention_phase2.py" + ) + + +_REPO_ROOT = find_repo_root(Path(__file__).resolve()) +sys.path.insert(0, str(_REPO_ROOT)) + +from scripts.core.attention_phase2 import AttentionAggregator + +# ── default paths (relative to repo root) ───────────────────────────────────── +P1_EMBEDDINGS = "results/output_regions/ORD_W_Scaled/embeddings/all_blocks.npy" +P1_LATENT_DIMS = "results/output_regions/ORD_W_Scaled/embeddings/all_blocks_latent_dims.npy" +P2_CONFIG = "configs/config_phase2.yaml" +P2_CHECKPOINT = "results/output_regions2/ORD_W_Scaled/models/attention_aggregator.pt" +P2_EMBEDDINGS = "results/output_regions2/ORD_W_Scaled/embeddings/individual_embeddings.npy" +P2_EMB_CSV = "results/output_regions2/ORD_W_Scaled/embeddings/individual_embeddings.csv" +BLOCK_ORDER = "results/output_regions/block_order.csv" +SUBJECTS_CSV = "results/output_regions/subjects.csv" +PHENO_FILE = "metadata/COS_TRIO_pheno_1165.csv" +DEFAULT_OUT = "results/analysis/phase2_block_attribution" + +BATCH_SIZE = 256 +PHENO_TARGET = "log10Ige" + +# ── no-HLA run paths ────────────────────────────────────────────────────────── +P2_CHECKPOINT_NOHLA = "results/output_regions2_noHLA/ORD_W_Scaled/models/attention_aggregator.pt" +P2_EMBEDDINGS_NOHLA = "results/output_regions2_noHLA/ORD_W_Scaled/embeddings/individual_embeddings.npy" +P2_EMB_CSV_NOHLA = "results/output_regions2_noHLA/ORD_W_Scaled/embeddings/individual_embeddings.csv" +HLA_PATTERN = "HLA_classII" + +# Full-run CSVs needed when building the comparison table +FULL_PC1_CSV = "results/analysis/phase2_block_attribution/phase2_PC1_leave_one_block_out.csv" +FULL_PC2_CSV = "results/analysis/phase2_block_attribution/phase2_PC2_leave_one_block_out.csv" + + +# ── config and block-dim helpers ────────────────────────────────────────────── +def load_phase2_config(config_path: str) -> dict: + """Return the [attention] section of config_phase2.yaml, or {} if not found.""" + p = Path(config_path) + if not p.exists(): + warnings.warn(f"Phase 2 config not found: {p} — using hardcoded defaults") + return {} + with open(p) as f: + cfg = yaml.safe_load(f) + return cfg.get("attention", {}) + + +def load_block_dims(latent_dims_path: str, block_meta: pd.DataFrame): + """Return per-block latent dim list, or None for legacy shared projection. + + Tries all_blocks_latent_dims.npy first (authoritative actual dims after + clamping), then falls back to block_order.csv['latent_dim']. + """ + p = Path(latent_dims_path) + if p.exists(): + dims = np.load(p).astype(int).tolist() + counts = dict(pd.Series(dims).value_counts().sort_index()) + print(f"[model] using BlockProjector with per-block dims: {counts}") + return dims + if "latent_dim" in block_meta.columns: + dims = block_meta["latent_dim"].astype(int).tolist() + counts = dict(pd.Series(dims).value_counts().sort_index()) + print(f"[model] using BlockProjector (from block_order.csv) with per-block dims: {counts}") + return dims + print("[model] using legacy shared input projection") + return None + + +# ── model loader ────────────────────────────────────────────────────────────── +def load_model(ckpt_path: str, n_blocks: int, d_in: int, + d_model: int = 64, n_heads: int = 4, + n_layers: int = 2, d_ff: int = 128, + dropout: float = 0.0, + block_dims: list = None, + n_pool_tokens: int = 1) -> AttentionAggregator: + """Load AttentionAggregator from checkpoint. + + Pass block_dims to instantiate with BlockProjector (heterogeneous per-block + latent dims, checkpoint keys input_proj.projectors.*). Omit for legacy + shared projection (checkpoint keys input_proj.0.weight). + dropout=0.0 at inference — no randomness during attribution. + """ + model = AttentionAggregator( + n_blocks=n_blocks, d_in=d_in, d_model=d_model, + n_heads=n_heads, n_layers=n_layers, d_ff=d_ff, + dropout=dropout, + block_dims=block_dims, + n_pool_tokens=n_pool_tokens, + ) + state = torch.load(ckpt_path, map_location="cpu", weights_only=True) + model.load_state_dict(state) + model.eval() + for p in model.parameters(): + p.requires_grad_(False) + return model + + +# ── batched inference ───────────────────────────────────────────────────────── +@torch.no_grad() +def encode_batched(model: AttentionAggregator, x_np: np.ndarray, + batch_size: int = BATCH_SIZE) -> np.ndarray: + """ + Run model.encode on x_np (N, B, d_in) in mini-batches. + Returns subject embeddings (N, emb_dim), where emb_dim = n_pool_tokens * d_model. + """ + N = x_np.shape[0] + results = [] + for start in range(0, N, batch_size): + batch = torch.tensor(x_np[start:start + batch_size], dtype=torch.float32) + emb, _, _ = model.encode(batch, return_self_attn=False) + results.append(emb.numpy()) + return np.concatenate(results, axis=0) + + +# ── PC projection ───────────────────────────────────────────────────────────── +def fit_pca(embeddings_np: np.ndarray, n_components: int = 2): + """ + Fit StandardScaler + PCA on embeddings. Returns (scaler, pca, scores). + """ + scaler = StandardScaler() + z = scaler.fit_transform(embeddings_np) + pca = PCA(n_components=n_components, random_state=42) + scores = pca.fit_transform(z) + return scaler, pca, scores + + +def project_embeddings(embeddings_np: np.ndarray, + scaler: StandardScaler, pca: PCA) -> np.ndarray: + """Project embeddings onto pre-fitted PCA axes. Returns (N, n_components).""" + z = scaler.transform(embeddings_np) + return pca.transform(z) + + +# ── phenotype loader ────────────────────────────────────────────────────────── +def load_phenotype(pheno_path: str, subject_iids: np.ndarray, + target: str = PHENO_TARGET): + """ + Merge phenotype file on subject IIDs. + + Returns + ------- + pheno_mask : bool array (N,) — subjects with non-NA phenotype + y : float array (N_pheno,) — phenotype values + """ + pheno = pd.read_csv(pheno_path, low_memory=False) + if "S_SUBJECTID" not in pheno.columns: + raise ValueError("Expected 'S_SUBJECTID' in phenotype file.") + pheno["IID"] = pheno["S_SUBJECTID"].astype(str).str.strip() + + if target not in pheno.columns: + raise ValueError(f"Target column '{target}' not in phenotype file.") + + pheno_sub = pheno[["IID", target]].dropna(subset=[target]) + iid_to_pheno = dict(zip(pheno_sub["IID"], pheno_sub[target])) + + pheno_vals = np.array([iid_to_pheno.get(iid, np.nan) for iid in subject_iids]) + pheno_mask = np.isfinite(pheno_vals) + + n_with = pheno_mask.sum() + print(f"[phenotype] {target}: {n_with}/{len(subject_iids)} subjects have data") + if n_with < 50: + warnings.warn(f"Only {n_with} subjects with {target} — phenotype attribution unreliable.") + + return pheno_mask, pheno_vals[pheno_mask] + + +# ── ridge regression for phenotype ──────────────────────────────────────────── +def fit_ridge(base_emb: np.ndarray, y: np.ndarray) -> RidgeCV: + """ + Fit RidgeCV on baseline embeddings for phenotype attribution. + Uses 5-fold CV over log-spaced alphas. + """ + alphas = np.logspace(-3, 4, 50) + scaler = StandardScaler() + X = scaler.fit_transform(base_emb) + ridge = RidgeCV(alphas=alphas, scoring="r2", cv=5) + ridge.fit(X, y) + # attach scaler so we can use it later + ridge._feature_scaler = scaler + r2_cv = ridge.best_score_ + print(f"[ridge] {PHENO_TARGET} baseline: best α={ridge.alpha_:.4g}, CV R²={r2_cv:.3f}") + return ridge + + +def ridge_predict(ridge: RidgeCV, emb: np.ndarray) -> np.ndarray: + X = ridge._feature_scaler.transform(emb) + return ridge.predict(X) + + +# ── leave-one-block-out loop ────────────────────────────────────────────────── +def run_lobo( + model: AttentionAggregator, + x_np: np.ndarray, + base_emb: np.ndarray, + scaler: StandardScaler, + pca: PCA, + base_scores: np.ndarray, + block_names: list, + batch_size: int = BATCH_SIZE, + include_block_indices=None, + ridge_PC1=None, + ridge_PC2=None, + pheno_mask=None, + pheno_y=None, + ridge_ige=None, +): + """ + Leave-one-block-out masking attribution. + + For each block b in include_block_indices (default: all blocks): + - Replace x[:, b, :] with the cross-subject population mean for block b. + - Run encode → masked subject embeddings. + - Project onto PCA axes, compute per-subject delta vs baseline. + + Parameters + ---------- + include_block_indices : list of int or None + Block indices to include in LOBO. None = all blocks. + Use to skip HLA blocks in a no-HLA run. + + Returns + ------- + results_PC1 : list of dicts + results_PC2 : list of dicts + results_PC3 : list of dicts + results_PC4 : list of dicts + results_PC5 : list of dicts + results_ige : list of dicts or None + """ + B = x_np.shape[1] + n_pc = base_scores.shape[1] # number of PCA components available + assert len(block_names) == B, f"block_names length {len(block_names)} != B={B}" + indices_to_run = list(range(B)) if include_block_indices is None else list(include_block_indices) + + # pre-compute per-block population means for mean-masking + block_means = x_np.mean(axis=0) # (B, d_in) + + do_ige = (ridge_ige is not None and pheno_mask is not None and pheno_y is not None) + base_ige_pred = ridge_predict(ridge_ige, base_emb[pheno_mask]) if do_ige else None + + results_PC1, results_PC2, results_PC3, results_PC4, results_PC5, results_ige = \ + [], [], [], [], [], [] + + print(f"\n[lobo] running {len(indices_to_run)}/{B} blocks " + f"(mean masking, batch_size={batch_size}, n_pc={n_pc}) ...") + for loop_idx, b in enumerate(indices_to_run): + x_masked = x_np.copy() + x_masked[:, b, :] = block_means[b] # replace block b with population mean + + masked_emb = encode_batched(model, x_masked, batch_size) # (N, emb_dim) + masked_scores = project_embeddings(masked_emb, scaler, pca) # (N, n_pc) + + for pc_idx, (results_list, label) in enumerate([ + (results_PC1, "emb_PC1"), + (results_PC2, "emb_PC2"), + (results_PC3, "emb_PC3"), + (results_PC4, "emb_PC4"), + (results_PC5, "emb_PC5"), + ]): + if pc_idx >= n_pc: + continue + delta = masked_scores[:, pc_idx] - base_scores[:, pc_idx] + results_list.append(dict( + block_index=b, + block_id=block_names[b], + target=label, + mean_abs_delta_score=float(np.abs(delta).mean()), + mean_signed_delta_score=float(delta.mean()), + n_subjects=len(delta), + )) + + if do_ige: + masked_ige_pred = ridge_predict(ridge_ige, masked_emb[pheno_mask]) + delta_ige = masked_ige_pred - base_ige_pred + results_ige.append(dict( + block_index=b, + block_id=block_names[b], + target="log10Ige_ridge_pred", + mean_abs_delta_score=float(np.abs(delta_ige).mean()), + mean_signed_delta_score=float(delta_ige.mean()), + n_subjects=int(pheno_mask.sum()), + )) + + if (loop_idx + 1) % 25 == 0 or loop_idx == len(indices_to_run) - 1: + n_run = len(indices_to_run) + print(f" block {loop_idx+1:3d}/{n_run} done") + + return results_PC1, results_PC2, results_PC3, results_PC4, results_PC5, \ + (results_ige if do_ige else None) + + +# ── output helpers ──────────────────────────────────────────────────────────── +def finalise_df(records: list, block_meta: pd.DataFrame, method: str) -> pd.DataFrame: + """Add rank, method, and block metadata columns.""" + df = pd.DataFrame(records) + df = df.sort_values("mean_abs_delta_score", ascending=False).reset_index(drop=True) + df["rank"] = df.index + 1 + df["method"] = method + df["notes"] = ( + "LOBO masking: block replaced by cross-subject population mean of Phase 1 " + "latent embedding; full Phase 2 forward pass re-run; checkpoint-based." + ) + # merge block metadata (pos, gene, n_snps) if available + meta_cols = [ + c for c in [ + "pos", "gene", "region_id", "region", "group", + "chr", "n_snps", "latent_dim" + ] + if c in block_meta.columns + ] + + if meta_cols: + df = df.merge(block_meta[["block_id"] + meta_cols], on="block_id", how="left") + + # If latent_dim was not in block_order.csv, infer it from padded Phase 1 dim later if needed. + df = add_normalized_attribution_columns(df) + + # Add normalized ranks + for col in [ + "mean_abs_delta_score", + "attr_per_snp", + "attr_per_sqrt_snp", + "attr_per_latent_dim", + "attr_per_sqrt_latent_dim", + "attr_per_snp_x_latent", + ]: + if col in df.columns: + rank_col = col + "_rank" + df[rank_col] = df[col].rank(ascending=False, method="min") + + return df + +def add_normalized_attribution_columns(df: pd.DataFrame) -> pd.DataFrame: + """ + Add size/capacity-normalized attribution metrics. + + Raw LOBO favors large/high-capacity blocks. These columns help separate: + - raw model dependence + - per-SNP impact + - per-latent-dimension impact + """ + df = df.copy() + + if "n_snps" in df.columns: + df["attr_per_snp"] = df["mean_abs_delta_score"] / df["n_snps"].replace(0, np.nan) + df["attr_per_sqrt_snp"] = df["mean_abs_delta_score"] / np.sqrt(df["n_snps"].replace(0, np.nan)) + + if "latent_dim" in df.columns: + df["attr_per_latent_dim"] = df["mean_abs_delta_score"] / df["latent_dim"].replace(0, np.nan) + df["attr_per_sqrt_latent_dim"] = df["mean_abs_delta_score"] / np.sqrt( + df["latent_dim"].replace(0, np.nan) + ) + + if "n_snps" in df.columns and "latent_dim" in df.columns: + df["attr_per_snp_x_latent"] = df["mean_abs_delta_score"] / ( + df["n_snps"].replace(0, np.nan) * df["latent_dim"].replace(0, np.nan) + ) + + return df + +def infer_region_id(block_id: str) -> str: + """ + Convert subblock IDs into region IDs. + Examples: + region_6p21_HLA_classII_sb4 -> region_6p21_HLA_classII + region_5q21_PDE4D_sb14 -> region_5q21_PDE4D + control_CFTR_sb3 -> control_CFTR + """ + s = str(block_id) + if "_sb" in s: + return s.rsplit("_sb", 1)[0] + return s + + +def write_region_summary(df: pd.DataFrame, out_path: Path, target_label: str): + """ + Summarize attribution at region level. + + sum_raw favors large regions. + mean/median/top3 are better for comparing regions with different subblock counts. + """ + d = df.copy() + + if "region_id" not in d.columns: + d["region_id"] = d["block_id"].apply(infer_region_id) + + agg_dict = { + "n_subblocks": ("block_id", "count"), + "sum_raw_attr": ("mean_abs_delta_score", "sum"), + "mean_raw_attr": ("mean_abs_delta_score", "mean"), + "median_raw_attr": ("mean_abs_delta_score", "median"), + "max_raw_attr": ("mean_abs_delta_score", "max"), + } + + for col in [ + "attr_per_snp", + "attr_per_sqrt_snp", + "attr_per_latent_dim", + "attr_per_sqrt_latent_dim", + ]: + if col in d.columns: + agg_dict[f"mean_{col}"] = (col, "mean") + agg_dict[f"max_{col}"] = (col, "max") + + region = ( + d.groupby("region_id") + .agg(**agg_dict) + .reset_index() + ) + + # Top-3 average raw attribution per region + top3 = ( + d.sort_values(["region_id", "mean_abs_delta_score"], ascending=[True, False]) + .groupby("region_id") + .head(3) + .groupby("region_id")["mean_abs_delta_score"] + .mean() + .reset_index(name="top3_mean_raw_attr") + ) + + region = region.merge(top3, on="region_id", how="left") + + # ranks + for col in [ + "sum_raw_attr", + "mean_raw_attr", + "median_raw_attr", + "max_raw_attr", + "top3_mean_raw_attr", + ]: + region[col + "_rank"] = region[col].rank(ascending=False, method="min") + + region = region.sort_values("mean_raw_attr", ascending=False) + region["target"] = target_label + region.to_csv(out_path, index=False) + print(f"[output] {out_path.name} ({len(region)} regions)") + return region + +def plot_top_blocks(df: pd.DataFrame, target_label: str, + out_path: Path, top_n: int = 20, + score_col: str = "mean_abs_delta_score", + focus_patterns=("HLA", "17q21", "IL1RL1", "FCER1A", "PDE4D")): + if score_col not in df.columns: + print(f"[plot] WARNING: {score_col} not found; using mean_abs_delta_score") + score_col = "mean_abs_delta_score" + + top = df.sort_values(score_col, ascending=False).head(top_n).copy() + colors = [] + for bid in top["block_id"]: + bid_s = str(bid) + if any(p in bid_s for p in focus_patterns): + colors.append("tomato") + else: + colors.append("steelblue") + + fig, ax = plt.subplots(figsize=(8, max(5, top_n * 0.35))) + y = np.arange(len(top)) + ax.barh(y, top[score_col].values, color=colors, alpha=0.85) + ax.set_yticks(y) + ax.set_yticklabels(top["block_id"].values, fontsize=8) + ax.invert_yaxis() + ax.set_xlabel(score_col) + ax.set_title(f"Top {top_n} blocks — {target_label}\n" + "(red = asthma-associated or control locus)") + + # signed delta as annotation + for i, (_, row) in enumerate(top.iterrows()): + signed = row["mean_signed_delta_score"] + ax.text(row[score_col] + 1e-4, i, + f" signed={signed:+.4f}", va="center", fontsize=7) + + plt.tight_layout() + fig.savefig(out_path, dpi=150) + plt.close(fig) + print(f"[plot] saved {out_path}") + + +def write_readme(out_dir: Path, method_note: str): + text = f"""# Phase 2 Block Attribution — README + +## What is this? + +This directory contains post-hoc **Phase 2 block-level attribution** results +for the block-based genotype embedding model. + +## What was done? + +Leave-one-block-out (LOBO) masking was applied to identify which of the 174 +genomic blocks changes the Phase 2 subject embedding axes (PC1, PC2) most +when that block's Phase 1 latent embedding is replaced by its cross-subject +population mean. + +## Method: Leave-One-Block-Out + +For each block b (b = 0 … 173): +1. The Phase 1 block embedding for block b across all 997 subjects is replaced + by the population mean of that block (mean masking). +2. The full Phase 2 AttentionAggregator forward pass is re-run on the masked + input (checkpoint-based — no retraining). +3. The new subject embeddings are projected onto the PC1/PC2 axes fitted on + the original (unmasked) embeddings. +4. Attribution = mean absolute change in PC score across subjects. +5. Mean signed change is also reported (direction of effect). + +{method_note} + +## Attribution Targets + +| File | Target | +|------|--------| +| phase2_PC1_leave_one_block_out.csv | Phase 2 embedding PC1 | +| phase2_PC2_leave_one_block_out.csv | Phase 2 embedding PC2 | +| phase2_PC3_leave_one_block_out.csv | Phase 2 embedding PC3 | +| phase2_PC4_leave_one_block_out.csv | Phase 2 embedding PC4 | +| phase2_PC5_leave_one_block_out.csv | Phase 2 embedding PC5 | +| phase2_log10Ige_leave_one_block_out.csv | Ridge-predicted log10IgE (if computed) | + +## What this is NOT + +- This is **not SNP-level attribution** — it operates at the block/region level. +- This is **not causal evidence** — masking a block reduces its information, but + a block may rank highly due to its correlation with other blocks or global + embedding geometry, not biological importance. +- This is **not retraining** — the Phase 2 model weights are frozen throughout. +- The phenotype target (log10IgE) uses a **ridge regression prediction score**, + not a direct phenotype-embedding association. + +## Why Leave-One-Block-Out? + +LOBO is simple, interpretable, and directly aligned with the block-based model +structure. It requires no gradient computation (post-hoc, frozen model) and +produces a single summary statistic per block. + +## Residual Limitations + +- Phase 1 embedding for a masked block is replaced with the population mean, + not zero. This avoids extreme distribution shift, but the positional embedding + added by the Transformer still encodes block position, so complete ablation + is not achieved. +- Blocks with correlated Phase 1 embeddings will show similar attribution + patterns. + +## Environment + +Requires PyTorch (`vae` conda environment). +Run: `KMP_DUPLICATE_LIB_OK=TRUE python scripts/analysis/09_phase2_block_attribution.py` +""" + (out_dir / "README_phase2_block_attribution.md").write_text(text) + print(f"[readme] saved README_phase2_block_attribution.md") + + +# ── comparison table ────────────────────────────────────────────────────────── +def build_comparison_table( + full_pc1_csv: str, full_pc2_csv: str, + df_noHLA_PC1: pd.DataFrame, df_noHLA_PC2: pd.DataFrame, + all_block_names: list, out_path: Path, +) -> pd.DataFrame: + """ + Join full-run and no-HLA LOBO results into a per-block rank comparison table. + + Columns + ------- + block_id, full_PC1_rank, full_PC2_rank, noHLA_PC1_rank, noHLA_PC2_rank, + rank_change_PC1 (noHLA - full, negative = rose in rank after HLA removal), + rank_change_PC2, is_PDE4D, is_HLA, is_17q21, is_IL1RL1, is_FCER1A + """ + df_full_PC1 = pd.read_csv(full_pc1_csv)[["block_id", "rank"]].rename( + columns={"rank": "full_PC1_rank"}) + df_full_PC2 = pd.read_csv(full_pc2_csv)[["block_id", "rank"]].rename( + columns={"rank": "full_PC2_rank"}) + npc1 = df_noHLA_PC1[["block_id", "rank"]].rename(columns={"rank": "noHLA_PC1_rank"}) + npc2 = df_noHLA_PC2[["block_id", "rank"]].rename(columns={"rank": "noHLA_PC2_rank"}) + + comp = (pd.DataFrame({"block_id": all_block_names}) + .merge(df_full_PC1, on="block_id", how="left") + .merge(df_full_PC2, on="block_id", how="left") + .merge(npc1, on="block_id", how="left") + .merge(npc2, on="block_id", how="left")) + + comp["rank_change_PC1"] = comp["noHLA_PC1_rank"] - comp["full_PC1_rank"] + comp["rank_change_PC2"] = comp["noHLA_PC2_rank"] - comp["full_PC2_rank"] + + for pat, col in [("PDE4D", "is_PDE4D"), ("HLA", "is_HLA"), + ("17q21", "is_17q21"), ("IL1RL1", "is_IL1RL1"), ("FCER1A", "is_FCER1A")]: + comp[col] = comp["block_id"].str.contains(pat, na=False) + + comp.to_csv(out_path, index=False) + print(f"[output] {out_path.name} ({len(comp)} blocks)") + return comp + + +def print_comparison_report( + comp: pd.DataFrame, + df_noHLA_PC1: pd.DataFrame, df_noHLA_PC2: pd.DataFrame, + n_hla: int, n_non_hla: int, +): + """Print the required post-run comparison report.""" + focus = ["HLA", "17q21", "IL1RL1", "FCER1A", "PDE4D"] + + def flag(bid): + return "YES" if any(p in str(bid) for p in focus) else "-" + + print("\n══ no-HLA LOBO summary ════════════════════════════════════") + print(f" HLA blocks excluded from LOBO : {n_hla}") + print(f" Non-HLA blocks in LOBO : {n_non_hla}") + + # Verify HLA truly absent from no-HLA results + hla_in_results = df_noHLA_PC1["block_id"].str.contains(HLA_PATTERN, na=False).sum() + if hla_in_results == 0: + print(" HLA blocks in no-HLA results : CONFIRMED ABSENT") + else: + print(f" WARNING: {hla_in_results} HLA block(s) found in no-HLA results — check logic") + + print("\n══ Top 10 no-HLA PC1 blocks ═══════════════════════════════") + top10_nPC1 = df_noHLA_PC1.head(10)[ + ["rank", "block_id", "mean_abs_delta_score", "mean_signed_delta_score"]].copy() + top10_nPC1["focus"] = top10_nPC1["block_id"].apply(flag) + print(top10_nPC1.to_string(index=False)) + + print("\n══ Top 10 no-HLA PC2 blocks ═══════════════════════════════") + top10_nPC2 = df_noHLA_PC2.head(10)[ + ["rank", "block_id", "mean_abs_delta_score", "mean_signed_delta_score"]].copy() + top10_nPC2["focus"] = top10_nPC2["block_id"].apply(flag) + print(top10_nPC2.to_string(index=False)) + + print("\n══ PDE4D rank comparison ═══════════════════════════════════") + pde4d_rows = comp[comp["is_PDE4D"]] + if pde4d_rows.empty: + print(" PDE4D: no matching blocks found in comparison table") + else: + for _, row in pde4d_rows.iterrows(): + full_r1 = int(row["full_PC1_rank"]) if pd.notna(row["full_PC1_rank"]) else None + full_r2 = int(row["full_PC2_rank"]) if pd.notna(row["full_PC2_rank"]) else None + noHLA_r1 = int(row["noHLA_PC1_rank"]) if pd.notna(row["noHLA_PC1_rank"]) else None + noHLA_r2 = int(row["noHLA_PC2_rank"]) if pd.notna(row["noHLA_PC2_rank"]) else None + chg1 = row["rank_change_PC1"] + chg2 = row["rank_change_PC2"] + print(f" {row['block_id']}") + print(f" PC1: full={full_r1} noHLA={noHLA_r1} " + f"change={int(chg1) if pd.notna(chg1) else 'n/a'}") + print(f" PC2: full={full_r2} noHLA={noHLA_r2} " + f"change={int(chg2) if pd.notna(chg2) else 'n/a'}") + + # Interpretation + pde4d_pc1_full = pde4d_rows["full_PC1_rank"].min() if not pde4d_rows.empty else None + pde4d_pc1_noHLA = pde4d_rows["noHLA_PC1_rank"].min() if not pde4d_rows.empty else None + pde4d_pc2_full = pde4d_rows["full_PC2_rank"].min() if not pde4d_rows.empty else None + pde4d_pc2_noHLA = pde4d_rows["noHLA_PC2_rank"].min() if not pde4d_rows.empty else None + + print("\n══ Interpretation ══════════════════════════════════════════") + if pde4d_pc1_noHLA is not None and pde4d_pc1_full is not None: + if pde4d_pc1_noHLA < pde4d_pc1_full: + print(f" PDE4D rose in PC1 after HLA removal " + f"(full rank {int(pde4d_pc1_full)} → no-HLA rank {int(pde4d_pc1_noHLA)}). " + "This is consistent with PDE4D being a secondary organizing axis " + "masked by HLA dominance in the full embedding space. " + "Do not interpret as evidence of causal importance.") + else: + print(f" PDE4D did not rise in PC1 after HLA removal " + f"(full rank {int(pde4d_pc1_full)}, no-HLA rank {int(pde4d_pc1_noHLA)}). " + "PDE4D does not clearly emerge as a dominant axis when HLA is removed.") + if pde4d_pc2_noHLA is not None and pde4d_pc2_full is not None: + if pde4d_pc2_noHLA < pde4d_pc2_full: + print(f" PDE4D rose in PC2 after HLA removal " + f"(full rank {int(pde4d_pc2_full)} → no-HLA rank {int(pde4d_pc2_noHLA)}).") + else: + print(f" PDE4D did not rise in PC2 after HLA removal " + f"(full rank {int(pde4d_pc2_full)}, no-HLA rank {int(pde4d_pc2_noHLA)}).") + +def print_region_diagnostic(region_df: pd.DataFrame, label: str): + focus_regions = ["HLA", "PDE4D", "17q21", "TNFSF", "SH2B3", "IL1RL1", "IL33", "FCER1A"] + + sub = region_df[ + region_df["region_id"].apply(lambda x: any(p in str(x) for p in focus_regions)) + ].copy() + + if sub.empty: + return + + cols = [ + "region_id", + "n_subblocks", + "sum_raw_attr", + "mean_raw_attr", + "max_raw_attr", + "top3_mean_raw_attr", + "sum_raw_attr_rank", + "mean_raw_attr_rank", + "top3_mean_raw_attr_rank", + ] + cols = [c for c in cols if c in sub.columns] + + print(f"\n══ Region-level attribution diagnostic — {label} ═════════════") + print(sub[cols].sort_values("mean_raw_attr_rank").to_string(index=False)) + +# ── validation helpers ──────────────────────────────────────────────────────── +def check_baseline_agreement(base_emb_model: np.ndarray, + base_emb_saved: np.ndarray, tol: float = 1e-4): + """ + Compare model-recomputed embeddings with saved embeddings. + Max absolute difference should be < tol; warn if larger. + """ + diff = np.abs(base_emb_model - base_emb_saved) + max_diff = diff.max() + mean_diff = diff.mean() + print(f"[check] baseline vs saved embeddings: max|Δ|={max_diff:.6f}, mean|Δ|={mean_diff:.6f}") + if max_diff > tol: + warnings.warn( + f"Baseline embeddings differ from saved by {max_diff:.4f} (> tol={tol}). " + "PC axes are fitted on saved embeddings, which is the reference. " + "This may indicate a version mismatch in the model or random seed drift." + ) + + +# ── no-HLA mode ────────────────────────────────────────────────────────────── +def run_noHLA_mode(args): + """ + Run leave-one-block-out attribution on the no-HLA Phase 2 model. + + The no-HLA Phase 2 model was trained on the full 174-block architecture + with HLA_classII block embeddings zeroed out before input. This function: + 1. Replicates that zero-masking on the Phase 1 embeddings. + 2. Verifies the no-HLA baseline matches saved no-HLA embeddings. + 3. Runs LOBO only on the non-HLA blocks (HLA blocks are always zero + and therefore not informative to attribute). + 4. Saves results and generates a full-vs-noHLA rank comparison table. + """ + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # ── 0. Load Phase 2 config (hyperparameters) ────────────────────────────── + print(f"[config] Phase 2 config: {args.phase2_config}") + attn_cfg = load_phase2_config(args.phase2_config) + d_model = attn_cfg.get("d_model", 64) + n_heads = attn_cfg.get("n_heads", 4) + n_layers = attn_cfg.get("n_layers", 2) + d_ff = attn_cfg.get("d_ff", 128) + n_pool_tokens = int(attn_cfg.get("n_pool_tokens", 1)) + if n_pool_tokens < 1: + raise ValueError("attention.n_pool_tokens must be >= 1") + emb_dim = n_pool_tokens * d_model + print(f" d_model={d_model}, n_heads={n_heads}, n_layers={n_layers}, d_ff={d_ff}, " + f"n_pool_tokens={n_pool_tokens}, emb_dim={emb_dim}") + + # ── 1. Load Phase 1 input embeddings ────────────────────────────────────── + print(f"\n[load] Phase 1 ORD embeddings: {args.p1_embeddings}") + x_np = np.load(args.p1_embeddings) + N, B, d_in = x_np.shape + print(f" shape: N={N}, B={B}, d_in={d_in}") + if B != 174: + raise ValueError(f"Expected B=174 blocks, got B={B}. Check {args.p1_embeddings}.") + + # ── 2. Load block order and per-block latent dims ───────────────────────── + print(f"[load] block_order: {args.block_order}") + block_meta = pd.read_csv(args.block_order) + if len(block_meta) != B: + raise ValueError( + f"block_order.csv has {len(block_meta)} rows but Phase 1 embeddings have {B} blocks." + ) + block_names = block_meta["block_id"].tolist() + + block_dims = load_block_dims(args.p1_latent_dims, block_meta) + if block_dims is not None: + block_meta = block_meta.copy() + block_meta["latent_dim"] = block_dims + + # ── 3. Identify HLA vs non-HLA block indices ────────────────────────────── + hla_indices = [i for i, name in enumerate(block_names) if HLA_PATTERN in name] + non_hla_indices = [i for i in range(B) if i not in set(hla_indices)] + print(f" HLA blocks: {len(hla_indices)} (indices {hla_indices[:3]}…)") + print(f" Non-HLA blocks for LOBO: {len(non_hla_indices)}") + if len(hla_indices) == 0: + raise RuntimeError( + f"No blocks matching '{HLA_PATTERN}' found in block_order.csv. " + "Cannot identify HLA blocks to zero out." + ) + + # ── 4. Zero-mask HLA blocks (replicates no-HLA training preprocessing) ──── + x_np_zeroed = x_np.copy() + x_np_zeroed[:, hla_indices, :] = 0.0 + print(f"[mask] zeroed {len(hla_indices)} HLA blocks in Phase 1 embeddings") + + # ── 5. Load subject IIDs ────────────────────────────────────────────────── + print(f"[load] subjects: {args.subjects}") + subjects_csv_df = pd.read_csv(args.subjects) + if "IID" not in subjects_csv_df.columns: + raise ValueError(f"Expected 'IID' column in {args.subjects}.") + subject_iids = subjects_csv_df["IID"].astype(str).values + if len(subject_iids) != N: + raise ValueError( + f"subjects.csv has {len(subject_iids)} rows but embeddings have N={N}." + ) + + # ── 6. Verify IID order vs no-HLA embedding CSV ─────────────────────────── + print(f"[load] no-HLA embeddings CSV: {args.noHLA_emb_csv}") + emb_csv = pd.read_csv(args.noHLA_emb_csv) + if "IID" not in emb_csv.columns: + raise ValueError(f"Expected 'IID' column in {args.noHLA_emb_csv}.") + emb_iids = emb_csv["IID"].astype(str).values + if len(emb_iids) != N: + raise ValueError( + f"no-HLA individual_embeddings.csv has {len(emb_iids)} rows but N={N}." + ) + if not (subject_iids == emb_iids).all(): + raise ValueError( + "subjects.csv and no-HLA individual_embeddings.csv IID order does not match. " + "Cannot safely align row indices." + ) + print(f" {N} subjects confirmed, IID order verified.") + + # ── 7. Load saved no-HLA embeddings ────────────────────────────────────── + print(f"[load] saved no-HLA Phase 2 embeddings: {args.noHLA_embeddings}") + base_emb_saved = np.load(args.noHLA_embeddings) + if base_emb_saved.shape != (N, emb_dim): + raise ValueError( + f"Expected shape ({N}, {emb_dim}), got {base_emb_saved.shape}. " + "Check attention.d_model and attention.n_pool_tokens in the Phase 2 config." + ) + + # ── 8. Load no-HLA model ────────────────────────────────────────────────── + print(f"[model] loading no-HLA checkpoint: {args.noHLA_checkpoint}") + model = load_model(args.noHLA_checkpoint, n_blocks=B, d_in=d_in, + d_model=d_model, n_heads=n_heads, n_layers=n_layers, + d_ff=d_ff, dropout=0.0, block_dims=block_dims, + n_pool_tokens=n_pool_tokens) + n_params = sum(p.numel() for p in model.parameters()) + print(f" AttentionAggregator (no-HLA): n_blocks={B}, d_in={d_in}, " + f"d_model={d_model}, n_heads={n_heads}, n_layers={n_layers}, " + f"n_pool_tokens={n_pool_tokens}, emb_dim={emb_dim} — {n_params:,} params") + + # ── 9. Verify baseline (zeroed input → no-HLA checkpoint vs saved embeddings) + print("[check] verifying no-HLA baseline forward pass ...") + base_emb_model = encode_batched(model, x_np_zeroed, args.batch_size) + check_baseline_agreement(base_emb_model, base_emb_saved) + + # ── 10. Fit PCA on saved no-HLA embeddings (defines reference axes) ─────── + print("[PCA] fitting PCA on saved no-HLA embeddings ...") + scaler, pca, base_scores = fit_pca(base_emb_saved, n_components=5) + var_exp = pca.explained_variance_ratio_ + print(" no-HLA variance explained: " + + " ".join(f"PC{i+1}={v*100:.1f}%" for i, v in enumerate(var_exp))) + + # ── 11. LOBO on non-HLA blocks (mean masking computed from zeroed input) ── + res_PC1, res_PC2, res_PC3, res_PC4, res_PC5, _ = run_lobo( + model=model, + x_np=x_np_zeroed, # HLA blocks permanently zeroed + base_emb=base_emb_saved, + scaler=scaler, + pca=pca, + base_scores=base_scores, + block_names=block_names, + batch_size=args.batch_size, + include_block_indices=non_hla_indices, # skip HLA blocks in LOBO + ) + + # ── 12. Build and save result DataFrames ────────────────────────────────── + df_noHLA_PC1 = finalise_df(res_PC1, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") + df_noHLA_PC2 = finalise_df(res_PC2, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") + df_noHLA_PC3 = finalise_df(res_PC3, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") + df_noHLA_PC4 = finalise_df(res_PC4, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") + df_noHLA_PC5 = finalise_df(res_PC5, block_meta, method="LOBO_noHLA_checkpoint_mean_mask") + + for pc_label, df_noHLA in [ + ("PC1", df_noHLA_PC1), ("PC2", df_noHLA_PC2), + ("PC3", df_noHLA_PC3), ("PC4", df_noHLA_PC4), ("PC5", df_noHLA_PC5), + ]: + out_path = out_dir / f"phase2_noHLA_{pc_label}_leave_one_block_out.csv" + df_noHLA.to_csv(out_path, index=False) + print(f"[output] {out_path.name} ({len(df_noHLA)} blocks)") + + # ── 13. Plots ───────────────────────────────────────────────────────────── + for pc_label, df_noHLA in [ + ("PC1", df_noHLA_PC1), ("PC2", df_noHLA_PC2), + ("PC3", df_noHLA_PC3), ("PC4", df_noHLA_PC4), ("PC5", df_noHLA_PC5), + ]: + plot_top_blocks(df_noHLA, f"no-HLA Phase 2 emb {pc_label}", + out_dir / f"top20_noHLA_{pc_label}_attribution.png", + top_n=args.top_n) + + # ── 14. Comparison table ────────────────────────────────────────────────── + comp_path = out_dir / "phase2_full_vs_noHLA_rank_comparison.csv" + full_pc1 = args.full_pc1_csv + full_pc2 = args.full_pc2_csv + if Path(full_pc1).exists() and Path(full_pc2).exists(): + comp = build_comparison_table( + full_pc1, full_pc2, + df_noHLA_PC1, df_noHLA_PC2, + block_names, comp_path, + ) + print_comparison_report( + comp, df_noHLA_PC1, df_noHLA_PC2, + n_hla=len(hla_indices), n_non_hla=len(non_hla_indices), + ) + else: + missing = [p for p in [full_pc1, full_pc2] if not Path(p).exists()] + print(f"\n[WARNING] Full-run LOBO CSVs not found; comparison table skipped.") + for p in missing: + print(f" missing: {p}") + print(" Run first with --mode full to generate the full-run attribution.") + + print(f"\n[done] All no-HLA outputs written to: {out_dir}/") + + +# ── main ────────────────────────────────────────────────────────────────────── +def parse_args(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--mode", choices=["full", "noHLA"], default="full", + help="'full' = full-run LOBO (default); 'noHLA' = no-HLA run + comparison") + + # shared paths + p.add_argument("--p1-embeddings", default=P1_EMBEDDINGS, + help="Phase 1 ORD all_blocks.npy (N, B, d_in)") + p.add_argument("--p1-latent-dims", default=P1_LATENT_DIMS, + help="all_blocks_latent_dims.npy — actual per-block latent dims " + "(used to build BlockProjector; falls back to block_order.csv)") + p.add_argument("--phase2-config", default=P2_CONFIG, + help="Phase 2 YAML config — reads attention hyperparameters " + "(d_model, n_heads, n_layers, d_ff, n_pool_tokens)") + p.add_argument("--block-order", default=BLOCK_ORDER, + help="block_order.csv") + p.add_argument("--subjects", default=SUBJECTS_CSV, + help="subjects.csv (IID order matching Phase 1 embeddings)") + p.add_argument("--out-dir", default=DEFAULT_OUT) + p.add_argument("--batch-size", type=int, default=BATCH_SIZE) + p.add_argument("--top-n", type=int, default=20, + help="Number of top blocks to plot") + p.add_argument( + "--rank-by", + default="mean_abs_delta_score", + choices=[ + "mean_abs_delta_score", + "attr_per_snp", + "attr_per_sqrt_snp", + "attr_per_latent_dim", + "attr_per_sqrt_latent_dim", + "attr_per_snp_x_latent", + ], + help="Column used for top-block plots. Default = raw LOBO attribution." + ) + + # full-mode paths + p.add_argument("--checkpoint", default=P2_CHECKPOINT, + help="Phase 2 attention_aggregator.pt (full run)") + p.add_argument("--p2-embeddings", default=P2_EMBEDDINGS, + help="Saved Phase 2 individual_embeddings.npy (full run)") + p.add_argument("--p2-emb-csv", default=P2_EMB_CSV, + help="individual_embeddings.csv (full run, for IID order)") + p.add_argument("--pheno", default=PHENO_FILE, + help="Phenotype CSV with S_SUBJECTID and log10Ige") + p.add_argument("--skip-phenotype", action="store_true", + help="Skip log10Ige phenotype attribution (full mode only)") + + # no-HLA mode paths + p.add_argument("--noHLA-checkpoint", default=P2_CHECKPOINT_NOHLA, + help="no-HLA Phase 2 attention_aggregator.pt") + p.add_argument("--noHLA-embeddings", default=P2_EMBEDDINGS_NOHLA, + help="Saved no-HLA individual_embeddings.npy") + p.add_argument("--noHLA-emb-csv", default=P2_EMB_CSV_NOHLA, + help="no-HLA individual_embeddings.csv (for IID order)") + p.add_argument("--full-pc1-csv", default=FULL_PC1_CSV, + help="Full-run PC1 LOBO CSV (for comparison table)") + p.add_argument("--full-pc2-csv", default=FULL_PC2_CSV, + help="Full-run PC2 LOBO CSV (for comparison table)") + + return p.parse_args() + + +def main(): + args = parse_args() + + if args.mode == "noHLA": + run_noHLA_mode(args) + return + + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # ── 0. Load Phase 2 config (hyperparameters) ────────────────────────────── + print(f"[config] Phase 2 config: {args.phase2_config}") + attn_cfg = load_phase2_config(args.phase2_config) + d_model = attn_cfg.get("d_model", 64) + n_heads = attn_cfg.get("n_heads", 4) + n_layers = attn_cfg.get("n_layers", 2) + d_ff = attn_cfg.get("d_ff", 128) + n_pool_tokens = int(attn_cfg.get("n_pool_tokens", 1)) + if n_pool_tokens < 1: + raise ValueError("attention.n_pool_tokens must be >= 1") + emb_dim = n_pool_tokens * d_model + print(f" d_model={d_model}, n_heads={n_heads}, n_layers={n_layers}, d_ff={d_ff}, " + f"n_pool_tokens={n_pool_tokens}, emb_dim={emb_dim}") + + # ── 1. Load Phase 1 input embeddings ────────────────────────────────────── + print(f"\n[load] Phase 1 ORD embeddings: {args.p1_embeddings}") + x_np = np.load(args.p1_embeddings) + N, B, d_in = x_np.shape + print(f" shape: N={N}, B={B}, d_in={d_in}") + if B != 174: + raise ValueError(f"Expected B=174 blocks, got B={B}. Check {args.p1_embeddings}.") + + # ── 2. Load block order and per-block latent dims ───────────────────────── + print(f"[load] block_order: {args.block_order}") + block_meta = pd.read_csv(args.block_order) + if len(block_meta) != B: + raise ValueError( + f"block_order.csv has {len(block_meta)} rows but Phase 1 embeddings have {B} blocks." + ) + block_names = block_meta["block_id"].tolist() + print(f" {B} blocks confirmed. Sample: {block_names[:3]}") + + block_dims = load_block_dims(args.p1_latent_dims, block_meta) + if block_dims is not None: + block_meta = block_meta.copy() + block_meta["latent_dim"] = block_dims + + # ── 3. Load subject IIDs ────────────────────────────────────────────────── + print(f"[load] subjects: {args.subjects}") + subjects_csv = pd.read_csv(args.subjects) + if "IID" not in subjects_csv.columns: + raise ValueError(f"Expected 'IID' column in {args.subjects}.") + subject_iids = subjects_csv["IID"].astype(str).values + if len(subject_iids) != N: + raise ValueError(f"subjects.csv has {len(subject_iids)} rows but embeddings have N={N}.") + + # Verify IID order matches embedding CSV + emb_csv = pd.read_csv(args.p2_emb_csv) + if "IID" not in emb_csv.columns: + raise ValueError(f"Expected 'IID' column in {args.p2_emb_csv}.") + emb_iids = emb_csv["IID"].astype(str).values + if len(emb_iids) != N: + raise ValueError( + f"individual_embeddings.csv has {len(emb_iids)} rows but N={N}." + ) + if not (subject_iids == emb_iids).all(): + raise ValueError( + "subjects.csv and individual_embeddings.csv IID order does not match. " + "Cannot safely align row indices." + ) + print(f" {N} subjects confirmed, IID order verified.") + + # ── 4. Load saved baseline embeddings ───────────────────────────────────── + print(f"[load] saved Phase 2 embeddings: {args.p2_embeddings}") + base_emb_saved = np.load(args.p2_embeddings) + if base_emb_saved.shape != (N, emb_dim): + raise ValueError( + f"Expected shape ({N}, {emb_dim}), got {base_emb_saved.shape}. " + "Check attention.d_model and attention.n_pool_tokens in the Phase 2 config." + ) + + # ── 5. Load model ───────────────────────────────────────────────────────── + print(f"[model] loading checkpoint: {args.checkpoint}") + model = load_model(args.checkpoint, n_blocks=B, d_in=d_in, + d_model=d_model, n_heads=n_heads, n_layers=n_layers, + d_ff=d_ff, dropout=0.0, block_dims=block_dims, + n_pool_tokens=n_pool_tokens) + n_params = sum(p.numel() for p in model.parameters()) + print(f" AttentionAggregator: n_blocks={B}, d_in={d_in}, d_model={d_model}, " + f"n_heads={n_heads}, n_layers={n_layers}, " + f"n_pool_tokens={n_pool_tokens}, emb_dim={emb_dim} — {n_params:,} params") + + # ── 6. Verify baseline (model vs saved) ─────────────────────────────────── + print("[check] verifying baseline forward pass ...") + base_emb_model = encode_batched(model, x_np, args.batch_size) + check_baseline_agreement(base_emb_model, base_emb_saved) + + # ── 7. Fit PCA on saved embeddings (defines the reference axes) ─────────── + print("[PCA] fitting PCA on saved embeddings ...") + scaler, pca, base_scores = fit_pca(base_emb_saved, n_components=5) + var_exp = pca.explained_variance_ratio_ + print(" variance explained: " + + " ".join(f"PC{i+1}={v*100:.1f}%" for i, v in enumerate(var_exp))) + + # ── 8. Phenotype loading + ridge ────────────────────────────────────────── + do_pheno = not args.skip_phenotype + pheno_mask = pheno_y = ridge_ige = None + + if do_pheno: + try: + pheno_mask, pheno_y = load_phenotype(args.pheno, subject_iids) + if pheno_mask.sum() >= 50: + ridge_ige = fit_ridge(base_emb_saved[pheno_mask], pheno_y) + else: + print(f"[phenotype] too few subjects ({pheno_mask.sum()}) — skipping") + do_pheno = False + except Exception as e: + print(f"[phenotype] WARNING: {e} — skipping phenotype attribution") + do_pheno = False + + # ── 9. Leave-one-block-out ──────────────────────────────────────────────── + res_PC1, res_PC2, res_PC3, res_PC4, res_PC5, res_ige = run_lobo( + model=model, + x_np=x_np, + base_emb=base_emb_saved, + scaler=scaler, + pca=pca, + base_scores=base_scores, + block_names=block_names, + batch_size=args.batch_size, + ridge_ige=ridge_ige if do_pheno else None, + pheno_mask=pheno_mask, + pheno_y=pheno_y, + ) + + method_str = ( + "Method: checkpoint-based forward pass; mean masking (population mean per block)." + ) + + # ── 10. Build and save result DataFrames ────────────────────────────────── + df_PC1 = finalise_df(res_PC1, block_meta, method="LOBO_checkpoint_mean_mask") + df_PC2 = finalise_df(res_PC2, block_meta, method="LOBO_checkpoint_mean_mask") + df_PC3 = finalise_df(res_PC3, block_meta, method="LOBO_checkpoint_mean_mask") + df_PC4 = finalise_df(res_PC4, block_meta, method="LOBO_checkpoint_mean_mask") + df_PC5 = finalise_df(res_PC5, block_meta, method="LOBO_checkpoint_mean_mask") + + df_PC1.to_csv(out_dir / "phase2_PC1_leave_one_block_out.csv", index=False) + df_PC2.to_csv(out_dir / "phase2_PC2_leave_one_block_out.csv", index=False) + df_PC3.to_csv(out_dir / "phase2_PC3_leave_one_block_out.csv", index=False) + df_PC4.to_csv(out_dir / "phase2_PC4_leave_one_block_out.csv", index=False) + df_PC5.to_csv(out_dir / "phase2_PC5_leave_one_block_out.csv", index=False) + print(f"\n[output] phase2_PC1_leave_one_block_out.csv ({len(df_PC1)} blocks)") + print(f"[output] phase2_PC2_leave_one_block_out.csv ({len(df_PC2)} blocks)") + print(f"[output] phase2_PC3_leave_one_block_out.csv ({len(df_PC3)} blocks)") + print(f"[output] phase2_PC4_leave_one_block_out.csv ({len(df_PC4)} blocks)") + print(f"[output] phase2_PC5_leave_one_block_out.csv ({len(df_PC5)} blocks)") + + # ── 10b. Optional phenotype attribution DataFrame ───────────────────────── + df_ige = None + if do_pheno and res_ige is not None: + df_ige = finalise_df( + res_ige, + block_meta, + method="LOBO_checkpoint_mean_mask" + ) + df_ige.to_csv( + out_dir / "phase2_log10Ige_leave_one_block_out.csv", + index=False + ) + print(f"[output] phase2_log10Ige_leave_one_block_out.csv ({len(df_ige)} blocks)") + + # ── 10c. Region-level summaries ─────────────────────────────────────────── + region_summaries = {} + + for pc_label, df_pc in [ + ("PC1", df_PC1), ("PC2", df_PC2), + ("PC3", df_PC3), ("PC4", df_PC4), ("PC5", df_PC5), + ]: + region_summaries[pc_label] = write_region_summary( + df_pc, + out_dir / f"region_summary_{pc_label}_attribution.csv", + target_label=pc_label, + ) + + if df_ige is not None: + region_summaries["log10Ige"] = write_region_summary( + df_ige, + out_dir / "region_summary_log10Ige_attribution.csv", + target_label="log10Ige", + ) + + # Print region diagnostics + for label, region_df in region_summaries.items(): + print_region_diagnostic(region_df, label) + + if do_pheno and res_ige is not None: + df_ige = finalise_df(res_ige, block_meta, method="LOBO_checkpoint_mean_mask") + df_ige.to_csv(out_dir / "phase2_log10Ige_leave_one_block_out.csv", index=False) + print(f"[output] phase2_log10Ige_leave_one_block_out.csv ({len(df_ige)} blocks)") + + # ── 11. Plots ───────────────────────────────────────────────────────────── + plot_top_blocks( + df_PC1, + "Phase 2 emb PC1", + out_dir / f"top20_PC1_attribution_by_{args.rank_by}.png", + top_n=args.top_n, + score_col=args.rank_by, + ) + plot_top_blocks( + df_PC2, + "Phase 2 emb PC2", + out_dir / f"top20_PC2_attribution_by_{args.rank_by}.png", + top_n=args.top_n, + score_col=args.rank_by, + ) + # plot_top_blocks(df_PC2, "Phase 2 emb PC2", + # out_dir / "top20_PC2_attribution.png", top_n=args.top_n) + plot_top_blocks(df_PC3, "Phase 2 emb PC3", + out_dir / "top20_PC3_attribution.png", top_n=args.top_n) + plot_top_blocks(df_PC4, "Phase 2 emb PC4", + out_dir / "top20_PC4_attribution.png", top_n=args.top_n) + plot_top_blocks(df_PC5, "Phase 2 emb PC5", + out_dir / "top20_PC5_attribution.png", top_n=args.top_n) + if do_pheno and res_ige is not None: + plot_top_blocks(df_ige, "log10IgE ridge-pred", + out_dir / "top20_log10Ige_attribution.png", top_n=args.top_n) + + # ── 12. README ──────────────────────────────────────────────────────────── + write_readme(out_dir, method_str) + + # ── 13. Print top-10 summary ────────────────────────────────────────────── + focus = ["HLA", "17q21", "IL1RL1", "FCER1A", "PDE4D"] + + def flag_focus(bid): + return "YES" if any(p in str(bid) for p in focus) else "-" + + for pc_label, df_pc in [ + ("PC1", df_PC1), ("PC2", df_PC2), + ("PC3", df_PC3), ("PC4", df_PC4), ("PC5", df_PC5), + ]: + print(f"\n══ Top 10 blocks — {pc_label} ════════════════════════════════") + top10 = df_pc.head(10)[ + ["rank", "block_id", "mean_abs_delta_score", "mean_signed_delta_score"] + ].copy() + top10["focus"] = top10["block_id"].apply(flag_focus) + print(top10.to_string(index=False)) + + if do_pheno and res_ige is not None: + print("\n══ Top 10 blocks — log10IgE (ridge) ══════════════════") + top10_ige = df_ige.head(10)[["rank", "block_id", "mean_abs_delta_score", "mean_signed_delta_score"]].copy() + top10_ige["focus"] = top10_ige["block_id"].apply(flag_focus) + print(top10_ige.to_string(index=False)) + else: + print("\n[phenotype] log10Ige attribution was skipped.") + + print(f"\n[done] All outputs written to: {out_dir}/") + + +if __name__ == "__main__": + main() diff --git a/scripts/analysis/10_phase1_snp_attribution_within_blocks.py b/scripts/analysis/attribution/10_phase1_snp_attribution_within_blocks.py similarity index 100% rename from scripts/analysis/10_phase1_snp_attribution_within_blocks.py rename to scripts/analysis/attribution/10_phase1_snp_attribution_within_blocks.py diff --git a/scripts/analysis/attribution/10_rank_shift_scatter.py b/scripts/analysis/attribution/10_rank_shift_scatter.py new file mode 100644 index 0000000..4c6922d --- /dev/null +++ b/scripts/analysis/attribution/10_rank_shift_scatter.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +""" +10_rank_shift_scatter.py + +Rank-shift scatter plot: full Phase 2 model vs no-HLA Phase 2 model. + +For each genomic block (non-HLA only) plots: + x = rank in full Phase 2 PC1 attribution + y = rank in no-HLA Phase 2 PC1 attribution + +Blocks that rise in importance after HLA removal appear in the upper-right +region (high full rank, low no-HLA rank). The diagonal y = x marks no change. + +Outputs +------- +results/analysis/phase2_block_attribution/phase2_PC1_rank_shift_scatter.png +results/analysis/phase2_block_attribution/phase2_PC2_rank_shift_scatter.png + +Run +--- + python scripts/analysis/10_rank_shift_scatter.py +""" + +import sys +import warnings +from pathlib import Path + +warnings.filterwarnings("ignore", category=FutureWarning) + +import numpy as np +import pandas as pd +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + +# ── paths ───────────────────────────────────────────────────────────────────── +COMPARISON_CSV = "results/analysis/phase2_block_attribution/phase2_full_vs_noHLA_rank_comparison.csv" +FULL_PC1_CSV = "results/analysis/phase2_block_attribution/phase2_PC1_leave_one_block_out.csv" +FULL_PC2_CSV = "results/analysis/phase2_block_attribution/phase2_PC2_leave_one_block_out.csv" +NOHA_PC1_CSV = "results/analysis/phase2_block_attribution/phase2_noHLA_PC1_leave_one_block_out.csv" +NOHA_PC2_CSV = "results/analysis/phase2_block_attribution/phase2_noHLA_PC2_leave_one_block_out.csv" +OUT_DIR = "results/analysis/phase2_block_attribution" + + +# ── data loading ────────────────────────────────────────────────────────────── +def load_comparison() -> pd.DataFrame: + """ + Load the pre-built comparison table if available; otherwise build it from + the individual attribution CSVs. Only non-HLA blocks (those with finite + noHLA ranks) are returned. + """ + comp_path = Path(COMPARISON_CSV) + if comp_path.exists(): + df = pd.read_csv(comp_path) + # Derive is_SH2B3 flag if not already present + if "is_SH2B3" not in df.columns: + df["is_SH2B3"] = df["block_id"].str.contains("SH2B3", na=False) + return df + else: + # Fallback: merge from individual CSVs on block_id + print("[warn] comparison CSV not found; building from individual CSVs") + f_pc1 = pd.read_csv(FULL_PC1_CSV)[["block_id", "rank"]].rename( + columns={"rank": "full_PC1_rank"}) + f_pc2 = pd.read_csv(FULL_PC2_CSV)[["block_id", "rank"]].rename( + columns={"rank": "full_PC2_rank"}) + n_pc1 = pd.read_csv(NOHA_PC1_CSV)[["block_id", "rank"]].rename( + columns={"rank": "noHLA_PC1_rank"}) + n_pc2 = pd.read_csv(NOHA_PC2_CSV)[["block_id", "rank"]].rename( + columns={"rank": "noHLA_PC2_rank"}) + + df = (f_pc1.merge(f_pc2, on="block_id", how="outer") + .merge(n_pc1, on="block_id", how="outer") + .merge(n_pc2, on="block_id", how="outer")) + + for pat, col in [("PDE4D", "is_PDE4D"), ("HLA", "is_HLA"), + ("17q21", "is_17q21"), ("IL1RL1", "is_IL1RL1"), + ("FCER1A", "is_FCER1A"), ("SH2B3", "is_SH2B3")]: + df[col] = df["block_id"].str.contains(pat, na=False) + + df["rank_change_PC1"] = df["noHLA_PC1_rank"] - df["full_PC1_rank"] + df["rank_change_PC2"] = df["noHLA_PC2_rank"] - df["full_PC2_rank"] + return df + + +# ── plotting ────────────────────────────────────────────────────────────────── +def _assign_color(row) -> str: + if row.get("is_PDE4D", False): + return "tomato" + if row.get("is_IL1RL1", False): + return "darkorange" + if row.get("is_SH2B3", False): + return "mediumpurple" + if row.get("is_17q21", False): + return "forestgreen" + return "steelblue" + + +def _assign_size(row) -> float: + # Slightly larger dot for annotated loci + if any(row.get(k, False) for k in + ["is_PDE4D", "is_IL1RL1", "is_SH2B3", "is_17q21"]): + return 28 + return 14 + + +def _short_label(block_id: str) -> str: + """Return a compact readable label from block_id.""" + # e.g. region_5q21_PDE4D_sb33 → PDE4D_sb33 + parts = block_id.split("_") + # find locus token (contains letters + digits), skip 'region', 'control', 'chr', coords + skip = {"region", "control", "cluster", "core"} + label_parts = [p for p in parts if p.lower() not in skip + and not p[:2].isdigit() + and not (len(p) > 1 and p[0].isdigit() and p[-1].isalpha() + and len(p) <= 5)] + return "_".join(label_parts) if label_parts else block_id + + +def make_rank_shift_plot( + df: pd.DataFrame, + full_rank_col: str, + noHLA_rank_col: str, + rank_change_col: str, + pc_label: str, + out_path: Path, + top_n_pde4d_label: int = 5, + top_n_improvers_label: int = 3, +) -> None: + """ + Scatter: x = full rank, y = no-HLA rank, for non-HLA blocks only. + + Rank 1 is plotted at top-left (x-axis: 1 on left; y-axis inverted so 1 is + at top). Blocks above the diagonal improved after HLA removal. + """ + # Filter to blocks with both ranks (excludes HLA blocks which have NaN noHLA rank) + plot_df = df[df[full_rank_col].notna() & df[noHLA_rank_col].notna()].copy() + plot_df[full_rank_col] = plot_df[full_rank_col].astype(int) + plot_df[noHLA_rank_col] = plot_df[noHLA_rank_col].astype(int) + + n_hla_excluded = df["is_HLA"].sum() if "is_HLA" in df.columns else 0 + n_total = len(plot_df) + + colors = [_assign_color(row) for _, row in plot_df.iterrows()] + sizes = [_assign_size(row) for _, row in plot_df.iterrows()] + + fig, ax = plt.subplots(figsize=(7, 7)) + + # Background scatter (non-highlighted blocks) + is_other = ~(plot_df["is_PDE4D"] | plot_df.get("is_IL1RL1", False) | + plot_df.get("is_SH2B3", False) | plot_df.get("is_17q21", False)) + ax.scatter( + plot_df.loc[is_other, full_rank_col], + plot_df.loc[is_other, noHLA_rank_col], + s=12, color="steelblue", alpha=0.45, linewidths=0, zorder=2, + ) + + # Highlighted categories (plotted on top) + for mask, color, label, zorder in [ + (plot_df.get("is_17q21", pd.Series(False, index=plot_df.index)), "forestgreen", "17q21", 3), + (plot_df.get("is_IL1RL1", pd.Series(False, index=plot_df.index)), "darkorange", "IL1RL1", 4), + (plot_df.get("is_SH2B3", pd.Series(False, index=plot_df.index)), "mediumpurple","SH2B3", 4), + (plot_df["is_PDE4D"], "tomato", "PDE4D", 5), + ]: + sub = plot_df[mask] + if sub.empty: + continue + ax.scatter( + sub[full_rank_col], sub[noHLA_rank_col], + s=28, color=color, alpha=0.85, linewidths=0.4, + edgecolors="white", zorder=zorder, label=label, + ) + + # Diagonal reference line (y = x, no rank change) + max_rank = max(plot_df[full_rank_col].max(), plot_df[noHLA_rank_col].max()) + diag = np.arange(1, max_rank + 1) + ax.plot(diag, diag, color="black", linewidth=0.8, linestyle="--", + alpha=0.5, zorder=1, label="no change (y = x)") + + # Invert y-axis so rank 1 is at the top + ax.invert_yaxis() + + # ── Annotate top-N PDE4D blocks (best no-HLA rank) ──────────────────────── + pde4d_df = plot_df[plot_df["is_PDE4D"]].nsmallest(top_n_pde4d_label, noHLA_rank_col) + for _, row in pde4d_df.iterrows(): + lbl = _short_label(row["block_id"]) + x, y = row[full_rank_col], row[noHLA_rank_col] + ax.annotate( + lbl, xy=(x, y), + xytext=(x + max_rank * 0.04, y - max_rank * 0.025), + fontsize=6.5, color="tomato", + arrowprops=dict(arrowstyle="-", color="tomato", lw=0.6), + ) + + # ── Annotate top-N largest rank improvements (any category) ─────────────── + # rank improvement = full_rank - noHLA_rank = -rank_change + improvement = plot_df[full_rank_col] - plot_df[noHLA_rank_col] + top_improvers = plot_df[improvement == improvement.nlargest(top_n_improvers_label).iloc[-1]].index + # Use nlargest directly on the series + top_improver_idx = improvement.nlargest(top_n_improvers_label).index + # skip any already labelled by PDE4D annotation + labelled_ids = set(pde4d_df["block_id"]) + for idx in top_improver_idx: + row = plot_df.loc[idx] + if row["block_id"] in labelled_ids: + continue + lbl = _short_label(row["block_id"]) + x, y = row[full_rank_col], row[noHLA_rank_col] + ax.annotate( + lbl, xy=(x, y), + xytext=(x - max_rank * 0.10, y + max_rank * 0.025), + fontsize=6.5, color="dimgray", + arrowprops=dict(arrowstyle="-", color="dimgray", lw=0.6), + ) + + # ── Axes and labels ─────────────────────────────────────────────────────── + ax.set_xlabel(f"Full model rank ({pc_label})", fontsize=11) + ax.set_ylabel(f"No-HLA model rank ({pc_label})", fontsize=11) + ax.set_title( + f"Rank shift after HLA masking\n(Phase 2 {pc_label} attribution)", + fontsize=12, fontweight="bold", + ) + + # Axis limits: 0.5 padding so rank-1 dot isn't clipped + ax.set_xlim(0.5, max_rank + 0.5) + ax.set_ylim(max_rank + 0.5, 0.5) # y inverted + + # ── Legend ──────────────────────────────────────────────────────────────── + # Add a dummy entry to note that HLA blocks are excluded + existing_handles, existing_labels = ax.get_legend_handles_labels() + hla_patch = mpatches.Patch( + color="lightgray", + label=f"HLA blocks (n={n_hla_excluded}, no no-HLA rank)", + ) + other_patch = mpatches.Patch(color="steelblue", alpha=0.6, label=f"other ({n_total} blocks shown)") + handles = existing_handles + [other_patch, hla_patch] + ax.legend(handles=handles, fontsize=8, loc="lower right", framealpha=0.85) + + # ── Annotation: above-diagonal = improved ───────────────────────────────── + ax.text( + 0.03, 0.03, "↑ improved after HLA removal", + transform=ax.transAxes, fontsize=7.5, color="dimgray", + va="bottom", style="italic", + ) + + plt.tight_layout() + fig.savefig(out_path, dpi=150) + plt.close(fig) + print(f"[plot] saved {out_path}") + + +# ── summary table ───────────────────────────────────────────────────────────── +def print_top_improvers(df: pd.DataFrame, full_col: str, noHLA_col: str, + change_col: str, pc_label: str, top_n: int = 10) -> None: + valid = df[df[full_col].notna() & df[noHLA_col].notna()].copy() + valid["rank_improvement"] = valid[full_col] - valid[noHLA_col] + top = valid.nlargest(top_n, "rank_improvement")[ + ["block_id", full_col, noHLA_col, "rank_improvement", + "is_PDE4D", "is_HLA", "is_IL1RL1"] + ].copy() + top["is_PDE4D"] = top["is_PDE4D"].map({True: "YES", False: "-"}) + top["is_IL1RL1"] = top["is_IL1RL1"].map({True: "YES", False: "-"}) + + print(f"\n══ Top {top_n} rank improvements — {pc_label} " + f"(full rank − noHLA rank, positive = rose) ══") + print(top.to_string(index=False)) + + pde4d = valid[valid["is_PDE4D"]] + best_full = pde4d[full_col].min() + best_noHLA = pde4d[noHLA_col].min() + print(f"\n PDE4D best rank — full model: {int(best_full)} | " + f"no-HLA model: {int(best_noHLA)}") + if best_noHLA < best_full: + print(f" → PDE4D shifts toward rank 1 after HLA removal " + f"({int(best_full)} → {int(best_noHLA)})") + else: + print(f" → PDE4D does not clearly improve after HLA removal") + + +# ── main ───────────────────────────────────────────────────────────────────── +def main(): + out_dir = Path(OUT_DIR) + out_dir.mkdir(parents=True, exist_ok=True) + + df = load_comparison() + if "is_SH2B3" not in df.columns: + df["is_SH2B3"] = df["block_id"].str.contains("SH2B3", na=False) + + n_overlapping_PC1 = df["full_PC1_rank"].notna() & df["noHLA_PC1_rank"].notna() + n_overlapping_PC2 = df["full_PC2_rank"].notna() & df["noHLA_PC2_rank"].notna() + print(f"[data] loaded {len(df)} total blocks") + print(f" overlapping (both full + noHLA ranks): " + f"PC1={n_overlapping_PC1.sum()} PC2={n_overlapping_PC2.sum()}") + print(f" HLA blocks (excluded from scatter): {df['is_HLA'].sum()}") + + # ── PC1 plot ───────────────────────────────────────────────────────────── + make_rank_shift_plot( + df=df, + full_rank_col="full_PC1_rank", + noHLA_rank_col="noHLA_PC1_rank", + rank_change_col="rank_change_PC1", + pc_label="PC1", + out_path=out_dir / "phase2_PC1_rank_shift_scatter.png", + ) + print_top_improvers(df, "full_PC1_rank", "noHLA_PC1_rank", + "rank_change_PC1", "PC1") + + # ── PC2 plot ───────────────────────────────────────────────────────────── + make_rank_shift_plot( + df=df, + full_rank_col="full_PC2_rank", + noHLA_rank_col="noHLA_PC2_rank", + rank_change_col="rank_change_PC2", + pc_label="PC2", + out_path=out_dir / "phase2_PC2_rank_shift_scatter.png", + ) + print_top_improvers(df, "full_PC2_rank", "noHLA_PC2_rank", + "rank_change_PC2", "PC2") + + print(f"\n[done] outputs in {out_dir}/") + + +if __name__ == "__main__": + main() diff --git a/scripts/analysis/diagnostics/01_analyze_VAEtraining_logs.py b/scripts/analysis/diagnostics/01_analyze_VAEtraining_logs.py new file mode 100644 index 0000000..669a641 --- /dev/null +++ b/scripts/analysis/diagnostics/01_analyze_VAEtraining_logs.py @@ -0,0 +1,107 @@ +import pandas as pd +import numpy as np +from pathlib import Path +import matplotlib.pyplot as plt + +LOG_DIR = Path("/Users/shraddh_mac/Documents/GitHub/blockbased-genotype-embedding-analysis/results/output_regions_ord_weighted/ORD/logs") +SUMMARY = pd.read_csv("/Users/shraddh_mac/Documents/GitHub/blockbased-genotype-embedding-analysis/results/output_regions_ord_weighted/vae_summary.csv") +OUT = Path("/Users/shraddh_mac/Documents/GitHub/blockbased-genotype-embedding-analysis/results/output_regions_ord_weighted/training_dynamics") +OUT.mkdir(exist_ok=True) + +# ─── Load all logs into one long-form df ─── +rows = [] +for f in LOG_DIR.glob("*.csv"): + bid = f.stem + df = pd.read_csv(f) + df["block"] = bid + rows.append(df) +all_logs = pd.concat(rows, ignore_index=True) +print(f"Loaded {len(rows)} blocks, {len(all_logs)} total epoch-rows") + +# ─── Priority 1: convergence speed ─── +conv_rows = [] +for bid, g in all_logs.groupby("block"): + g = g.sort_values("epoch") + vmin = g["va_recon"].min() + vmax = g["va_recon"].iloc[0] + threshold = vmin + 0.01 * (vmax - vmin) # within 1% of best + converged = g[g["va_recon"] <= threshold] + ep_conv = int(converged["epoch"].iloc[0]) if len(converged) else int(g["epoch"].max()) + conv_rows.append({"block": bid, "epoch_converged": ep_conv, + "total_epochs": int(g["epoch"].max()), + "min_va_recon": float(vmin)}) +conv_df = pd.DataFrame(conv_rows).merge(SUMMARY[["block","n_snps","maf_mean","best_epoch"]], + on="block", how="left") +conv_df.to_csv(OUT/"convergence_summary.csv", index=False) + +fig, ax = plt.subplots(figsize=(10,4)) +ax.hist(conv_df["epoch_converged"], bins=30) +ax.set_xlabel("Epochs to within 1% of best val_recon") +ax.set_ylabel("Number of blocks") +ax.set_title("Convergence speed distribution (ORD weighted)") +plt.tight_layout(); plt.savefig(OUT/"convergence_hist.png", dpi=150); plt.close() + +# Does block size predict convergence speed? +fig, ax = plt.subplots(figsize=(10,4)) +ax.scatter(conv_df["n_snps"], conv_df["epoch_converged"], alpha=0.5) +ax.set_xlabel("n_snps"); ax.set_ylabel("epochs_to_converge") +ax.set_title("Convergence speed vs block size") +plt.tight_layout(); plt.savefig(OUT/"convergence_vs_size.png", dpi=150); plt.close() + +# ─── Priority 2 & 3: train/val curves + KL for representative blocks ─── +REP_BLOCKS = ["region_17q21_core_sb1", "region_11q13_FCER1A", + "region_2q12_IL1RL1_cluster_sb3", "region_5q21_PDE4D_sb55", + "region_6p21_HLA_classII_sb1", "region_5q31_type2_cytokine_sb9", + "region_1q31_TNFSF_cluster_sb4", "control_OCA2_sb10"] + +for bid in REP_BLOCKS: + g = all_logs[all_logs["block"] == bid].sort_values("epoch") + if g.empty: continue + fig, axes = plt.subplots(1, 3, figsize=(18, 4)) + axes[0].plot(g["epoch"], g["tr_recon"], label="train") + axes[0].plot(g["epoch"], g["va_recon"], label="val") + axes[0].axvline(50, ls="--", c="gray", alpha=0.5, label="β warmup end") + axes[0].set_title(f"{bid} — reconstruction loss") + axes[0].set_xlabel("epoch"); axes[0].legend() + + axes[1].plot(g["epoch"], g["tr_kl"], label="train") + axes[1].plot(g["epoch"], g["va_kl"], label="val") + axes[1].axvline(50, ls="--", c="gray", alpha=0.5) + axes[1].set_title(f"{bid} — KL divergence") + axes[1].set_xlabel("epoch"); axes[1].legend() + + axes[2].plot(g["epoch"], g["tr_loss"], label="train") + axes[2].plot(g["epoch"], g["va_loss"], label="val") + axes[2].axvline(50, ls="--", c="gray", alpha=0.5) + axes[2].set_title(f"{bid} — total loss") + axes[2].set_xlabel("epoch"); axes[2].legend() + + plt.tight_layout(); plt.savefig(OUT/f"{bid}_curves.png", dpi=150); plt.close() + +# ─── Priority 4: best-epoch distribution ─── +fig, ax = plt.subplots(figsize=(10,4)) +ax.hist(SUMMARY["best_epoch"], bins=30) +ax.axvline(50, ls="--", c="red", label="β warmup ends") +ax.set_xlabel("best_epoch"); ax.set_ylabel("number of blocks") +ax.set_title("Where best checkpoint occurred (ORD weighted)") +ax.legend(); plt.tight_layout() +plt.savefig(OUT/"best_epoch_hist.png", dpi=150); plt.close() + +print(f"Wrote analyses to {OUT}/") +print(f" Convergence: median epoch {conv_df['epoch_converged'].median():.0f}, " + f"max {conv_df['epoch_converged'].max():.0f}") +print(f" Best-epoch: median {SUMMARY['best_epoch'].median():.0f}, " + f"frac before warmup: {(SUMMARY['best_epoch'] < 50).mean() * 100:.1f}%") + +# ------------ Cross-block latent geometry ------------ +sub = SUMMARY[SUMMARY['loss'] == 'ORD'] +print(f"mu_var_median: min={sub['mu_var_median'].min():.3f}, " + f"median={sub['mu_var_median'].median():.3f}, " + f"max={sub['mu_var_median'].max():.3f}, " + f"IQR={sub['mu_var_median'].quantile(0.75) - sub['mu_var_median'].quantile(0.25):.3f}") + +# ------------ KL per-block consistency ------------ +print(f"kl_per_dim_median: min={sub['kl_per_dim_median'].min():.3f}, " + f"median={sub['kl_per_dim_median'].median():.3f}, " + f"max={sub['kl_per_dim_median'].max():.3f}, " + f"IQR={sub['kl_per_dim_median'].quantile(0.75) - sub['kl_per_dim_median'].quantile(0.25):.3f}") \ No newline at end of file diff --git a/scripts/analysis/02_subject_cluster_analysis.py b/scripts/analysis/diagnostics/02_subject_cluster_analysis.py similarity index 100% rename from scripts/analysis/02_subject_cluster_analysis.py rename to scripts/analysis/diagnostics/02_subject_cluster_analysis.py diff --git a/scripts/analysis/04_cluster_stability_analysis.py b/scripts/analysis/diagnostics/04_cluster_stability_analysis.py similarity index 100% rename from scripts/analysis/04_cluster_stability_analysis.py rename to scripts/analysis/diagnostics/04_cluster_stability_analysis.py diff --git a/scripts/analysis/05_attention_confounder_analysis.py b/scripts/analysis/diagnostics/05_attention_confounder_analysis.py similarity index 100% rename from scripts/analysis/05_attention_confounder_analysis.py rename to scripts/analysis/diagnostics/05_attention_confounder_analysis.py diff --git a/scripts/analysis/06_phase1_phase2_block_comparison.py b/scripts/analysis/diagnostics/06_phase1_phase2_block_comparison.py similarity index 100% rename from scripts/analysis/06_phase1_phase2_block_comparison.py rename to scripts/analysis/diagnostics/06_phase1_phase2_block_comparison.py diff --git a/scripts/analysis/08_clinical_pc_embedding_alignment.py b/scripts/analysis/diagnostics/08_clinical_pc_embedding_alignment.py similarity index 99% rename from scripts/analysis/08_clinical_pc_embedding_alignment.py rename to scripts/analysis/diagnostics/08_clinical_pc_embedding_alignment.py index d9e40e1..f83c424 100644 --- a/scripts/analysis/08_clinical_pc_embedding_alignment.py +++ b/scripts/analysis/diagnostics/08_clinical_pc_embedding_alignment.py @@ -55,7 +55,7 @@ warnings.filterwarnings("ignore", category=RuntimeWarning) # ── defaults ────────────────────────────────────────────────────────────────── -DEFAULT_EMBEDDINGS = "results/output_regions2/ORD/embeddings/individual_embeddings.csv" +DEFAULT_EMBEDDINGS = "results/output_regions2/ORD_W_Scaled/embeddings/individual_embeddings.csv" DEFAULT_CLINICAL = "metadata/2022-07-21_Endotype_CRA.tsv" DEFAULT_EIGENVEC = "metadata/ldpruned_997subs.eigenvec" DEFAULT_OUT = "results/analysis/clinical_pc_embedding_alignment_cra" diff --git a/scripts/analysis/03_leave_hla_out_analysis.py b/scripts/analysis/sensitivity/03_leave_hla_out_analysis.py similarity index 100% rename from scripts/analysis/03_leave_hla_out_analysis.py rename to scripts/analysis/sensitivity/03_leave_hla_out_analysis.py diff --git a/scripts/analysis/07_17q21_validation.py b/scripts/analysis/sensitivity/07_17q21_validation.py similarity index 100% rename from scripts/analysis/07_17q21_validation.py rename to scripts/analysis/sensitivity/07_17q21_validation.py diff --git a/scripts/analysis/sensitivity/11_top_snp_ld_check.py b/scripts/analysis/sensitivity/11_top_snp_ld_check.py new file mode 100644 index 0000000..43c7c2e --- /dev/null +++ b/scripts/analysis/sensitivity/11_top_snp_ld_check.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 +""" +11_top_snp_ld_check.py + +Compute pairwise LD (r²) among the top N attributed SNPs from a Phase 1 SNP +attribution run (output of 10_phase1_snp_attribution_within_blocks.py). + +Run as: + python scripts/analysis/11_top_snp_ld_check.py \\ + --attrib-csv results/analysis/snp_attribution/all_selected_blocks_snp_attribution.csv \\ + --block-id region_9p24_IL33 \\ + --target emb_PC1 \\ + --plink /path/to/plink \\ + --top-n 20 + +What this does +-------------- +1. Reads the SNP attribution CSV (per-block or combined output of script 10). +2. Filters to --block-id and --target. +3. Optionally filters to a single --run (full or noHLA); otherwise deduplicates + on snp_locus, keeping the highest mean_abs_delta_score per SNP. +4. Sorts by mean_abs_delta_score descending and takes the top --top-n SNPs. +5. Reads the block's .raw header to map each snp_locus to its PLINK .raw column + name. Fails loudly if any top SNP cannot be mapped. +6. Strips the trailing allele suffix to obtain exact PLINK variant IDs. +7. Writes a variant-ID file for PLINK --extract. +8. Verifies .bed/.bim/.fam exist, then runs: + --bfile --extract --r2 square --out + via subprocess.run(..., check=True). +9. Loads the resulting .ld matrix and saves a tidy long-format LD summary CSV. +10. Prints all commands and output paths. + +Outputs (--out-dir, default results/analysis/snp_ld_check/) +------- +_snp_loci.txt snp_locus IDs for the top N SNPs (human-readable) +_plink_ids.txt PLINK variant IDs for --extract +_top_snps.csv top-N SNP attribution rows +.ld raw PLINK r² square matrix +_ld_summary.csv tidy long-format r² pairs with attribution metadata + +Limitations +----------- +- LD is estimated from the per-block cohort genotypes; sample size equals the + full cohort, not a case-stratified subset. +- High r² between two top-ranked SNPs suggests redundant signal, not biological + co-regulation. Interpret alongside annotation and fine-mapping results. +- Do not interpret LD structure as evidence of causality. +""" + +import os +import sys +import argparse +import subprocess +from pathlib import Path + +import numpy as np +import pandas as pd + + +DEFAULT_OUT = "results/analysis/snp_ld_check" +RAW_DIR = "data/raw_plink" +BFILE_DIR = "data/region_blocks" + + +# ── SNP column name parsing ─────────────────────────────────────────────────── + +def parse_snp_locus(col_name: str) -> str: + """Strip trailing PLINK allele suffix: '12:111742373:A:G_A' -> '12:111742373:A:G'.""" + parts = col_name.rsplit("_", 1) + if len(parts) == 2 and 1 <= len(parts[1]) <= 2 and parts[1].isalpha(): + return parts[0] + return col_name + + +# ── .raw header reader ──────────────────────────────────────────────────────── + +def read_raw_header(raw_path: Path) -> list: + """Return SNP column names (cols 6+) from the first line of a .raw file.""" + with open(raw_path, "r") as fh: + header_line = fh.readline() + cols = header_line.strip().split() + if len(cols) < 7: + raise ValueError( + f".raw header has only {len(cols)} columns (expected ≥7): {raw_path}" + ) + return cols[6:] # skip FID IID PAT MAT SEX PHENOTYPE + + +# ── locus → PLINK variant ID mapping ───────────────────────────────────────── + +def map_loci_to_plink_ids(snp_loci: list, raw_snp_cols: list) -> dict: + """ + Build {snp_locus: plink_variant_id} from .raw header column names. + + plink_variant_id = parse_snp_locus(raw_col_name) — i.e. the variant ID + without the allele suffix, matching the .bim 2nd column. + + Raises ValueError if any locus is absent from the .raw header. + """ + locus_to_col = {} + for col in raw_snp_cols: + locus = parse_snp_locus(col) + locus_to_col[locus] = col + + result = {} + missing = [] + for locus in snp_loci: + if locus in locus_to_col: + result[locus] = parse_snp_locus(locus_to_col[locus]) + else: + missing.append(locus) + + if missing: + raise ValueError( + f"Cannot map {len(missing)} SNP locus ID(s) to .raw header column names:\n" + + "\n".join(f" {m}" for m in missing[:10]) + + ("\n ..." if len(missing) > 10 else "") + ) + return result + + +# ── PLINK helpers ───────────────────────────────────────────────────────────── + +def check_bfile(bfile_prefix: str) -> None: + """Raise FileNotFoundError if any of .bed/.bim/.fam are absent.""" + missing = [ + bfile_prefix + ext + for ext in (".bed", ".bim", ".fam") + if not Path(bfile_prefix + ext).exists() + ] + if missing: + raise FileNotFoundError( + "PLINK binary file(s) not found:\n" + + "\n".join(f" {m}" for m in missing) + ) + + +def run_plink_ld(plink_exe: str, bfile: str, + extract_file: str, out_prefix: str) -> Path: + """Run PLINK r² square matrix computation; return path to .ld output file.""" + cmd = [ + plink_exe, + "--bfile", bfile, + "--extract", extract_file, + "--r2", "square", + "--out", out_prefix, + ] + print("\n[plink] command:") + print(" " + " ".join(cmd)) + subprocess.run(cmd, check=True) + ld_path = Path(out_prefix + ".ld") + if not ld_path.exists(): + raise RuntimeError( + f"PLINK did not produce expected .ld output: {ld_path}\n" + "Check the PLINK log for errors." + ) + return ld_path + + +# ── LD matrix loader ────────────────────────────────────────────────────────── + +def load_ld_matrix(ld_path: Path, snp_ids: list) -> pd.DataFrame: + """ + Load a whitespace-delimited r² square matrix produced by PLINK --r2 square. + Return a tidy long-format DataFrame with one row per unique SNP pair (i < j). + """ + mat = np.loadtxt(str(ld_path)) + n = len(snp_ids) + if mat.ndim == 1 and n == 1: + mat = mat.reshape(1, 1) + if mat.shape != (n, n): + raise RuntimeError( + f"LD matrix shape {mat.shape} does not match {n} extracted SNPs. " + "Verify that all variant IDs are present in the .bim file." + ) + rows = [] + for i in range(n): + for j in range(i + 1, n): + rows.append(dict( + snp_a=snp_ids[i], + snp_b=snp_ids[j], + r2=float(mat[i, j]), + )) + return pd.DataFrame(rows) if rows else pd.DataFrame( + columns=["snp_a", "snp_b", "r2"] + ) + + +# ── Argparse ────────────────────────────────────────────────────────────────── + +def parse_args(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--attrib-csv", required=True, + help="SNP attribution CSV from script 10 " + "(per-block or all_selected_blocks_snp_attribution.csv)") + p.add_argument("--raw-dir", default=RAW_DIR, + help="Directory containing per-block .raw PLINK dosage files " + "(default: %(default)s)") + p.add_argument("--bfile-dir", default=BFILE_DIR, + help="Directory containing per-block PLINK binary files " + "(.bed/.bim/.fam). Expects / prefix. " + "Default: %(default)s") + p.add_argument("--block-id", required=True, + help="Block ID to compute LD for, e.g. region_9p24_IL33") + p.add_argument("--target", required=True, + choices=["emb_PC1", "emb_PC2", "log10Ige_ridge_pred"], + help="Attribution target to rank SNPs by") + p.add_argument("--run", default=None, choices=["full", "noHLA"], + help="Filter to a specific model run. " + "Default None = use all runs, keep best score per SNP locus.") + p.add_argument("--top-n", type=int, default=20, + help="Number of top SNPs by mean_abs_delta_score (default 20)") + p.add_argument("--plink", required=True, + help="Full path to the PLINK executable. Not assumed to be on PATH.") + p.add_argument("--out-dir", default=DEFAULT_OUT) + return p.parse_args() + + +# ── Main ────────────────────────────────────────────────────────────────────── + +def main(): + args = parse_args() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + block_id = args.block_id + target = args.target + top_n = args.top_n + + safe_tgt = target.lower().replace(" ", "_").replace("/", "_") + prefix = f"{block_id}_{safe_tgt}_top{top_n}" + + # ── 1. Read attribution CSV ─────────────────────────────────────────────── + print(f"[load] attribution CSV: {args.attrib_csv}") + df = pd.read_csv(args.attrib_csv) + + required_cols = {"block_id", "snp_locus", "target", "mean_abs_delta_score"} + missing_cols = required_cols - set(df.columns) + if missing_cols: + sys.exit( + f"ERROR: Attribution CSV missing required columns: {sorted(missing_cols)}\n" + f"Available: {sorted(df.columns.tolist())}" + ) + + # ── 2. Filter to block_id and target ───────────────────────────────────── + df_f = df[(df["block_id"] == block_id) & (df["target"] == target)].copy() + if df_f.empty: + sys.exit( + f"ERROR: No rows for block_id='{block_id}' and target='{target}' " + f"in {args.attrib_csv}.\n" + f"Available block IDs: {sorted(df['block_id'].unique().tolist())}\n" + f"Available targets: {sorted(df['target'].unique().tolist())}" + ) + print(f" {len(df_f)} rows matched " + f"(block='{block_id}', target='{target}')") + + # ── 3. Optional run filter ──────────────────────────────────────────────── + if args.run is not None: + if "run" not in df_f.columns: + sys.exit( + "ERROR: Attribution CSV has no 'run' column; cannot filter by --run.\n" + "Remove --run to use all available rows." + ) + df_f = df_f[df_f["run"] == args.run].copy() + if df_f.empty: + sys.exit( + f"ERROR: No rows remain after filtering to run='{args.run}'.\n" + f"Available run values: " + f"{sorted(df['run'].dropna().unique().tolist())}" + ) + print(f" {len(df_f)} rows after --run='{args.run}' filter") + + # ── 4. Sort, deduplicate on snp_locus (keep best score), take top N ─────── + df_f = df_f.sort_values("mean_abs_delta_score", ascending=False) + df_f = df_f.drop_duplicates(subset="snp_locus", keep="first") + top_df = df_f.head(top_n).reset_index(drop=True) + + actual_n = len(top_df) + if actual_n < top_n: + print(f"[warn] only {actual_n} unique SNPs available; " + f"requested {top_n}") + print(f" {actual_n} SNP(s) selected for LD computation") + + # ── 5. Read .raw header → validate and map loci to PLINK variant IDs ────── + raw_path = Path(args.raw_dir) / f"{block_id}.raw" + if not raw_path.exists(): + raise FileNotFoundError(f".raw file not found: {raw_path}") + print(f"[raw] reading header: {raw_path}") + raw_snp_cols = read_raw_header(raw_path) + print(f" {len(raw_snp_cols)} SNP columns in .raw header") + + snp_loci = top_df["snp_locus"].tolist() + locus_map = map_loci_to_plink_ids(snp_loci, raw_snp_cols) + plink_ids = [locus_map[locus] for locus in snp_loci] + print(f" all {len(plink_ids)} loci mapped to PLINK variant IDs") + + # ── 6. Write snp_loci.txt and plink_ids.txt ─────────────────────────────── + loci_file = out_dir / f"{prefix}_snp_loci.txt" + ids_file = out_dir / f"{prefix}_plink_ids.txt" + + loci_file.write_text("\n".join(snp_loci) + "\n") + ids_file.write_text("\n".join(plink_ids) + "\n") + print(f"[output] {loci_file.name} ({len(snp_loci)} loci)") + print(f"[output] {ids_file.name} ({len(plink_ids)} PLINK IDs)") + + # ── 7. Verify PLINK executable and bfile ────────────────────────────────── + plink_exe = args.plink + if not Path(plink_exe).exists(): + raise FileNotFoundError(f"PLINK executable not found: {plink_exe}") + + bfile = str(Path(args.bfile_dir) / block_id) + check_bfile(bfile) + print(f"[plink] bfile prefix: {bfile}") + + # ── 8. Run PLINK LD ─────────────────────────────────────────────────────── + ld_out_prefix = str(out_dir / prefix) + ld_path = run_plink_ld(plink_exe, bfile, str(ids_file), ld_out_prefix) + print(f"[plink] LD matrix: {ld_path}") + + # ── 9. Load r² matrix → tidy long-format CSV ────────────────────────────── + ld_long = load_ld_matrix(ld_path, plink_ids) + + if not ld_long.empty: + score_map = dict(zip(plink_ids, top_df["mean_abs_delta_score"].values)) + ld_long["mean_abs_delta_score_a"] = ld_long["snp_a"].map(score_map) + ld_long["mean_abs_delta_score_b"] = ld_long["snp_b"].map(score_map) + ld_long = ld_long.sort_values("r2", ascending=False).reset_index(drop=True) + + ld_csv = out_dir / f"{prefix}_ld_summary.csv" + ld_long.to_csv(ld_csv, index=False) + print(f"[output] {ld_csv.name} ({len(ld_long)} pairs)") + + # ── 10. Save top-SNP attribution rows ───────────────────────────────────── + top_csv = out_dir / f"{prefix}_top_snps.csv" + top_df.to_csv(top_csv, index=False) + print(f"[output] {top_csv.name} ({len(top_df)} SNPs)") + + # ── 11. Summary ─────────────────────────────────────────────────────────── + print(f"\n══ Top {actual_n} SNPs — {target} — {block_id} ══") + disp_cols = ["snp_locus", "mean_abs_delta_score"] + if "rank_within_block" in top_df.columns: + disp_cols = ["rank_within_block"] + disp_cols + if "run" in top_df.columns: + disp_cols = disp_cols + ["run"] + print(top_df[disp_cols].to_string(index=False)) + + if not ld_long.empty: + high_ld = ld_long[ld_long["r2"] >= 0.8] + if not high_ld.empty: + print(f"\n[LD] {len(high_ld)} SNP pair(s) with r² ≥ 0.80:") + print(high_ld[["snp_a", "snp_b", "r2"]].head(10).to_string(index=False)) + else: + print("\n[LD] No SNP pairs with r² ≥ 0.80 among top SNPs.") + + print(f"\n[done] All outputs written to: {out_dir}/") + + +if __name__ == "__main__": + main() diff --git a/scripts/analysis/01_block_embedding_phenotype_analysis.py b/scripts/analysis/validation/01_block_embedding_phenotype_analysis.py similarity index 99% rename from scripts/analysis/01_block_embedding_phenotype_analysis.py rename to scripts/analysis/validation/01_block_embedding_phenotype_analysis.py index c857f2d..582cd3f 100644 --- a/scripts/analysis/01_block_embedding_phenotype_analysis.py +++ b/scripts/analysis/validation/01_block_embedding_phenotype_analysis.py @@ -66,13 +66,13 @@ warnings.filterwarnings("ignore", category=RuntimeWarning) # ── paths ────────────────────────────────────────────────────────────────────── -EMB_SUBJ_NPY = "results/output_regions2/ORD/embeddings/individual_embeddings.npy" -EMB_BLOCK_NPY = "results/output_regions2/ORD/embeddings/block_contextual_repr.npy" -ATTN_CSV = "results/output_regions2/ORD/embeddings/pooling_attention_weights.csv" +EMB_SUBJ_NPY = "results/output_regions2/ORD_W_Scaled/embeddings/individual_embeddings.npy" +EMB_BLOCK_NPY = "results/output_regions2/ORD_W_Scaled/embeddings/block_contextual_repr.npy" +ATTN_CSV = "results/output_regions2/ORD_W_Scaled/embeddings/pooling_attention_weights.csv" BLOCK_ORDER = "results/output_regions/block_order.csv" EIGENVEC_FILE = "metadata/ldpruned_997subs.eigenvec" PHENO_FILE = "metadata/COS_TRIO_pheno_1165.csv" -OUT_DIR = "results/output_regions2/ORD/all_blocks_pheno_analysis" +OUT_DIR = "results/output_regions2/ORD_W_Scaled/all_blocks_pheno_analysis" PC_COLS = [f"PC{i}" for i in range(1, 11)] diff --git a/scripts/core/VAE_phase1.py b/scripts/core/VAE_phase1.py index a182a79..411537d 100644 --- a/scripts/core/VAE_phase1.py +++ b/scripts/core/VAE_phase1.py @@ -21,6 +21,8 @@ import platform import socket import sklearn +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler from itertools import product from torch.utils.data import TensorDataset, DataLoader from pathlib import Path @@ -41,13 +43,8 @@ DEFAULT_CFG = { "data": { "raw_dir": "data/region_blocks", - "block_def": "data/block_plan/manifest.tsv", - # "block_def_ctrl": "data/block_plan/manifest_blocks_ctrl.tsv", - "output_dir": "results/output_regions", - # "raw_dir": "data/region_blocks", - # "block_def": "data/block_plan_test.tsv", - # # "block_def_ctrl": "data/block_plan/manifest_blocks_ctrl.tsv", - # "output_dir": "results/test_ord_mps_run1", + "block_def": "data/block_plan/manifest.tsv", + "output_dir": "results/output_regions_ord_weighted", }, "runtime": { "device": "cpu", @@ -55,13 +52,12 @@ "vae": { "latent_dim": 8, - # buckets: [min_snps, max_snps, latent_dim] "latent_dim_by_snps": [ [8, 49, 4], [50, 149, 8], [150, 299, 12], [300, 799, 16], - [800, 10**9, 16], + [800, 10_000_000, 16], ], "dropout": 0.30, @@ -75,6 +71,11 @@ "val_frac": 0.20, "seed": 42, "cat_weight_clip": 10.0, + "free_bits": 0.0, # per-dim KL floor; 0.0 = disabled (original behavior) + "ld_corr_repeats": 1, # >1 averages ld_corr over multiple SNP subsamples + "ld_corr_max_snps": 200, # cap on SNPs used per ld_corr evaluation + "ord_weighted": False, # apply per-sample class weights to ordinal_loss + "ord_weight_clip": 10.0, # same convention as cat_weight_clip }, # "loss_functions": ["MSE", "BCE", "MSE_STD","CAT","ORD"], "loss_functions": ["MSE"], @@ -94,7 +95,10 @@ "metric": "bal_acc_va", "top_k": 2, "bottom_k": 2 - } + }, + "pca": { + "standardize": True, # fit StandardScaler on train PCA scores; transform all subjects + }, } @@ -124,7 +128,6 @@ def set_seed(s): torch.manual_seed(s) if torch.cuda.is_available(): torch.cuda.manual_seed_all(s) - try: torch.use_deterministic_algorithms(True, warn_only=True) except Exception: @@ -148,20 +151,6 @@ def get_device(requested="auto"): print("[device] CPU") return torch.device("cpu") -# def get_device(): -# if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): -# try: -# torch.zeros(2, device="mps") # quick smoke-test -# print("[device] Apple MPS") -# return torch.device("mps") -# except Exception: -# pass -# if torch.cuda.is_available(): -# print("[device] CUDA") -# return torch.device("cuda") -# print("[device] CPU") -# return torch.device("cpu") - def latent_dim_for_p(p: int, cfg: dict) -> int: v = cfg.get("vae", {}) sched = v.get("latent_dim_by_snps", None) @@ -169,7 +158,8 @@ def latent_dim_for_p(p: int, cfg: dict) -> int: return int(v.get("latent_dim", 8)) for lo, hi, d in sched: - if p >= int(lo) and p <= int(hi): + hi_f = float(hi) # handles both plain ints and float("inf") from YAML + if p >= int(lo) and (hi_f == float("inf") or p <= hi_f): return int(d) return int(v.get("latent_dim", 8)) @@ -178,15 +168,10 @@ def latent_dim_for_p(p: int, cfg: dict) -> int: # 2. DATA LOADING # ────────────────────────────────────────────────────────────────── -# Your TSV columns (has header row) -BLOCK_COLS = [ - "block_id", "gene", "class", "subblock", "chr", - "from_bp", "to_bp", "snp_count_original", "out_prefix", "status" -] def load_block_defs(tsv: str) -> pd.DataFrame: df = pd.read_csv(tsv, sep="\t", header=0) - # keep only OK blocks + # keep only OK blocks from LD block formation script if "status" in df.columns: df = df[df["status"].astype(str).str.lower().eq("ok")].copy() print(f"[blocks] {len(df)} OK entries from {tsv}") @@ -295,7 +280,7 @@ class BlockVAE(nn.Module): def __init__(self, p, d=16, drop=0.3, loss_type="MSE", class_weights=None): super().__init__() self.p, self.d, self.loss_type = p, d, loss_type - h = [64, 32] if p < 100 else [128, 64] + h = [64, 32] if p < 100 else [128, 64] # hidden layer neurons based on block size # ---- encoder ---- layers, inp = [], p @@ -324,9 +309,9 @@ def __init__(self, p, d=16, drop=0.3, loss_type="MSE", class_weights=None): self.dec = nn.Sequential(*layers) if class_weights is not None: - self.register_buffer("ce_w", torch.tensor(class_weights, dtype=torch.float32)) + self.register_buffer("class_w", torch.tensor(class_weights, dtype=torch.float32)) else: - self.ce_w = None + self.class_w = None def encode(self, x): h = self.enc(x) @@ -351,13 +336,14 @@ def forward(self, x): z = self.reparam(mu, lv) return self.decode(z), mu, lv - def compute_loss(self, x_in, recon, mu, lv, beta, y=None): + def compute_loss(self, x_in, recon, mu, lv, beta, y=None, free_bits=0.0): """ x_in: float input to encoder, shape (B,P) recon: - CAT: logits shape (B,P,3) - else: shape (B,P) - y: required for CAT, long targets shape (B,P) with values {0,1,2} + y: required for CAT/ORD, long targets shape (B,P) with values {0,1,2} + free_bits: per-dimension KL floor; 0.0 disables (original behavior unchanged) """ B = x_in.size(0) @@ -368,87 +354,103 @@ def compute_loss(self, x_in, recon, mu, lv, beta, y=None): logits = recon.reshape(-1, 3) targets = y.reshape(-1) - ce = F.cross_entropy(logits, targets, reduction="sum", - weight=self.ce_w) / B + ce = F.cross_entropy(logits, targets, reduction="mean", + weight=self.ce_w) rl = ce elif self.loss_type == "ORD": if y is None: raise ValueError("ORD loss requires y targets (LongTensor 0/1/2).") - rl = ordinal_loss(recon, y) # see corrected ordinal_loss below + rl = ordinal_loss(recon, y, class_weights=self.class_w) elif self.loss_type == "BCE": - rl = F.binary_cross_entropy_with_logits(recon, x_in, reduction="sum") / B + rl = F.binary_cross_entropy_with_logits(recon, x_in, reduction="mean") else: - rl = F.mse_loss(recon, x_in, reduction="sum") / B + rl = F.mse_loss(recon, x_in, reduction="mean") - kl = -0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) / B + if free_bits > 0.0: + kl_dim = (-0.5 * (1 + lv - mu.pow(2) - lv.exp())).mean(dim=0) # (d,) + kl = torch.clamp(kl_dim, min=free_bits).sum() + else: + kl = -0.5 * torch.mean(1 + lv - mu.pow(2) - lv.exp()) return rl + beta * kl, rl, kl # ────────────────────────────────────────────────────────────────── # 4. DATA TRANSFORMS (per loss type) # ────────────────────────────────────────────────────────────────── -def prepare_data(geno, loss_type, tr_ix, va_ix): +def prepare_data(geno, loss_type, tr_ix, va_ix, te_ix=None): tr, va = geno[tr_ix].copy(), geno[va_ix].copy() + te = geno[te_ix].copy() if te_ix is not None and len(te_ix) > 0 else None stats = {} if loss_type == "CAT": # categorical - # tr/va are expected to be {0,1,2} ints (or floats that are exactly 0/1/2) - tr_x = torch.tensor(tr, dtype=torch.float32) # encoder input + # encoder inputs stay float32; targets are rounded/clipped to {0,1,2} + tr_x = torch.tensor(tr, dtype=torch.float32) va_x = torch.tensor(va, dtype=torch.float32) + tr_y = torch.tensor(np.clip(np.round(tr), 0, 2).astype(np.int64), dtype=torch.long) + va_y = torch.tensor(np.clip(np.round(va), 0, 2).astype(np.int64), dtype=torch.long) + te_x = torch.tensor(te, dtype=torch.float32) if te is not None else None + te_y = torch.tensor(np.clip(np.round(te), 0, 2).astype(np.int64), dtype=torch.long) \ + if te is not None else None + return (tr_x, va_x, tr_y, va_y, stats, te_x, te_y) - tr_y = torch.tensor(tr, dtype=torch.long) # CE targets - va_y = torch.tensor(va, dtype=torch.long) - - return (tr_x, va_x, tr_y, va_y, stats) - if loss_type == "ORD": tr_x = torch.tensor(tr, dtype=torch.float32) va_x = torch.tensor(va, dtype=torch.float32) - tr_y = torch.tensor(tr, dtype=torch.long) - va_y = torch.tensor(va, dtype=torch.long) - return (tr_x, va_x, tr_y, va_y, stats) + tr_y = torch.tensor(np.clip(np.round(tr), 0, 2).astype(np.int64), dtype=torch.long) + va_y = torch.tensor(np.clip(np.round(va), 0, 2).astype(np.int64), dtype=torch.long) + te_x = torch.tensor(te, dtype=torch.float32) if te is not None else None + te_y = torch.tensor(np.clip(np.round(te), 0, 2).astype(np.int64), dtype=torch.long) \ + if te is not None else None + return (tr_x, va_x, tr_y, va_y, stats, te_x, te_y) - if loss_type == "BCE": tr, va = tr / 2.0, va / 2.0 # → {0, 0.5, 1} - + te = te / 2.0 if te is not None else None elif loss_type == "MSE_STD": m = tr.mean(0, keepdims=True) s = tr.std(0, keepdims=True); s[s < 1e-8] = 1.0 stats = {"mean": m, "std": s} tr, va = (tr - m) / s, (va - m) / s - + te = (te - m) / s if te is not None else None + + te_x = torch.tensor(te, dtype=torch.float32) if te is not None else None return (torch.tensor(tr, dtype=torch.float32), - torch.tensor(va, dtype=torch.float32), stats) + torch.tensor(va, dtype=torch.float32), stats, te_x) def compute_class_weights(tr_geno_block, eps=1e-6): # tr_geno_block: numpy array (N,P) with values 0/1/2 (ints or exact floats) - y = tr_geno_block.reshape(-1).astype(np.int64) + y = np.clip(np.round(tr_geno_block.reshape(-1)), 0, 2).astype(np.int64) counts = np.bincount(y, minlength=3).astype(np.float64) freq = counts / (counts.sum() + eps) w = 1.0 / (freq + eps) w = w / w.mean() # normalize so average weight ~1 return w.tolist(), counts.tolist() -def ordinal_loss(logits2: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: +def ordinal_loss(logits2: torch.Tensor, targets: torch.Tensor, + class_weights: torch.Tensor = None) -> torch.Tensor: """ logits2: (B,P,2) logits2[...,0] = t0 logits2[...,1] = raw_delta (we enforce t1 = t0 + softplus(delta) so t1>=t0) targets: (B,P) in {0,1,2} - Returns: sum-over-features / mean-over-batch + class_weights: optional (3,) tensor; per-sample weight = class_weights[target] + Returns: per-element (weighted) mean over batch and features """ - B = targets.size(0) t0 = logits2[..., 0] - t1 = t0 + F.softplus(logits2[..., 1]) # enforce ordering - - p0 = torch.sigmoid(t0) # P(Y<=0) - p1 = torch.sigmoid(t1) # P(Y<=1) + t1 = t0 + F.softplus(logits2[..., 1]) + p0 = torch.sigmoid(t0) + p1 = torch.sigmoid(t1) p_cls = torch.stack([p0, p1 - p0, 1 - p1], dim=-1).clamp(1e-6, 1.0) - log_p = p_cls.log().gather(-1, targets.clamp(0, 2).unsqueeze(-1)).squeeze(-1) - return -log_p.sum() / B + log_p = p_cls.log().gather(-1, targets.clamp(0, 2).unsqueeze(-1)).squeeze(-1) # (B,P) + + if class_weights is not None: + # Per-element weight by true class + w = class_weights[targets.clamp(0, 2)] # (B,P) + # Weighted mean: sum(w * loss) / sum(w), so the result stays on a similar scale + return -(w * log_p).sum() / (w.sum() + 1e-12) + return -log_p.mean() # ────────────────────────────────────────────────────────────────── @@ -456,6 +458,7 @@ def ordinal_loss(logits2: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: # ────────────────────────────────────────────────────────────────── def train_block_vae(model, tr_t, va_t, cfg, device, log_csv, tr_y=None, va_y=None): v = cfg["vae"] + free_bits = float(v.get("free_bits", 0.0)) if tr_y is None: tr_ds = TensorDataset(tr_t) @@ -487,6 +490,7 @@ def train_block_vae(model, tr_t, va_t, cfg, device, log_csv, tr_y=None, va_y=Non ) best_val, wait, best_sd, log = float("inf"), 0, None, [] + best_recon = float("inf") best_epoch = 0 best_metrics = {} @@ -507,7 +511,7 @@ def train_block_vae(model, tr_t, va_t, cfg, device, log_csv, tr_y=None, va_y=Non y = y.to(device) opt.zero_grad() recon, mu, lv = model(x) - loss, recon_loss, kl_loss = model.compute_loss(x, recon, mu, lv, beta, y) + loss, recon_loss, kl_loss = model.compute_loss(x, recon, mu, lv, beta, y, free_bits=free_bits) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), v["grad_clip"]) opt.step() @@ -532,7 +536,7 @@ def train_block_vae(model, tr_t, va_t, cfg, device, log_csv, tr_y=None, va_y=Non if y is not None: y = y.to(device) recon, mu, lv = model(x) - loss, recon_loss, kl_loss = model.compute_loss(x, recon, mu, lv, beta, y) + loss, recon_loss, kl_loss = model.compute_loss(x, recon, mu, lv, beta, y, free_bits=free_bits) total_va_loss += loss.item() * x.size(0) total_va_recon += recon_loss.item() * x.size(0) total_va_kl += kl_loss.item() * x.size(0) @@ -551,20 +555,22 @@ def train_block_vae(model, tr_t, va_t, cfg, device, log_csv, tr_y=None, va_y=Non "va_kl": vK, }) - if vL < best_val: - best_val, wait = vL, 0 - best_epoch = ep - best_metrics = { - "best_val_loss": vL, - "best_val_recon": vR, - "best_val_kl": vK, - "best_tr_loss": tL, - "best_tr_recon": tR, - "best_tr_kl": tK, - } - best_sd = {k: w.cpu().clone() for k, w in model.state_dict().items()} - else: - wait += 1 + if ep > v["beta_warmup"]: + if vR < best_recon: + # best_val, wait = vL, 0 + best_recon, wait = vR, 0 + best_epoch = ep + best_metrics = { + "best_val_loss": vL, + "best_val_recon": vR, + "best_val_kl": vK, + "best_tr_loss": tL, + "best_tr_recon": tR, + "best_tr_kl": tK, + } + best_sd = {k: w.cpu().clone() for k, w in model.state_dict().items()} + else: + wait += 1 if wait > v["patience"]: print(f" early stop at epoch {ep}") @@ -687,7 +693,61 @@ def block_maf_stats(G012: np.ndarray) -> dict: "maf_frac_lt_20pct": float(np.mean(maf < 0.20)), } -def eval_genotype_metrics(model, x_t, loss_type, stats): +def pca_block_baseline(geno_tr: np.ndarray, geno_va: np.ndarray, n_components: int): + """ + Fit PCA on training genotype matrix; reconstruct validation set. + Uses the same number of components as the VAE latent dimension. + Returns (pca_conc_va, pca_r2_va) or (nan, nan) on any failure. + """ + try: + n_comp = min(n_components, geno_tr.shape[0] - 1, geno_tr.shape[1]) + if n_comp < 1: + raise ValueError( + f"effective n_components={n_comp} < 1 " + f"(requested {n_components}, train_n={geno_tr.shape[0]}, p={geno_tr.shape[1]})" + ) + pca = PCA(n_components=n_comp) + pca.fit(geno_tr.astype(np.float64)) + rec_cont = pca.inverse_transform( + pca.transform(geno_va.astype(np.float64)) + ).astype(np.float32) + rec_cont = np.clip(rec_cont, 0.0, 2.0) + truth_cont = geno_va.astype(np.float32) + # R² on continuous dosage — mirrors the r2_va definition in eval_genotype_metrics + var = float(np.var(truth_cont)) + r2 = float(1.0 - np.mean((truth_cont - rec_cont) ** 2) / (var + 1e-12)) + # concordance after rounding to {0,1,2} + pred = np.clip(np.round(rec_cont), 0, 2).astype(np.int8) + truth = np.clip(np.round(truth_cont), 0, 2).astype(np.int8) + conc = float(np.mean(pred == truth)) + return conc, r2 + except Exception as exc: + print(f"[pca_baseline] WARNING: failed (n_components={n_components}): {exc}") + return float("nan"), float("nan") + + +def extract_pca_block_emb(geno: np.ndarray, tr_ix: np.ndarray, n_components: int): + """Fit PCA on training subjects only; transform all subjects. + + Returns (emb, pca_model, n_components_used). + emb shape: (N_all, d) float32, where d = min(n_components, len(tr_ix)-1, n_features). + """ + n_comp = max(1, min(n_components, len(tr_ix) - 1, geno.shape[1])) + pca = PCA(n_components=n_comp) + pca.fit(geno[tr_ix].astype(np.float64)) + emb = pca.transform(geno.astype(np.float64)).astype(np.float32) + return emb, pca, n_comp + + +def majority_baseline_from_train(train_geno012: np.ndarray) -> np.ndarray: + """Return per-SNP majority genotype class (0/1/2) learned from training subjects only.""" + g = np.clip(np.round(train_geno012), 0, 2).astype(np.int8) + counts = np.stack([(g == c).sum(axis=0) for c in (0, 1, 2)], axis=1) # (P, 3) + return counts.argmax(axis=1).astype(np.int8) # (P,) + + +def eval_genotype_metrics(model, x_t, loss_type, stats, baseline_mode=None, + ld_corr_repeats=1, ld_corr_max_snps=200): """ Returns: - conc: exact match fraction (your current metric) @@ -730,16 +790,22 @@ def eval_genotype_metrics(model, x_t, loss_type, stats): pred = np.clip(np.round(pred_cont), 0, 2).astype(np.int8) truth = np.clip(np.round(truth_cont), 0, 2).astype(np.int8) - # ld = ld_corr_score(truth.round().clip(0,2).astype(np.float32), pred.clip(0, None), max_snps=200, seed=0) - ld = ld_corr_score(truth, pred, max_snps=200, seed=0) + if ld_corr_repeats <= 1: + ld = ld_corr_score(truth, pred, max_snps=ld_corr_max_snps, seed=0) + else: + ld = float(np.mean([ + ld_corr_score(truth, pred, max_snps=ld_corr_max_snps, seed=s) + for s in range(ld_corr_repeats) + ])) conc = float(np.mean(pred == truth)) # baseline: per-SNP majority-class predictor - # (compute mode of truth for each SNP column, then compare) - # truth: (N, P). mode per column: - # counts shape (P, 3) - counts = np.stack([(truth == g).sum(axis=0) for g in (0, 1, 2)], axis=1) # (P,3) - mode = counts.argmax(axis=1).astype(np.int8) # (P,) + # Baseline mode can be supplied from training data to avoid validation/test leakage. + if baseline_mode is not None: + mode = np.asarray(baseline_mode, dtype=np.int8) + else: + counts = np.stack([(truth == g).sum(axis=0) for g in (0, 1, 2)], axis=1) # (P,3) + mode = counts.argmax(axis=1).astype(np.int8) base_pred = np.broadcast_to(mode, truth.shape) # (N,P) base_conc = float(np.mean(base_pred == truth)) @@ -776,6 +842,47 @@ def eval_genotype_metrics(model, x_t, loss_type, stats): } +def compute_latent_diagnostics(model: "BlockVAE", va_x: torch.Tensor) -> dict: + """Encode validation set with the best checkpoint and compute per-dim latent stats. + + Returns scalar summary keys plus private array keys prefixed with '_'. + Callers should pop the '_*' keys to get the arrays before storing the dict. + """ + model.eval() + with torch.no_grad(): + mu, lv = model.encode(va_x) + + mu_np = mu.cpu().numpy() # (N, d) + lv_np = lv.cpu().numpy() # (N, d) + + kl_per_dim = -0.5 * (1.0 + lv_np - mu_np ** 2 - np.exp(lv_np)) # (N, d) + kl_per_dim_mean = kl_per_dim.mean(axis=0) # (d,) + mu_var_per_dim = mu_np.var(axis=0) # (d,) + sigma_per_dim = np.exp(0.5 * lv_np).mean(axis=0) # (d,) + + d = mu_np.shape[1] + n_active = int((mu_var_per_dim > 0.01).sum()) + + return { + "n_active_latents": n_active, + "frac_dims_collapsed": round(float(1.0 - n_active / d), 4) if d > 0 else float("nan"), + "latent_underused": bool(n_active < d / 2), + "kl_per_dim_min": round(float(kl_per_dim_mean.min()), 6), + "kl_per_dim_median": round(float(np.median(kl_per_dim_mean)), 6), + "kl_per_dim_max": round(float(kl_per_dim_mean.max()), 6), + "mu_var_min": round(float(mu_var_per_dim.min()), 6), + "mu_var_median": round(float(np.median(mu_var_per_dim)), 6), + "mu_var_max": round(float(mu_var_per_dim.max()), 6), + "sigma_min": round(float(sigma_per_dim.min()), 6), + "sigma_median": round(float(np.median(sigma_per_dim)), 6), + "sigma_max": round(float(sigma_per_dim.max()), 6), + # arrays — pop before logging to CSV + "_kl_per_dim": kl_per_dim_mean, + "_mu_var_per_dim": mu_var_per_dim, + "_sigma_per_dim": sigma_per_dim, + } + + # ────────────────────────────────────────────────────────────────── # 7. MAIN PHASE-1 PIPELINE # ────────────────────────────────────────────────────────────────── @@ -784,6 +891,7 @@ def run_tuning(cfg): tc = cfg.get("tuning", {}) print("\n══════ Tuning Mode ══════") + print("[tuning] test_frac is ignored in tuning mode; tuning uses train/validation only and does not evaluate test.") print(f"Tuning config - loss: {tc.get('loss', 'N/A')}") print(f"Tuning config - metric: {tc.get('metric', 'N/A')}") print(f"Tuning config - blocks from config: {tc.get('blocks', [])}") @@ -847,20 +955,22 @@ def run_tuning(cfg): # Results results = [] for bid in tuning_blocks: + # Reseed per-block so each block is independently reproducible + set_seed(vc["seed"] + (hash(bid) & 0xFFFF)) if bid not in blocks: continue G = blocks[bid]["geno"] p = G.shape[1] d_block = latent_dim_for_p(p, cfg) if lt in ["ORD", "CAT"]: - tr_t, va_t, tr_y, va_y, stats = prepare_data(G, lt, tr_ix, va_ix) + tr_t, va_t, tr_y, va_y, stats, *_ = prepare_data(G, lt, tr_ix, va_ix) else: - tr_t, va_t, stats = prepare_data(G, lt, tr_ix, va_ix) + tr_t, va_t, stats, *_ = prepare_data(G, lt, tr_ix, va_ix) tr_y = va_y = None for drop, lr, beta_max in combos: # Reset seed for comparability - set_seed(vc["seed"]) + set_seed(vc["seed"] + (hash(bid) & 0xFFFF) + (hash((drop, lr, beta_max)) & 0xFFFF)) cfg_copy = copy.deepcopy(cfg) cfg_copy["vae"]["dropout"] = drop @@ -875,11 +985,49 @@ def run_tuning(cfg): if lt in ["ORD", "CAT"] and (tr_y is None or va_y is None): print(f"[tuning][WARN] expected targets for {lt} but tr_y/va_y missing (tr_y={tr_y}, va_y={va_y})") - model = BlockVAE(p, d_block, drop, loss_type=lt) + if lt == "CAT": + w, _ = compute_class_weights(blocks[bid]["geno"][tr_ix]) + w = np.clip(np.array(w, dtype=np.float32), 0.25, vc.get("cat_weight_clip", 10.0)).tolist() + model = BlockVAE(p, d_block, drop, loss_type="CAT", class_weights=w) + elif lt == "ORD" and bool(vc.get("ord_weighted", False)): + w, _ = compute_class_weights(blocks[bid]["geno"][tr_ix]) + w = np.clip(np.array(w, dtype=np.float32), 0.25, vc.get("ord_weight_clip", 10.0)).tolist() + model = BlockVAE(p, d_block, drop, loss_type="ORD", class_weights=w) + else: + model = BlockVAE(p, d_block, drop, loss_type=lt) + t0 = time.time() - log, best_epoch, best_metrics = train_block_vae(model, tr_t, va_t, cfg_copy, dev, None, tr_y=tr_y, va_y=va_y) + log, best_epoch, best_metrics = train_block_vae( + model, tr_t, va_t, cfg_copy, dev, None, tr_y=tr_y, va_y=va_y + ) dt = time.time() - t0 - m_va = eval_genotype_metrics(model, va_t, lt, stats) + + # (2) Train-derived majority baseline → collapse-aware gain. + # Matches run_phase1 so tuning and production are comparable. + baseline_mode = majority_baseline_from_train(blocks[bid]["geno"][tr_ix]) + ld_reps = int(vc.get("ld_corr_repeats", 1)) + ld_max_snps = int(vc.get("ld_corr_max_snps", 200)) + m_va = eval_genotype_metrics( + model, va_t, lt, stats, + baseline_mode=baseline_mode, + ld_corr_repeats=ld_reps, + ld_corr_max_snps=ld_max_snps, + ) + + # (3) Latent diagnostics on the best checkpoint + # (train_block_vae already restored best_sd into `model`). + diag = compute_latent_diagnostics(model, va_t) + # Tuning rows only need scalars; drop the per-dim arrays. + for _k in ("_kl_per_dim", "_mu_var_per_dim", "_sigma_per_dim"): + diag.pop(_k, None) + + conc_va = m_va["conc"] + base_conc_va = m_va["base_conc"] + gain_va = ( + conc_va - base_conc_va + if np.isfinite(conc_va) and np.isfinite(base_conc_va) + else float("nan") + ) results.append({ "loss": lt, @@ -889,10 +1037,25 @@ def run_tuning(cfg): "beta_max": beta_max, "best_epoch": best_epoch, "best_val_loss": best_metrics["best_val_loss"], + # reconstruction-side metrics "bal_acc_va": m_va["bal_acc"], "ld_corr_va": m_va["ld_corr"], - "conc_va": m_va["conc"], - "runtime_sec": dt + "conc_va": conc_va, + # (2) baseline + gain — collapse-aware selection target + "base_conc_va": base_conc_va, + "concordance_gain_va": gain_va, + # (3) latent diagnostics + "n_active_latents": diag["n_active_latents"], + "frac_dims_collapsed": diag["frac_dims_collapsed"], + "latent_underused": diag["latent_underused"], + "kl_per_dim_min": diag["kl_per_dim_min"], + "kl_per_dim_median": diag["kl_per_dim_median"], + "kl_per_dim_max": diag["kl_per_dim_max"], + "mu_var_min": diag["mu_var_min"], + "mu_var_median": diag["mu_var_median"], + "mu_var_max": diag["mu_var_max"], + "sigma_median": diag["sigma_median"], + "runtime_sec": dt, }) # Save results @@ -900,25 +1063,66 @@ def run_tuning(cfg): rdf.to_csv(tuning_dir / "tuning_results.csv", index=False) # Aggregate with std - agg = rdf.groupby(["dropout", "lr", "beta_max"]).agg({ - "best_val_loss": ["mean", "std"], - "bal_acc_va": ["mean", "std"], - "ld_corr_va": ["mean", "std"], - "conc_va": ["mean", "std"], - "block": "count" - }).reset_index() - agg.columns = ["dropout", "lr", "beta_max", "mean_best_val_loss", "std_best_val_loss", "mean_bal_acc_va", "std_bal_acc_va", "mean_ld_corr_va", "std_ld_corr_va", "mean_conc_va", "std_conc_va", "n_blocks"] + # Columns to aggregate as (mean, std) across blocks. + agg_metric_cols = [ + "best_val_loss", + "bal_acc_va", + "ld_corr_va", + "conc_va", + "base_conc_va", + "concordance_gain_va", + "n_active_latents", + "frac_dims_collapsed", + "kl_per_dim_median", + "mu_var_median", + "sigma_median", + ] + agg_spec = {c: ["mean", "std"] for c in agg_metric_cols} + agg_spec["block"] = "count" + + agg = rdf.groupby(["dropout", "lr", "beta_max"]).agg(agg_spec).reset_index() + + # Flatten MultiIndex columns: ("bal_acc_va","mean") -> "mean_bal_acc_va". + new_cols = [] + for c in agg.columns: + if isinstance(c, tuple): + name, stat = c + if name == "block" and stat == "count": + new_cols.append("n_blocks") + elif stat == "": + new_cols.append(name) + else: + new_cols.append(f"{stat}_{name}") + else: + new_cols.append(c) + agg.columns = new_cols + + # Optional: also flag configs where any block collapsed, useful for filtering. + collapsed = ( + rdf.assign(_col=rdf["latent_underused"].astype(bool)) + .groupby(["dropout", "lr", "beta_max"])["_col"] + .sum() + .reset_index() + .rename(columns={"_col": "n_blocks_collapsed"}) + ) + agg = agg.merge(collapsed, on=["dropout", "lr", "beta_max"], how="left") + agg.to_csv(tuning_dir / "tuning_summary.csv", index=False) # Select best metric = tc.get("metric", "bal_acc_va") mean_col = f"mean_{metric}" - if metric in ["bal_acc_va", "ld_corr_va", "conc_va"]: + higher_is_better = { + "bal_acc_va", "ld_corr_va", "conc_va", + "concordance_gain_va", "n_active_latents", + } + lower_is_better = {"best_val_loss", "frac_dims_collapsed"} + if metric in higher_is_better: best_row = agg.loc[agg[mean_col].idxmax()] - elif metric == "best_val_loss": + elif metric in lower_is_better: best_row = agg.loc[agg[mean_col].idxmin()] else: - best_row = agg.iloc[0] + raise ValueError(f"[tuning] unknown selection metric: {metric}") best_params = { "loss": str(lt), @@ -1089,14 +1293,27 @@ def run_phase1(cfg, *, config_path=None): pd.DataFrame({"IID": subjects}).to_csv(out/"subjects.csv", index=False) # ---- split (same for every loss type) ---- - N = len(subjects); n_val = int(N * vc["val_frac"]) - perm = np.random.permutation(N) - va_ix, tr_ix = perm[:n_val], perm[n_val:] + N = len(subjects) + n_test = int(N * vc.get("test_frac", 0.0)) + n_val = int(N * vc["val_frac"]) + if n_val < 1: + raise ValueError(f"val_frac={vc['val_frac']} yields 0 validation subjects for N={N}") + if N - n_test - n_val < 2: + raise ValueError( + f"Train set too small (N={N}, n_test={n_test}, n_val={n_val}). " + "Reduce test_frac or val_frac." + ) + perm = np.random.permutation(N) + test_ix = perm[:n_test] + va_ix = perm[n_test:n_test + n_val] + tr_ix = perm[n_test + n_val:] np.save(out/"train_idx.npy", tr_ix) np.save(out/"val_idx.npy", va_ix) - print(f"[split] train {len(tr_ix)} / val {len(va_ix)}") + np.save(out/"test_idx.npy", test_ix) + print(f"[split] train {len(tr_ix)} / val {len(va_ix)} / test {len(test_ix)}") pd.DataFrame({"IID": subjects[tr_ix]}).to_csv(out / "train_subjects.csv", index=False) pd.DataFrame({"IID": subjects[va_ix]}).to_csv(out / "val_subjects.csv", index=False) + pd.DataFrame({"IID": subjects[test_ix]}).to_csv(out / "test_subjects.csv", index=False) # Representative blocks rep_cfg = cfg.get("representative", {}) @@ -1116,7 +1333,11 @@ def run_phase1(cfg, *, config_path=None): out_rep = out / "representative_blocks" out_rep.mkdir(exist_ok=True) + ld_reps = int(vc.get("ld_corr_repeats", 1)) + ld_max_snps = int(vc.get("ld_corr_max_snps", 200)) + rows = [] # summary collector + _cm_cache: dict = {} # (lt, bid) → confusion_matrix; filled during training for deferred auto-rep save for lt in cfg["loss_functions"]: print(f"\n{'═'*55}\n Loss: {lt}\n{'═'*55}") @@ -1127,9 +1348,15 @@ def run_phase1(cfg, *, config_path=None): emb_dict = {} # block_id → (N, d) for bid in block_ids: + # Reseed per-block so each block is independently reproducible + set_seed(vc["seed"] + (hash(bid) & 0xFFFF)) G = blocks[bid]["geno"] n0_tr, n1_tr, n2_tr = geno_class_counts(G[tr_ix]) n0_va, n1_va, n2_va = geno_class_counts(G[va_ix]) + if len(test_ix) > 0: + n0_test, n1_test, n2_test = geno_class_counts(G[test_ix]) + else: + n0_test = n1_test = n2_test = 0 p = G.shape[1] maf_stats = block_maf_stats(G) print(f"\n ── {bid} ({p} SNPs) ──") @@ -1137,16 +1364,21 @@ def run_phase1(cfg, *, config_path=None): d_block = latent_dim_for_p(p, cfg) if lt in ("CAT", "ORD"): - tr_t, va_t, tr_y, va_y, stats = prepare_data(G, lt, tr_ix, va_ix) + tr_t, va_t, tr_y, va_y, stats, te_t, te_y = prepare_data(G, lt, tr_ix, va_ix, test_ix) else: - tr_t, va_t, stats = prepare_data(G, lt, tr_ix, va_ix) - tr_y = va_y = None + tr_t, va_t, stats, te_t = prepare_data(G, lt, tr_ix, va_ix, test_ix) + tr_y = va_y = te_y = None if lt == "CAT": w, counts = compute_class_weights(blocks[bid]["geno"][tr_ix]) w = np.clip(np.array(w, dtype=np.float32), 0.25, vc.get("cat_weight_clip", 10.0)).tolist() model = BlockVAE(p, d_block, vc["dropout"], loss_type="CAT", class_weights=w) print(f" latent_dim {d_block} | CAT weights {w} counts {counts}") + elif lt == "ORD" and bool(vc.get("ord_weighted", False)): + w, counts = compute_class_weights(blocks[bid]["geno"][tr_ix]) + w = np.clip(np.array(w, dtype=np.float32), 0.25, vc.get("ord_weight_clip", 10.0)).tolist() + model = BlockVAE(p, d_block, vc["dropout"], loss_type="ORD", class_weights=w) + print(f" latent_dim {d_block} | ORD weights {w} counts {counts}") else: model = BlockVAE(p, d_block, vc["dropout"], loss_type=lt) print(f" latent_dim {d_block}") @@ -1158,33 +1390,92 @@ def run_phase1(cfg, *, config_path=None): log, best_epoch, best_metrics = train_block_vae(model, tr_t, va_t, cfg, dev,ld/"logs"/f"{bid}.csv",tr_y=tr_y, va_y=va_y) dt = time.time() - t0 - m_tr = eval_genotype_metrics(model, tr_t, lt, stats) - m_va = eval_genotype_metrics(model, va_t, lt, stats) + baseline_mode = majority_baseline_from_train(G[tr_ix]) + m_tr = eval_genotype_metrics(model, tr_t, lt, stats, baseline_mode=baseline_mode, + ld_corr_repeats=ld_reps, ld_corr_max_snps=ld_max_snps) + m_va = eval_genotype_metrics(model, va_t, lt, stats, baseline_mode=baseline_mode, + ld_corr_repeats=ld_reps, ld_corr_max_snps=ld_max_snps) + pca_conc_va, pca_r2_va = pca_block_baseline(G[tr_ix], G[va_ix], d_block) + has_test = te_t is not None + if has_test: + m_te = eval_genotype_metrics(model, te_t, lt, stats, baseline_mode=baseline_mode, + ld_corr_repeats=ld_reps, ld_corr_max_snps=ld_max_snps) + pca_conc_te, pca_r2_te = pca_block_baseline(G[tr_ix], G[test_ix], d_block) + else: + _nan = float("nan") + m_te = {k: _nan for k in ["conc", "base_conc", "bal_acc", "acc0", "acc1", "acc2", "r2", "ld_corr"]} + pca_conc_te = pca_r2_te = _nan + + # Latent collapse diagnostics (best checkpoint already loaded into model) + diag = compute_latent_diagnostics(model, va_t) + diag_dir = ld / "diagnostics" / bid + diag_dir.mkdir(parents=True, exist_ok=True) + np.save(diag_dir / "kl_per_dim.npy", diag.pop("_kl_per_dim")) + np.save(diag_dir / "mu_var_per_dim.npy", diag.pop("_mu_var_per_dim")) + np.save(diag_dir / "sigma_per_dim.npy", diag.pop("_sigma_per_dim")) + if diag.get("latent_underused"): + print(f" [WARN] latent underused: {diag['n_active_latents']}/{d_block} dims active") + + # KL/reconstruction balance at the best checkpoint + _bvk = best_metrics.get("best_val_kl", float("nan")) + _bvr = best_metrics.get("best_val_recon", float("nan")) + kl_recon_ratio = (_bvk / _bvr) if (np.isfinite(_bvr) and _bvr > 0) else float("nan") + beta_warmup_cfg = int(vc.get("beta_warmup", 50)) + beta_max_cfg = float(vc.get("beta_max", 0.5)) + beta_eff_at_best = round(beta_max_cfg * min(1.0, best_epoch / max(beta_warmup_cfg, 1)), 6) + best_before_full_beta = bool(best_epoch < beta_warmup_cfg) + te_str = f" | conc te {m_te['conc']:.4f}" if has_test and np.isfinite(m_te["conc"]) else "" print( - f" conc tr {m_tr['conc']:.4f} va {m_va['conc']:.4f} | " + f" conc tr {m_tr['conc']:.4f} va {m_va['conc']:.4f}{te_str} | " f"base va {m_va['base_conc']:.4f} | bal va {m_va['bal_acc']:.4f} | " f"r2 va {m_va['r2']:.4f} | ld_corr va {m_va['ld_corr']:.4f} ({dt:.1f}s)" ) - if bid in rep_blocks_set and "confusion_matrix" in m_va and m_va["confusion_matrix"] is not None: + _cm = m_va.get("confusion_matrix") + if _cm is not None: + _cm_cache[(lt, bid)] = _cm + if bid in rep_blocks_set and _cm is not None: rep_dir = out_rep / bid rep_dir.mkdir(exist_ok=True) - np.save(rep_dir / f"{bid}_{lt}_confusion.npy", m_va["confusion_matrix"]) - plot_confusion_matrix(m_va["confusion_matrix"], bid, lt, rep_dir) + np.save(rep_dir / f"{bid}_{lt}_confusion.npy", _cm) + plot_confusion_matrix(_cm, bid, lt, rep_dir) emb = extract_emb(model, G, lt, stats) + if cfg.get("vae", {}).get("standardize_embeddings", True): + scaler = StandardScaler() + scaler.fit(emb[tr_ix]) + emb = scaler.transform(emb).astype(np.float32) emb_dict[bid] = emb np.save(ld/"embeddings"/f"{bid}.npy", emb) torch.save(model.state_dict(), ld/"models"/f"{bid}.pt") fin = log[-1] + kl_final = fin.get("va_kl", float("nan")) + recon_final = fin.get("va_recon", float("nan")) + kl_at_best_over_kl_final = ( + round(_bvk / kl_final, 6) + if (np.isfinite(kl_final) and kl_final > 0) else float("nan") + ) + ord_w_active = (lt == "ORD" and bool(vc.get("ord_weighted", False))) rows.append(dict( loss=lt, block=bid, gene=blocks[bid]["gene"], latent_dim=d_block, n_snps=p, **maf_stats, params=npar, epochs=len(log), + ord_weighted=ord_w_active, best_epoch=best_epoch, n0_tr=n0_tr, n1_tr=n1_tr, n2_tr=n2_tr, n0_va=n0_va, n1_va=n1_va, n2_va=n2_va, + n0_test=n0_test, n1_test=n1_test, n2_test=n2_test, + # NOTE: best_val_loss is not comparable across loss functions (MSE/ORD/CAT) **best_metrics, + # KL/reconstruction balance at the best checkpoint + kl_recon_ratio=round(kl_recon_ratio, 6) if np.isfinite(kl_recon_ratio) else float("nan"), + beta_eff_at_best=beta_eff_at_best, + best_before_full_beta=best_before_full_beta, + kl_final=round(kl_final, 6) if np.isfinite(kl_final) else float("nan"), + recon_final=round(recon_final, 6) if np.isfinite(recon_final) else float("nan"), + kl_at_best_over_kl_final=kl_at_best_over_kl_final, + # Latent collapse diagnostics + **diag, conc_tr=round(m_tr["conc"],4), conc_va=round(m_va["conc"],4), base_conc_va=round(m_va["base_conc"],4), @@ -1195,6 +1486,24 @@ def run_phase1(cfg, *, config_path=None): acc2_va=round(m_va["acc2"],4) if not np.isnan(m_va["acc2"]) else np.nan, r2_va=round(m_va["r2"],4), ld_corr_va=round(m_va["ld_corr"], 4) if np.isfinite(m_va["ld_corr"]) else np.nan, + compression_ratio=round(d_block / p, 6), + pca_conc_va=round(pca_conc_va, 4) if np.isfinite(pca_conc_va) else np.nan, + pca_r2_va=round(pca_r2_va, 4) if np.isfinite(pca_r2_va) else np.nan, + vae_minus_pca_conc_va=round(m_va["conc"] - pca_conc_va, 4) if np.isfinite(pca_conc_va) else np.nan, + vae_minus_pca_r2_va=round(m_va["r2"] - pca_r2_va, 4) if np.isfinite(pca_r2_va) else np.nan, + conc_test=round(m_te["conc"], 4) if np.isfinite(m_te["conc"]) else np.nan, + base_conc_test=round(m_te["base_conc"], 4) if np.isfinite(m_te["base_conc"]) else np.nan, + concordance_gain_test=round(m_te["conc"] - m_te["base_conc"], 4) if np.isfinite(m_te["conc"]) and np.isfinite(m_te["base_conc"]) else np.nan, + bal_acc_test=round(m_te["bal_acc"], 4) if np.isfinite(m_te["bal_acc"]) else np.nan, + acc0_test=round(m_te["acc0"], 4) if np.isfinite(m_te["acc0"]) else np.nan, + acc1_test=round(m_te["acc1"], 4) if np.isfinite(m_te["acc1"]) else np.nan, + acc2_test=round(m_te["acc2"], 4) if np.isfinite(m_te["acc2"]) else np.nan, + r2_test=round(m_te["r2"], 4) if np.isfinite(m_te["r2"]) else np.nan, + ld_corr_test=round(m_te["ld_corr"], 4) if np.isfinite(m_te["ld_corr"]) else np.nan, + pca_conc_test=round(pca_conc_te, 4) if np.isfinite(pca_conc_te) else np.nan, + pca_r2_test=round(pca_r2_te, 4) if np.isfinite(pca_r2_te) else np.nan, + vae_minus_pca_conc_test=round(m_te["conc"] - pca_conc_te, 4) if np.isfinite(m_te["conc"]) and np.isfinite(pca_conc_te) else np.nan, + vae_minus_pca_r2_test=round(m_te["r2"] - pca_r2_te, 4) if np.isfinite(m_te["r2"]) and np.isfinite(pca_r2_te) else np.nan, sec=round(dt,1) )) @@ -1217,19 +1526,192 @@ def run_phase1(cfg, *, config_path=None): # ---- save block-order metadata (Phase 2 needs it) ---- meta = [dict(pos=i, block_id=b, gene=blocks[b]["gene"], - n_snps=blocks[b]["n_snps"]) for i, b in enumerate(block_ids)] + n_snps=blocks[b]["n_snps"], + latent_dim=latent_dim_for_p(blocks[b]["n_snps"], cfg)) + for i, b in enumerate(block_ids)] pd.DataFrame(meta).to_csv(out/"block_order.csv", index=False) + # ---- PCA baseline embeddings (Phase 2 compatible) ---- + print(f"\n{'═'*55}\n Loss: PCA (baseline)\n{'═'*55}") + pca_ld = out / "PCA" + (pca_ld / "embeddings").mkdir(parents=True, exist_ok=True) + + pca_standardize = cfg.get("pca", {}).get("standardize", True) + print(f" pca_standardize={pca_standardize}") + + pca_emb_dict: dict = {} + pca_rows: list = [] + + for bid in block_ids: + G_pca = blocks[bid]["geno"] + p_pca = G_pca.shape[1] + d_pca = latent_dim_for_p(p_pca, cfg) + print(f"\n ── {bid} ({p_pca} SNPs, latent_dim={d_pca}) ──") + try: + emb_pca_raw, pca_model, n_comp_used = extract_pca_block_emb(G_pca, tr_ix, d_pca) + evr = pca_model.explained_variance_ratio_ + + raw_mean_abs = float(np.abs(emb_pca_raw).mean()) + raw_std = float(emb_pca_raw.std()) + + if pca_standardize: + scaler = StandardScaler() + scaler.fit(emb_pca_raw[tr_ix]) + emb_pca = scaler.transform(emb_pca_raw).astype(np.float32) + else: + emb_pca = emb_pca_raw + + scaled_mean_abs = float(np.abs(emb_pca).mean()) + scaled_std = float(emb_pca.std()) + + pca_emb_dict[bid] = emb_pca + np.save(pca_ld / "embeddings" / f"{bid}.npy", emb_pca) + pca_rows.append(dict( + block_id=bid, + gene=blocks[bid]["gene"], + n_snps=p_pca, + latent_dim=d_pca, + n_components_used=n_comp_used, + explained_variance_ratio_sum=round(float(evr.sum()), 4), + explained_variance_ratio_pc1=round(float(evr[0]), 4), + raw_mean_abs=round(raw_mean_abs, 6), + raw_std=round(raw_std, 6), + scaled_mean_abs=round(scaled_mean_abs, 6), + scaled_std=round(scaled_std, 6), + )) + print( + f" n_comp={n_comp_used} | explained_var={evr.sum():.4f} | " + f"raw_std={raw_std:.4f} → scaled_std={scaled_std:.4f}" + ) + except Exception as exc: + print(f"[pca_embeddings] WARNING: {bid} failed — {exc}") + + if pca_emb_dict: + pca_dims_full = np.array([ + pca_emb_dict[b].shape[1] if b in pca_emb_dict + else latent_dim_for_p(blocks[b]["n_snps"], cfg) + for b in block_ids + ], dtype=np.int32) + pca_max_d = int(pca_dims_full.max()) + N_pca = next(iter(pca_emb_dict.values())).shape[0] + B_pca = len(block_ids) + pca_stack = np.zeros((N_pca, B_pca, pca_max_d), dtype=np.float32) + for j, b in enumerate(block_ids): + if b in pca_emb_dict: + e = pca_emb_dict[b].astype(np.float32) + pca_stack[:, j, :e.shape[1]] = e + np.save(pca_ld / "embeddings" / "all_blocks.npy", pca_stack) + np.save(pca_ld / "embeddings" / "all_blocks_latent_dims.npy", pca_dims_full) + pd.DataFrame({"IID": subjects}).to_csv(pca_ld / "subjects.csv", index=False) + # Use actual clamped dims, not requested d_block, so Phase 2 latent mask is correct. + pca_block_meta = [ + {**m, "latent_dim": int(pca_dims_full[i])} for i, m in enumerate(meta) + ] + pd.DataFrame(pca_block_meta).to_csv(pca_ld / "block_order.csv", index=False) + pd.DataFrame(pca_rows).to_csv(pca_ld / "pca_summary.csv", index=False) + print( + f"\n stacked PCA embeddings (padded) {pca_stack.shape} | max_d={pca_max_d} | " + f"dims={dict(pd.Series(pca_dims_full).value_counts().sort_index())}" + ) + else: + print("[pca_embeddings] WARNING: no PCA embeddings produced — PCA/ folder not written") + # ---- summary table ---- sdf = pd.DataFrame(rows) sdf.to_csv(out/"vae_summary.csv", index=False) print(f"\n{'═'*55}\n Phase 1 complete — summary\n{'═'*55}") print(sdf.to_string(index=False, max_cols=None)) + # ---- aggregate summary (paper-style, one row per loss function) ---- + agg_rows = [] + for lt_agg in sdf["loss"].unique(): + sub = sdf[sdf["loss"] == lt_agg].copy() + w = sub["n_snps"].values.astype(np.float64) + total_snps = int(sub["n_snps"].sum()) + total_latent = int(sub["latent_dim"].sum()) + + conc_vals = sub["conc_va"].values.astype(np.float64) + weighted_conc = float(np.nansum(conc_vals * w) / np.nansum(np.where(np.isfinite(conc_vals), w, 0))) + + pca_vals = sub["pca_conc_va"].values.astype(np.float64) if "pca_conc_va" in sub.columns else np.full(len(sub), np.nan) + valid = np.isfinite(pca_vals) + weighted_pca = ( + float(np.sum(pca_vals[valid] * w[valid]) / (np.sum(w[valid]) + 1e-12)) + if valid.any() else float("nan") + ) + vae_minus_pca = round(weighted_conc - weighted_pca, 4) if np.isfinite(weighted_pca) else float("nan") + + # test weighted metrics + conc_te_vals = sub["conc_test"].values.astype(np.float64) if "conc_test" in sub.columns else np.full(len(sub), np.nan) + valid_te = np.isfinite(conc_te_vals) + weighted_conc_te = ( + float(np.sum(conc_te_vals[valid_te] * w[valid_te]) / (np.sum(w[valid_te]) + 1e-12)) + if valid_te.any() else float("nan") + ) + pca_te_vals = sub["pca_conc_test"].values.astype(np.float64) if "pca_conc_test" in sub.columns else np.full(len(sub), np.nan) + valid_te_pca = np.isfinite(pca_te_vals) + weighted_pca_te = ( + float(np.sum(pca_te_vals[valid_te_pca] * w[valid_te_pca]) / (np.sum(w[valid_te_pca]) + 1e-12)) + if valid_te_pca.any() else float("nan") + ) + vae_minus_pca_te = round(weighted_conc_te - weighted_pca_te, 4) if np.isfinite(weighted_conc_te) and np.isfinite(weighted_pca_te) else float("nan") + + agg_rows.append({ + "loss": lt_agg, + "n_blocks": len(sub), + "total_snps": total_snps, + "total_latent_dim": total_latent, + "overall_compression_ratio": round(total_latent / (total_snps + 1e-12), 6), + "weighted_conc_va": round(weighted_conc, 4), + "weighted_pca_conc_va": round(weighted_pca, 4) if np.isfinite(weighted_pca) else float("nan"), + "weighted_vae_minus_pca_conc_va": vae_minus_pca, + "mean_bal_acc_va": round(float(sub["bal_acc_va"].mean()), 4), + "mean_acc0_va": round(float(sub["acc0_va"].mean()), 4), + "mean_acc1_va": round(float(sub["acc1_va"].mean()), 4), + "mean_acc2_va": round(float(sub["acc2_va"].mean()), 4), + "mean_r2_va": round(float(sub["r2_va"].mean()), 4), + "mean_pca_r2_va": round(float(sub["pca_r2_va"].mean()), 4) if "pca_r2_va" in sub.columns else float("nan"), + "mean_ld_corr_va": round(float(sub["ld_corr_va"].mean()), 4), + "weighted_conc_test": round(weighted_conc_te, 4) if np.isfinite(weighted_conc_te) else float("nan"), + "weighted_pca_conc_test": round(weighted_pca_te, 4) if np.isfinite(weighted_pca_te) else float("nan"), + "weighted_vae_minus_pca_conc_test": vae_minus_pca_te, + "mean_bal_acc_test": round(float(sub["bal_acc_test"].mean()), 4) if "bal_acc_test" in sub.columns else float("nan"), + "mean_acc0_test": round(float(sub["acc0_test"].mean()), 4) if "acc0_test" in sub.columns else float("nan"), + "mean_acc1_test": round(float(sub["acc1_test"].mean()), 4) if "acc1_test" in sub.columns else float("nan"), + "mean_acc2_test": round(float(sub["acc2_test"].mean()), 4) if "acc2_test" in sub.columns else float("nan"), + "mean_r2_test": round(float(sub["r2_test"].mean()), 4) if "r2_test" in sub.columns else float("nan"), + "mean_pca_r2_test": round(float(sub["pca_r2_test"].mean()), 4) if "pca_r2_test" in sub.columns else float("nan"), + "mean_ld_corr_test": round(float(sub["ld_corr_test"].mean()), 4) if "ld_corr_test" in sub.columns else float("nan"), + }) + + agg_df = pd.DataFrame(agg_rows) + agg_df.to_csv(out / "phase1_aggregate_summary.csv", index=False) + for _, agg_row in agg_df.iterrows(): + wpca_va = f"{agg_row['weighted_pca_conc_va']:.4f}" if np.isfinite(float(agg_row["weighted_pca_conc_va"])) else "nan" + wvae_te = f"{agg_row['weighted_conc_test']:.4f}" if np.isfinite(float(agg_row["weighted_conc_test"])) else "nan" + wpca_te = f"{agg_row['weighted_pca_conc_test']:.4f}" if np.isfinite(float(agg_row["weighted_pca_conc_test"])) else "nan" + print( + f"[phase1][{agg_row['loss']}] aggregate: " + f"{agg_row['total_snps']} SNPs | " + f"{agg_row['total_latent_dim']} latent dims | " + f"compression {agg_row['overall_compression_ratio']:.4f} | " + f"val VAE {agg_row['weighted_conc_va']:.4f} vs PCA {wpca_va} | " + f"test VAE {wvae_te} vs PCA {wpca_te}" + ) + # ---- representative blocks selection ---- if not manual_rep: rep_blocks = select_representative_blocks(sdf, cfg, block_ids) rep_blocks_set = set(rep_blocks) + # write confusion matrices for auto-selected blocks from cache + for bid in rep_blocks_set: + rep_dir = out_rep / bid + rep_dir.mkdir(exist_ok=True) + for lt in cfg["loss_functions"]: + _cm = _cm_cache.get((lt, bid)) + if _cm is not None: + np.save(rep_dir / f"{bid}_{lt}_confusion.npy", _cm) + plot_confusion_matrix(_cm, bid, lt, rep_dir) # ---- output organization ---- out_summary = out / "summary" @@ -1240,7 +1722,7 @@ def run_phase1(cfg, *, config_path=None): out_rep.mkdir(exist_ok=True) # Move summary - sdf.to_csv(out_summary / "vae_summary.csv", index=False) + agg_df.to_csv(out_summary / "phase1_aggregate_summary.csv", index=False) # Aggregate plots plot_aggregate_boxes(sdf, out_agg_plots) @@ -1283,7 +1765,144 @@ def run_phase1(cfg, *, config_path=None): # ────────────────────────────────────────────────────────────────── -# 8. CLI (validate_cfg merged from 01_phase1_block_embedding.py) +# 8. PCA-ONLY REGENERATION +# ────────────────────────────────────────────────────────────────── +def run_pca_only(cfg, *, config_path=None): + """Regenerate PCA/ baseline folder only, reusing existing train/val/test split. + + Raises FileNotFoundError if split index files are missing from output_dir — + they must be created by a prior run_phase1() call; we never create a new split. + """ + t0_run = time.time() + out = Path(cfg["data"]["output_dir"]) + + # ---- require existing split indices ---- + for fname in ("train_idx.npy", "val_idx.npy", "test_idx.npy"): + fp = out / fname + if not fp.exists(): + raise FileNotFoundError( + f"[pca-only] Required split file missing: {fp}\n" + "Run run_phase1 (or VAE_phase1.py without --pca-only) first to create the split." + ) + tr_ix = np.load(out / "train_idx.npy") + va_ix = np.load(out / "val_idx.npy") + test_ix = np.load(out / "test_idx.npy") + print(f"[pca-only] reusing split: train={len(tr_ix)} val={len(va_ix)} test={len(test_ix)}") + + # ---- load blocks ---- + print("\n══════ Loading data (pca-only) ══════") + bdf = load_block_defs(cfg["data"]["block_def"]) + blocks, subjects, block_ids = load_all_blocks(bdf, cfg["data"]["raw_dir"]) + + # ---- PCA ---- + pca_ld = out / "PCA" + (pca_ld / "embeddings").mkdir(parents=True, exist_ok=True) + + pca_standardize = cfg.get("pca", {}).get("standardize", True) + print(f" pca_standardize={pca_standardize}") + + pca_emb_dict: dict = {} + pca_rows: list = [] + + for bid in block_ids: + G_pca = blocks[bid]["geno"] + p_pca = G_pca.shape[1] + d_pca = latent_dim_for_p(p_pca, cfg) + print(f"\n ── {bid} ({p_pca} SNPs, latent_dim={d_pca}) ──") + try: + emb_pca_raw, pca_model, n_comp_used = extract_pca_block_emb(G_pca, tr_ix, d_pca) + evr = pca_model.explained_variance_ratio_ + + raw_mean_abs = float(np.abs(emb_pca_raw).mean()) + raw_std = float(emb_pca_raw.std()) + + if pca_standardize: + scaler = StandardScaler() + scaler.fit(emb_pca_raw[tr_ix]) + emb_pca = scaler.transform(emb_pca_raw).astype(np.float32) + else: + emb_pca = emb_pca_raw + + scaled_mean_abs = float(np.abs(emb_pca).mean()) + scaled_std = float(emb_pca.std()) + + pca_emb_dict[bid] = emb_pca + np.save(pca_ld / "embeddings" / f"{bid}.npy", emb_pca) + pca_rows.append(dict( + block_id=bid, + gene=blocks[bid]["gene"], + n_snps=p_pca, + latent_dim=d_pca, + n_components_used=n_comp_used, + explained_variance_ratio_sum=round(float(evr.sum()), 4), + explained_variance_ratio_pc1=round(float(evr[0]), 4), + raw_mean_abs=round(raw_mean_abs, 6), + raw_std=round(raw_std, 6), + scaled_mean_abs=round(scaled_mean_abs, 6), + scaled_std=round(scaled_std, 6), + )) + print( + f" n_comp={n_comp_used} | explained_var={evr.sum():.4f} | " + f"raw_std={raw_std:.4f} → scaled_std={scaled_std:.4f}" + ) + except Exception as exc: + print(f"[pca_embeddings] WARNING: {bid} failed — {exc}") + + if not pca_emb_dict: + raise RuntimeError("[pca-only] No PCA embeddings produced — check block data.") + + # ---- stack + save ---- + # Read block_order.csv for gene/n_snps metadata (written by earlier run_phase1). + # If not present, build it from loaded blocks. + block_order_fp = out / "block_order.csv" + if block_order_fp.exists(): + meta_df = pd.read_csv(block_order_fp) + meta = meta_df.to_dict("records") + else: + meta = [dict(pos=i, block_id=b, gene=blocks[b]["gene"], + n_snps=blocks[b]["n_snps"], + latent_dim=latent_dim_for_p(blocks[b]["n_snps"], cfg)) + for i, b in enumerate(block_ids)] + + pca_dims_full = np.array([ + pca_emb_dict[b].shape[1] if b in pca_emb_dict + else latent_dim_for_p(blocks[b]["n_snps"], cfg) + for b in block_ids + ], dtype=np.int32) + pca_max_d = int(pca_dims_full.max()) + N_pca = next(iter(pca_emb_dict.values())).shape[0] + B_pca = len(block_ids) + pca_stack = np.zeros((N_pca, B_pca, pca_max_d), dtype=np.float32) + for j, b in enumerate(block_ids): + if b in pca_emb_dict: + e = pca_emb_dict[b].astype(np.float32) + pca_stack[:, j, :e.shape[1]] = e + + np.save(pca_ld / "embeddings" / "all_blocks.npy", pca_stack) + np.save(pca_ld / "embeddings" / "all_blocks_latent_dims.npy", pca_dims_full) + pd.DataFrame({"IID": subjects}).to_csv(pca_ld / "subjects.csv", index=False) + pca_block_meta = [ + {**m, "latent_dim": int(pca_dims_full[i])} for i, m in enumerate(meta) + ] + pd.DataFrame(pca_block_meta).to_csv(pca_ld / "block_order.csv", index=False) + pd.DataFrame(pca_rows).to_csv(pca_ld / "pca_summary.csv", index=False) + + print( + f"\n stacked PCA embeddings (padded) {pca_stack.shape} | max_d={pca_max_d} | " + f"dims={dict(pd.Series(pca_dims_full).value_counts().sort_index())}" + ) + _write_run_metadata( + pca_ld, + config_path=config_path or out / "config_phase1.yaml", + cfg=cfg, + t0=t0_run, + t1=time.time(), + ) + print(f"\n[pca-only] complete (took {time.time() - t0_run:.1f}s)") + + +# ────────────────────────────────────────────────────────────────── +# 9. CLI (validate_cfg merged from 01_phase1_block_embedding.py) # ────────────────────────────────────────────────────────────────── def validate_cfg(cfg): """Pre-flight checks: verify required paths exist and create output dir.""" @@ -1305,6 +1924,8 @@ def validate_cfg(cfg): ap.add_argument("--config", default=None, help="Path to Phase1 config YAML (default: configs/config_phase1.yaml)") ap.add_argument("--tune", action="store_true", help="Run hyperparameter tuning mode") ap.add_argument("--dry-run", action="store_true", help="Display configuration without running") + ap.add_argument("--pca-only", action="store_true", + help="Regenerate PCA/ baseline folder only; reuses existing train/val/test split from output_dir") ap.add_argument("--save-config", action="store_true", help="write default YAML and exit") args = ap.parse_args() @@ -1331,17 +1952,30 @@ def validate_cfg(cfg): print(f"[phase1] output_dir={out_dir}") if args.dry_run: + if args.pca_only: + for fname in ("train_idx.npy", "val_idx.npy", "test_idx.npy"): + fp = out_dir / fname + status = "EXISTS" if fp.exists() else "MISSING" + print(f"[dry-run][pca-only] split file {fname}: {status}") print("[phase1] dry-run complete; no pipeline executed.") sys.exit(0) t0 = time.time() - if args.tune: + if args.pca_only: + run_pca_only(cfg, config_path=resolved_config) + elif args.tune: run_tuning(cfg) else: run_phase1(cfg, config_path=resolved_config) # Post-run output validation - if args.tune: + if args.pca_only: + pca_emb_dir = out_dir / "PCA" / "embeddings" + if pca_emb_dir.exists() and (pca_emb_dir / "all_blocks.npy").exists(): + print(f"[validation] PCA embeddings found: {pca_emb_dir}") + else: + print(f"[validation] WARNING: no PCA embeddings at {pca_emb_dir}") + elif args.tune: tuning_dir = out_dir / "tuning" for p in [tuning_dir / "best_params.yaml", tuning_dir / "tuning_results.csv", @@ -1361,6 +1995,11 @@ def validate_cfg(cfg): print(f"[validation] embeddings found for {lt}") else: print(f"[validation] WARNING: no embeddings for {lt}") + pca_emb_dir = out_dir / "PCA" / "embeddings" + if pca_emb_dir.exists() and (pca_emb_dir / "all_blocks.npy").exists(): + print(f"[validation] embeddings found for PCA") + else: + print(f"[validation] WARNING: no embeddings for PCA") print(f"[validation] phase1 complete: {out_dir}") print(f"\n[phase1] complete (took {time.time() - t0:.1f}s)") \ No newline at end of file diff --git a/scripts/core/attention_phase2.py b/scripts/core/attention_phase2.py index 73e39ca..b27e077 100644 --- a/scripts/core/attention_phase2.py +++ b/scripts/core/attention_phase2.py @@ -86,6 +86,8 @@ # New: self-attention extraction controls "extract_self_attn": True, "save_full_self_attn": False, # can be very large: (N, H, B, B) per layer + # Multi-token pooling: 1 = original single-query pooling (default, backward compat) + "n_pool_tokens": 1, }, "clustering": { "k_range": [2, 3, 4, 5, 6, 8, 10], @@ -95,6 +97,19 @@ "umap_min_dist": 0.1, "umap_seed": 42, }, + "diagnostics": { + "enabled": True, + "save_initial_block_repr": True, + "compute_contextualization_change": True, + "compute_phase1_phase2_join": True, + }, + "baselines": { + "enabled": True, + "run_pca": True, + "run_mean_pool": True, + "pca_n_components": None, # None → use attention.d_model with cap + "pca_sweep": [64, 128, 256, 512] + }, "loss_functions": ["ORD", "MSE", "MSE_STD"], } @@ -128,19 +143,36 @@ def set_seed(seed): def get_device(device_cfg): if device_cfg == "cpu": - print("[device] CPU (as configured)") + print("[device] CPU") return torch.device("cpu") - if device_cfg == "cuda" and torch.cuda.is_available(): - print(f"[device] CUDA — {torch.cuda.get_device_name(0)}") - return torch.device("cuda") - if device_cfg == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + if device_cfg == "cuda": + if torch.cuda.is_available(): + print(f"[device] CUDA — {torch.cuda.get_device_name(0)}") + return torch.device("cuda") + print("[device] CUDA requested but not available — falling back to CPU") + return torch.device("cpu") + if device_cfg == "mps": + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + try: + torch.zeros(2, device="mps") + print("[device] Apple MPS") + return torch.device("mps") + except Exception: + pass + print("[device] MPS requested but not available — falling back to CPU") + return torch.device("cpu") + # "auto" or anything else: try MPS → CUDA → CPU + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): try: torch.zeros(2, device="mps") - print("[device] Apple MPS") + print("[device] Apple MPS (auto)") return torch.device("mps") except Exception: pass - print("[device] CPU (fallback)") + if torch.cuda.is_available(): + print(f"[device] CUDA (auto) — {torch.cuda.get_device_name(0)}") + return torch.device("cuda") + print("[device] CPU (auto)") return torch.device("cpu") @@ -151,14 +183,16 @@ def load_phase1(p1_dir: str, loss_functions: list): """ Returns ------- - subjects : ndarray of IID strings (N,) - tr_ix : ndarray (n_train,) - va_ix : ndarray (n_val,) - block_meta : DataFrame [pos, block_id, gene, n_snps] - embeddings : dict {loss_type: ndarray (N, B, d)} + subjects : ndarray of IID strings (N,) + tr_ix : ndarray (n_train,) + va_ix : ndarray (n_val,) + te_ix : ndarray (n_test,) + block_meta : DataFrame [pos, block_id, gene, n_snps] + embeddings : dict {loss_type: ndarray (N, B, d)} + latent_dims_per_loss : dict {loss_type: list[int] | None} """ p1 = Path(p1_dir) - required = ["subjects.csv", "train_idx.npy", "val_idx.npy", "block_order.csv"] + required = ["subjects.csv", "train_idx.npy", "val_idx.npy", "test_idx.npy", "block_order.csv"] for f in required: fp = p1 / f if not fp.exists(): @@ -167,9 +201,15 @@ def load_phase1(p1_dir: str, loss_functions: list): subjects = pd.read_csv(p1 / "subjects.csv")["IID"].astype(str).values tr_ix = np.load(p1 / "train_idx.npy") va_ix = np.load(p1 / "val_idx.npy") + te_ix = np.load(p1 / "test_idx.npy") block_meta = pd.read_csv(p1 / "block_order.csv") + n_subjects = len(subjects) + n_blocks_meta = len(block_meta) + has_latent_dim = "latent_dim" in block_meta.columns + embeddings = {} + latent_dims_per_loss = {} # per-loss actual dims, None means fall back to block_meta N, B, d_in = None, None, None for lt in loss_functions: @@ -177,14 +217,57 @@ def load_phase1(p1_dir: str, loss_functions: list): if not fp.exists(): raise FileNotFoundError(f"Missing stacked embeddings: {fp}") emb = np.load(fp) # (N, B, d) + + # sanity: subject count + if emb.shape[0] != n_subjects: + raise ValueError( + f"[{lt}] all_blocks.npy subject count {emb.shape[0]} " + f"!= subjects.csv count {n_subjects}" + ) + # sanity: block count + if emb.shape[1] != n_blocks_meta: + raise ValueError( + f"[{lt}] all_blocks.npy block count {emb.shape[1]} " + f"!= block_order.csv count {n_blocks_meta}" + ) + + # Prefer all_blocks_latent_dims.npy (records actual post-clamping dims) over + # block_order.csv["latent_dim"] (records the requested dims) for the sanity + # check and for building the per-loss latent mask in run_phase2(). + dims_fp = p1 / lt / "embeddings" / "all_blocks_latent_dims.npy" + if dims_fp.exists(): + lt_dims = np.load(dims_fp).astype(int) + if len(lt_dims) != n_blocks_meta: + raise ValueError( + f"[{lt}] all_blocks_latent_dims.npy has {len(lt_dims)} entries " + f"!= block_order.csv count {n_blocks_meta}" + ) + expected_max_d = int(lt_dims.max()) + latent_dims_per_loss[lt] = lt_dims.tolist() + dim_source = "all_blocks_latent_dims.npy" + elif has_latent_dim: + expected_max_d = int(block_meta["latent_dim"].max()) + latent_dims_per_loss[lt] = None # run_phase2 falls back to block_meta + dim_source = "block_order.csv" + else: + expected_max_d = None + latent_dims_per_loss[lt] = None + dim_source = None + + if expected_max_d is not None and emb.shape[2] != expected_max_d: + raise ValueError( + f"[{lt}] all_blocks.npy dim {emb.shape[2]} " + f"!= max_d {expected_max_d} (from {dim_source})" + ) + if N is None: N, B, d_in = emb.shape elif emb.shape != (N, B, d_in): raise ValueError(f"Embedding shape mismatch for {lt}: expected {(N, B, d_in)}, got {emb.shape}") embeddings[lt] = emb - print(f" [{lt:8s}] loaded embeddings {emb.shape}") + print(f" [{lt:8s}] loaded embeddings {emb.shape} dim_source={dim_source}") - return subjects, tr_ix, va_ix, block_meta, embeddings + return subjects, tr_ix, va_ix, te_ix, block_meta, embeddings, latent_dims_per_loss # ============================================================ @@ -242,14 +325,40 @@ def forward(self, x, return_attn=False): ff = self.linear2(self.ff_dropout(self.activation(self.linear1(x_norm)))) x = x + self.dropout2(ff) - return x, attn_weights if return_attn else None + return x, (attn_weights if return_attn else None) + + +class BlockProjector(nn.Module): + """Per-block input projection when blocks have heterogeneous latent dims. + + Accepts a padded (batch, B, max_d) tensor. For block i, slices the first + block_dims[i] columns and projects them independently to d_model. + Drop-in replacement for a shared nn.Linear(d_in, d_model) when all + block_dims are equal. + """ + + def __init__(self, block_dims: list, d_model: int): + super().__init__() + self.block_dims = list(block_dims) + self.projectors = nn.ModuleList([ + nn.Sequential(nn.Linear(d_i, d_model), nn.LayerNorm(d_model), nn.GELU()) + for d_i in block_dims + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """x: (batch, B, max_d) → (batch, B, d_model)""" + tokens = [ + self.projectors[i](x[:, i, :self.block_dims[i]]) + for i in range(len(self.block_dims)) + ] + return torch.stack(tokens, dim=1) class AttentionAggregator(nn.Module): """ Transformer-style model: (B, d_in) frozen block embeddings - -> projected block tokens + -> projected block tokens (per-block via BlockProjector, or shared Linear) -> contextualized by self-attention -> pooled to one subject embedding -> decoded back to (B, d_in) @@ -270,6 +379,8 @@ def __init__( n_layers: int = 2, d_ff: int = 128, dropout: float = 0.1, + block_dims: list = None, + n_pool_tokens: int = 1, ): super().__init__() self.n_blocks = n_blocks @@ -277,12 +388,21 @@ def __init__( self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers + self.n_pool_tokens = n_pool_tokens + self.emb_dim = n_pool_tokens * d_model - self.input_proj = nn.Sequential( - nn.Linear(d_in, d_model), - nn.LayerNorm(d_model), - nn.GELU(), - ) + if block_dims is not None: + if len(block_dims) != n_blocks: + raise ValueError( + f"block_dims length {len(block_dims)} != n_blocks {n_blocks}" + ) + self.input_proj = BlockProjector(block_dims, d_model) + else: + self.input_proj = nn.Sequential( + nn.Linear(d_in, d_model), + nn.LayerNorm(d_model), + nn.GELU(), + ) self.pos_emb = nn.Parameter(torch.randn(1, n_blocks, d_model) * 0.02) @@ -297,23 +417,29 @@ def __init__( ]) self.post_norm = nn.LayerNorm(d_model) - # learned pooling query - self.pool_query = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) + # learned pooling queries: (1, K, d_model) where K = n_pool_tokens + K = n_pool_tokens + self.pool_queries = nn.Parameter(torch.randn(1, K, d_model) * 0.02) self._scale = math.sqrt(d_model) + # embed_head and decoder operate on K*d_model; for K=1 this is identical to before self.embed_head = nn.Sequential( - nn.Linear(d_model, d_model), - nn.LayerNorm(d_model), + nn.Linear(K * d_model, K * d_model), + nn.LayerNorm(K * d_model), nn.GELU(), ) + # For K>1, avoid immediately compressing K*d_model back to a small d_ff. + # This makes multi-token pooling a true wider-bottleneck test. + decoder_hidden = max(d_ff, K * d_model) + self.decoder_hidden = decoder_hidden self.decoder = nn.Sequential( - nn.Linear(d_model, d_ff), + nn.Linear(K * d_model, decoder_hidden), nn.GELU(), nn.Dropout(dropout), - nn.Linear(d_ff, d_ff), + nn.Linear(decoder_hidden, decoder_hidden), nn.GELU(), - nn.Linear(d_ff, n_blocks * d_in), + nn.Linear(decoder_hidden, n_blocks * d_in), ) self._init_weights() @@ -333,8 +459,8 @@ def encode(self, x, return_self_attn=False): Returns ------- - embedding : (batch, d_model) - pool_attn : (batch, B) + embedding : (batch, K*d_model) , where K = model.n_pool_tokens + pool_attn : (batch, B) , mean over K tokens (backward-compat) h_blocks : (batch, B, d_model) self_attn_maps : list[n_layers] of (batch, n_heads, B, B), optional """ @@ -350,20 +476,29 @@ def encode(self, x, return_self_attn=False): h = self.post_norm(h) # pooling attention (NOT block->block; this is block->subject embedding weight) - q = self.pool_query.expand(batch_size, -1, -1) # (batch, 1, d_model) - scores = torch.bmm(q, h.transpose(1, 2)) / self._scale # (batch, 1, B) - pool_attn = F.softmax(scores, dim=-1) # (batch, 1, B) - pooled = torch.bmm(pool_attn, h).squeeze(1) # (batch, d_model) + # pool_queries: (1, K, d_model) → expand to (batch, K, d_model) + q = self.pool_queries.expand(batch_size, -1, -1) # (batch, K, d_model) + scores = torch.bmm(q, h.transpose(1, 2)) / self._scale # (batch, K, B) + pool_attn_full = F.softmax(scores, dim=-1) # (batch, K, B) + pooled = torch.bmm(pool_attn_full, h) # (batch, K, d_model) + pooled_flat = pooled.reshape(batch_size, self.emb_dim) # (batch, K*d_model) - embedding = self.embed_head(pooled) + embedding = self.embed_head(pooled_flat) # (batch, K*d_model) + + # Return mean over K for backward compatibility: (batch, B) + pool_attn = pool_attn_full.mean(dim=1) # (batch, B) if return_self_attn: - return embedding, pool_attn.squeeze(1), h, self_attn_maps - return embedding, pool_attn.squeeze(1), h + return embedding, pool_attn, h, self_attn_maps + return embedding, pool_attn, h def decode(self, z): return self.decoder(z).view(-1, self.n_blocks, self.d_in) + def get_initial_tokens(self, x: torch.Tensor) -> torch.Tensor: + """Return projected block tokens before self-attention: (batch, B, d_model).""" + return self.input_proj(x) + self.pos_emb + def forward(self, x, return_self_attn=False): if return_self_attn: emb, pool_attn, h_blocks, self_attn_maps = self.encode( @@ -380,7 +515,17 @@ def forward(self, x, return_self_attn=False): # ============================================================ # 4. TRAINING # ============================================================ -def train_attention_model(model, tr_t, va_t, cfg, device, log_csv): +def _masked_mse(recon: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """MSE restricted to real latent dimensions. + + mask : (B, max_d) float — 1 for real dims, 0 for padding. + Broadcast over the batch dimension; divide by the number of real elements. + """ + sq = (recon - target) ** 2 # (batch, B, max_d) + return (sq * mask).sum() / mask.sum() / recon.size(0) + + +def train_attention_model(model, tr_t, va_t, cfg, device, log_csv, latent_mask_t=None): ac = cfg["attention"] tr_dl = DataLoader( @@ -405,6 +550,8 @@ def train_attention_model(model, tr_t, va_t, cfg, device, log_csv): opt, patience=15, factor=0.5, min_lr=1e-6 ) + mk = latent_mask_t.to(device) if latent_mask_t is not None else None + best_val = float("inf") best_sd = None best_epoch = 0 @@ -422,7 +569,7 @@ def train_attention_model(model, tr_t, va_t, cfg, device, log_csv): xb = xb.to(device) opt.zero_grad() recon, _, _ = model(xb, return_self_attn=False) - loss = F.mse_loss(recon, xb) + loss = _masked_mse(recon, xb, mk) if mk is not None else F.mse_loss(recon, xb) loss.backward() if ac.get("grad_clip", 0) > 0: nn.utils.clip_grad_norm_(model.parameters(), ac["grad_clip"]) @@ -439,7 +586,7 @@ def train_attention_model(model, tr_t, va_t, cfg, device, log_csv): for (xb,) in va_dl: xb = xb.to(device) recon, _, _ = model(xb, return_self_attn=False) - loss = F.mse_loss(recon, xb) + loss = _masked_mse(recon, xb, mk) if mk is not None else F.mse_loss(recon, xb) va_loss_acc += loss.item() * xb.size(0) va_n += xb.size(0) va_loss = va_loss_acc / va_n @@ -571,8 +718,17 @@ def extract_all( return embs, pool_attns, recons, block_reprs, self_attn_mean, self_attn_full -def per_block_mse(recon, truth): - """(N, B, d) -> (B,) mean SE per block.""" +def per_block_mse(recon, truth, block_dims=None): + """(N, B, d) -> (B,) mean SE per block. + + When block_dims is given, MSE is computed only over the first d_i columns + for block i, ignoring zero-padding in the stacked array. + """ + if block_dims is not None: + result = np.empty(recon.shape[1], dtype=np.float64) + for i, d_i in enumerate(block_dims): + result[i] = np.mean((recon[:, i, :d_i] - truth[:, i, :d_i]) ** 2) + return result return np.mean((recon - truth) ** 2, axis=(0, 2)) @@ -699,11 +855,11 @@ def run_clustering(emb, cc, out_dir): def compute_umap(emb, cc): if not HAS_UMAP: return None - # return umap.UMAP( - # n_neighbors=cc["umap_n_neighbors"], - # min_dist=cc["umap_min_dist"], - # random_state=cc.get("umap_seed", 42), - # ).fit_transform(emb) + return umap.UMAP( + n_neighbors=cc["umap_n_neighbors"], + min_dist=cc["umap_min_dist"], + random_state=cc.get("umap_seed", 42), + ).fit_transform(emb) def _best_kmeans_key(labels): @@ -971,7 +1127,499 @@ def _plot_attention_comparison(all_res, out_dir): # ============================================================ -# 9. MAIN PHASE-2 PIPELINE +# 9. PHASE 2 DIAGNOSTICS AND BASELINES +# ============================================================ + +@torch.no_grad() +def _extract_initial_tokens(model, data_np, batch_size=256): + """Return projected block tokens before self-attention: (N, B, d_model).""" + model.eval() + dl = DataLoader( + TensorDataset(torch.tensor(data_np, dtype=torch.float32)), + batch_size=batch_size, + shuffle=False, + ) + out = [] + for (xb,) in dl: + out.append(model.get_initial_tokens(xb).cpu().numpy()) + return np.concatenate(out, axis=0) + + +@torch.no_grad() +def _extract_pool_attn_by_token(model, data_np, batch_size=256): + """Return per-token pooling attention (N, K, B). Only meaningful when K > 1. + + For K=1 this is (N, 1, B), which equals pool_attn[:, np.newaxis, :]. + """ + model.eval() + dl = DataLoader( + TensorDataset(torch.tensor(data_np, dtype=torch.float32)), + batch_size=batch_size, + shuffle=False, + ) + out = [] + for (xb,) in dl: + h = model.input_proj(xb) + model.pos_emb + for layer in model.transformer_layers: + h, _ = layer(h, return_attn=False) + h = model.post_norm(h) + q = model.pool_queries.expand(xb.size(0), -1, -1) + scores = torch.bmm(q, h.transpose(1, 2)) / model._scale + pool_attn_full = F.softmax(scores, dim=-1) # (batch, K, B) + out.append(pool_attn_full.cpu().numpy()) + return np.concatenate(out, axis=0) # (N, K, B) + + +def _pool_attn_entropy(pool_attn: np.ndarray) -> float: + """Mean Shannon entropy of per-subject pooling attention weights (nats).""" + eps = 1e-10 + return float(-(pool_attn * np.log(pool_attn + eps)).sum(axis=-1).mean()) + + +def _pool_attn_topk_mass(pool_attn: np.ndarray, k: int) -> float: + """Mean total attention mass in the top-k blocks per subject.""" + k = max(1, min(k, pool_attn.shape[-1])) + desc = np.sort(pool_attn, axis=-1)[:, ::-1] + return float(desc[:, :k].sum(axis=-1).mean()) + +def _pool_attn_token_entropy(pool_attn_by_token: np.ndarray) -> float: + """Mean entropy across subjects and pooling tokens. Shape: (N, K, B).""" + eps = 1e-10 + ent = -(pool_attn_by_token * np.log(pool_attn_by_token + eps)).sum(axis=-1) + return float(ent.mean()) + + +def _pool_attn_token_pairwise_corr(pool_attn_by_token: np.ndarray) -> float: + """Mean pairwise correlation among pooling-token attention profiles. + + Uses token-level mean attention profiles averaged across subjects. + Shape input: (N, K, B). + High value near 1 means tokens are redundant. + Lower value means tokens specialize differently. + """ + K = pool_attn_by_token.shape[1] + if K < 2: + return float("nan") + + token_profiles = pool_attn_by_token.mean(axis=0) # (K, B) + cors = [] + for i in range(K): + for j in range(i + 1, K): + a = token_profiles[i] + b = token_profiles[j] + if np.std(a) < 1e-12 or np.std(b) < 1e-12: + continue + cors.append(float(np.corrcoef(a, b)[0, 1])) + + return float(np.mean(cors)) if cors else float("nan") + + +def _pool_attn_token_diversity(pool_attn_by_token: np.ndarray) -> float: + """Simple diversity score = 1 - mean pairwise token correlation.""" + c = _pool_attn_token_pairwise_corr(pool_attn_by_token) + return float(1.0 - c) if np.isfinite(c) else float("nan") + +def compute_contextualization_change( + initial_np: np.ndarray, + contextual_np: np.ndarray, + block_names, + block_meta: pd.DataFrame = None, +) -> tuple: + """Compare initial projected tokens to post-Transformer contextual tokens. + + These diagnostics identify which blocks are most modified by cross-block + context — not which blocks are causal drivers. + + Returns (per_block_df, per_subject_df). + """ + delta = contextual_np - initial_np # (N, B, d_model) + l2 = np.linalg.norm(delta, axis=-1) # (N, B) + + eps = 1e-8 + init_n = initial_np / (np.linalg.norm(initial_np, axis=-1, keepdims=True) + eps) + ctx_n = contextual_np / (np.linalg.norm(contextual_np, axis=-1, keepdims=True) + eps) + cos_dist = 1.0 - (init_n * ctx_n).sum(axis=-1) # (N, B) + + mean_l2 = l2.mean(axis=0) + std_l2 = l2.std(axis=0) + mean_cd = cos_dist.mean(axis=0) + std_cd = cos_dist.std(axis=0) + ranks = (-mean_l2).argsort().argsort() + 1 # 1 = most changed + + row = {"block_id": block_names} + if block_meta is not None: + for col in ("n_snps", "latent_dim"): + if col in block_meta.columns: + row[col] = block_meta[col].values + row.update({ + "mean_context_delta_l2": np.round(mean_l2, 6), + "std_context_delta_l2": np.round(std_l2, 6), + "mean_context_delta_cosine": np.round(mean_cd, 6), + "std_context_delta_cosine": np.round(std_cd, 6), + "context_change_rank": ranks, + }) + per_block_df = pd.DataFrame(row) + + per_subj_df = pd.DataFrame({ + "mean_context_delta_l2_per_subject": np.round(l2.mean(axis=1), 6), + "mean_context_delta_cosine_per_subject": np.round(cos_dist.mean(axis=1), 6), + }) + return per_block_df, per_subj_df + + +def _save_attention_correlation_summary(df: pd.DataFrame, diag_dir: Path): + """Pearson and Spearman correlations between key block-level columns.""" + try: + from scipy.stats import pearsonr, spearmanr + _has_sp = True + except ImportError: + _has_sp = False + + pairs = [ + ("mean_pool_attn", "n_snps"), + ("mean_pool_attn", "n_active_latents"), + ("mean_pool_attn", "frac_dims_collapsed"), + ("mean_pool_attn", "ld_corr_va"), + ("mean_pool_attn", "phase2_recon_mse"), + ("mean_pool_attn", "mean_context_delta_l2"), + ("phase2_recon_mse", "mean_context_delta_l2"), + ] + rows = [] + for ca, cb in pairs: + base = { + "col_a": ca, "col_b": cb, + "pearson_r": float("nan"), "pearson_p": float("nan"), + "spearman_r": float("nan"), "spearman_p": float("nan"), + "n": 0, + } + if ca not in df.columns or cb not in df.columns: + rows.append(base) + continue + mask = df[ca].notna() & df[cb].notna() + n = int(mask.sum()) + base["n"] = n + if n < 3 or not _has_sp: + rows.append(base) + continue + a = df.loc[mask, ca].values.astype(float) + b = df.loc[mask, cb].values.astype(float) + pr, pp = pearsonr(a, b) + sr, sp = spearmanr(a, b) + rows.append({**base, + "pearson_r": round(float(pr), 4), + "pearson_p": round(float(pp), 4), + "spearman_r": round(float(sr), 4), + "spearman_p": round(float(sp), 4), + "n": n}) + pd.DataFrame(rows).to_csv( + diag_dir / "phase2_attention_diagnostic_summary.csv", index=False + ) + + +def run_phase2_block_diagnostics( + block_names, + pool_attn: np.ndarray, + blk_mse: np.ndarray, + ctx_per_block_df, + p1_dir: str, + lt: str, + lt_dir: Path, +) -> pd.DataFrame: + """Join Phase 2 block-level stats with Phase 1 vae_summary.csv (if available). + + Diagnostics identify blocks the aggregator relies on or modifies most — + not causal driver blocks. + """ + diag_dir = lt_dir / "diagnostics" + diag_dir.mkdir(parents=True, exist_ok=True) + + mean_pool = pool_attn.mean(axis=0) + std_pool = pool_attn.std(axis=0) + pool_rank = (-mean_pool).argsort().argsort() + 1 + + p2_df = pd.DataFrame({ + "block_id": block_names, + "mean_pool_attn": np.round(mean_pool, 6), + "std_pool_attn": np.round(std_pool, 6), + "pool_attn_rank": pool_rank, + "phase2_recon_mse": np.round(blk_mse, 6), + }) + + if ctx_per_block_df is not None: + ctx_cols = [c for c in [ + "block_id", "mean_context_delta_l2", "std_context_delta_l2", + "mean_context_delta_cosine", "std_context_delta_cosine", + "context_change_rank", + ] if c in ctx_per_block_df.columns] + p2_df = p2_df.merge(ctx_per_block_df[ctx_cols], on="block_id", how="left") + + # Phase 1 join (graceful: missing file or columns are silently skipped) + p1_summary = Path(p1_dir) / lt / "vae_summary.csv" + p1_want = [ + "block_id", "n_snps", "latent_dim", "n_active_latents", + "frac_dims_collapsed", "latent_underused", "conc_va", + "bal_acc_va", "ld_corr_va", "mean_r2_va", "r2_va", + ] + if p1_summary.exists(): + try: + p1_df = pd.read_csv(p1_summary) + avail = ["block_id"] + [c for c in p1_want[1:] if c in p1_df.columns] + p2_df = p2_df.merge(p1_df[avail], on="block_id", how="left") + except Exception: + pass + + p2_df.to_csv(diag_dir / "phase2_block_diagnostics.csv", index=False) + _save_attention_correlation_summary(p2_df, diag_dir) + return p2_df + +def run_pca_baseline_sweep( + emb_block, tr_ix, va_ix, te_ix, block_dims, latent_mask, + n_components_list, out_dir, +): + """Run PCA baseline at multiple k values; return DataFrame and best (smallest k).""" + rows = [] + for k in n_components_list: + try: + n_used, va_mse, te_mse = run_pca_baseline( + emb_block, tr_ix, va_ix, te_ix, block_dims, latent_mask, + k, out_dir / f"pca_k{k}", + ) + rows.append({ + "n_components_req": k, + "n_components_used": n_used, + "val_recon_mse": va_mse, + "test_recon_mse": te_mse, + }) + except Exception as e: + warnings.warn(f"PCA sweep k={k} failed: {e}") + df = pd.DataFrame(rows) + df.to_csv(out_dir / "pca_baseline_sweep.csv", index=False) + return df + +def run_pca_baseline( + emb_block: np.ndarray, + tr_ix: np.ndarray, + va_ix: np.ndarray, + te_ix: np.ndarray, + block_dims, + latent_mask, + n_components_req: int, + out_dir: Path, +) -> tuple: + """Concat + PCA subject-embedding baseline. Fits PCA on training subjects only. + + n_components_req : target number of components (from config or d_model default). + Always capped at min(n_train - 1, flat_dim, n_components_req). + Returns (n_components, val_recon_mse, test_recon_mse). + This is a Phase-2-level PCA over flattened Phase-1 block embeddings — + distinct from any PCA loss function used in Phase 1. + """ + if not HAS_SKLEARN: + return float("nan"), float("nan"), float("nan") + from sklearn.decomposition import PCA + + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + N, B, max_d = emb_block.shape + + # Flatten: concatenate only real latent dims per block + if block_dims is not None: + flat = np.concatenate( + [emb_block[:, i, :int(block_dims[i])] for i in range(B)], axis=1 + ) + else: + flat = emb_block.reshape(N, -1) + flat_dim = flat.shape[1] + + n_train = len(tr_ix) + n_comp = max(1, min(n_train - 1, flat_dim, int(n_components_req))) + + pca = PCA(n_components=n_comp) + pca.fit(flat[tr_ix]) # fit on TRAIN ONLY + + pca_all = pca.transform(flat) # (N, n_comp) + recon_all = pca.inverse_transform(pca_all) # (N, flat_dim) + + def _mse(idx): + if len(idx) == 0: + return float("nan") + return float(np.mean((recon_all[idx] - flat[idx]) ** 2)) + + tr_mse = _mse(tr_ix) + va_mse = _mse(va_ix) + te_mse = _mse(te_ix) + + np.save(out_dir / "pca_subject_embeddings.npy", pca_all.astype(np.float32)) + pd.DataFrame( + pca_all, columns=[f"pc_{i}" for i in range(n_comp)] + ).to_csv(out_dir / "pca_subject_embeddings.csv", index=False) + + if recon_all.nbytes < 50 * 1024 * 1024: + np.save( + out_dir / "pca_reconstructions_flat.npy", recon_all.astype(np.float32) + ) + + pd.DataFrame({ + "requested_n_components": [int(n_components_req)], + "n_components": [n_comp], + "flat_dim": [flat_dim], + "pca_train_recon_loss": [round(tr_mse, 6)], + "pca_val_recon_loss": [round(va_mse, 6)], + "pca_test_recon_loss": [round(te_mse, 6) if np.isfinite(te_mse) else float("nan")], + "explained_variance_ratio": [round(float(pca.explained_variance_ratio_.sum()), 4)], + }).to_csv(out_dir / "pca_baseline_summary.csv", index=False) + + return n_comp, va_mse, te_mse + + +@torch.no_grad() +def run_mean_pool_baseline( + model, + emb_block: np.ndarray, + tr_ix: np.ndarray, + va_ix: np.ndarray, + te_ix: np.ndarray, + latent_mask, + latent_mask_t, + out_dir: Path, +) -> tuple: + """Two mean-pool baselines sharing the same projected tokens. + + raw : mean_token → decoder + (input_proj + pos_emb, mean across blocks, decode — no attn, no pooling) + embedhead: mean_token → embed_head → decoder + (adds the trained embed_head MLP before decoding) + + Both use the trained model's weights; neither trains new parameters. + Reconstruction loss is directly comparable to the Transformer model's MSE. + + Returns (raw_val_mse, raw_test_mse, embedhead_val_mse, embedhead_test_mse). + """ + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + model.eval() + dl = DataLoader( + TensorDataset(torch.tensor(emb_block, dtype=torch.float32)), + batch_size=256, + shuffle=False, + ) + + K = model.n_pool_tokens + mean_embs, recons_raw, recons_eh = [], [], [] + for (xb,) in dl: + tokens = model.get_initial_tokens(xb) # (batch, B, d_model) + z_raw = tokens.mean(dim=1) # (batch, d_model) + # Expand to K*d_model so decoder input size matches (for K=1 this is a no-op) + z_raw_rep = z_raw.unsqueeze(1).expand(-1, K, -1).reshape(xb.size(0), model.emb_dim) + z_eh = model.embed_head(z_raw_rep) # (batch, K*d_model) + rec_raw = model.decode(z_raw_rep) # (batch, B, d_in) + rec_eh = model.decode(z_eh) # (batch, B, d_in) + mean_embs.append(z_raw.cpu().numpy()) # save d_model-dim mean token + recons_raw.append(rec_raw.cpu().numpy()) + recons_eh.append(rec_eh.cpu().numpy()) + + mean_embs_np = np.concatenate(mean_embs, axis=0) + recons_raw_np = np.concatenate(recons_raw, axis=0) + recons_eh_np = np.concatenate(recons_eh, axis=0) + + _lm = latent_mask[np.newaxis] if latent_mask is not None else None + _ld = float(latent_mask.sum()) if latent_mask is not None else None + + def _mse(recons_arr, idx): + if len(idx) == 0: + return float("nan") + r, t = recons_arr[idx], emb_block[idx] + if _lm is not None: + return float(np.sum((r - t) ** 2 * _lm) / (_ld * len(idx))) + return float(np.mean((r - t) ** 2)) + + raw_va = _mse(recons_raw_np, va_ix) + raw_te = _mse(recons_raw_np, te_ix) + eh_va = _mse(recons_eh_np, va_ix) + eh_te = _mse(recons_eh_np, te_ix) + + np.save(out_dir / "mean_pool_subject_embeddings.npy", mean_embs_np.astype(np.float32)) + pd.DataFrame( + mean_embs_np, columns=[f"d_{i}" for i in range(mean_embs_np.shape[1])] + ).to_csv(out_dir / "mean_pool_subject_embeddings.csv", index=False) + + pd.DataFrame({ + "raw_mean_pool_val_recon_loss": [round(raw_va, 6)], + "raw_mean_pool_test_recon_loss": [round(raw_te, 6) if np.isfinite(raw_te) else float("nan")], + "embedhead_mean_pool_val_recon_loss": [round(eh_va, 6)], + "embedhead_mean_pool_test_recon_loss": [round(eh_te, 6) if np.isfinite(eh_te) else float("nan")], + "note": ["raw=mean_token+decoder; embedhead=mean_token+embed_head+decoder"], + }).to_csv(out_dir / "mean_pool_baseline_summary.csv", index=False) + + return raw_va, raw_te, eh_va, eh_te + + +def _save_diagnostic_plots(block_diag_df, lt: str, out_dir: Path): + """Scatter and bar plots for Phase 2 diagnostics. Silent no-op if matplotlib absent.""" + if not HAS_PLT or block_diag_df is None: + return + out_dir.mkdir(parents=True, exist_ok=True) + + def _scatter(xcol, ycol, fname, xlabel, ylabel): + if xcol not in block_diag_df.columns or ycol not in block_diag_df.columns: + return + mask = block_diag_df[xcol].notna() & block_diag_df[ycol].notna() + x = block_diag_df.loc[mask, xcol].values.astype(float) + y = block_diag_df.loc[mask, ycol].values.astype(float) + if len(x) < 2: + return + fig, ax = plt.subplots(figsize=(6, 5)) + ax.scatter(x, y, s=30, alpha=0.7, color="steelblue") + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(f"{lt}") + plt.tight_layout() + plt.savefig(out_dir / fname, dpi=150) + plt.close() + + try: + _scatter( + "mean_pool_attn", "mean_context_delta_l2", + "attention_vs_context_delta.png", + "mean pool attention", "mean context delta L2", + ) + _scatter( + "n_active_latents", "mean_pool_attn", + "attention_vs_phase1_active_latents.png", + "n_active_latents (Phase 1)", "mean pool attention", + ) + _scatter( + "n_snps", "mean_pool_attn", + "attention_vs_n_snps.png", + "n_snps", "mean pool attention", + ) + + if "mean_context_delta_l2" in block_diag_df.columns: + top25 = ( + block_diag_df[["block_id", "mean_context_delta_l2"]] + .dropna() + .nlargest(25, "mean_context_delta_l2") + ) + if len(top25) > 0: + fig, ax = plt.subplots(figsize=(10, max(4, len(top25) * 0.35))) + y = np.arange(len(top25)) + ax.barh(y, top25["mean_context_delta_l2"].values, color="coral", alpha=0.85) + ax.set_yticks(y) + ax.set_yticklabels(top25["block_id"].values, fontsize=7) + ax.set_xlabel("Mean context delta L2") + ax.set_title(f"{lt} — top {len(top25)} blocks by contextualization change") + ax.invert_yaxis() + plt.tight_layout() + plt.savefig(out_dir / "context_delta_by_block.png", dpi=150) + plt.close() + except Exception: + pass + + +# ============================================================ +# 10. MAIN PHASE-2 PIPELINE # ============================================================ def run_phase2(cfg, *, config_path=None): t0_run = time.time() @@ -979,11 +1627,14 @@ def run_phase2(cfg, *, config_path=None): cc = cfg["clustering"] set_seed(ac["seed"]) - dev = get_device(ac.get("device", "cpu")) + dev = get_device(ac.get("device", "auto")) # Set deterministic behavior if dev.type == "cpu": - torch.use_deterministic_algorithms(True) + try: + torch.use_deterministic_algorithms(True, warn_only=True) + except Exception: + pass elif dev.type == "cuda": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -995,13 +1646,12 @@ def run_phase2(cfg, *, config_path=None): yaml.dump(cfg, f, default_flow_style=False) print("\n══════ Step 8 · Loading Phase 1 frozen embeddings ══════") - subjects, tr_ix, va_ix, block_meta, embeddings = load_phase1( + subjects, tr_ix, va_ix, te_ix, block_meta, embeddings, latent_dims_per_loss = load_phase1( cfg["phase1_dir"], cfg["loss_functions"] ) N, B, d_in = list(embeddings.values())[0].shape block_names = block_meta["block_id"].values - print(f" N={N} B={B} d_in={d_in}") all_results = {} summary_rows = [] @@ -1012,16 +1662,84 @@ def run_phase2(cfg, *, config_path=None): print(f"{'═' * 55}") lt_dir = out / lt - for sub in ("logs", "models", "embeddings", "clustering", "plots", "attention_maps"): + for sub in ("logs", "models", "embeddings", "clustering", "plots", "attention_maps", + "baselines", "diagnostics"): (lt_dir / sub).mkdir(parents=True, exist_ok=True) emb_block = embeddings[lt] # (N, B, d_in) + # Build per-loss block_dims and latent_mask. + # Prefer all_blocks_latent_dims.npy (actual post-clamping dims) when available; + # fall back to block_order.csv["latent_dim"] (requested dims) otherwise. + _lt_dims = latent_dims_per_loss.get(lt) + if _lt_dims is not None: + block_dims = [int(d) for d in _lt_dims] + elif "latent_dim" in block_meta.columns: + block_dims = [int(d) for d in block_meta["latent_dim"]] + else: + block_dims = None + + if block_dims is not None: + dim_counts = dict(pd.Series(block_dims).value_counts().sort_index()) + print(f" N={N} B={B} max_d={d_in} per-block dims={dim_counts} [BlockProjector]") + latent_mask = np.zeros((B, d_in), dtype=np.float32) + for i, d_i in enumerate(block_dims): + d_i = int(d_i) + if d_i <= 0 or d_i > d_in: + raise ValueError(f"Invalid latent_dim for block {i}: {d_i}, max_d={d_in}") + latent_mask[i, :d_i] = 1.0 + latent_mask_t = torch.tensor(latent_mask) + else: + latent_mask = None + latent_mask_t = None + print(f" N={N} B={B} d_in={d_in} [shared input_proj — legacy]") + + n_test_subj = len(te_ix) + + # ------------------ pre-training baselines ------------------ + # Zero baseline: predict zeros. Train-mean baseline: predict training mean. + # Both use the same latent mask as the model loss so units are comparable. + # _masked_mse divides by mask.sum() (real dims) then by batch — same here. + _tr_mean = emb_block[tr_ix].mean(axis=0, keepdims=True) # (1, B, d_in) + _va_zero_sq = emb_block[va_ix] ** 2 + _va_mean_sq = (emb_block[va_ix] - _tr_mean) ** 2 + + if latent_mask is not None: + _lm = latent_mask[np.newaxis] # (1, B, d_in) + _ld = float(latent_mask.sum()) + val_zero_baseline = float(np.sum(_va_zero_sq * _lm) / (_ld * len(va_ix))) + val_train_mean_baseline = float(np.sum(_va_mean_sq * _lm) / (_ld * len(va_ix))) + if n_test_subj > 0: + _te_zero_sq = emb_block[te_ix] ** 2 + _te_mean_sq = (emb_block[te_ix] - _tr_mean) ** 2 + test_zero_baseline = float(np.sum(_te_zero_sq * _lm) / (_ld * n_test_subj)) + test_train_mean_baseline = float(np.sum(_te_mean_sq * _lm) / (_ld * n_test_subj)) + else: + test_zero_baseline = test_train_mean_baseline = float("nan") + else: + val_zero_baseline = float(np.mean(_va_zero_sq)) + val_train_mean_baseline = float(np.mean(_va_mean_sq)) + if n_test_subj > 0: + test_zero_baseline = float(np.mean(emb_block[te_ix] ** 2)) + test_train_mean_baseline = float(np.mean((emb_block[te_ix] - _tr_mean) ** 2)) + else: + test_zero_baseline = test_train_mean_baseline = float("nan") + + print(f" val zero-baseline loss: {val_zero_baseline:.6f}") + print(f" val train-mean-baseline loss: {val_train_mean_baseline:.6f}") + if n_test_subj > 0: + print(f" test zero-baseline loss: {test_zero_baseline:.6f}") + print(f" test train-mean-baseline loss: {test_train_mean_baseline:.6f}") + # ------------------ Step 9: train ------------------ print("\n Step 9 · Training attention aggregator ...") tr_t = torch.tensor(emb_block[tr_ix], dtype=torch.float32) va_t = torch.tensor(emb_block[va_ix], dtype=torch.float32) + K = int(ac.get("n_pool_tokens", 1)) + if K < 1: + raise ValueError("attention.n_pool_tokens must be >= 1") + model = AttentionAggregator( n_blocks=B, d_in=d_in, @@ -1030,24 +1748,49 @@ def run_phase2(cfg, *, config_path=None): n_layers=ac["n_layers"], d_ff=ac["d_ff"], dropout=ac["dropout"], + block_dims=block_dims, + n_pool_tokens=K, ) npar = sum(p.numel() for p in model.parameters()) print( f" architecture: {B}x{d_in} -> d_model={ac['d_model']} " - f"heads={ac['n_heads']} layers={ac['n_layers']} params={npar:,}" + f"heads={ac['n_heads']} layers={ac['n_layers']} " + f"n_pool_tokens={K} emb_dim={K * ac['d_model']} " + f"decoder_hidden={model.decoder_hidden} params={npar:,}" ) t0 = time.time() log, best_epoch, best_val_loss = train_attention_model( model, tr_t, va_t, cfg, dev, - lt_dir / "logs" / "attention_training.csv" + lt_dir / "logs" / "attention_training.csv", + latent_mask_t=latent_mask_t, ) dt = time.time() - t0 print(f" done in {dt:.1f}s ({len(log)} epochs, best at {best_epoch})") torch.save(model.state_dict(), lt_dir / "models" / "attention_aggregator.pt") + # ------------------ held-out test reconstruction loss ------------------ + model.eval() + if n_test_subj > 0: + te_t = torch.tensor(emb_block[te_ix], dtype=torch.float32) + te_dl = DataLoader(TensorDataset(te_t), batch_size=256, shuffle=False) + te_loss_acc = 0.0 + with torch.no_grad(): + for (xb,) in te_dl: + recon_te, _, _ = model(xb, return_self_attn=False) + loss_te = ( + _masked_mse(recon_te, xb, latent_mask_t) + if latent_mask_t is not None + else F.mse_loss(recon_te, xb) + ) + te_loss_acc += loss_te.item() * xb.size(0) + te_recon_loss = te_loss_acc / n_test_subj + else: + te_recon_loss = float("nan") + print(f" test reconstruction loss: {te_recon_loss:.6f}") + # ------------------ Step 10: extract ------------------ print("\n Step 10 · Extracting embeddings, pooling attention, and self-attention ...") final_emb, pool_attn, recon, block_repr, self_attn_mean, self_attn_full = extract_all( @@ -1058,8 +1801,12 @@ def run_phase2(cfg, *, config_path=None): save_full_self_attn=ac.get("save_full_self_attn", False), ) - global_mse = float(np.mean((recon - emb_block) ** 2)) - blk_mse = per_block_mse(recon, emb_block) + if latent_mask is not None: + sq = (recon - emb_block) ** 2 # (N, B, max_d) + global_mse = float(np.sum(sq * latent_mask[np.newaxis]) / (latent_mask.sum() * N)) + else: + global_mse = float(np.mean((recon - emb_block) ** 2)) + blk_mse = per_block_mse(recon, emb_block, block_dims=block_dims) print(f" individual embedding : {final_emb.shape}") print(f" pooling attention : {pool_attn.shape}") @@ -1072,6 +1819,15 @@ def run_phase2(cfg, *, config_path=None): np.save(lt_dir / "embeddings" / "reconstructions.npy", recon) np.save(lt_dir / "embeddings" / "block_contextual_repr.npy", block_repr) + # For K>1 also save per-token pooling attention (N, K, B) alongside the mean (N, B) + pool_attn_by_token = None + if K > 1: + pool_attn_by_token = _extract_pool_attn_by_token(model, emb_block) + np.save( + lt_dir / "embeddings" / "pooling_attention_weights_by_token.npy", + pool_attn_by_token, + ) + # ---- human-readable CSVs ---- emb_df = pd.DataFrame(final_emb, columns=[f"emb_{i}" for i in range(final_emb.shape[1])]) emb_df.insert(0, "IID", subjects) @@ -1131,6 +1887,121 @@ def run_phase2(cfg, *, config_path=None): plot_reconstruction_per_block(blk_mse, block_meta, lt, lt_dir / "plots") plot_self_attention_heatmaps(self_attn_mean, block_meta, lt, lt_dir / "plots") + # ── Phase 2 diagnostics and baselines ──────────────────────────────── + dc = cfg.get("diagnostics", DEFAULT_CFG.get("diagnostics", {})) + bc = cfg.get("baselines", DEFAULT_CFG.get("baselines", {})) + + # Attention entropy / top-k mass (always computed — cheap) + _entropy = _pool_attn_entropy(pool_attn) + _top1_mass = _pool_attn_topk_mass(pool_attn, 1) + _top5_mass = _pool_attn_topk_mass(pool_attn, min(5, B)) + _token_entropy = float("nan") + _token_pairwise_corr = float("nan") + _token_diversity = float("nan") + + if pool_attn_by_token is not None: + _token_entropy = _pool_attn_token_entropy(pool_attn_by_token) + _token_pairwise_corr = _pool_attn_token_pairwise_corr(pool_attn_by_token) + _token_diversity = _pool_attn_token_diversity(pool_attn_by_token) + # Self-attention mass statistics (layer 0, averaged over heads and subjects) + _sa_diag_entry_l0 = _sa_offdiag_entry_l0 = float("nan") + _sa_diag_mass_l0 = _sa_offdiag_mass_l0 = float("nan") + if self_attn_mean is not None and len(self_attn_mean) > 0: + _sa0 = self_attn_mean[0].mean(axis=0) # (B, B), averaged over heads + _eye = np.eye(B, dtype=bool) + _sa0_tot = float(_sa0.sum()) + _sa_diag_entry_l0 = float(_sa0[_eye].mean()) + _sa_offdiag_entry_l0 = float(_sa0[~_eye].mean()) if B > 1 else float("nan") + if _sa0_tot > 0: + _sa_diag_mass_l0 = float(_sa0[_eye].sum() / _sa0_tot) + _sa_offdiag_mass_l0 = float(_sa0[~_eye].sum() / _sa0_tot) if B > 1 else float("nan") + + # Part A — initial (pre-attention) block tokens + initial_repr = None + if dc.get("enabled", True) and dc.get("save_initial_block_repr", True): + initial_repr = _extract_initial_tokens(model, emb_block) + np.save(lt_dir / "embeddings" / "block_initial_repr.npy", initial_repr) + + # Part B — contextualization-change diagnostics + ctx_per_block_df = ctx_per_subj_df = None + if initial_repr is not None and dc.get("compute_contextualization_change", True): + ctx_per_block_df, ctx_per_subj_df = compute_contextualization_change( + initial_repr, block_repr, block_names, block_meta + ) + ctx_per_block_df.to_csv( + lt_dir / "embeddings" / "per_block_contextualization_change.csv", + index=False, + ) + ctx_per_subj_df.to_csv( + lt_dir / "embeddings" / "per_subject_contextualization_change.csv", + index=False, + ) + + # Part C — Phase 2 × Phase 1 block diagnostics join + block_diag_df = None + if dc.get("enabled", True) and dc.get("compute_phase1_phase2_join", True): + block_diag_df = run_phase2_block_diagnostics( + block_names, pool_attn, blk_mse, ctx_per_block_df, + cfg["phase1_dir"], lt, lt_dir, + ) + + # Part D — PCA subject-embedding baseline + # replace existing single-k PCA call with sweep + _pca_sweep_list = bc.get("pca_sweep") + if _pca_sweep_list: + pca_sweep_df = run_pca_baseline_sweep( + emb_block, tr_ix, va_ix, te_ix, block_dims, latent_mask, + _pca_sweep_list, lt_dir / "baselines", + ) + # for backward-compat reporting, pick k closest to d_model + _target = ac["d_model"] + closest = pca_sweep_df.iloc[(pca_sweep_df["n_components_req"] - _target).abs().argsort()[:1]] + pca_n_components = int(closest["n_components_used"].iloc[0]) + pca_val_recon = float(closest["val_recon_mse"].iloc[0]) + pca_te_recon = float(closest["test_recon_mse"].iloc[0]) + # pca_n_components = pca_val_recon = pca_te_recon = float("nan") + # if bc.get("enabled", True) and bc.get("run_pca", True): + # try: + # _pca_req = bc.get("pca_n_components") + # _pca_target = int(_pca_req) if _pca_req is not None else ac["d_model"] + # pca_n_components, pca_val_recon, pca_te_recon = run_pca_baseline( + # emb_block, tr_ix, va_ix, te_ix, block_dims, latent_mask, + # _pca_target, lt_dir / "baselines", + # ) + # print( + # f" PCA baseline : n_comp={pca_n_components}" + # f" val={pca_val_recon:.6f} test={pca_te_recon:.6f}" + # ) + # except Exception as _e: + # warnings.warn(f"PCA baseline failed: {_e}") + + # Part E — mean-pool baselines (raw and embed_head variants) + mp_raw_va = mp_raw_te = mp_eh_va = mp_eh_te = float("nan") + if bc.get("enabled", True) and bc.get("run_mean_pool", True): + try: + mp_raw_va, mp_raw_te, mp_eh_va, mp_eh_te = run_mean_pool_baseline( + model, emb_block, tr_ix, va_ix, te_ix, + latent_mask, latent_mask_t, lt_dir / "baselines", + ) + print( + f" mean-pool raw : val={mp_raw_va:.6f} test={mp_raw_te:.6f}" + ) + print( + f" mean-pool embedhead: val={mp_eh_va:.6f} test={mp_eh_te:.6f}" + ) + except Exception as _e: + warnings.warn(f"Mean-pool baseline failed: {_e}") + + # Part F — diagnostic scatter / bar plots + _save_diagnostic_plots(block_diag_df, lt, lt_dir / "plots") + + # Context-change aggregates for summary row + _mean_ctx_l2 = _mean_ctx_cos = _max_ctx_l2 = float("nan") + if ctx_per_block_df is not None: + _mean_ctx_l2 = round(float(ctx_per_block_df["mean_context_delta_l2"].mean()), 6) + _mean_ctx_cos = round(float(ctx_per_block_df["mean_context_delta_cosine"].mean()), 6) + _max_ctx_l2 = round(float(ctx_per_block_df["mean_context_delta_l2"].max()), 6) + # ------------------ stash for comparison ------------------ best_sil = float(cluster_metrics["silhouette"].max()) if len(cluster_metrics) > 0 else 0.0 @@ -1152,14 +2023,54 @@ def run_phase2(cfg, *, config_path=None): "final_tr_loss": log[-1]["tr_loss"], "final_va_loss": log[-1]["va_loss"], "best_va_loss": round(best_val_loss, 6), + "test_recon_loss": round(te_recon_loss, 6) if not np.isnan(te_recon_loss) else float("nan"), + "val_zero_baseline_loss": round(val_zero_baseline, 6), + "test_zero_baseline_loss": round(test_zero_baseline, 6) if not np.isnan(test_zero_baseline) else float("nan"), + "val_train_mean_baseline_loss": round(val_train_mean_baseline, 6), + "test_train_mean_baseline_loss": round(test_train_mean_baseline, 6) if not np.isnan(test_train_mean_baseline) else float("nan"), + "model_vs_mean_baseline_ratio": round(best_val_loss / val_train_mean_baseline, 4) if val_train_mean_baseline > 0 else float("nan"), "recon_mse": round(global_mse, 6), "best_silhouette": round(best_sil, 4), + # ── attention diagnostics ────────────────────────────────────────── + "mean_pool_attn_entropy": round(_entropy, 6), + "mean_pool_attn_top1_mass": round(_top1_mass, 6), + "mean_pool_attn_top5_mass": round(_top5_mass, 6), + "pool_token_entropy": round(_token_entropy, 6) if np.isfinite(_token_entropy) else float("nan"), + "pool_token_pairwise_corr": round(_token_pairwise_corr, 6) if np.isfinite(_token_pairwise_corr) else float("nan"), + "pool_token_diversity": round(_token_diversity, 6) if np.isfinite(_token_diversity) else float("nan"), + "mean_context_delta_l2": _mean_ctx_l2, + "mean_context_delta_cosine": _mean_ctx_cos, + "max_context_delta_l2": _max_ctx_l2, + "self_attn_diag_entry_mean_layer0": round(_sa_diag_entry_l0, 6) if np.isfinite(_sa_diag_entry_l0) else float("nan"), + "self_attn_offdiag_entry_mean_layer0": round(_sa_offdiag_entry_l0, 6) if np.isfinite(_sa_offdiag_entry_l0) else float("nan"), + "self_attn_diag_total_mass_layer0": round(_sa_diag_mass_l0, 6) if np.isfinite(_sa_diag_mass_l0) else float("nan"), + "self_attn_offdiag_total_mass_layer0": round(_sa_offdiag_mass_l0, 6) if np.isfinite(_sa_offdiag_mass_l0) else float("nan"), + # ── PCA baseline ─────────────────────────────────────────────────── + "pca_n_components": int(pca_n_components) if np.isfinite(float(pca_n_components)) else float("nan"), + "pca_val_recon_loss": round(pca_val_recon, 6) if np.isfinite(pca_val_recon) else float("nan"), + "pca_test_recon_loss": round(pca_te_recon, 6) if np.isfinite(pca_te_recon) else float("nan"), + "transformer_vs_pca_val_ratio": round(best_val_loss / pca_val_recon, 4) if (np.isfinite(pca_val_recon) and pca_val_recon > 0) else float("nan"), + "transformer_vs_pca_test_ratio": round(te_recon_loss / pca_te_recon, 4) if (np.isfinite(pca_te_recon) and pca_te_recon > 0 and np.isfinite(te_recon_loss)) else float("nan"), + # ── mean-pool baselines (raw and embed_head variants) ────────────── + "raw_mean_pool_val_recon_loss": round(mp_raw_va, 6) if np.isfinite(mp_raw_va) else float("nan"), + "raw_mean_pool_test_recon_loss": round(mp_raw_te, 6) if np.isfinite(mp_raw_te) else float("nan"), + "embedhead_mean_pool_val_recon_loss": round(mp_eh_va, 6) if np.isfinite(mp_eh_va) else float("nan"), + "embedhead_mean_pool_test_recon_loss": round(mp_eh_te, 6) if np.isfinite(mp_eh_te) else float("nan"), + "transformer_vs_raw_mean_pool_val_ratio": round(best_val_loss / mp_raw_va, 4) if (np.isfinite(mp_raw_va) and mp_raw_va > 0) else float("nan"), + "transformer_vs_raw_mean_pool_test_ratio": round(te_recon_loss / mp_raw_te, 4) if (np.isfinite(mp_raw_te) and mp_raw_te > 0 and np.isfinite(te_recon_loss)) else float("nan"), + # ── run info ────────────────────────────────────────────────────── "seconds": round(dt, 1), "seed": ac["seed"], "device": str(dev), "n_subjects": N, + "n_train": len(tr_ix), + "n_val": len(va_ix), + "n_test": n_test_subj, "n_blocks": B, "d_in": d_in, + "n_pool_tokens": K, + "embedding_dim": K * ac["d_model"], + "decoder_hidden": model.decoder_hidden, "loss_functions": str(cfg["loss_functions"]), }) @@ -1201,7 +2112,7 @@ def validate_cfg(cfg): if not phase1_dir.exists(): raise FileNotFoundError(f"phase1_dir missing: {phase1_dir}") - for f in ["subjects.csv", "train_idx.npy", "val_idx.npy", "block_order.csv"]: + for f in ["subjects.csv", "train_idx.npy", "val_idx.npy", "test_idx.npy", "block_order.csv"]: fp = phase1_dir / f if not fp.exists(): raise FileNotFoundError(f"Required Phase 1 file missing: {fp}") diff --git a/scripts/core/plots_updated.py b/scripts/core/plots_updated.py index f1c5022..90db230 100644 --- a/scripts/core/plots_updated.py +++ b/scripts/core/plots_updated.py @@ -24,7 +24,7 @@ import matplotlib.pyplot as plt -DEFAULT_OUTDIR = "/Users/shraddh_mac/geno_ld_attention/geno_ld_attention/results/plots_regions" +DEFAULT_OUTDIR = "/Users/shraddh_mac/Documents/GitHub/blockbased-genotype-embedding-analysis/results/plots_regions" # ---------------------------- @@ -144,11 +144,6 @@ def plot_runtime_vs_perf(df: pd.DataFrame, losses: list[str], outdir: Path) -> N plt.legend(frameon=False, ncol=min(3, len(losses))) savefig(outdir / "scatter_runtime_vs_bal_acc.png") - -# ---------------------------- -# NEW visuals you requested -# ---------------------------- - # 1) Per-class recall heatmap faceted by loss def plot_per_class_heatmaps(df: pd.DataFrame, losses: list[str], outdir: Path) -> None: require_cols(df, ["block", "loss", "acc0_va", "acc1_va", "acc2_va"], "per-class heatmap") @@ -365,9 +360,9 @@ def plot_kl_over_epochs(phase1_dir: Path, df_summary: pd.DataFrame, losses: list # ---------------------------- def main() -> None: ap = argparse.ArgumentParser() - ap.add_argument("--phase1_dir", type=str, default="/Users/shraddh_mac/geno_ld_attention/geno_ld_attention/results/output_regions", + ap.add_argument("--phase1_dir", type=str, default="/Users/shraddh_mac/Documents/GitHub/blockbased-genotype-embedding-analysis/results/output_regions", help="Path to Phase-1 output directory (contains vae_summary_ord.csv and per-loss logs/).") - ap.add_argument("--summary_csv", type=str, default="vae_summary_ord.csv", + ap.add_argument("--summary_csv", type=str, default="vae_summary.csv", help="Summary CSV filename inside phase1_dir.") ap.add_argument("--outdir", type=str, default=DEFAULT_OUTDIR, help="Directory to save plot PNGs.") @@ -375,7 +370,7 @@ def main() -> None: help="Optional explicit list of losses to plot (e.g. --losses ORD MSE MSE_STD CAT BCE).") ap.add_argument("--k", type=int, default=None, help="Optional: number of loss types to plot (uses prefer-first ordering). If omitted, uses all.") - ap.add_argument("--prefer", type=str, nargs="*", default=["ORD", "MSE", "MSE_STD", "CAT", "BCE"], + ap.add_argument("--prefer", type=str, nargs="*", default=["ORD_W_Scaled", "ORD", "MSE", "MSE_STD", "CAT", "BCE"], help="Preferred ordering used when selecting all or when --k is set.") args = ap.parse_args() @@ -403,7 +398,6 @@ def main() -> None: "Validation LD-correlation distribution by loss", "ld_corr_va") plot_runtime_vs_perf(dsel, losses, outdir) - # ----- New requested visuals ----- # 1) per-class recall heatmap per loss plot_per_class_heatmaps(dsel, losses, outdir) diff --git a/tests/test_phase2_diagnostics.py b/tests/test_phase2_diagnostics.py new file mode 100644 index 0000000..bc3761d --- /dev/null +++ b/tests/test_phase2_diagnostics.py @@ -0,0 +1,439 @@ +""" +Tests for Phase 2 diagnostics and baseline utilities. +All tests use tiny synthetic tensors; no real data or checkpoints required. +""" +import math +import numpy as np +import pytest +from pathlib import Path + +torch = pytest.importorskip("torch", reason="torch not installed") +pd = pytest.importorskip("pandas", reason="pandas not installed") + +from scripts.core.attention_phase2 import ( + AttentionAggregator, + compute_contextualization_change, + _extract_initial_tokens, + _extract_pool_attn_by_token, + _pool_attn_entropy, + _pool_attn_topk_mass, + run_pca_baseline, + run_mean_pool_baseline, + run_phase2_block_diagnostics, +) + + +# ── fixtures ─────────────────────────────────────────────────────────────────── + +@pytest.fixture(scope="module") +def tiny_model(): + torch.manual_seed(0) + m = AttentionAggregator( + n_blocks=4, d_in=6, d_model=8, n_heads=2, + n_layers=1, d_ff=16, dropout=0.0, + ) + m.eval() + return m + + +@pytest.fixture(scope="module") +def synthetic_data(): + rng = np.random.RandomState(1) + N, B, d_in = 20, 4, 6 + emb = rng.randn(N, B, d_in).astype(np.float32) + tr_ix = np.arange(14) + va_ix = np.arange(14, 17) + te_ix = np.arange(17, 20) + return emb, tr_ix, va_ix, te_ix, B, d_in + + +# ── Part A: get_initial_tokens / _extract_initial_tokens ────────────────────── + +def test_get_initial_tokens_shape(tiny_model): + """get_initial_tokens must return (batch, B, d_model).""" + x = torch.randn(5, 4, 6) + with torch.no_grad(): + tokens = tiny_model.get_initial_tokens(x) + assert tokens.shape == (5, 4, 8), f"Expected (5, 4, 8), got {tokens.shape}" + + +def test_extract_initial_tokens_shape(tiny_model, synthetic_data): + """_extract_initial_tokens must return (N, B, d_model) for the full dataset.""" + emb, *_ = synthetic_data + arr = _extract_initial_tokens(tiny_model, emb, batch_size=8) + assert arr.shape == (20, 4, 8), f"Expected (20, 4, 8), got {arr.shape}" + + +def test_initial_tokens_differ_from_contextual(tiny_model, synthetic_data): + """Initial tokens before self-attention should differ from contextual tokens after.""" + emb, *_ = synthetic_data + initial = _extract_initial_tokens(tiny_model, emb, batch_size=32) + x_t = torch.tensor(emb) + with torch.no_grad(): + _, _, contextual = tiny_model.encode(x_t, return_self_attn=False) + contextual_np = contextual.numpy() + # They must not be identical (transformer layers change the representation) + assert not np.allclose(initial, contextual_np), ( + "initial and contextual tokens are identical — Transformer may not be modifying them" + ) + + +# ── Part B: compute_contextualization_change ─────────────────────────────────── + +def test_ctx_change_required_columns(): + """per_block_df must contain all required columns.""" + N, B, d = 10, 4, 8 + rng = np.random.RandomState(0) + init = rng.randn(N, B, d).astype(np.float32) + ctx = init + rng.randn(N, B, d).astype(np.float32) * 0.1 + bnames = [f"block_{i}" for i in range(B)] + per_block, per_subj = compute_contextualization_change(init, ctx, bnames) + + required_pb = { + "block_id", "mean_context_delta_l2", "std_context_delta_l2", + "mean_context_delta_cosine", "std_context_delta_cosine", + "context_change_rank", + } + missing = required_pb - set(per_block.columns) + assert not missing, f"per_block missing columns: {missing}" + + required_ps = { + "mean_context_delta_l2_per_subject", + "mean_context_delta_cosine_per_subject", + } + missing2 = required_ps - set(per_subj.columns) + assert not missing2, f"per_subj missing columns: {missing2}" + + +def test_ctx_change_row_counts(): + """per_block must have B rows; per_subj must have N rows.""" + N, B, d = 10, 5, 8 + init = np.zeros((N, B, d), dtype=np.float32) + ctx = np.ones((N, B, d), dtype=np.float32) + bnames = [f"block_{i}" for i in range(B)] + per_block, per_subj = compute_contextualization_change(init, ctx, bnames) + assert len(per_block) == B, f"Expected {B} rows, got {len(per_block)}" + assert len(per_subj) == N, f"Expected {N} rows, got {len(per_subj)}" + + +def test_ctx_change_l2_nonneg(): + """Mean L2 context-change must be >= 0 for all blocks.""" + N, B, d = 8, 3, 6 + rng = np.random.RandomState(2) + init = rng.randn(N, B, d).astype(np.float32) + ctx = rng.randn(N, B, d).astype(np.float32) + bnames = [f"b{i}" for i in range(B)] + per_block, _ = compute_contextualization_change(init, ctx, bnames) + assert (per_block["mean_context_delta_l2"] >= 0).all() + + +def test_ctx_change_identical_inputs_zero_delta(): + """When initial == contextual, L2 delta must be ~0 for all blocks.""" + N, B, d = 6, 3, 4 + rng = np.random.RandomState(3) + x = rng.randn(N, B, d).astype(np.float32) + bnames = [f"b{i}" for i in range(B)] + per_block, per_subj = compute_contextualization_change(x, x, bnames) + assert np.allclose(per_block["mean_context_delta_l2"].values, 0.0, atol=1e-5) + assert np.allclose(per_subj["mean_context_delta_l2_per_subject"].values, 0.0, atol=1e-5) + + +def test_ctx_change_rank_is_permutation(): + """context_change_rank must be a permutation of 1..B.""" + N, B, d = 12, 5, 8 + rng = np.random.RandomState(4) + init = rng.randn(N, B, d).astype(np.float32) + ctx = rng.randn(N, B, d).astype(np.float32) + bnames = [f"b{i}" for i in range(B)] + per_block, _ = compute_contextualization_change(init, ctx, bnames) + ranks = sorted(per_block["context_change_rank"].tolist()) + assert ranks == list(range(1, B + 1)), f"Ranks not a permutation of 1..{B}: {ranks}" + + +# ── Pool attention entropy / top-k mass ──────────────────────────────────────── + +def test_pool_attn_entropy_positive(): + """Entropy must be > 0 for non-degenerate attention distributions.""" + pa = np.array([[0.25, 0.25, 0.25, 0.25], + [0.70, 0.10, 0.10, 0.10]]) + assert _pool_attn_entropy(pa) > 0 + + +def test_pool_attn_entropy_uniform_max(): + """Uniform distribution has higher entropy than peaked distribution.""" + B = 8 + uniform = np.full((20, B), 1.0 / B) + peaked = np.zeros((20, B)); peaked[:, 0] = 1.0 + assert _pool_attn_entropy(uniform) > _pool_attn_entropy(peaked) + + +def test_pool_attn_topk_mass_bounds(): + """top-1 mass in [0,1]; top-all mass == 1.0.""" + pa = np.array([[0.1, 0.5, 0.3, 0.1], + [0.25, 0.25, 0.25, 0.25]]) + m1 = _pool_attn_topk_mass(pa, 1) + m4 = _pool_attn_topk_mass(pa, 4) + assert 0.0 <= m1 <= 1.0 + assert math.isclose(m4, 1.0, abs_tol=1e-5) + + +# ── Part D: PCA baseline ─────────────────────────────────────────────────────── + +sklearn = pytest.importorskip("sklearn", reason="scikit-learn not installed") + + +def test_pca_baseline_n_components_bounded_by_n_train(synthetic_data, tmp_path): + """n_components must be <= n_train - 1 (PCA cannot exceed training samples).""" + emb, tr_ix, va_ix, te_ix, B, d_in = synthetic_data + n_comp, va_mse, te_mse = run_pca_baseline( + emb, tr_ix, va_ix, te_ix, + block_dims=None, latent_mask=None, + n_components_req=8, out_dir=tmp_path / "pca", + ) + assert n_comp <= len(tr_ix) - 1, ( + f"n_comp={n_comp} > n_train-1={len(tr_ix)-1}" + ) + assert np.isfinite(va_mse), "val MSE is not finite" + + +def test_pca_baseline_creates_required_files(synthetic_data, tmp_path): + """PCA baseline must create embeddings, CSV, and summary file.""" + emb, tr_ix, va_ix, te_ix, *_ = synthetic_data + run_pca_baseline( + emb, tr_ix, va_ix, te_ix, + block_dims=None, latent_mask=None, + n_components_req=8, out_dir=tmp_path / "pca2", + ) + assert (tmp_path / "pca2" / "pca_subject_embeddings.npy").exists() + assert (tmp_path / "pca2" / "pca_subject_embeddings.csv").exists() + assert (tmp_path / "pca2" / "pca_baseline_summary.csv").exists() + + +def test_pca_baseline_with_heterogeneous_block_dims(synthetic_data, tmp_path): + """PCA baseline must work when block_dims varies per block.""" + emb, tr_ix, va_ix, te_ix, B, d_in = synthetic_data + # block_dims must be <= d_in for each block + block_dims = [3, 4, 5, 6] + n_comp, va_mse, _ = run_pca_baseline( + emb, tr_ix, va_ix, te_ix, + block_dims=block_dims, latent_mask=None, + n_components_req=8, out_dir=tmp_path / "pca3", + ) + assert np.isfinite(va_mse) + # flat_dim = sum(block_dims) = 18; n_comp <= min(13, 18, 8) = 8 + assert n_comp <= 8 + + +def test_pca_embedding_shape(synthetic_data, tmp_path): + """PCA subject embeddings must have n_components columns.""" + emb, tr_ix, va_ix, te_ix, *_ = synthetic_data + n_comp, _, _ = run_pca_baseline( + emb, tr_ix, va_ix, te_ix, + block_dims=None, latent_mask=None, + n_components_req=6, out_dir=tmp_path / "pca4", + ) + arr = np.load(tmp_path / "pca4" / "pca_subject_embeddings.npy") + assert arr.shape == (20, n_comp), f"Expected (20, {n_comp}), got {arr.shape}" + + +# ── Part E: mean-pool baseline ───────────────────────────────────────────────── + +def test_mean_pool_creates_required_files(tiny_model, synthetic_data, tmp_path): + """Mean-pool baseline must create embeddings and summary files.""" + emb, tr_ix, va_ix, te_ix, *_ = synthetic_data + raw_va, raw_te, eh_va, eh_te = run_mean_pool_baseline( + tiny_model, emb, tr_ix, va_ix, te_ix, + latent_mask=None, latent_mask_t=None, + out_dir=tmp_path / "mp", + ) + assert (tmp_path / "mp" / "mean_pool_subject_embeddings.npy").exists() + assert (tmp_path / "mp" / "mean_pool_subject_embeddings.csv").exists() + assert (tmp_path / "mp" / "mean_pool_baseline_summary.csv").exists() + assert np.isfinite(raw_va), "raw val MSE is not finite" + assert np.isfinite(eh_va), "embedhead val MSE is not finite" + + +def test_mean_pool_embedding_shape(tiny_model, synthetic_data, tmp_path): + """Mean-pool embeddings must have shape (N, d_model).""" + emb, tr_ix, va_ix, te_ix, *_ = synthetic_data + run_mean_pool_baseline( + tiny_model, emb, tr_ix, va_ix, te_ix, + latent_mask=None, latent_mask_t=None, + out_dir=tmp_path / "mp_shape", + ) + arr = np.load(tmp_path / "mp_shape" / "mean_pool_subject_embeddings.npy") + assert arr.shape == (20, 8), f"Expected (20, 8), got {arr.shape}" # d_model=8 + + +def test_mean_pool_val_recon_finite(tiny_model, synthetic_data, tmp_path): + """Both mean-pool val MSEs (raw and embedhead) must be finite.""" + emb, tr_ix, va_ix, te_ix, *_ = synthetic_data + raw_va, _, eh_va, _ = run_mean_pool_baseline( + tiny_model, emb, tr_ix, va_ix, te_ix, + latent_mask=None, latent_mask_t=None, + out_dir=tmp_path / "mp_fin", + ) + assert np.isfinite(raw_va), "raw val MSE is not finite" + assert np.isfinite(eh_va), "embedhead val MSE is not finite" + + +def test_mean_pool_summary_columns(tiny_model, synthetic_data, tmp_path): + """mean_pool_baseline_summary.csv must have both variant column names.""" + emb, tr_ix, va_ix, te_ix, *_ = synthetic_data + run_mean_pool_baseline( + tiny_model, emb, tr_ix, va_ix, te_ix, + latent_mask=None, latent_mask_t=None, + out_dir=tmp_path / "mp_cols", + ) + df = pd.read_csv(tmp_path / "mp_cols" / "mean_pool_baseline_summary.csv") + for col in ("raw_mean_pool_val_recon_loss", "embedhead_mean_pool_val_recon_loss"): + assert col in df.columns, f"Missing column: {col}" + + +# ── Part C: Phase 2 block diagnostics (graceful with missing Phase 1) ────────── + +def test_block_diagnostics_no_crash_missing_phase1(tmp_path): + """run_phase2_block_diagnostics must not crash when vae_summary.csv is absent.""" + block_names = [f"block_{i}" for i in range(4)] + rng = np.random.RandomState(5) + pool_attn = rng.dirichlet(np.ones(4), size=20) + blk_mse = rng.rand(4) + run_phase2_block_diagnostics( + block_names, pool_attn, blk_mse, + ctx_per_block_df=None, + p1_dir=str(tmp_path / "nonexistent"), + lt="ORD", + lt_dir=tmp_path / "p2_diag", + ) + assert (tmp_path / "p2_diag" / "diagnostics" / "phase2_block_diagnostics.csv").exists() + + +def test_block_diagnostics_creates_correlation_summary(tmp_path): + """run_phase2_block_diagnostics must create phase2_attention_diagnostic_summary.csv.""" + block_names = [f"block_{i}" for i in range(5)] + rng = np.random.RandomState(6) + pool_attn = rng.dirichlet(np.ones(5), size=30) + blk_mse = rng.rand(5) + run_phase2_block_diagnostics( + block_names, pool_attn, blk_mse, + ctx_per_block_df=None, + p1_dir=str(tmp_path / "nonexistent"), + lt="ORD", + lt_dir=tmp_path / "p2_corr", + ) + assert ( + tmp_path / "p2_corr" / "diagnostics" / "phase2_attention_diagnostic_summary.csv" + ).exists() + + +def test_block_diagnostics_correct_n_rows(tmp_path): + """phase2_block_diagnostics.csv must have one row per block.""" + B = 6 + block_names = [f"block_{i}" for i in range(B)] + rng = np.random.RandomState(7) + pool_attn = rng.dirichlet(np.ones(B), size=25) + blk_mse = rng.rand(B) + run_phase2_block_diagnostics( + block_names, pool_attn, blk_mse, + ctx_per_block_df=None, + p1_dir=str(tmp_path / "nonexistent"), + lt="ORD", + lt_dir=tmp_path / "p2_rows", + ) + df = pd.read_csv(tmp_path / "p2_rows" / "diagnostics" / "phase2_block_diagnostics.csv") + assert len(df) == B, f"Expected {B} rows, got {len(df)}" + + +# ── Multi-token pooling ──────────────────────────────────────────────────────── + +@pytest.mark.parametrize("K", [1, 4]) +def test_multi_pool_embedding_dim(K): + """Embedding dim must be K * d_model.""" + torch.manual_seed(K) + d_model = 8 + m = AttentionAggregator( + n_blocks=4, d_in=6, d_model=d_model, n_heads=2, + n_layers=1, d_ff=16, dropout=0.0, n_pool_tokens=K, + ) + m.eval() + x = torch.randn(5, 4, 6) + with torch.no_grad(): + emb, pool_attn, _ = m.encode(x) + assert emb.shape == (5, K * d_model), ( + f"K={K}: expected emb shape (5, {K * d_model}), got {tuple(emb.shape)}" + ) + assert m.emb_dim == K * d_model + + +@pytest.mark.parametrize("K", [1, 4]) +def test_multi_pool_attn_shape_and_sums(K): + """pool_attn returned by encode() must be (batch, B) and sum to ~1 per subject.""" + torch.manual_seed(K + 10) + m = AttentionAggregator( + n_blocks=4, d_in=6, d_model=8, n_heads=2, + n_layers=1, d_ff=16, dropout=0.0, n_pool_tokens=K, + ) + m.eval() + x = torch.randn(7, 4, 6) + with torch.no_grad(): + _, pool_attn, _ = m.encode(x) + assert pool_attn.shape == (7, 4), ( + f"K={K}: expected pool_attn (7, 4), got {tuple(pool_attn.shape)}" + ) + sums = pool_attn.sum(dim=-1) + assert torch.allclose(sums, torch.ones(7), atol=1e-5), ( + f"K={K}: pool_attn row sums not ~1: {sums.tolist()}" + ) + + +@pytest.mark.parametrize("K", [1, 4]) +def test_multi_pool_forward_recon_shape(K): + """forward() must return recon of shape (batch, B, d_in) regardless of K.""" + torch.manual_seed(K + 20) + B, d_in = 4, 6 + m = AttentionAggregator( + n_blocks=B, d_in=d_in, d_model=8, n_heads=2, + n_layers=1, d_ff=16, dropout=0.0, n_pool_tokens=K, + ) + m.eval() + x = torch.randn(5, B, d_in) + with torch.no_grad(): + recon, emb, pool_attn = m(x, return_self_attn=False) + assert recon.shape == (5, B, d_in), ( + f"K={K}: expected recon (5, {B}, {d_in}), got {tuple(recon.shape)}" + ) + + +def test_multi_pool_by_token_shape(): + """_extract_pool_attn_by_token must return (N, K, B).""" + K, B, d_in = 4, 4, 6 + torch.manual_seed(99) + m = AttentionAggregator( + n_blocks=B, d_in=d_in, d_model=8, n_heads=2, + n_layers=1, d_ff=16, dropout=0.0, n_pool_tokens=K, + ) + m.eval() + rng = np.random.RandomState(9) + data = rng.randn(20, B, d_in).astype(np.float32) + out = _extract_pool_attn_by_token(m, data, batch_size=8) + assert out.shape == (20, K, B), f"Expected (20, {K}, {B}), got {out.shape}" + + +def test_k1_by_token_equals_mean_pool_attn(): + """For K=1, _extract_pool_attn_by_token squeezed must match encode() pool_attn.""" + torch.manual_seed(77) + B, d_in = 4, 6 + m = AttentionAggregator( + n_blocks=B, d_in=d_in, d_model=8, n_heads=2, + n_layers=1, d_ff=16, dropout=0.0, n_pool_tokens=1, + ) + m.eval() + rng = np.random.RandomState(8) + data = rng.randn(16, B, d_in).astype(np.float32) + by_token = _extract_pool_attn_by_token(m, data, batch_size=8) # (N, 1, B) + with torch.no_grad(): + _, pool_attn, _ = m.encode(torch.tensor(data)) # (N, B) + pool_attn_np = pool_attn.numpy() + assert np.allclose(by_token[:, 0, :], pool_attn_np, atol=1e-5), ( + "K=1 by_token[:, 0, :] must match encode() pool_attn" + ) diff --git a/tests/test_smoke_outputs.py b/tests/test_smoke_outputs.py index 5392a23..18af67a 100644 --- a/tests/test_smoke_outputs.py +++ b/tests/test_smoke_outputs.py @@ -15,10 +15,15 @@ _P1_OUT = _ROOT / "results" / "synthetic_test" _P2_OUT = _ROOT / "results" / "synthetic_test2" -_ALL_BLOCKS = _P1_OUT / "MSE" / "embeddings" / "all_blocks.npy" -_INDIV_EMB = _P2_OUT / "ORD" / "embeddings" / "individual_embeddings.npy" -_POOL_ATTN = _P2_OUT / "ORD" / "embeddings" / "pooling_attention_weights.csv" -_P2_SUMMARY = _P2_OUT / "phase2_summary.csv" +_ALL_BLOCKS = _P1_OUT / "MSE" / "embeddings" / "all_blocks.npy" +_INDIV_EMB = _P2_OUT / "ORD" / "embeddings" / "individual_embeddings.npy" +_POOL_ATTN = _P2_OUT / "ORD" / "embeddings" / "pooling_attention_weights.csv" +_P2_SUMMARY = _P2_OUT / "phase2_summary.csv" +_INITIAL_REPR = _P2_OUT / "ORD" / "embeddings" / "block_initial_repr.npy" +_CTX_CHANGE_BLOCK = _P2_OUT / "ORD" / "embeddings" / "per_block_contextualization_change.csv" +_P2_BLOCK_DIAG = _P2_OUT / "ORD" / "diagnostics" / "phase2_block_diagnostics.csv" +_PCA_SUMMARY = _P2_OUT / "ORD" / "baselines" / "pca_baseline_summary.csv" +_MP_SUMMARY = _P2_OUT / "ORD" / "baselines" / "mean_pool_baseline_summary.csv" @pytest.mark.skipif(not _ALL_BLOCKS.exists(), reason="Phase 1 synthetic outputs not generated") @@ -59,3 +64,64 @@ def test_phase2_summary_exists_and_nonempty(): import pandas as pd df = pd.read_csv(_P2_SUMMARY) assert len(df) > 0, "phase2_summary.csv is empty" + + +@pytest.mark.skipif( + not (_P2_SUMMARY.exists() and _INITIAL_REPR.exists()), + reason="Phase 2 synthetic outputs not generated (or pre-date diagnostics)" +) +def test_phase2_summary_diagnostic_columns(): + """phase2_summary.csv must contain the new diagnostic columns.""" + import pandas as pd + df = pd.read_csv(_P2_SUMMARY) + required = { + "mean_pool_attn_entropy", + "mean_context_delta_l2", + "pca_val_recon_loss", + "raw_mean_pool_val_recon_loss", + "embedhead_mean_pool_val_recon_loss", + } + missing = required - set(df.columns) + assert not missing, f"phase2_summary.csv missing columns: {missing}" + + +@pytest.mark.skipif(not _INITIAL_REPR.exists(), reason="Phase 2 synthetic outputs not generated") +def test_block_initial_repr_shape(): + """block_initial_repr.npy must be (N, B, d_model) = (30, 4, 16) for synthetic config.""" + import numpy as np + arr = np.load(_INITIAL_REPR) + assert arr.ndim == 3, f"Expected 3-D array, got shape {arr.shape}" + assert arr.shape[0] == 30, f"Expected 30 subjects, got {arr.shape[0]}" + + +@pytest.mark.skipif(not _CTX_CHANGE_BLOCK.exists(), reason="Phase 2 synthetic outputs not generated") +def test_per_block_ctx_change_has_block_rows(): + """per_block_contextualization_change.csv must have one row per block.""" + import pandas as pd + df = pd.read_csv(_CTX_CHANGE_BLOCK) + assert len(df) == 4, f"Expected 4 block rows, got {len(df)}" + assert "mean_context_delta_l2" in df.columns + + +@pytest.mark.skipif(not _P2_BLOCK_DIAG.exists(), reason="Phase 2 synthetic outputs not generated") +def test_phase2_block_diagnostics_exists(): + """phase2_block_diagnostics.csv must be non-empty.""" + import pandas as pd + df = pd.read_csv(_P2_BLOCK_DIAG) + assert len(df) > 0, "phase2_block_diagnostics.csv is empty" + + +@pytest.mark.skipif(not _PCA_SUMMARY.exists(), reason="Phase 2 synthetic outputs not generated") +def test_pca_baseline_summary_exists(): + """pca_baseline_summary.csv must exist and contain key columns.""" + import pandas as pd + df = pd.read_csv(_PCA_SUMMARY) + assert "pca_val_recon_loss" in df.columns + + +@pytest.mark.skipif(not _MP_SUMMARY.exists(), reason="Phase 2 synthetic outputs not generated") +def test_mean_pool_baseline_summary_exists(): + """mean_pool_baseline_summary.csv must exist and contain key columns.""" + import pandas as pd + df = pd.read_csv(_MP_SUMMARY) + assert "raw_mean_pool_val_recon_loss" in df.columns diff --git a/tests/test_vae_diagnostics.py b/tests/test_vae_diagnostics.py new file mode 100644 index 0000000..8929a4d --- /dev/null +++ b/tests/test_vae_diagnostics.py @@ -0,0 +1,151 @@ +""" +Tests for VAE latent-space diagnostics and free-bits support. +All tests use tiny synthetic tensors; no real data or checkpoints required. +""" +import math +import numpy as np +import pytest + +torch = pytest.importorskip("torch", reason="torch not installed") + +from scripts.core.VAE_phase1 import BlockVAE, compute_latent_diagnostics + + +# ── fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture(scope="module") +def trained_vae(): + """Tiny MSE VAE with known weights (eval mode, no dropout).""" + torch.manual_seed(0) + model = BlockVAE(p=20, d=4, drop=0.0, loss_type="MSE") + model.eval() + return model + + +@pytest.fixture(scope="module") +def va_x(trained_vae): + torch.manual_seed(1) + return torch.rand(30, 20) * 2.0 # simulate genotype [0, 2] + + +# ── compute_latent_diagnostics: key presence and ranges ─────────────────────── + +def test_diag_required_keys(trained_vae, va_x): + diag = compute_latent_diagnostics(trained_vae, va_x) + required = { + "n_active_latents", "frac_dims_collapsed", "latent_underused", + "kl_per_dim_min", "kl_per_dim_median", "kl_per_dim_max", + "mu_var_min", "mu_var_median", "mu_var_max", + "sigma_min", "sigma_median", "sigma_max", + "_kl_per_dim", "_mu_var_per_dim", "_sigma_per_dim", + } + missing = required - diag.keys() + assert not missing, f"compute_latent_diagnostics missing keys: {missing}" + + +def test_n_active_latents_in_range(trained_vae, va_x): + diag = compute_latent_diagnostics(trained_vae, va_x) + d = 4 + assert 0 <= diag["n_active_latents"] <= d, ( + f"n_active_latents={diag['n_active_latents']} out of [0, {d}]" + ) + + +def test_frac_dims_collapsed_in_range(trained_vae, va_x): + diag = compute_latent_diagnostics(trained_vae, va_x) + fdc = diag["frac_dims_collapsed"] + assert 0.0 <= fdc <= 1.0, f"frac_dims_collapsed={fdc} out of [0,1]" + + +def test_frac_dims_collapsed_consistent_with_n_active(trained_vae, va_x): + diag = compute_latent_diagnostics(trained_vae, va_x) + d = 4 + expected_fdc = round(1.0 - diag["n_active_latents"] / d, 4) + assert math.isclose(diag["frac_dims_collapsed"], expected_fdc, abs_tol=1e-4), ( + f"frac_dims_collapsed={diag['frac_dims_collapsed']} inconsistent with " + f"n_active={diag['n_active_latents']} / d={d}" + ) + + +def test_latent_underused_is_bool(trained_vae, va_x): + diag = compute_latent_diagnostics(trained_vae, va_x) + assert isinstance(diag["latent_underused"], bool) + + +def test_diag_arrays_have_correct_length(trained_vae, va_x): + diag = compute_latent_diagnostics(trained_vae, va_x) + d = 4 + assert diag["_kl_per_dim"].shape == (d,), f"_kl_per_dim shape {diag['_kl_per_dim'].shape}" + assert diag["_mu_var_per_dim"].shape == (d,), f"_mu_var_per_dim shape {diag['_mu_var_per_dim'].shape}" + assert diag["_sigma_per_dim"].shape == (d,), f"_sigma_per_dim shape {diag['_sigma_per_dim'].shape}" + + +def test_kl_per_dim_nonneg(trained_vae, va_x): + """Per-dimension KL from N(mu,sigma) to N(0,1) must be >= 0 for a proper posterior.""" + diag = compute_latent_diagnostics(trained_vae, va_x) + kl = diag["_kl_per_dim"] + assert np.all(kl >= -1e-6), f"Negative per-dim KL: {kl}" + + +def test_sigma_per_dim_positive(trained_vae, va_x): + diag = compute_latent_diagnostics(trained_vae, va_x) + assert diag["sigma_min"] > 0, f"sigma_min={diag['sigma_min']} must be > 0" + + +# ── free-bits: zero disables, positive clamps ───────────────────────────────── + +def test_free_bits_zero_matches_standard_kl(): + """free_bits=0.0 must produce the same KL as the original formula.""" + torch.manual_seed(42) + model = BlockVAE(p=10, d=4, drop=0.0, loss_type="MSE") + model.eval() + x = torch.rand(8, 10) + recon, mu, lv = model(x) + + _, _, kl_std = model.compute_loss(x, recon, mu, lv, beta=0.5, free_bits=0.0) + # Reference: pure per-element mean over (B, d) — matches compute_loss exactly + kl_ref = (-0.5 * torch.mean(1 + lv - mu.pow(2) - lv.exp())).item() + assert math.isclose(kl_std.item(), kl_ref, rel_tol=1e-5), ( + f"free_bits=0.0 KL={kl_std.item():.6f} != reference {kl_ref:.6f}" + ) + + +def test_free_bits_positive_no_crash(): + torch.manual_seed(7) + model = BlockVAE(p=10, d=4, drop=0.0, loss_type="MSE") + x = torch.rand(8, 10) + recon, mu, lv = model(x) + loss, rl, kl = model.compute_loss(x, recon, mu, lv, beta=0.5, free_bits=0.5) + assert torch.isfinite(loss), f"loss is not finite with free_bits=0.5" + assert kl.item() >= 0.0, f"KL with free_bits should be >= 0" + + +def test_free_bits_positive_raises_kl(): + """With free_bits > 0, KL should be >= free_bits * d (all dims clamped at least).""" + torch.manual_seed(3) + model = BlockVAE(p=10, d=4, drop=0.0, loss_type="MSE") + # Force near-zero posterior: small weights → mu≈0, lv≈0 → KL per dim ≈ 0 + with torch.no_grad(): + for p in model.parameters(): + p.zero_() + x = torch.rand(8, 10) + recon, mu, lv = model(x) + free_bits = 0.3 + _, _, kl_fb = model.compute_loss(x, recon, mu, lv, beta=1.0, free_bits=free_bits) + assert kl_fb.item() >= free_bits * 4 - 1e-5, ( + f"KL={kl_fb.item():.4f} should be >= free_bits*d={free_bits*4:.4f}" + ) + + +# ── beta_eff_at_best is derivable from config (sanity check) ────────────────── + +def test_beta_eff_at_best_formula(): + beta_max = 0.5 + beta_warmup = 50 + best_epoch = 25 + expected = beta_max * min(1.0, best_epoch / beta_warmup) + assert math.isclose(expected, 0.25, rel_tol=1e-9) + + best_epoch_after = 100 + expected_full = beta_max * min(1.0, best_epoch_after / beta_warmup) + assert math.isclose(expected_full, beta_max, rel_tol=1e-9) From 674e2a812004acafc074788c85a48ba97c701968 Mon Sep 17 00:00:00 2001 From: Shraddha Piparia Date: Mon, 25 May 2026 07:41:38 -0700 Subject: [PATCH 2/2] Fix CAT loss class weight attribute in BlockVAE --- scripts/core/VAE_phase1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/core/VAE_phase1.py b/scripts/core/VAE_phase1.py index 411537d..a9a6f30 100644 --- a/scripts/core/VAE_phase1.py +++ b/scripts/core/VAE_phase1.py @@ -355,7 +355,7 @@ def compute_loss(self, x_in, recon, mu, lv, beta, y=None, free_bits=0.0): targets = y.reshape(-1) ce = F.cross_entropy(logits, targets, reduction="mean", - weight=self.ce_w) + weight=self.class_w) rl = ce elif self.loss_type == "ORD": if y is None: