Fancy causal quartet

Author

Andrew Heiss

Published

September 6, 2024

library(tidyverse)
library(tinytable)
library(ggdag)
library(quartets)
library(MoMAColors)
library(patchwork)
library(ggtext)

clrs <- MoMAColors::moma.colors("ustwo")

node_clrs <- c(
  "Treatment" = clrs[1],
  "Outcome" = clrs[6],
  "Third thing" = clrs[4],
  "Other" = "grey50"
)

update_geom_defaults("richtext", list(family = "Noto Sans"))
update_geom_defaults(ggdag:::GeomDagText, list(family = "Noto Sans Bold"))
plot_collider <- dagify(
  y ~ x,
  z ~ x + y,
  coords = list(
    x = c(x = 1, y = 3, z = 2),
    y = c(x = 1, y = 1, z = 2)
  )
) |>
  tidy_dagitty() |>
  mutate(var_type = case_match(
    name,
    "x" ~ "Treatment",
    "y" ~ "Outcome",
    "z" ~ "Third thing"
  )) |>
  ggplot(aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_dag_edges() +
  geom_dag_point(aes(color = var_type)) +
  geom_dag_text() +
  scale_color_manual(values = node_clrs, guide = "none") +
  scale_y_continuous(expand = expansion(mult = 0.2)) +
  labs(title = "(1) Collider") +
  theme_dag(base_family = "Noto Sans")

plot_confounder <- dagify(
  y ~ x + z,
  x ~ z,
  coords = list(
    x = c(x = 1, y = 3, z = 2),
    y = c(x = 1, y = 1, z = 2)
  )
) |>
  tidy_dagitty() |>
  mutate(var_type = case_match(
    name,
    "x" ~ "Treatment",
    "y" ~ "Outcome",
    "z" ~ "Third thing"
  )) |>
  ggplot(aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_dag_edges() +
  geom_dag_point(aes(color = var_type)) +
  geom_dag_text() +
  scale_color_manual(values = node_clrs, guide = "none") +
  scale_y_continuous(expand = expansion(mult = 0.2)) +
  labs(title = "(2) Confounder") +
  theme_dag(base_family = "Noto Sans") +
  theme(plot.margin = unit(c(0, 0, 0, 2), "lines"))

plot_mediator <- dagify(
  y ~ z,
  z ~ x,
  coords = list(
    x = c(x = 1, y = 3, z = 2),
    y = c(x = 1, y = 1, z = 2)
  )
) |>
  tidy_dagitty() |>
  mutate(var_type = case_match(
    name,
    "x" ~ "Treatment",
    "y" ~ "Outcome",
    "z" ~ "Third thing"
  )) |>
  ggplot(aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_dag_edges() +
  geom_dag_point(aes(color = var_type)) +
  geom_dag_text() +
  scale_color_manual(values = node_clrs, guide = "none") +
  scale_y_continuous(expand = expansion(mult = 0.2)) +
  labs(title = "(3) Mediator") +
  theme_dag(base_family = "Noto Sans")

plot_m <- dagify(
  y ~ x + u2,
  x ~ u1,
  z ~ u1 + u2,
  coords = list(
    x = c(x = 1, y = 3, z = 2, u1 = 1.5, u2 = 2.5),
    y = c(x = 1, y = 1, z = 2, u1 = 2.5, u2 = 2.5)
  )
) |>
  tidy_dagitty() |>
  mutate(var_type = case_match(name,
    "x" ~ "Treatment",
    "y" ~ "Outcome",
    "z" ~ "Third thing",
    .default = "Other"
  )) |>
  ggplot(aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_dag_edges() +
  geom_dag_point(aes(color = var_type)) +
  geom_dag_text() +
  scale_color_manual(values = node_clrs, guide = "none") +
  scale_y_continuous(expand = expansion(mult = 0.2)) +
  labs(title = "(4) M-bias") +
  theme_dag(base_family = "Noto Sans") +
  theme(plot.margin = unit(c(2, 0, 0, 2), "lines"))

(plot_collider | plot_confounder) / (plot_mediator | plot_m)

quartet_labs <- causal_quartet |>
  nest_by(dataset) |>
  mutate(
    naive_cor = round(cor(data$exposure, data$covariate), 2),
    naive_effect = round(coef(lm(outcome ~ exposure, data = data))[2], 2),
    actual_effect = round(coef(lm(outcome ~ exposure + covariate, data = data))[2], 2)
  ) |>
  mutate(label = glue::glue("
  Correlation: **{naive_cor}**<br>
  Naive effect: **{naive_effect}**<br>
  True effect: **{actual_effect}**
  "))

ggplot(causal_quartet, aes(x = exposure, y = outcome)) +
  geom_point() +
  geom_smooth(method = "lm", formula = "y ~ x", color = clrs[2]) +
  ggfx::with_shadow(
    geom_richtext(
      data = quartet_labs, aes(x = -5, y = 3, label = label),
      hjust = 0, size = 3, label.color = NA
    ),
    x_offset = 0, y_offset = 0, sigma = 2, color = "grey95"
  ) +
  labs(x = "Exposure (x)", y = "Outcome (y)") +
  coord_cartesian(xlim = c(-5, 4.5)) +
  facet_wrap(vars(dataset)) +
  theme_minimal(base_size = 12, base_family = "Noto Sans") +
  theme(
    axis.title.x = element_text(hjust = 0),
    axis.title.y = element_text(hjust = 1),
    strip.text = element_text(hjust = 0, size = rel(1)),
    panel.grid.minor = element_blank()
  )