Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/recipe/recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ requirements:
- r-readr
- r-rfast
- r-rlang
- r-softimpute
- r-stringr
- r-susier
- r-tibble
Expand Down Expand Up @@ -131,6 +132,7 @@ requirements:
- r-readr
- r-rfast
- r-rlang
- r-softimpute
- r-stringr
- r-susier
- r-tibble
Expand Down
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ Suggests:
SNPRelate,
snpStats,
testthat,
VariantAnnotation,
xgboost
VariantAnnotation
Remotes:
stephenslab/fsusieR,
stephenslab/mvsusieR,
Expand Down
6 changes: 4 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export(bayes_c_weights)
export(bayes_l_weights)
export(bayes_n_weights)
export(bayes_r_weights)
export(build_mrmash_prior_matrices)
export(build_top_loci)
export(check_ld)
export(classify_variant_type)
Expand Down Expand Up @@ -176,9 +177,11 @@ export(mr_analysis)
export(mr_ash_rss_weights)
export(mr_format)
export(mrash_weights)
export(mrmash_rss_weights)
export(mrmash_weights)
export(mrmash_wrapper)
export(multivariate_analysis_pipeline)
export(mvsusie_rss_weights)
export(mvsusie_weights)
export(nSnps)
export(normalize_variant_id)
Expand Down Expand Up @@ -229,6 +232,7 @@ export(trim_ctwas_variants)
export(twas_analysis)
export(twas_joint_z)
export(twas_multivariate_weights_pipeline)
export(twas_multivariate_weights_sumstat_pipeline)
export(twas_pipeline)
export(twas_predict)
export(twas_weights)
Expand All @@ -240,7 +244,6 @@ export(univariate_analysis_pipeline)
export(update_mash_model_cov)
export(wald_test_pval)
export(writeSumstatsVcf)
export(xgboost_imputation)
export(xqtl_enrichment_wrapper)
export(z_to_pvalue)
exportClasses(AlleleQCResult)
Expand Down Expand Up @@ -320,7 +323,6 @@ import(dplyr)
import(tibble)
import(tidyr)
importFrom(BiocParallel,MulticoreParam)
importFrom(BiocParallel,SerialParam)
importFrom(BiocParallel,bplapply)
importFrom(BiocParallel,bpparam)
importFrom(BiocParallel,bpworkers)
Expand Down
108 changes: 0 additions & 108 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -850,114 +850,6 @@ filter_molecular_events <- function(events, filters, condition = NULL, remove_al
return(filtered_events)
}

#' XGBoost-based iterative imputation of missing values
#'
#' Imputes missing values in a numeric matrix by iteratively training
#' per-column XGBoost models on observed entries and predicting missing ones.
#' Columns that are entirely missing are removed. Initial imputation uses
#' column means.
#'
#' @param data Numeric matrix with missing values (NA).
#' @param maxiter Maximum number of imputation iterations (default 10).
#' @param max_depth Maximum tree depth for XGBoost (default 2).
#' @param nrounds Number of boosting rounds per variable (default 50).
#' @param decreasing Logical. If TRUE, impute variables with most missing
#' values first. Default FALSE (fewest missing first).
#' @param num_workers Number of parallel workers for BiocParallel. Default 1
#' (sequential).
#' @param verbose Logical, print progress (default TRUE).
#' @return The imputed matrix with the same dimensions as the input (minus
#' any all-NA columns).
#' @importFrom BiocParallel MulticoreParam SerialParam bplapply
#' @export
xgboost_imputation <- function(data, maxiter = 10L, max_depth = 2L,
nrounds = 50L, decreasing = FALSE,
num_workers = 1L, verbose = TRUE) {
if (!requireNamespace("xgboost", quietly = TRUE))
stop("Package 'xgboost' is required for xgboost_imputation")

xmis <- as.matrix(data)
n <- nrow(xmis)
p <- ncol(xmis)

# Remove completely missing columns
all_na <- colSums(is.na(xmis)) == n
if (any(all_na)) {
if (verbose)
message("Removed ", sum(all_na), " column(s) with all entries missing.")
xmis <- xmis[, !all_na, drop = FALSE]
p <- ncol(xmis)
}

# Initial mean imputation
ximp <- xmis
col_means <- colMeans(xmis, na.rm = TRUE)
for (j in seq_len(p)) {
ximp[is.na(xmis[, j]), j] <- col_means[j]
}

# Missing value locations
NAloc <- is.na(xmis)
noNAvar <- colSums(NAloc)
sort_j <- order(noNAvar, decreasing = decreasing)
nzsort_j <- sort_j[noNAvar[sort_j] > 0]

if (length(nzsort_j) == 0) {
if (verbose) message("No missing values to impute.")
return(ximp)
}

# Set up BiocParallel
if (num_workers > 1L) {
BPPARAM <- MulticoreParam(workers = num_workers)
} else {
BPPARAM <- SerialParam()
}

iter <- 0L
conv_new <- 0
conv_old <- Inf
ximp_history <- vector("list", maxiter)

while (conv_new < conv_old && iter < maxiter) {
if (iter > 0) conv_old <- conv_new
if (verbose) message(" XGBoost iteration ", iter + 1L, " in progress...")

ximp_old <- ximp

# Impute each variable with missing values
impute_one <- function(var_idx) {
obsi <- !NAloc[, var_idx]
misi <- NAloc[, var_idx]
obsY <- ximp[obsi, var_idx]
obsX <- ximp[obsi, -var_idx, drop = FALSE]
misX <- ximp[misi, -var_idx, drop = FALSE]

xgb_train <- xgboost::xgb.DMatrix(data = obsX, label = obsY)
xgb_pred <- xgboost::xgb.DMatrix(data = misX)
model <- xgboost::xgb.train(
params = list(max_depth = max_depth, verbosity = 0),
data = xgb_train, nrounds = nrounds)
list(var_idx = var_idx, predicted = predict(model, xgb_pred))
}

results <- bplapply(nzsort_j, impute_one, BPPARAM = BPPARAM)

for (res in results) {
misi <- NAloc[, res$var_idx]
ximp[misi, res$var_idx] <- res$predicted
}

iter <- iter + 1L
ximp_history[[iter]] <- ximp

# Convergence: relative change in imputed values
conv_new <- sum((ximp - ximp_old)^2) / sum(ximp^2)
}

# Return last improving iteration
if (iter == maxiter) ximp_history[[iter]] else ximp_history[[max(iter - 1L, 1L)]]
}

