Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
- Added support for the RFCI algorithm via `rfci()`, interfacing with
implementations from pcalg and Tetrad.

- The infix edge operators `%-->%` and `%!-->%` in `knowledge()` now accept `+`
on both sides to specify multiple variables, e.g. `A + B %-->% C + D` is
equivalent to `c(A, B) %-->% c(C, D)`.

- The infix edge operators `%-->%` and `%!-->%` in `knowledge()` now support
tidyselect set operations such as `!`, `&`, and `|` on either side, e.g.
`child_x1 %!-->% !starts_with("child")`.

## Deprecated

- The `summary()` methods for `Knowledge` and `Disco` objects are deprecated;
use `print()` instead.

## Bug fixes

- Fixed `knowledge()` dropping variables from `tier()` formulas when infix edge
Expand All @@ -14,6 +27,15 @@

## Improvements

- The `print()` methods for `Knowledge` and `Disco` objects now give a more
concise and readable summary.

- `Disco` objects now store an attribute `graph_class` of the actual learned
graph class.

- `Knowledge` objects generated from `knowledge()` now verifies the knowledge
for requires edges doesn't contain a directed cycle.

- Improved the documentation.

- Reduced the number of package dependencies.
Expand Down
282 changes: 248 additions & 34 deletions R/disco-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,85 @@ as_disco.pcAlgo <- function(
if (!is_knowledge(kn)) {
stop("`kn` must be a Knowledge object.", call. = FALSE)
}
cg <- caugi::as_caugi(graph@graph, collapse = TRUE, class = class)
# Convert via the adjacency matrix rather than the `graphNEL` slot so that
# bidirected conflict edges (\pkg{pcalg} amat code 2, produced with
# `solve.confl = TRUE`) are preserved as `<->` instead of being collapsed to
# undirected edges.
amat <- methods::as(graph, "amat")
nodes <- colnames(amat)
edges <- .pcalg_amat_to_edges(amat, nodes)

if (nrow(edges) == 0L) {
cg <- caugi::caugi(nodes = nodes, class = class)
} else {
cg_class <- if (any(edges$edge == "<->")) "UNKNOWN" else class
cg <- caugi::caugi(
from = edges$from,
edge = edges$edge,
to = edges$to,
nodes = nodes,
class = cg_class
)
}
new_disco(cg, kn)
}

#' @title Convert a \pkg{pcalg} CPDAG Adjacency Matrix to caugi Edges
#'
#' @description
#' Decodes a \pkg{pcalg} `amat` of type `"cpdag"` into a data frame of caugi
#' edge triplets. The \pkg{pcalg} coding for an unordered pair `(i, j)` is:
#' `amat[i,j] = 1, amat[j,i] = 0` means `j -> i`; `amat[i,j] = 1, amat[j,i] = 1`
#' means `i -- j`; and `amat[i,j] = 2, amat[j,i] = 2` (only with
#' `solve.confl = TRUE`) means the bidirected conflict edge `i <-> j`.
#'
#' @param amat A \pkg{pcalg} adjacency matrix of type `"cpdag"`.
#' @param nodes The node names (column names of `amat`).
#' @returns A data frame with character columns `from`, `edge`, and `to`.
#' @keywords internal
#' @noRd
.pcalg_amat_to_edges <- function(amat, nodes) {
from <- character(0)
edge <- character(0)
to <- character(0)
np <- length(nodes)

if (np >= 2L) {
for (i in seq_len(np - 1L)) {
for (j in seq(i + 1L, np)) {
a <- amat[i, j]
b <- amat[j, i]
if (a == 0 && b == 0) {
next
}
ni <- nodes[i]
nj <- nodes[j]
if (a == 2 || b == 2) {
from <- c(from, ni)
edge <- c(edge, "<->")
to <- c(to, nj)
} else if (a == 1 && b == 1) {
from <- c(from, ni)
edge <- c(edge, "---")
to <- c(to, nj)
} else if (a == 1 && b == 0) {
# amat[i,j] = 1, amat[j,i] = 0 => j -> i
from <- c(from, nj)
edge <- c(edge, "-->")
to <- c(to, ni)
} else {
# amat[i,j] = 0, amat[j,i] = 1 => i -> j
from <- c(from, ni)
edge <- c(edge, "-->")
to <- c(to, nj)
}
}
}
}

data.frame(from = from, edge = edge, to = to, stringsAsFactors = FALSE)
}

