Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ losing local genomic context. This project takes a hierarchical approach:

1. **Phase 1 — Per-block β-VAE.** Each LD block is encoded independently into a
low-dimensional latent vector that captures local haplotype structure.
2. **Phase 2 — Cross-block Transformer.** A Transformer aggregates all block embeddings
2. **Phase 2 — Cross-block Attention.** Phase 2 aggregates all block embeddings
into a single subject-level representation. Learned attention weights identify which
blocks are most informative for organizing genetic variation across individuals.

Expand Down
36 changes: 22 additions & 14 deletions configs/config_phase1.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
data:
raw_dir: "data/region_blocks"
block_def: "data/block_plan/manifest.tsv"
output_dir: "results/output_regions"
output_dir: "results/output_regions_ord_weighted"
runtime:
device: "cpu"
vae:
Expand All @@ -12,28 +12,36 @@ vae:
- [150, 299, 12]
- [300, 799, 16]
- [800, 1000000000, 16]
dropout: 0.3
lr: 0.001
dropout: 0.1
lr: 0.002
batch_size: 64
epochs: 200
beta_max: 0.5
beta_max: 1 # tuned: best for both ORD and MSE after per-element-mean normalization
beta_warmup: 50
patience: 20
patience: 5
grad_clip: 1.0
val_frac: 0.2
val_frac: 0.20
test_frac: 0.10
seed: 42
cat_weight_clip: 10.0
loss_functions: ["MSE","ORD","MSE_STD","CAT","BCE"]
# loss_functions: ["ORD"]
# VAE diagnostics / behavior-preserving defaults
free_bits: 0.05
ld_corr_repeats: 1
ld_corr_max_snps: 200
ord_weighted: true # apply per-sample class weights to ordinal_loss
ord_weight_clip: 10.0 # same convention as cat_weight_clip
standardize_embeddings: true
# loss_functions: ["MSE","ORD","MSE_STD","CAT","BCE"]
loss_functions: ["ORD"]
tuning:
enabled: true
loss: "ORD" # loss function to tune on
metric: "bal_acc_va" # or "best_val_loss", "ld_corr_va"
blocks: ["region_17q21_core_sb1", "region_11q13_FCER1A","region_2q12_IL1RL1_cluster_sb3","region_5q21_PDE4D_sb55", "region_6p21_HLA_classII_sb1","region_5q31_type2_cytokine_sb9", "region_1q31_TNFSF_cluster_sb4", "control_OCA2_sb10"] # list of block_ids to tune on, empty means all
loss: "ORD"
metric: "concordance_gain_va"
blocks: ["region_17q21_core_sb1", "region_11q13_FCER1A","region_2q12_IL1RL1_cluster_sb3","region_5q21_PDE4D_sb55", "region_6p21_HLA_classII_sb1","region_5q31_type2_cytokine_sb9", "region_1q31_TNFSF_cluster_sb4", "control_OCA2_sb10"]
params:
dropout: [0.1, 0.3, 0.5]
lr: [0.0005, 0.001, 0.002]
beta_max: [0.3, 0.5]
dropout: [0.1, 0.3]
lr: [0.001, 0.002]
beta_max: [0.5, 1.0, 2.0, 4.0]
representative:
blocks: ["region_17q21_core_sb1", "region_11q13_FCER1A","region_2q12_IL1RL1_cluster_sb3","region_5q21_PDE4D_sb55", "region_6p21_HLA_classII_sb1","region_5q31_type2_cytokine_sb9", "region_1q31_TNFSF_cluster_sb4", "control_OCA2_sb10"] # list of block_ids to tune on, empty means all
metric: "bal_acc_va" # metric for top/bottom selection
Expand Down
15 changes: 8 additions & 7 deletions configs/config_phase2.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
phase1_dir: results/output_regions
output_dir: results/output_regions2
attention:
d_model: 64
d_model: 128
n_heads: 4
n_layers: 2
d_ff: 128
dropout: 0.1
d_ff: 512
n_pool_tokens: 4
dropout: 0.30
epochs: 150
patience: 20
lr: 0.0005
weight_decay: 0.0001
weight_decay: 0.001
batch_size: 64
epochs: 300
patience: 30
grad_clip: 1.0
seed: 42
extract_self_attn: true
Expand All @@ -23,4 +24,4 @@ clustering:
umap_n_neighbors: 15
umap_min_dist: 0.1
umap_seed: 42
loss_functions: [ORD]
loss_functions: [ORD_W_Scaled]
Loading
Loading