#' Robust Mahalanobis Distance
#'
Expand Down
149 changes: 96 additions & 53 deletions R/mrmash_wrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,16 @@ mrmash_wrapper <- function(X,
)
}

prior_grid <- compute_grid(bhat = sumstats$Bhat, sbhat = sumstats$Shat)

# Compute canonical matrices, if requested
if (isTRUE(canonical_prior_matrices)) {
canonical_prior_matrices <- mr.mashr::compute_canonical_covs(ncol(Y),
singletons = TRUE,
hetgrid = c(0, 0.25, 0.5, 0.75, 1)
)
if (!is.null(data_driven_prior_matrices)) {
S0_raw <- c(canonical_prior_matrices, data_driven_prior_matrices$U)
} else {
S0_raw <- canonical_prior_matrices
}
} else {
S0_raw <- data_driven_prior_matrices$U
}

# Compute prior covariance
S0 <- mr.mashr::expand_covs(S0_raw, prior_grid, zeromat = TRUE)
# Build prior covariance via shared helper (also used by mrmash_rss_weights)
prior_built <- build_mrmash_prior_matrices(
Bhat = sumstats$Bhat, Shat = sumstats$Shat,
K = ncol(Y),
data_driven_prior_matrices = data_driven_prior_matrices,
canonical_prior_matrices = canonical_prior_matrices,
prior_grid = prior_grid
)
S0 <- prior_built$S0
prior_grid <- prior_built$prior_grid
time1 <- proc.time()