#' @inheritParams as_disco
#' @export
as_disco.fciAlgo <- function(
Expand Down Expand Up @@ -184,7 +259,7 @@ as_disco.EssGraph <- function(

#' @title Print a Disco Object
#' @param x A `Disco` object.
#' @inheritParams print.Knowledge
#' @param ... Additional arguments (not used).
#' @returns Invisibly returns the `Disco` object.
#' @examples
#' data(tpc_example)
Expand All @@ -199,43 +274,186 @@ as_disco.EssGraph <- function(
#' cd_tges <- tpc(engine = "causalDisco", test = "fisher_z")
#' disco_cd_tges <- disco(data = tpc_example, method = cd_tges, knowledge = kn)
#' print(disco_cd_tges)
#' print(disco_cd_tges, wide_vars = TRUE)
#' print(disco_cd_tges, compact = TRUE)
#'
#' @exportS3Method print Disco
print.Disco <- function(
x,
compact = FALSE,
wide_vars = FALSE,
...
) {
print.Disco <- function(x, ...) {
.check_if_pkgs_are_installed(
pkgs = c("cli", "tibble"),
pkgs = c("cli"),
function_name = "print.Disco"
)

cli::cli_h1("caugi graph")

# Graph info
graph_class <- x$caugi@graph_class
cg <- x$caugi
graph_class <- x$graph_type
if (is.null(graph_class)) {
graph_class <- cg@graph_class
}
nd <- nodes(cg)
ed <- edges(cg)
n_nodes <- nrow(nd)
n_edges <- nrow(ed)

cli::cli_text("Graph class: {.strong {graph_class}}")
kn <- x$knowledge
n_tiers <- if (!is.null(kn$tiers)) nrow(kn$tiers) else 0L
n_vars <- if (!is.null(kn$vars)) nrow(kn$vars) else 0L
n_required <- sum(kn$edges$status == "required", na.rm = TRUE)
n_forbidden <- sum(kn$edges$status == "forbidden", na.rm = TRUE)
kn_has_content <- n_tiers > 0L || n_required > 0L || n_forbidden > 0L

cg <- x$caugi
if (compact) {
cli::cli_text("{nrow(edges(cg))} edges, {nrow(nodes(cg))} nodes")
kn_parts <- character(0)
if (n_tiers > 0L) {
kn_parts <- c(kn_parts, paste0(n_tiers, " tier", if (n_tiers != 1L) "s"))
}
if (n_required > 0L) {
kn_parts <- c(kn_parts, paste0(n_required, " required"))
}
if (n_forbidden > 0L) {
kn_parts <- c(kn_parts, paste0(n_forbidden, " forbidden"))
}
kn_str <- if (length(kn_parts) > 0L) {
paste0(" | Knowledge: ", paste(kn_parts, collapse = ", "))
} else {
print_section("Edges", edges(cg))
print_section("Nodes", nodes(cg))
""
}

# Knowledge info
print.Knowledge(x$knowledge, compact = compact, wide_vars = wide_vars, ...)
cat(sprintf(
"<Disco %s: %d nodes | %d edges%s>\n",
graph_class,
n_nodes,
n_edges,
kn_str
))

if (kn_has_content) {
cat("Learned graph:\n")
}
.print_item_line("nodes", nd$name)
if (n_edges > 0L) {
.print_item_line("edges", paste0(ed$from, ed$edge, ed$to))
}

if (kn_has_content) {
cat("Knowledge:\n")
.print_knowledge_body(kn)
}

invisible(x)
}

.knowledge_has_content <- function(kn) {
if (is.null(kn)) {
return(FALSE)
}
n_tiers <- if (!is.null(kn$tiers)) nrow(kn$tiers) else 0L
n_required <- sum(kn$edges$status == "required", na.rm = TRUE)
n_forbidden <- sum(kn$edges$status == "forbidden", na.rm = TRUE)
n_tiers > 0L || n_required > 0L || n_forbidden > 0L
}

.disco_graph_type <- function(graph_class, has_knowledge) {
if (is.null(graph_class)) {
return("UNKNOWN")
}
switch(
graph_class,
PDAG = if (has_knowledge) "MPDAG" else "CPDAG",
PAG = "PAG",
`RFCI-PAG` = "RFCI-PAG",
graph_class
)
}

#' @title Verify the Semantic Graph Class of a Learned Graph
#'
#' @description
#' Constraint-based algorithms may output graphs that are not valid CPDAGs/MPDAGs due to statistical errors in
#' finite samples, violations of faithfulness, or latent confounding. This helper checks the claimed
#' semantic class against the actual graph and, when the claim does not hold, emits a message then downgrades
#' the reported class to:
#' `"PDAG"` if the graph is at least a valid PDAG, otherwise to `"UNKNOWN"`.
#'
#' @param cg A [caugi::caugi] object.
#' @param claimed The semantic class proposed by `.disco_graph_type()`.
#' @param has_knowledge Whether background knowledge was supplied.
#'
#' @returns The verified semantic class: `claimed` if valid, otherwise `"PDAG"`
#' or `"UNKNOWN"`.
#' @keywords internal
#' @noRd
.validate_graph_type <- function(cg, claimed, has_knowledge) {
if (!claimed %in% c("CPDAG", "MPDAG")) {
return(claimed)
}

valid <- tryCatch(
if (claimed == "CPDAG") caugi::is_cpdag(cg) else caugi::is_mpdag(cg),
error = function(e) NA
)
if (isTRUE(valid)) {
return(claimed)
}

is_valid_pdag <- tryCatch(caugi::is_pdag(cg), error = function(e) NA)
fallback <- if (isTRUE(is_valid_pdag)) "PDAG" else "UNKNOWN"

reason <- if (has_knowledge) {
"the background knowledge conflicts with the structure learned from the data"
} else {
"of conflicting edge orientations, which can happen due to statistical errors in finite samples, violations of faithfulness, or latent confounding"
}
message(
sprintf(
paste0(
"The learned graph is not a valid %s because %s; it is reported as ",
"%s instead."
),
claimed,
reason,
fallback
)
)
fallback
}

.print_item_line <- function(label, items, max_items = 10L) {
if (!length(items)) {
return(invisible())
}
width <- getOption("width", 80L)
prefix <- paste0(" ", label, ": ")
pad <- strrep(" ", nchar(prefix))

n_omitted <- max(0L, length(items) - max_items)
shown <- if (n_omitted > 0L) items[seq_len(max_items)] else items
if (n_omitted > 0L) {
shown <- c(shown, sprintf("... and %d more", n_omitted))
}

lines <- character(0)
cur <- prefix
first <- TRUE

for (item in shown) {
chunk <- if (first) item else paste0(", ", item)
if (!first && nchar(cur) + nchar(chunk) > width) {
lines <- c(lines, cur)
cur <- paste0(pad, item)
} else {
cur <- paste0(cur, chunk)
first <- FALSE
}
}

lines <- c(lines, cur)
cat(paste(lines, collapse = "\n"), "\n", sep = "")
}

#' @title Summarize a Disco Object
#'
#' @description
#' `r lifecycle::badge("deprecated")`
#'
#' `summary()` for `Disco` objects is deprecated. Use [print()] instead.
#'
#' @param object A `Disco` object.
#' @param ... Additional arguments (not used).
#' @returns Invisibly returns the `Disco` object.
Expand All @@ -251,20 +469,16 @@ print.Disco <- function(
#' )
#' cd_tges <- tpc(engine = "causalDisco", test = "fisher_z")
#' disco_cd_tges <- disco(data = tpc_example, method = cd_tges, knowledge = kn)
#' summary(disco_cd_tges)
#' print(disco_cd_tges)
#'
#' @exportS3Method summary Disco
summary.Disco <- function(object, ...) {
cg <- object$caugi
# Graph info
cli::cli_h1("caugi graph summary")
cli::cli_text("Graph class: {.strong {cg@graph_class}}")
cli::cli_text("Nodes: {.strong {nrow(nodes(cg))}}")
cli::cli_text("Edges: {.strong {nrow(edges(cg))}}")

# Knowledge info
summary.Knowledge(object$knowledge, ...)

lifecycle::deprecate_warn(
when = "1.2.0",
what = "summary.Disco()",
with = "print.Disco()"
)
print(object, ...)
invisible(object)
}

Expand Down
Loading