if (B_init_method == "glasso") {
Expand Down Expand Up @@ -372,46 +363,36 @@ autoselect_mixsd <- function(gmin, gmax, mult = 2) {
#' Compute covariance matrix using FLASH
#'
#' Estimates a covariance matrix from a data matrix Y using empirical Bayes
#' matrix factorization (flashier). Falls back to an identity matrix on failure
#' if error_cache is provided.
#' matrix factorization (\code{flashier::flash}). When the FLASH fit finds
#' no shared factors, the returned covariance is diagonal with entries
#' \code{residuals_sd^2}; otherwise the factor contribution is added.
#' FLASH errors are not caught; callers should handle them explicitly or
#' supply a pre-computed prior covariance instead.
#'
#' @param Y Numeric matrix (samples x conditions).
#' @param error_cache Optional file path to save diagnostics on FLASH failure.
#' When NULL (default), errors propagate; when set, saves a list with data
#' and message to this path and falls back to the identity matrix.
#' @return A covariance matrix of dimension ncol(Y) x ncol(Y), rescaled by
#' column standard deviations.
#' the column standard deviations of Y.
#' @export
compute_cov_flash <- function(Y, error_cache = NULL) {
covar <- diag(ncol(Y))
tryCatch({
fl <- flashier::flash(Y, var.type = 2,
prior.family = c(flashier::prior.normal(),
flashier::prior.normal.scale.mix()),
backfit = TRUE, verbose.lvl = 0)
if (fl$n.factors == 0) {
covar <- diag(fl$residuals.sd^2)
} else {
fsd <- sapply(fl$fitted.g[[1]], "[[", "sd")
covar <- diag(fl$residuals.sd^2) + crossprod(t(fl$flash.fit$EF[[2]]) * fsd)
}
if (nrow(covar) == 0) {
covar <- diag(ncol(Y))
stop("Computed covariance matrix has zero rows")
}
}, error = function(e) {
if (!is.null(error_cache)) {
saveRDS(list(data = Y, message = warning(e)), error_cache)
warning("FLASH failed. Using Identity matrix instead.")
warning(e)
} else {
stop(e)
}
})
compute_cov_flash <- function(Y) {
# flashier >= 1.0 API: var_type / ebnm_fn / verbose (renamed from var.type
# / prior.family / verbose.lvl). Prior families now come from `ebnm`.
fl <- flashier::flash(Y, var_type = 2,
ebnm_fn = c(ebnm::ebnm_normal, ebnm::ebnm_normal_scale_mixture),
backfit = TRUE, verbose = 0)
if (fl$n_factors == 0) {
covar <- diag(fl$residuals_sd^2)
} else {
# For each factor's right-side prior, marginal variance for a
# mean-zero scale-mixture-of-normals is sum(pi * sd^2).
fsd <- vapply(fl$F_ghat, function(g) sqrt(sum(g$pi * g$sd^2)), numeric(1))
covar <- diag(fl$residuals_sd^2) + crossprod(t(fl$F_pm) * fsd)
}
if (nrow(covar) == 0) {
stop("compute_cov_flash: FLASH produced an empty covariance matrix.")
}
s <- apply(Y, 2, sd, na.rm = TRUE)
if (length(s) > 1) s <- diag(s) else s <- matrix(s, 1, 1)
covar <- s %*% cov2cor(covar) %*% s
return(covar)
s %*% cov2cor(covar) %*% s
}

#' Compute diagonal covariance matrix
Expand All @@ -424,3 +405,65 @@ compute_cov_flash <- function(Y, error_cache = NULL) {
compute_cov_diag <- function(Y) {
diag(apply(Y, 2, var, na.rm = TRUE))
}

#' Build mr.mash prior covariance matrices
#'
#' Shared helper used by both \code{\link{mrmash_wrapper}} (individual-level)
#' and \code{\link{mrmash_rss_weights}} (summary statistics). Constructs the
#' \code{S0} list of prior covariance matrices via the canonical mixture
#' (\code{mr.mashr::compute_canonical_covs}) and optional data-driven
#' matrices, expanded over a scaling grid via
#' \code{mr.mashr::expand_covs}. The prior grid is derived from \code{Bhat}
#' and \code{Shat} via \code{\link{compute_grid}} when not supplied.
#'
#' @param Bhat Numeric matrix of effect-size estimates (variants x conditions).
#' @param Shat Numeric matrix of standard errors (variants x conditions).
#' @param K Number of conditions. When NULL, inferred from \code{ncol(Bhat)}.
#' @param data_driven_prior_matrices Optional list with element \code{U}
#' (list of raw covariance matrices) computed e.g. by
#' \code{\link{compute_cov_flash}} / \code{\link{compute_cov_diag}}.
#' @param canonical_prior_matrices Logical. When TRUE (default for RSS),
#' include the standard canonical mixture from
#' \code{mr.mashr::compute_canonical_covs()}. When FALSE,
#' \code{data_driven_prior_matrices} must be supplied.
#' @param prior_grid Optional pre-computed scaling grid (numeric vector).
#' When NULL, derived from \code{Bhat}, \code{Shat} via
#' \code{compute_grid()}.
#' @param hetgrid Heterogeneity grid passed to
#' \code{mr.mashr::compute_canonical_covs()}. Default
#' \code{c(0, 0.25, 0.5, 0.75, 1)}, matching the individual-level wrapper.
#' @param singletons Whether to include single-condition prior components.
#' Default TRUE.
#' @return A list with components \code{S0} (the expanded list of prior
#' covariance matrices) and \code{prior_grid} (the scaling grid that was
#' used).
#' @export
build_mrmash_prior_matrices <- function(Bhat, Shat, K = NULL,
data_driven_prior_matrices = NULL,
canonical_prior_matrices = TRUE,
prior_grid = NULL,
hetgrid = c(0, 0.25, 0.5, 0.75, 1),
singletons = TRUE) {
if (!requireNamespace("mr.mashr", quietly = TRUE)) {
stop("Package 'mr.mashr' is required.")
}
if (is.null(data_driven_prior_matrices) && !isTRUE(canonical_prior_matrices)) {
stop("Supply data_driven_prior_matrices or set canonical_prior_matrices = TRUE.")
}
if (is.null(K)) K <- ncol(Bhat)
if (is.null(prior_grid)) prior_grid <- compute_grid(bhat = Bhat, sbhat = Shat)

if (isTRUE(canonical_prior_matrices)) {
canonical <- mr.mashr::compute_canonical_covs(K, singletons = singletons, hetgrid = hetgrid)
S0_raw <- if (!is.null(data_driven_prior_matrices)) {
c(canonical, data_driven_prior_matrices$U)
} else {
canonical
}
} else {
S0_raw <- data_driven_prior_matrices$U
}

S0 <- mr.mashr::expand_covs(S0_raw, prior_grid, zeromat = TRUE)
list(S0 = S0, prior_grid = prior_grid)
}
Loading