| Title: | Bayesian Tree Ensembles for Survival Analysis and Causal Inference |
|---|---|
| Description: | Bayesian regression tree ensembles for survival analysis and causal inference. Implements BART, DART, Bayesian Causal Forests (BCF), and Horseshoe Forest models. Supports right-censored and interval-censored survival outcomes via accelerated failure time (AFT) formulations. Designed for high-dimensional prediction and heterogeneous treatment effect estimation. |
| Authors: | Tijn Jacobs [aut, cre] (ORCID: <https://orcid.org/0009-0003-6188-9296>) |
| Maintainer: | Tijn Jacobs <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 2.0.2 |
| Built: | 2026-06-03 06:59:32 UTC |
| Source: | https://github.com/tijn-jacobs/shrinkagetrees |
Converts the posterior draws stored in a ShrinkageTrees object into
a mcmc.list for use with the coda package's
convergence diagnostics (Gelman–Rubin , effective sample
size, Geweke test, etc.).
## S3 method for class 'ShrinkageTrees' as.mcmc.list(x, ...)## S3 method for class 'ShrinkageTrees' as.mcmc.list(x, ...)
x |
A fitted |
... |
Currently unused. |
Requires the suggested package coda. For single-chain fits the returned object contains one chain.
A mcmc.list object.
Each chain is an mcmc object whose columns include:
Posterior draws of the residual standard deviation (continuous and survival outcomes only).
summary.ShrinkageTrees which reports R-hat and ESS
automatically when coda is available.
fit <- HorseTrees(y = rnorm(50), X_train = matrix(rnorm(250), 50, 5), N_post = 200, N_burn = 100, n_chains = 2) if (requireNamespace("coda", quietly = TRUE)) { mcmc_obj <- coda::as.mcmc.list(fit) coda::gelman.diag(mcmc_obj) coda::effectiveSize(mcmc_obj) }fit <- HorseTrees(y = rnorm(50), X_train = matrix(rnorm(250), 50, 5), N_post = 200, N_burn = 100, n_chains = 2) if (requireNamespace("coda", quietly = TRUE)) { mcmc_obj <- coda::as.mcmc.list(fit) coda::gelman.diag(mcmc_obj) coda::effectiveSize(mcmc_obj) }
Post-hoc reweights the stored posterior CATE draws of a fitted causal
model to produce credible intervals for the population ATE (PATE)
that incorporate uncertainty in the covariate distribution
.
bayesian_bootstrap_ate(object, alpha = 0.05)bayesian_bootstrap_ate(object, alpha = 0.05)
object |
Either a fitted |
alpha |
One minus the credible level. Default |
At each MCMC iteration the conditional treatment effects
are reweighted with
to give a draw
The collection approximates the
posterior of the PATE, integrating over and
. The equal-weight mixed ATE (MATE),
,
is returned alongside for comparison.
For reproducibility, call set.seed() before invoking the function
to fix the Dirichlet draws.
A list with
Posterior mean, credible
interval (named lower and upper), and full vector of
draws of the Bayesian-bootstrap PATE.
Same quantities for the equal-weight mixed ATE.
Number of observations and posterior draws used.
summary.CausalShrinkageForest,
plot.CausalShrinkageForest,
predict.CausalShrinkageForest
# Small toy causal model (binary outcome, for speed) set.seed(1) n <- 40; p <- 3 X <- matrix(runif(n * p), ncol = p) trt <- rbinom(n, 1, 0.5) y <- X[, 1] + trt * (0.5 + X[, 2]) + rnorm(n) fit <- CausalShrinkageForest( y = y, X_train_control = X, X_train_treat = X, treatment_indicator_train = trt, outcome_type = "continuous", number_of_trees_control = 5, number_of_trees_treat = 5, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = 0.1, global_hp_control = 0.1, local_hp_treat = 0.1, global_hp_treat = 0.1, N_post = 20, N_burn = 10, store_posterior_sample = TRUE, verbose = FALSE ) bb <- bayesian_bootstrap_ate(fit, alpha = 0.05) bb$pate_mean bb$pate_ci# Small toy causal model (binary outcome, for speed) set.seed(1) n <- 40; p <- 3 X <- matrix(runif(n * p), ncol = p) trt <- rbinom(n, 1, 0.5) y <- X[, 1] + trt * (0.5 + X[, 2]) + rnorm(n) fit <- CausalShrinkageForest( y = y, X_train_control = X, X_train_treat = X, treatment_indicator_train = trt, outcome_type = "continuous", number_of_trees_control = 5, number_of_trees_treat = 5, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = 0.1, global_hp_control = 0.1, local_hp_treat = 0.1, global_hp_treat = 0.1, N_post = 20, N_burn = 10, store_posterior_sample = TRUE, verbose = FALSE ) bb <- bayesian_bootstrap_ate(fit, alpha = 0.05) bb$pate_mean bb$pate_ci
This function fits a (Bayesian) Causal Horseshoe Forest. It can be used for estimation of conditional average treatments effects of survival data given high-dimensional covariates. The outcome is decomposed in a prognostic part (control) and a treatment effect part. For both of these, we specify a Horseshoe Trees regression function. Supports continuous, right-censored, and interval-censored outcomes.
CausalHorseForest( y = NULL, status = NULL, X_train_control, X_train_treat, treatment_indicator_train, X_test_control = NULL, X_test_treat = NULL, treatment_indicator_test = NULL, left_time = NULL, right_time = NULL, outcome_type = "continuous", timescale = "time", number_of_trees = 200, k = 0.1, power = 2, base = 0.95, p_grow = 0.4, p_prune = 0.4, nu = 3, q = 0.9, sigma = NULL, N_post = 5000, N_burn = 5000, delayed_proposal = 5, store_posterior_sample = FALSE, treatment_coding = "centered", propensity = NULL, propensity_test = NULL, n_chains = 1, verbose = TRUE )CausalHorseForest( y = NULL, status = NULL, X_train_control, X_train_treat, treatment_indicator_train, X_test_control = NULL, X_test_treat = NULL, treatment_indicator_test = NULL, left_time = NULL, right_time = NULL, outcome_type = "continuous", timescale = "time", number_of_trees = 200, k = 0.1, power = 2, base = 0.95, p_grow = 0.4, p_prune = 0.4, nu = 3, q = 0.9, sigma = NULL, N_post = 5000, N_burn = 5000, delayed_proposal = 5, store_posterior_sample = FALSE, treatment_coding = "centered", propensity = NULL, propensity_test = NULL, n_chains = 1, verbose = TRUE )
y |
Outcome vector. For survival, represents follow-up times (can be on
original or log scale depending on |
status |
Optional event indicator vector (1 = event occurred,
0 = censored). Required when |
X_train_control |
Covariate matrix for the control forest. Rows correspond to samples, columns to covariates. |
X_train_treat |
Covariate matrix for the treatment forest. Rows correspond to samples, columns to covariates. |
treatment_indicator_train |
Vector indicating treatment assignment for training samples (1 = treated, 0 = control). |
X_test_control |
Optional test covariate matrix for control forest. If
|
X_test_treat |
Optional test covariate matrix for treatment forest. If
|
treatment_indicator_test |
Optional vector indicating treatment assignment for test samples. |
left_time |
Optional numeric vector of left (lower) time boundaries.
Required when |
right_time |
Optional numeric vector of right (upper) time boundaries.
Required when |
outcome_type |
Type of outcome: one of |
timescale |
For survival outcomes: either |
number_of_trees |
Number of trees in each forest. Default is 200. |
k |
Horseshoe prior scale hyperparameter. Default is 0.1. Controls global-local shrinkage on step heights. |
power |
Power parameter for tree structure prior. Default is 2.0. |
base |
Base parameter for tree structure prior. Default is 0.95. |
p_grow |
Probability of proposing a grow move. Default is 0.4. |
p_prune |
Probability of proposing a prune move. Default is 0.4. |
nu |
Degrees of freedom for the error variance prior. Default is 3. |
q |
Quantile parameter for error variance prior. Default is 0.90. |
sigma |
Optional known standard deviation of the outcome. If
|
N_post |
Number of posterior samples to store. Default is 5000. |
N_burn |
Number of burn-in iterations. Default is 5000. |
delayed_proposal |
Number of delayed iterations before proposal updates. Default is 5. |
store_posterior_sample |
Logical; whether to store posterior samples of
predictions. Default is |
treatment_coding |
Treatment coding scheme for the two-forest model.
One of |
propensity |
Optional numeric vector of propensity scores
|
propensity_test |
Optional numeric vector of propensity scores for
test observations. Only used when |
n_chains |
Number of independent MCMC chains to run. Default is
|
verbose |
Logical; whether to print verbose output during sampling.
Default is |
The model separately regularizes the control and treatment trees using
Horseshoe priors with global-local shrinkage on the step heights.
This approach is designed for robust estimation of heterogeneous treatment
effects in high-dimensional settings.
It supports continuous, right-censored, and interval-censored survival
outcomes. For interval-censored data, provide left_time and
right_time instead of y and status; the event
indicators are derived internally following the
survival::Surv(type = "interval2") convention.
An S3 object of class "CausalShrinkageForest" containing:
Posterior mean predictions on training data (combined forest).
Posterior mean predictions on test data (combined forest).
Estimated control outcomes on training data.
Estimated control outcomes on test data.
Estimated treatment effects on training data.
Estimated treatment effects on test data.
Vector of posterior samples for the error standard deviation.
Average acceptance ratio in control forest.
Average acceptance ratio in treatment forest.
Matrix of posterior samples for
control predictions (if store_posterior_sample = TRUE).
Matrix of posterior samples for
control predictions (if store_posterior_sample = TRUE).
Matrix of posterior samples for
treatment effects (if store_posterior_sample = TRUE).
Matrix of posterior samples for
treatment effects (if store_posterior_sample = TRUE).
Model family: HorseTrees (non-causal, horseshoe prior),
ShrinkageTrees (non-causal, flexible prior),
CausalShrinkageForest (causal, flexible prior).
Survival wrappers: SurvivalBCF, SurvivalShrinkageBCF.
S3 methods: print.CausalShrinkageForest,
summary.CausalShrinkageForest,
predict.CausalShrinkageForest,
plot.CausalShrinkageForest.
# Example: Continuous outcome and homogeneous treatment effect n <- 50 p <- 3 X_control <- matrix(runif(n * p), ncol = p) X_treat <- matrix(runif(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) tau <- 2 y <- X_control[, 1] + (0.5 - treatment) * tau + rnorm(n) fit <- CausalHorseForest( y = y, X_train_control = X_control, X_train_treat = X_treat, treatment_indicator_train = treatment, outcome_type = "continuous", number_of_trees = 5, N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE ) ## Example: Right-censored survival outcome # Set data dimensions n <- 100 p <- 1000 # Generate covariates X <- matrix(runif(n * p), ncol = p) X_treat <- X treatment <- rbinom(n, 1, pnorm(X[, 1] - 1/2)) # Generate true survival times depending on X and treatment linpred <- X[, 1] - X[, 2] + (treatment - 0.5) * (1 + X[, 2] / 2 + X[, 3] / 3 + X[, 4] / 4) true_time <- linpred + rnorm(n, 0, 0.5) # Generate censoring times censor_time <- log(rexp(n, rate = 1 / 5)) # Observed times and event indicator time_obs <- pmin(true_time, censor_time) status <- as.numeric(true_time == time_obs) # Estimate propensity score using HorseTrees fit_prop <- HorseTrees( y = treatment, X_train = X, outcome_type = "binary", number_of_trees = 200, N_post = 1000, N_burn = 1000 ) # Retrieve estimated probability of treatment (propensity score) propensity <- fit_prop$train_probabilities # Combine propensity score with covariates for control forest X_control <- cbind(propensity, X) # Fit the Causal Horseshoe Forest for survival outcome fit_surv <- CausalHorseForest( y = time_obs, status = status, X_train_control = X_control, X_train_treat = X_treat, treatment_indicator_train = treatment, outcome_type = "right-censored", timescale = "log", number_of_trees = 200, k = 0.1, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE ) ## Evaluate and summarize results # Evaluate C-index if survival package is available if (requireNamespace("survival", quietly = TRUE)) { predicted_survtime <- fit_surv$train_predictions cindex_result <- survival::concordance(survival::Surv(time_obs, status) ~ predicted_survtime) c_index <- cindex_result$concordance cat("C-index:", round(c_index, 3), "\n") } else { cat("Package 'survival' not available. Skipping C-index computation.\n") } # Compute posterior ATE samples ate_samples <- rowMeans(fit_surv$train_predictions_sample_treat) mean_ate <- mean(ate_samples) ci_95 <- quantile(ate_samples, probs = c(0.025, 0.975)) cat("Posterior mean ATE:", round(mean_ate, 3), "\n") cat("95% credible interval: [", round(ci_95[1], 3), ", ", round(ci_95[2], 3), "]\n", sep = "") # Plot histogram of ATE samples hist( ate_samples, breaks = 30, col = "steelblue", freq = FALSE, border = "white", xlab = "Average Treatment Effect (ATE)", main = "Posterior distribution of ATE" ) abline(v = mean_ate, col = "orange3", lwd = 2) abline(v = ci_95, col = "orange3", lty = 2, lwd = 2) abline(v = 1.541667, col = "darkred", lwd = 2) legend( "topright", legend = c("Mean", "95% CI", "Truth"), col = c("orange3", "orange3", "red"), lty = c(1, 2, 1), lwd = 2 ) ## Plot individual CATE estimates # Summarize posterior distribution per patient posterior_matrix <- fit_surv$train_predictions_sample_treat posterior_mean <- colMeans(posterior_matrix) posterior_ci <- apply(posterior_matrix, 2, quantile, probs = c(0.025, 0.975)) df_cate <- data.frame( mean = posterior_mean, lower = posterior_ci[1, ], upper = posterior_ci[2, ] ) # Sort patients by posterior mean CATE df_cate_sorted <- df_cate[order(df_cate$mean), ] n_patients <- nrow(df_cate_sorted) # Create the plot plot( x = df_cate_sorted$mean, y = 1:n_patients, type = "n", xlab = "CATE per patient (95% credible interval)", ylab = "Patient index (sorted)", main = "Posterior CATE estimates", xlim = range(df_cate_sorted$lower, df_cate_sorted$upper) ) # Add CATE intervals segments( x0 = df_cate_sorted$lower, x1 = df_cate_sorted$upper, y0 = 1:n_patients, y1 = 1:n_patients, col = "steelblue" ) # Add mean points points(df_cate_sorted$mean, 1:n_patients, pch = 16, col = "orange3", lwd = 0.1) # Add reference line at 0 abline(v = 0, col = "black", lwd = 2)# Example: Continuous outcome and homogeneous treatment effect n <- 50 p <- 3 X_control <- matrix(runif(n * p), ncol = p) X_treat <- matrix(runif(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) tau <- 2 y <- X_control[, 1] + (0.5 - treatment) * tau + rnorm(n) fit <- CausalHorseForest( y = y, X_train_control = X_control, X_train_treat = X_treat, treatment_indicator_train = treatment, outcome_type = "continuous", number_of_trees = 5, N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE ) ## Example: Right-censored survival outcome # Set data dimensions n <- 100 p <- 1000 # Generate covariates X <- matrix(runif(n * p), ncol = p) X_treat <- X treatment <- rbinom(n, 1, pnorm(X[, 1] - 1/2)) # Generate true survival times depending on X and treatment linpred <- X[, 1] - X[, 2] + (treatment - 0.5) * (1 + X[, 2] / 2 + X[, 3] / 3 + X[, 4] / 4) true_time <- linpred + rnorm(n, 0, 0.5) # Generate censoring times censor_time <- log(rexp(n, rate = 1 / 5)) # Observed times and event indicator time_obs <- pmin(true_time, censor_time) status <- as.numeric(true_time == time_obs) # Estimate propensity score using HorseTrees fit_prop <- HorseTrees( y = treatment, X_train = X, outcome_type = "binary", number_of_trees = 200, N_post = 1000, N_burn = 1000 ) # Retrieve estimated probability of treatment (propensity score) propensity <- fit_prop$train_probabilities # Combine propensity score with covariates for control forest X_control <- cbind(propensity, X) # Fit the Causal Horseshoe Forest for survival outcome fit_surv <- CausalHorseForest( y = time_obs, status = status, X_train_control = X_control, X_train_treat = X_treat, treatment_indicator_train = treatment, outcome_type = "right-censored", timescale = "log", number_of_trees = 200, k = 0.1, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE ) ## Evaluate and summarize results # Evaluate C-index if survival package is available if (requireNamespace("survival", quietly = TRUE)) { predicted_survtime <- fit_surv$train_predictions cindex_result <- survival::concordance(survival::Surv(time_obs, status) ~ predicted_survtime) c_index <- cindex_result$concordance cat("C-index:", round(c_index, 3), "\n") } else { cat("Package 'survival' not available. Skipping C-index computation.\n") } # Compute posterior ATE samples ate_samples <- rowMeans(fit_surv$train_predictions_sample_treat) mean_ate <- mean(ate_samples) ci_95 <- quantile(ate_samples, probs = c(0.025, 0.975)) cat("Posterior mean ATE:", round(mean_ate, 3), "\n") cat("95% credible interval: [", round(ci_95[1], 3), ", ", round(ci_95[2], 3), "]\n", sep = "") # Plot histogram of ATE samples hist( ate_samples, breaks = 30, col = "steelblue", freq = FALSE, border = "white", xlab = "Average Treatment Effect (ATE)", main = "Posterior distribution of ATE" ) abline(v = mean_ate, col = "orange3", lwd = 2) abline(v = ci_95, col = "orange3", lty = 2, lwd = 2) abline(v = 1.541667, col = "darkred", lwd = 2) legend( "topright", legend = c("Mean", "95% CI", "Truth"), col = c("orange3", "orange3", "red"), lty = c(1, 2, 1), lwd = 2 ) ## Plot individual CATE estimates # Summarize posterior distribution per patient posterior_matrix <- fit_surv$train_predictions_sample_treat posterior_mean <- colMeans(posterior_matrix) posterior_ci <- apply(posterior_matrix, 2, quantile, probs = c(0.025, 0.975)) df_cate <- data.frame( mean = posterior_mean, lower = posterior_ci[1, ], upper = posterior_ci[2, ] ) # Sort patients by posterior mean CATE df_cate_sorted <- df_cate[order(df_cate$mean), ] n_patients <- nrow(df_cate_sorted) # Create the plot plot( x = df_cate_sorted$mean, y = 1:n_patients, type = "n", xlab = "CATE per patient (95% credible interval)", ylab = "Patient index (sorted)", main = "Posterior CATE estimates", xlim = range(df_cate_sorted$lower, df_cate_sorted$upper) ) # Add CATE intervals segments( x0 = df_cate_sorted$lower, x1 = df_cate_sorted$upper, y0 = 1:n_patients, y1 = 1:n_patients, col = "steelblue" ) # Add mean points points(df_cate_sorted$mean, 1:n_patients, pch = 16, col = "orange3", lwd = 0.1) # Add reference line at 0 abline(v = 0, col = "black", lwd = 2)
Fits a (Bayesian) Causal Shrinkage Forest model for estimating heterogeneous treatment effects.
This function generalizes CausalHorseForest by allowing flexible
global-local shrinkage priors on the step heights in both the control and treatment forests.
It supports continuous, right-censored, and interval-censored survival outcomes.
CausalShrinkageForest( y = NULL, status = NULL, X_train_control, X_train_treat, treatment_indicator_train, X_test_control = NULL, X_test_treat = NULL, treatment_indicator_test = NULL, left_time = NULL, right_time = NULL, outcome_type = "continuous", timescale = "time", number_of_trees_control = 200, number_of_trees_treat = 200, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = NULL, local_hp_treat = NULL, global_hp_control = NULL, global_hp_treat = NULL, a_dirichlet_control = 0.5, a_dirichlet_treat = 0.5, b_dirichlet_control = 1, b_dirichlet_treat = 1, rho_dirichlet_control = NULL, rho_dirichlet_treat = NULL, power_control = 2, power_treat = 2, base_control = 0.95, base_treat = 0.95, p_grow = 0.5, p_prune = 0.5, nu = 3, q = 0.9, sigma = NULL, N_post = 5000, N_burn = 5000, delayed_proposal = 5, store_posterior_sample = FALSE, treatment_coding = "centered", propensity = NULL, propensity_test = NULL, n_chains = 1, verbose = TRUE )CausalShrinkageForest( y = NULL, status = NULL, X_train_control, X_train_treat, treatment_indicator_train, X_test_control = NULL, X_test_treat = NULL, treatment_indicator_test = NULL, left_time = NULL, right_time = NULL, outcome_type = "continuous", timescale = "time", number_of_trees_control = 200, number_of_trees_treat = 200, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = NULL, local_hp_treat = NULL, global_hp_control = NULL, global_hp_treat = NULL, a_dirichlet_control = 0.5, a_dirichlet_treat = 0.5, b_dirichlet_control = 1, b_dirichlet_treat = 1, rho_dirichlet_control = NULL, rho_dirichlet_treat = NULL, power_control = 2, power_treat = 2, base_control = 0.95, base_treat = 0.95, p_grow = 0.5, p_prune = 0.5, nu = 3, q = 0.9, sigma = NULL, N_post = 5000, N_burn = 5000, delayed_proposal = 5, store_posterior_sample = FALSE, treatment_coding = "centered", propensity = NULL, propensity_test = NULL, n_chains = 1, verbose = TRUE )
y |
Outcome vector. Numeric. Represents continuous outcomes or follow-up times.
Set to |
status |
Optional event indicator vector (1 = event occurred, 0 = censored).
Required when |
X_train_control |
Covariate matrix for the control forest. Rows correspond to samples, columns to covariates. |
X_train_treat |
Covariate matrix for the treatment forest. |
treatment_indicator_train |
Vector indicating treatment assignment for training samples (1 = treated, 0 = control). |
X_test_control |
Optional covariate matrix for control forest test data. Defaults to
column means of |
X_test_treat |
Optional covariate matrix for treatment forest test data. Defaults to
column means of |
treatment_indicator_test |
Optional vector indicating treatment assignment for test data. |
left_time |
Optional numeric vector of left (lower) time boundaries.
Required when |
right_time |
Optional numeric vector of right (upper) time boundaries.
Required when |
outcome_type |
Type of outcome: one of |
timescale |
For survival outcomes: either |
number_of_trees_control |
Number of trees in the control forest. Default is 200. |
number_of_trees_treat |
Number of trees in the treatment forest. Default is 200. |
prior_type_control |
Type of prior on control forest step heights. One of
|
prior_type_treat |
Type of prior on treatment forest step heights. Same options as
|
local_hp_control |
Local hyperparameter controlling shrinkage on individual steps (control forest). Required for all prior types. |
local_hp_treat |
Local hyperparameter for treatment forest. |
global_hp_control |
Global hyperparameter for control forest. Required for horseshoe-type
priors; ignored for |
global_hp_treat |
Global hyperparameter for treatment forest. |
a_dirichlet_control |
First shape parameter of the Beta prior used in the
Dirichlet–Sparse splitting rule for the control forest. Together with
|
a_dirichlet_treat |
First shape parameter of the Beta prior used in the Dirichlet–Sparse splitting rule for the treatment forest. |
b_dirichlet_control |
Second shape parameter of the Beta prior for the sparsity level in the control forest. Larger values shrink splitting probabilities more strongly toward uniform sparsity. |
b_dirichlet_treat |
Second shape parameter of the Beta prior governing sparsity in the treatment forest. |
rho_dirichlet_control |
Sparsity hyperparameter for the control forest. Represents the expected number of active predictors. If left NULL, it defaults to the number of covariates in the control forest. |
rho_dirichlet_treat |
Sparsity hyperparameter for the treatment forest, interpreted as the expected number of active predictors. Defaults to the number of covariates in the treatment forest if not specified. |
power_control |
Power parameter for the control forest tree structure prior splitting probability. |
power_treat |
Power parameter for the treatment forest tree structure prior splitting probability. |
base_control |
Base parameter for the control forest tree structure prior splitting probability. |
base_treat |
Base parameter for the treatment forest tree structure prior splitting probability. |
p_grow |
Probability of proposing a grow move. Default is 0.5. These are fixed at 0.5 for prior_type
|
p_prune |
Probability of proposing a prune move. Default is 0.5. These are fixed at 0.5 for prior_type
|
nu |
Degrees of freedom for the error variance prior. Default is 3. |
q |
Quantile parameter for error variance prior. Default is 0.90. |
sigma |
Optional known standard deviation of the outcome. If NULL, estimated from data. |
N_post |
Number of posterior samples to store. Default is 5000. |
N_burn |
Number of burn-in iterations. Default is 5000. |
delayed_proposal |
Number of delayed iterations before proposal updates. Default is 5. |
store_posterior_sample |
Logical; whether to store posterior samples of predictions.
Default is |
treatment_coding |
Treatment coding scheme for the two-forest model.
One of |
propensity |
Optional numeric vector of propensity scores
|
propensity_test |
Optional numeric vector of propensity scores for
test observations. Only used when |
n_chains |
Number of independent MCMC chains to run. Default is
|
verbose |
Logical; whether to print verbose output. Default is |
This function is a flexible generalization of CausalHorseForest.
The Causal Shrinkage Forest model decomposes the outcome into a prognostic
(control) and a treatment effect part. Each part is modeled by its own
shrinkage tree ensemble, with separate flexible global-local shrinkage
priors. It is particularly useful for estimating heterogeneous treatment
effects in high-dimensional settings. Further
methodological details on the Horseshoe Forest framework can be found in
Jacobs, van Wieringen & van der Pas (2025).
The horseshoe prior is the fully Bayesian global-local shrinkage
prior, where both the global and local shrinkage parameters are assigned
half-Cauchy distributions with scale hyperparameters global_hp and
local_hp, respectively. The global shrinkage parameter is defined
separately for each tree, allowing adaptive regularization per tree.
The horseshoe_fw prior (forest-wide horseshoe) is similar to
horseshoe, except that the global shrinkage parameter is shared
across all trees in the forest simultaneously.
The half-cauchy prior considers only local shrinkage and does not
include a global shrinkage component. It places a half-Cauchy prior on each
local shrinkage parameter with scale hyperparameter local_hp.
The dirichlet prior implements the Dirichlet–Sparse splitting rule of
Linero (2018), in which splitting probabilities follow a Dirichlet prior
whose concentration is controlled by a Beta sparsity parameter
(a_dirichlet, b_dirichlet) and an expected sparsity level
rho_dirichlet.
An S3 object of class "CausalShrinkageForest" containing:
Posterior mean predictions on training data (combined forest).
Posterior mean predictions on test data (combined forest).
Estimated control outcomes on training data.
Estimated control outcomes on test data.
Estimated treatment effects on training data.
Estimated treatment effects on test data.
Vector of posterior samples for the error standard deviation.
Average acceptance ratio in control forest.
Average acceptance ratio in treatment forest.
Matrix of posterior samples for control predictions
(if store_posterior_sample = TRUE).
Matrix of posterior samples for control predictions
(if store_posterior_sample = TRUE).
Matrix of posterior samples for treatment effects
(if store_posterior_sample = TRUE).
Matrix of posterior samples for treatment effects
(if store_posterior_sample = TRUE).
Jacobs, T., van Wieringen, W. N., & van der Pas, S. L. (2025). Horseshoe Forests for High-Dimensional Causal Survival Analysis. arXiv:2507.22004. https://doi.org/10.48550/arXiv.2507.22004
Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). BART: Bayesian additive regression trees. Annals of Applied Statistics.
Linero, A. R. (2018). Bayesian regression trees for high-dimensional prediction and variable selection. Journal of the American Statistical Association.
Model family: CausalHorseForest (causal, horseshoe prior),
ShrinkageTrees (non-causal, flexible prior),
HorseTrees (non-causal, horseshoe prior).
Survival wrappers: SurvivalBCF, SurvivalShrinkageBCF.
S3 methods: print.CausalShrinkageForest,
summary.CausalShrinkageForest,
predict.CausalShrinkageForest,
plot.CausalShrinkageForest.
# Example: Continuous outcome, homogeneous treatment effect, two priors n <- 50 p <- 3 X <- matrix(runif(n * p), ncol = p) X_treat <- X_control <- X treat <- rbinom(n, 1, X[,1]) tau <- 2 y <- X[, 1] + (0.5 - treat) * tau + rnorm(n) # Fit a standard Causal Horseshoe Forest fit_horseshoe <- CausalShrinkageForest(y = y, X_train_control = X_control, X_train_treat = X_treat, treatment_indicator_train = treat, outcome_type = "continuous", number_of_trees_treat = 5, number_of_trees_control = 5, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = 0.1/sqrt(5), local_hp_treat = 0.1/sqrt(5), global_hp_control = 0.1/sqrt(5), global_hp_treat = 0.1/sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE ) # Fit a Causal Shrinkage Forest with half-cauchy prior fit_halfcauchy <- CausalShrinkageForest(y = y, X_train_control = X_control, X_train_treat = X_treat, treatment_indicator_train = treat, outcome_type = "continuous", number_of_trees_treat = 5, number_of_trees_control = 5, prior_type_control = "half-cauchy", prior_type_treat = "half-cauchy", local_hp_control = 1/sqrt(5), local_hp_treat = 1/sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE ) # Posterior mean CATEs CATE_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample_treat) CATE_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample_treat) # Posteriors of the ATE post_ATE_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample_treat) post_ATE_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample_treat) # Posterior mean ATE ATE_horseshoe <- mean(post_ATE_horseshoe) ATE_halfcauchy <- mean(post_ATE_halfcauchy) # Example: Interval-censored causal survival outcome n <- 50; p <- 3 X_ic <- matrix(rnorm(n * p), ncol = p) treat_ic <- rbinom(n, 1, 0.5) true_t <- rexp(n, rate = exp(-X_ic[, 1] - 0.5 * treat_ic)) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 15) left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 10); right_t[rc] <- Inf fit_ic <- CausalShrinkageForest( left_time = left_t, right_time = right_t, X_train_control = X_ic, X_train_treat = X_ic, treatment_indicator_train = treat_ic, outcome_type = "interval-censored", number_of_trees_control = 5, number_of_trees_treat = 5, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = 0.1/sqrt(5), local_hp_treat = 0.1/sqrt(5), global_hp_control = 0.1/sqrt(5), global_hp_treat = 0.1/sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE)# Example: Continuous outcome, homogeneous treatment effect, two priors n <- 50 p <- 3 X <- matrix(runif(n * p), ncol = p) X_treat <- X_control <- X treat <- rbinom(n, 1, X[,1]) tau <- 2 y <- X[, 1] + (0.5 - treat) * tau + rnorm(n) # Fit a standard Causal Horseshoe Forest fit_horseshoe <- CausalShrinkageForest(y = y, X_train_control = X_control, X_train_treat = X_treat, treatment_indicator_train = treat, outcome_type = "continuous", number_of_trees_treat = 5, number_of_trees_control = 5, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = 0.1/sqrt(5), local_hp_treat = 0.1/sqrt(5), global_hp_control = 0.1/sqrt(5), global_hp_treat = 0.1/sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE ) # Fit a Causal Shrinkage Forest with half-cauchy prior fit_halfcauchy <- CausalShrinkageForest(y = y, X_train_control = X_control, X_train_treat = X_treat, treatment_indicator_train = treat, outcome_type = "continuous", number_of_trees_treat = 5, number_of_trees_control = 5, prior_type_control = "half-cauchy", prior_type_treat = "half-cauchy", local_hp_control = 1/sqrt(5), local_hp_treat = 1/sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE ) # Posterior mean CATEs CATE_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample_treat) CATE_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample_treat) # Posteriors of the ATE post_ATE_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample_treat) post_ATE_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample_treat) # Posterior mean ATE ATE_horseshoe <- mean(post_ATE_horseshoe) ATE_halfcauchy <- mean(post_ATE_halfcauchy) # Example: Interval-censored causal survival outcome n <- 50; p <- 3 X_ic <- matrix(rnorm(n * p), ncol = p) treat_ic <- rbinom(n, 1, 0.5) true_t <- rexp(n, rate = exp(-X_ic[, 1] - 0.5 * treat_ic)) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 15) left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 10); right_t[rc] <- Inf fit_ic <- CausalShrinkageForest( left_time = left_t, right_time = right_t, X_train_control = X_ic, X_train_treat = X_ic, treatment_indicator_train = treat_ic, outcome_type = "interval-censored", number_of_trees_control = 5, number_of_trees_treat = 5, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = 0.1/sqrt(5), local_hp_treat = 0.1/sqrt(5), global_hp_control = 0.1/sqrt(5), global_hp_treat = 0.1/sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE)
Fits a Bayesian Horseshoe Trees model with a single learner.
Implements regularization on the step heights using a global-local Horseshoe
prior, controlled via the parameter k. Supports continuous, binary,
right-censored, and interval-censored (survival) outcomes.
HorseTrees( y = NULL, status = NULL, X_train, X_test = NULL, left_time = NULL, right_time = NULL, outcome_type = "continuous", timescale = "time", number_of_trees = 200, k = 0.1, power = 2, base = 0.95, p_grow = 0.4, p_prune = 0.4, nu = 3, q = 0.9, sigma = NULL, N_post = 1000, N_burn = 1000, delayed_proposal = 5, store_posterior_sample = TRUE, n_chains = 1, verbose = TRUE )HorseTrees( y = NULL, status = NULL, X_train, X_test = NULL, left_time = NULL, right_time = NULL, outcome_type = "continuous", timescale = "time", number_of_trees = 200, k = 0.1, power = 2, base = 0.95, p_grow = 0.4, p_prune = 0.4, nu = 3, q = 0.9, sigma = NULL, N_post = 1000, N_burn = 1000, delayed_proposal = 5, store_posterior_sample = TRUE, n_chains = 1, verbose = TRUE )
y |
Outcome vector. Numeric. Can represent continuous outcomes, binary
outcomes (0/1), or follow-up times for survival data. Set to |
status |
Optional censoring indicator vector (1 = event occurred,
0 = censored). Required if |
X_train |
Covariate matrix for training. Each row corresponds to an observation, and each column to a covariate. |
X_test |
Optional covariate matrix for test data. If NULL, defaults to the mean of the training covariates. |
left_time |
Optional numeric vector of left (lower) time boundaries.
Required when |
right_time |
Optional numeric vector of right (upper) time boundaries.
Required when |
outcome_type |
Type of outcome. One of |
timescale |
Indicates the scale of follow-up times. Options are
|
number_of_trees |
Number of trees in the ensemble. Default is 200. |
k |
Horseshoe scale hyperparameter (default 0.1). This parameter
controls the overall level of shrinkage by setting the scale for both
global and local shrinkage components. The local and global hyperparameters
are parameterized as
|
power |
Power parameter for tree structure prior. Default is 2.0. |
base |
Base parameter for tree structure prior. Default is 0.95. |
p_grow |
Probability of proposing a grow move. Default is 0.4. |
p_prune |
Probability of proposing a prune move. Default is 0.4. |
nu |
Degrees of freedom for the error distribution prior. Default is 3. |
q |
Quantile hyperparameter for the error variance prior. Default is 0.90. |
sigma |
Optional known value for error standard deviation. If NULL, estimated from data. |
N_post |
Number of posterior samples to store. Default is 1000. |
N_burn |
Number of burn-in iterations. Default is 1000. |
delayed_proposal |
Number of delayed iterations before proposal. Only for reversible updates. Default is 5. |
store_posterior_sample |
Logical; whether to store posterior samples for each iteration. Default is TRUE. |
n_chains |
Number of independent MCMC chains to run. Default is
|
verbose |
Logical; whether to print verbose output. Default is TRUE. |
For continuous outcomes, the model centers and optionally standardizes the
outcome using a prior guess of the standard deviation.
For binary outcomes, the function uses a probit link formulation.
For right-censored outcomes (survival data), the function can handle
follow-up times either on the original time scale or log-transformed.
For interval-censored outcomes, provide left_time and
right_time instead of y and status; the event
indicators are derived internally following the
survival::Surv(type = "interval2") convention.
Generalized implementation with multiple prior possibilities is given by
ShrinkageTrees.
An S3 object of class "ShrinkageTrees" with the following elements:
Vector of posterior mean predictions on the training data.
Vector of posterior mean predictions on the test
data (or on mean covariate vector if X_test not provided).
Vector of posterior samples of the error variance.
Average acceptance ratio across trees during sampling.
Matrix of posterior samples of training
predictions (iterations in rows, observations in columns). Present only
if store_posterior_sample = TRUE.
Matrix of posterior samples of test
predictions. Present only if store_posterior_sample = TRUE.
Vector of posterior mean probabilities on the
training data (only for outcome_type = "binary").
Vector of posterior mean probabilities on the
test data (only for outcome_type = "binary").
Matrix of posterior samples of training
probabilities (only for outcome_type = "binary" and if
store_posterior_sample = TRUE).
Matrix of posterior samples of test
probabilities (only for outcome_type = "binary" and if
store_posterior_sample = TRUE).
Model family: ShrinkageTrees (flexible prior choice),
CausalHorseForest (causal inference),
CausalShrinkageForest (causal, flexible prior).
Survival wrappers: SurvivalBART, SurvivalDART.
S3 methods: print.ShrinkageTrees,
summary.ShrinkageTrees,
predict.ShrinkageTrees,
plot.ShrinkageTrees.
# Minimal example: continuous outcome n <- 25 p <- 5 X <- matrix(rnorm(n * p), ncol = p) y <- X[, 1] + rnorm(n) fit1 <- HorseTrees(y = y, X_train = X, outcome_type = "continuous", number_of_trees = 5, N_post = 75, N_burn = 25, verbose = FALSE) # Minimal example: binary outcome X <- matrix(rnorm(n * p), ncol = p) y <- ifelse(X[, 1] + rnorm(n) > 0, 1, 0) fit2 <- HorseTrees(y = y, X_train = X, outcome_type = "binary", number_of_trees = 5, N_post = 75, N_burn = 25, verbose = FALSE) # Minimal example: right-censored outcome X <- matrix(rnorm(n * p), ncol = p) time <- rexp(n, rate = 0.1) status <- rbinom(n, 1, 0.7) fit3 <- HorseTrees(y = time, status = status, X_train = X, outcome_type = "right-censored", number_of_trees = 5, N_post = 75, N_burn = 25, verbose = FALSE) # Minimal example: interval-censored outcome X <- matrix(rnorm(n * p), ncol = p) true_t <- rexp(n, rate = 0.1) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) # Mark some as exact, some as right-censored exact <- sample(n, 8); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit4 <- HorseTrees(left_time = left_t, right_time = right_t, X_train = X, outcome_type = "interval-censored", number_of_trees = 5, N_post = 75, N_burn = 25, verbose = FALSE) # Larger continuous example (not run automatically) n <- 100 p <- 100 X <- matrix(rnorm(100 * p), ncol = p) X_test <- matrix(rnorm(50 * p), ncol = p) y <- X[, 1] + X[, 2] - X[, 3] + rnorm(100, sd = 0.5) fit5 <- HorseTrees(y = y, X_train = X, X_test = X_test, outcome_type = "continuous", number_of_trees = 200, N_post = 2500, N_burn = 2500, store_posterior_sample = TRUE, verbose = TRUE) plot(fit4$sigma, type = "l", ylab = expression(sigma), xlab = "Iteration", main = "Sigma traceplot") hist(fit4$train_predictions_sample[, 1], main = "Posterior distribution of prediction outcome individual 1", xlab = "Prediction", breaks = 20)# Minimal example: continuous outcome n <- 25 p <- 5 X <- matrix(rnorm(n * p), ncol = p) y <- X[, 1] + rnorm(n) fit1 <- HorseTrees(y = y, X_train = X, outcome_type = "continuous", number_of_trees = 5, N_post = 75, N_burn = 25, verbose = FALSE) # Minimal example: binary outcome X <- matrix(rnorm(n * p), ncol = p) y <- ifelse(X[, 1] + rnorm(n) > 0, 1, 0) fit2 <- HorseTrees(y = y, X_train = X, outcome_type = "binary", number_of_trees = 5, N_post = 75, N_burn = 25, verbose = FALSE) # Minimal example: right-censored outcome X <- matrix(rnorm(n * p), ncol = p) time <- rexp(n, rate = 0.1) status <- rbinom(n, 1, 0.7) fit3 <- HorseTrees(y = time, status = status, X_train = X, outcome_type = "right-censored", number_of_trees = 5, N_post = 75, N_burn = 25, verbose = FALSE) # Minimal example: interval-censored outcome X <- matrix(rnorm(n * p), ncol = p) true_t <- rexp(n, rate = 0.1) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) # Mark some as exact, some as right-censored exact <- sample(n, 8); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit4 <- HorseTrees(left_time = left_t, right_time = right_t, X_train = X, outcome_type = "interval-censored", number_of_trees = 5, N_post = 75, N_burn = 25, verbose = FALSE) # Larger continuous example (not run automatically) n <- 100 p <- 100 X <- matrix(rnorm(100 * p), ncol = p) X_test <- matrix(rnorm(50 * p), ncol = p) y <- X[, 1] + X[, 2] - X[, 3] + rnorm(100, sd = 0.5) fit5 <- HorseTrees(y = y, X_train = X, X_test = X_test, outcome_type = "continuous", number_of_trees = 200, N_post = 2500, N_burn = 2500, store_posterior_sample = TRUE, verbose = TRUE) plot(fit4$sigma, type = "l", ylab = expression(sigma), xlab = "Iteration", main = "Sigma traceplot") hist(fit4$train_predictions_sample[, 1], main = "Posterior distribution of prediction outcome individual 1", xlab = "Prediction", breaks = 20)
Gene expression and clinical covariates for ovarian cancer patients from
The Cancer Genome Atlas (TCGA-OV), combined with semi-synthetic
survival outcomes and treatment assignment. Real covariates (age,
FIGO stage, tumor grade, gene expression) are retained; survival
times, event indicator, and treatment assignment are simulated from a
known data-generating process so that the true treatment effect is
available for validation (see ovarian_truth).
ovarianovarian
A data frame with 357 rows (patients) and 1007 columns:
OS_time: Numeric. Observed survival time in days (simulated).
OS_event: Integer. Event indicator (simulated). 1 = event observed, 0 = right-censored.
treatment: Integer. Simulated treatment assignment.
1 = carboplatin, 0 = cisplatin. Driven primarily by
year_of_diagnosis as an instrumental variable
(cisplatin era pre-~2000, carboplatin after).
age: Integer. Age at initial pathologic diagnosis in years.
figo_stage: Integer. FIGO stage coded as 2 = Stage II, 3 = Stage III, 4 = Stage IV.
tumor_grade: Integer. Histologic tumor grade coded as 2 = G2, 3 = G3, 4 = G4. Rows with GX (unknown grade) were excluded.
year_of_diagnosis: Integer. Year of initial pathologic diagnosis (approx. 1992–2013). Used as an instrumental variable for treatment assignment in the DGP.
right_time, left_time: Numeric. Interval-censoring bounds
derived from the simulated survival times, suitable for passing
to the package's interval-censored survival interface
(right_time = Inf for right-censored observations,
left_time == right_time for exact events).
year_of_diagnosis.1: Integer. Duplicate of
year_of_diagnosis left in place from the data-assembly
join; retained for reproducibility and may be ignored.
ENSG...: Numeric. Log2(TPM + 1) normalised gene expression
levels for 997 Ensembl genes (columns named by versioned Ensembl
gene IDs, e.g. ENSG00000270372.1). Genes were selected as
the most variable transcripts across TCGA-OV samples, ranked by
median absolute deviation (MAD).
RNA-seq data were downloaded from the GDC portal using the
TCGAbiolinks package (STAR - Counts workflow). Expression values
were normalised to TPM and log2-transformed as log2(TPM + 1). Genes with
median TPM <= 1 across all samples were removed prior to MAD filtering.
Clinical data were obtained from the BCR Biotab clinical supplement.
Treatment assignment was derived from the drug table
(clinical_drug_ov), restricted to adjuvant (first-line) treatment
records. Samples were matched between expression and clinical data using
the 12-character TCGA patient barcode.
https://portal.gdc.cancer.gov/projects/TCGA-OV
Cancer Genome Atlas Research Network (2011). Integrated genomic analyses of ovarian carcinoma. Nature, 474, 609–615. doi:10.1038/nature10166
Colaprico, A. et al. (2016). TCGAbiolinks: an R/Bioconductor package for integrative analysis with GDC data. Nucleic Acids Research, 44(8). doi:10.1093/nar/gkv1507
data(ovarian) # Dimensions: patients x (6 clinical + 2000 gene columns) dim(ovarian) # Survival outcome head(ovarian[, c("OS_time", "OS_event", "treatment")]) # KM plot by treatment if (requireNamespace("survival", quietly = TRUE)) { library(survival) fit <- survfit(Surv(OS_time, OS_event) ~ treatment, data = ovarian) plot(fit, col = c("blue", "red"), xlab = "Time (days)", ylab = "Survival") legend("topright", c("Carboplatin", "Cisplatin"), col = c("blue", "red"), lty = 1) }data(ovarian) # Dimensions: patients x (6 clinical + 2000 gene columns) dim(ovarian) # Survival outcome head(ovarian[, c("OS_time", "OS_event", "treatment")]) # KM plot by treatment if (requireNamespace("survival", quietly = TRUE)) { library(survival) fit <- survfit(Surv(OS_time, OS_event) ~ treatment, data = ovarian) plot(fit, col = c("blue", "red"), xlab = "Time (days)", ylab = "Survival") legend("topright", c("Carboplatin", "Cisplatin"), col = c("blue", "red"), lty = 1) }
The simulated quantities that correspond to the
ovarian dataset. Because the ovarian outcomes and
treatment assignment are generated from a known data-generating
process, the underlying potential outcomes, prognostic function,
conditional treatment effect, and propensity score are available for
validating estimators of treatment effects under right- and
interval-censored survival.
ovarian_truthovarian_truth
A data frame with one row per patient in ovarian
and the following columns:
Numeric. True (uncensored) survival time on the log scale.
Numeric. True (uncensored) survival time on the original scale.
Numeric. True prognostic function
(expected log survival time at the reference treatment).
Numeric. True conditional average treatment effect
on the log-survival scale.
Numeric. True propensity for the treated group (carboplatin) used to simulate the observed assignment.
ovarian for the observed semi-synthesised data.
data(ovarian) data(ovarian_truth) stopifnot(nrow(ovarian) == nrow(ovarian_truth)) # True (population) average treatment effect on the log-survival scale: mean(ovarian_truth$true_tau)data(ovarian) data(ovarian_truth) stopifnot(nrow(ovarian) == nrow(ovarian_truth)) # True (population) average treatment effect on the log-survival scale: mean(ovarian_truth$true_tau)
A reduced and cleaned subset of the TCGA pancreatic ductal adenocarcinoma (PAAD)
dataset, derived from The Cancer Genome Atlas (TCGA) PAAD cohort. This version,
pdac, is smaller and simplified for practical analyses and package examples.
pdacpdac
A data frame with rows corresponding to patients and columns as described above.
This dataset was originally compiled and curated in the open-source pdacR
package by Torre-Healy et al. (2023), which harmonized and integrated the TCGA
PAAD gene expression and clinical data. The current version further reduces and
simplifies the data for efficient modeling demonstrations and survival analyses.
The data frame includes:
time: Overall survival time in months.
status: Event indicator; 1 = event occurred, 0 = censored.
treatment: Binary treatment indicator; 1 = radiation therapy, 0 = control.
age: Age at initial pathologic diagnosis (numeric).
sex: Binary sex indicator; 1 = male, 0 = female.
grade: Tumor differentiation grade (ordinal; 1 = well, 2 = moderate, 3 = poor, 4 = undifferentiated).
tumor.cellularity: Tumor cellularity estimate (numeric).
tumor.purity: Tumor purity class (binary; 1 = high, 0 = low).
absolute.purity: Absolute purity estimate (numeric).
moffitt.cluster: Moffitt transcriptional subtype (binary; 1 = basal-like, 0 = classical).
meth.leukocyte.percent: DNA methylation leukocyte estimate (numeric).
meth.purity.mode: DNA methylation purity mode (numeric).
stage: Nodal stage indicator (binary; 1 = n1, 0 = n0).
lymph.nodes: Number of lymph nodes examined (numeric).
Driver gene columns: Expression values of key driver genes (e.g., KRAS, TP53, CDKN2A, SMAD4, BRCA1, BRCA2).
Other gene columns: Expression values of ~3,000 most variable non-driver genes (based on median absolute deviation).
doi:10.1016/j.ccell.2017.07.007
Raphael BJ, et al. "Integrated genomic characterization of pancreatic ductal adenocarcinoma." Cancer Cell. 2017 Aug 14;32(2):185–203.e13. PMID: 28810144.
Torre-Healy LA, Kawalerski RR, Oh K, et al. "Open-source curation of a pancreatic ductal adenocarcinoma gene expression analysis platform (pdacR) supports a two-subtype model." Communications Biology. 2023; https://doi.org/10.1038/s42003-023-04461-6.
The Cancer Genome Atlas (TCGA), PAAD project, DbGaP: phs000178.
Visualises posterior draws using ggplot2. Requires the suggested package ggplot2.
## S3 method for class 'CausalShrinkageForest' plot( x, type = c("trace", "density", "ate", "cate", "vi"), forest = c("both", "control", "treat"), n_vi = 10, bayesian_bootstrap = TRUE, ... )## S3 method for class 'CausalShrinkageForest' plot( x, type = c("trace", "density", "ate", "cate", "vi"), forest = c("both", "control", "treat"), n_vi = 10, bayesian_bootstrap = TRUE, ... )
x |
A |
type |
Character; one of:
|
forest |
For |
n_vi |
Integer; number of top variables for |
bayesian_bootstrap |
Logical; only used when |
... |
Additional arguments (currently unused). |
A ggplot2 object, or (for type = "vi" with
forest = "both") a named list with elements control and
treat.
if (requireNamespace("ggplot2", quietly = TRUE)) { set.seed(1) n <- 60; p <- 5 X <- matrix(rnorm(n * p), ncol = p) w <- rbinom(n, 1, 0.5) y <- X[, 1] + w * 1.5 * (X[, 2] > 0) + rnorm(n, sd = 0.5) fit <- CausalShrinkageForest( y = y, X_train_control = X, X_train_treat = X, treatment_indicator_train = w, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = 0.1, global_hp_control = 0.1, local_hp_treat = 0.1, global_hp_treat = 0.1, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) plot(fit, type = "trace") plot(fit, type = "ate") plot(fit, type = "cate") }if (requireNamespace("ggplot2", quietly = TRUE)) { set.seed(1) n <- 60; p <- 5 X <- matrix(rnorm(n * p), ncol = p) w <- rbinom(n, 1, 0.5) y <- X[, 1] + w * 1.5 * (X[, 2] > 0) + rnorm(n, sd = 0.5) fit <- CausalShrinkageForest( y = y, X_train_control = X, X_train_treat = X, treatment_indicator_train = w, prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = 0.1, global_hp_control = 0.1, local_hp_treat = 0.1, global_hp_treat = 0.1, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) plot(fit, type = "trace") plot(fit, type = "ate") plot(fit, type = "cate") }
Visualises posterior draws using ggplot2. Requires the suggested package ggplot2.
## S3 method for class 'ShrinkageTrees' plot( x, type = c("trace", "density", "vi", "survival"), n_vi = 10, obs = NULL, t_grid = NULL, level = 0.95, km = FALSE, ... )## S3 method for class 'ShrinkageTrees' plot( x, type = c("trace", "density", "vi", "survival"), n_vi = 10, obs = NULL, t_grid = NULL, level = 0.95, km = FALSE, ... )
x |
A |
type |
Character; one of:
|
n_vi |
Integer; number of top variables to display when
|
obs |
Integer vector of training-set observation indices for
individual survival curves, or |
t_grid |
Optional numeric vector of time points (on the original
time scale) at which to evaluate the survival function. If |
level |
Width of the pointwise credible band for
|
km |
Logical; if |
... |
Additional arguments (currently unused). |
A ggplot2 object.
if (requireNamespace("ggplot2", quietly = TRUE)) { set.seed(1) n <- 50; p <- 5 X <- matrix(rnorm(n * p), ncol = p) y <- X[, 1] + rnorm(n) # Fit a small continuous model fit <- ShrinkageTrees( y = y, X_train = X, prior_type = "horseshoe", local_hp = 0.1, global_hp = 0.1, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE ) # Sigma traceplot -- check chain mixing plot(fit, type = "trace") # Overlaid posterior densities of sigma per chain plot(fit, type = "density") }if (requireNamespace("ggplot2", quietly = TRUE)) { set.seed(1) n <- 50; p <- 5 X <- matrix(rnorm(n * p), ncol = p) y <- X[, 1] + rnorm(n) # Fit a small continuous model fit <- ShrinkageTrees( y = y, X_train = X, prior_type = "horseshoe", local_hp = 0.1, global_hp = 0.1, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE ) # Sigma traceplot -- check chain mixing plot(fit, type = "trace") # Overlaid posterior densities of sigma per chain plot(fit, type = "density") }
Plots posterior predictive survival curves for new observations from
a ShrinkageTreesPrediction object. Only available for survival
outcome types ("right-censored" or "interval-censored").
## S3 method for class 'ShrinkageTreesPrediction' plot(x, type = "survival", obs = NULL, t_grid = NULL, level = 0.95, ...)## S3 method for class 'ShrinkageTreesPrediction' plot(x, type = "survival", obs = NULL, t_grid = NULL, level = 0.95, ...)
x |
A |
type |
Character; currently only |
obs |
Integer vector of predicted-observation indices for individual
survival curves, or |
t_grid |
Optional numeric vector of time points (on the original
time scale) at which to evaluate the survival function. If |
level |
Width of the pointwise credible band. Default |
... |
Additional arguments (currently unused). |
A ggplot2 object.
predict.ShrinkageTrees,
plot.ShrinkageTrees
if (requireNamespace("ggplot2", quietly = TRUE)) { set.seed(1) n <- 40; p <- 3 X <- matrix(rnorm(n * p), ncol = p) X_test <- matrix(rnorm(10 * p), ncol = p) time <- rexp(n, rate = exp(0.5 * X[, 1])) status <- rbinom(n, 1, 0.7) fit_surv <- SurvivalBART( time = time, status = status, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) pred <- predict(fit_surv, newdata = X_test) plot(pred, type = "survival") }if (requireNamespace("ggplot2", quietly = TRUE)) { set.seed(1) n <- 40; p <- 3 X <- matrix(rnorm(n * p), ncol = p) X_test <- matrix(rnorm(10 * p), ncol = p) time <- rexp(n, rate = exp(0.5 * X[, 1])) status <- rbinom(n, 1, 0.7) fit_surv <- SurvivalBART( time = time, status = status, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) pred <- predict(fit_surv, newdata = X_test) plot(pred, type = "survival") }
Re-runs the MCMC sampler on new covariate data using the stored training
data and hyperparameters, returning posterior mean predictions and credible
interval bounds for three quantities: the prognostic function
(control-forest prediction ), the Conditional Average
Treatment Effect (CATE, ), and the total predicted
outcome ().
## S3 method for class 'CausalShrinkageForest' predict( object, newdata_control, newdata_treat, level = 0.95, bayesian_bootstrap = TRUE, ... )## S3 method for class 'CausalShrinkageForest' predict( object, newdata_control, newdata_treat, level = 0.95, bayesian_bootstrap = TRUE, ... )
object |
A fitted |
newdata_control |
A matrix of new covariates for the control forest,
with the same number of columns as |
newdata_treat |
A matrix of new covariates for the treatment forest,
with the same number of columns as |
level |
Credible interval width. Default |
bayesian_bootstrap |
Logical; if |
... |
Currently unused. |
The causal forest decomposes the expected outcome as
where is the prognostic function (control forest),
is the CATE (treatment forest), and is the
treatment indicator.
For continuous outcomes and survival with
timescale = "log", all three components are on the response scale:
prognostic and total include the intercept shift
(), while cate is the pure additive treatment
effect with no intercept.
For survival with timescale = "time", predictions are
back-transformed to the original time scale:
prognostic: posterior expected baseline survival time
.
cate: multiplicative effect on survival time
; a value greater than 1 means treatment prolongs
survival.
total: posterior expected survival time under the observed
treatment .
A CausalShrinkageForestPrediction object with elements:
List with mean, lower, upper:
posterior summaries of the prognostic function
.
List with mean, lower, upper:
posterior summaries of the CATE .
List with mean, lower, upper:
posterior summaries of the total outcome
.
List with mean, lower, upper: posterior
summary of the average treatment effect over newdata. For
survival with timescale = "time", reported as a multiplicative
time ratio on the original scale.
matrix of posterior
CATE draws on the scale reported in cate.
Flag indicating whether the reported ATE CI used Dirichlet reweighting.
Number of test observations.
Credible level used.
Outcome type inherited from the fitted model.
Timescale inherited from the fitted model.
CausalHorseForest, CausalShrinkageForest,
print.CausalShrinkageForestPrediction,
summary.CausalShrinkageForestPrediction
Re-runs the MCMC sampler on new covariate data using the stored training data and hyperparameters, returning posterior mean predictions and credible interval bounds.
## S3 method for class 'ShrinkageTrees' predict(object, newdata, level = 0.95, ...)## S3 method for class 'ShrinkageTrees' predict(object, newdata, level = 0.95, ...)
object |
A fitted |
newdata |
A matrix (or object coercible to one) of new covariates with the same number of columns as the training data. |
level |
Credible interval width. Default |
... |
Currently unused. |
A ShrinkageTreesPrediction object with elements:
Posterior mean predictions (length nrow(newdata)).
Lower credible interval bound.
Upper credible interval bound.
Number of test observations.
Credible level used.
Outcome type inherited from the fitted model.
Timescale inherited from the fitted model (survival only).
(Survival only) N_post x n
matrix of posterior predictive draws on the original scale.
(Survival only) Posterior draws of sigma on the log-time
scale (length N_post).
HorseTrees, ShrinkageTrees,
print.ShrinkageTreesPrediction,
summary.ShrinkageTreesPrediction,
plot.ShrinkageTreesPrediction
Displays a concise summary of a fitted CausalShrinkageForest model
with per-forest columns for priors, tree counts, feature counts, and MCMC
acceptance ratios.
## S3 method for class 'CausalShrinkageForest' print(x, ...)## S3 method for class 'CausalShrinkageForest' print(x, ...)
x |
A fitted |
... |
Currently unused. |
Invisibly returns x.
summary.CausalShrinkageForest,
CausalHorseForest, CausalShrinkageForest
Displays a formatted table of posterior mean predictions and credible
interval bounds for the first n_head observations, with separate
sections for the prognostic function , the CATE ,
and the total outcome .
## S3 method for class 'CausalShrinkageForestPrediction' print(x, n_head = 6, digits = 3, ...)## S3 method for class 'CausalShrinkageForestPrediction' print(x, n_head = 6, digits = 3, ...)
x |
A |
n_head |
Number of observations to display per section. Default |
digits |
Number of decimal places. Default |
... |
Currently unused. |
Invisibly returns x.
predict.CausalShrinkageForest,
summary.CausalShrinkageForestPrediction
Displays a concise summary of a fitted ShrinkageTrees model,
including outcome type, prior, MCMC settings, acceptance ratio, and
posterior mean sigma.
## S3 method for class 'ShrinkageTrees' print(x, ...)## S3 method for class 'ShrinkageTrees' print(x, ...)
x |
A fitted |
... |
Currently unused. |
Invisibly returns x.
summary.ShrinkageTrees, HorseTrees,
ShrinkageTrees
Displays a formatted table of posterior mean predictions and credible
interval bounds for the first n_head observations.
## S3 method for class 'ShrinkageTreesPrediction' print(x, n_head = 6, digits = 3, ...)## S3 method for class 'ShrinkageTreesPrediction' print(x, n_head = 6, digits = 3, ...)
x |
A |
n_head |
Number of observations to display. Default |
digits |
Number of decimal places. Default |
... |
Currently unused. |
Invisibly returns x.
predict.ShrinkageTrees,
summary.ShrinkageTreesPrediction
Displays a detailed summary of a CausalShrinkageForest model,
including model specification, treatment effect estimates, prognostic
function, posterior sigma, variable importance for each forest, and MCMC
diagnostics.
## S3 method for class 'summary.CausalShrinkageForest' print(x, n_vi = 10, ...)## S3 method for class 'summary.CausalShrinkageForest' print(x, n_vi = 10, ...)
x |
A |
n_vi |
Maximum number of variables to display per variable importance
table. Default |
... |
Currently unused. |
Invisibly returns x.
Displays distributional summaries (min, Q1, median, max) of the posterior mean predictions and credible interval bounds, separately for the prognostic function, CATE, and total outcome.
## S3 method for class 'summary.CausalShrinkageForestPrediction' print(x, digits = 3, ...)## S3 method for class 'summary.CausalShrinkageForestPrediction' print(x, digits = 3, ...)
x |
A |
digits |
Number of decimal places. Default |
... |
Currently unused. |
Invisibly returns x.
summary.CausalShrinkageForestPrediction
Displays a detailed summary of a ShrinkageTrees model, including
model specification, posterior sigma, prediction summaries, variable
importance, and MCMC diagnostics.
## S3 method for class 'summary.ShrinkageTrees' print(x, n_vi = 10, ...)## S3 method for class 'summary.ShrinkageTrees' print(x, n_vi = 10, ...)
x |
A |
n_vi |
Maximum number of variables to display in the variable
importance table. Default |
... |
Currently unused. |
Invisibly returns x.
Displays distributional summaries (min, Q1, median, max) of the posterior mean predictions and credible interval bounds.
## S3 method for class 'summary.ShrinkageTreesPrediction' print(x, digits = 3, ...)## S3 method for class 'summary.ShrinkageTreesPrediction' print(x, digits = 3, ...)
x |
A |
digits |
Number of decimal places. Default |
... |
Currently unused. |
Invisibly returns x.
summary.ShrinkageTreesPrediction
Fits a Bayesian Shrinkage Tree model with flexible global-local priors on the
step heights. This function generalizes HorseTrees by allowing
different global-local shrinkage priors on the step heights.
Supports continuous, binary, right-censored, and interval-censored outcomes.
ShrinkageTrees( y = NULL, status = NULL, X_train, X_test = NULL, left_time = NULL, right_time = NULL, outcome_type = "continuous", timescale = "time", number_of_trees = 200, prior_type = "horseshoe", local_hp = NULL, global_hp = NULL, a_dirichlet = 0.5, b_dirichlet = 1, rho_dirichlet = NULL, power = 2, base = 0.95, p_grow = 0.4, p_prune = 0.4, nu = 3, q = 0.9, sigma = NULL, N_post = 1000, N_burn = 1000, delayed_proposal = 5, store_posterior_sample = TRUE, n_chains = 1, verbose = TRUE )ShrinkageTrees( y = NULL, status = NULL, X_train, X_test = NULL, left_time = NULL, right_time = NULL, outcome_type = "continuous", timescale = "time", number_of_trees = 200, prior_type = "horseshoe", local_hp = NULL, global_hp = NULL, a_dirichlet = 0.5, b_dirichlet = 1, rho_dirichlet = NULL, power = 2, base = 0.95, p_grow = 0.4, p_prune = 0.4, nu = 3, q = 0.9, sigma = NULL, N_post = 1000, N_burn = 1000, delayed_proposal = 5, store_posterior_sample = TRUE, n_chains = 1, verbose = TRUE )
y |
Outcome vector. Numeric. Can represent continuous outcomes, binary
outcomes (0/1), or follow-up times for survival data. Set to |
status |
Optional censoring indicator vector (1 = event occurred,
0 = censored). Required if |
X_train |
Covariate matrix for training. Each row corresponds to an observation, and each column to a covariate. |
X_test |
Optional covariate matrix for test data. If NULL, defaults to the mean of the training covariates. |
left_time |
Optional numeric vector of left (lower) time boundaries.
Required when |
right_time |
Optional numeric vector of right (upper) time boundaries.
Required when |
outcome_type |
Type of outcome. One of |
timescale |
Indicates the scale of follow-up times. Options are
|
number_of_trees |
Number of trees in the ensemble. Default is 200. |
prior_type |
Type of prior on the step heights. Options include
|
local_hp |
Local hyperparameter controlling shrinkage on individual
step heights. Should typically be set smaller than 1 / sqrt(number_of_trees).
Required for |
global_hp |
Global hyperparameter controlling overall shrinkage.
Must be specified for Horseshoe-type priors; ignored for
|
a_dirichlet |
First shape parameter of the Beta prior used in the
Dirichlet-Sparse splitting rule. Together with |
b_dirichlet |
Second shape parameter of the Beta prior for the
sparsity level. Larger values shrink splitting probabilities more strongly
toward uniform sparsity. Only when |
rho_dirichlet |
Sparsity hyperparameter. If left NULL, it defaults to
the number of covariates. Only when |
power |
Power parameter for the tree structure prior. Default is 2.0. |
base |
Base parameter for the tree structure prior. Default is 0.95. |
p_grow |
Probability of proposing a grow move. Default is 0.4. |
p_prune |
Probability of proposing a prune move. Default is 0.4. |
nu |
Degrees of freedom for the error distribution prior. Default is 3. |
q |
Quantile hyperparameter for the error variance prior. Default is 0.90. |
sigma |
Optional known value for error standard deviation. If NULL, estimated from data. |
N_post |
Number of posterior samples to store. Default is 1000. |
N_burn |
Number of burn-in iterations. Default is 1000. |
delayed_proposal |
Number of delayed iterations before proposal. Only for reversible updates. Default is 5. |
store_posterior_sample |
Logical; whether to store posterior samples for each iteration. Default is TRUE. |
n_chains |
Number of independent MCMC chains to run. Default is
|
verbose |
Logical; whether to print verbose output. Default is TRUE. |
This function is a flexible generalization of HorseTrees.
Instead of using a single Horseshoe prior, it allows specifying different
global–local shrinkage configurations for the tree step heights. Further
methodological details on the Horseshoe Forest framework can be found in
Jacobs, van Wieringen & van der Pas (2025).
The horseshoe prior is the fully Bayesian global-local shrinkage
prior, where both the global and local shrinkage parameters are assigned
half-Cauchy distributions with scale hyperparameters global_hp and
local_hp, respectively. The global shrinkage parameter is defined
separately for each tree, allowing adaptive regularization per tree.
The horseshoe_fw prior (forest-wide horseshoe) is similar to
horseshoe, except that the global shrinkage parameter is shared
across all trees in the forest simultaneously.
The half-cauchy prior considers only local shrinkage and does not
include a global shrinkage component. It places a half-Cauchy prior on each
local shrinkage parameter with scale hyperparameter local_hp.
The standard prior (Chipman, George & McCulloch, 2010) corresponds to
the classical BART specification, where step heights are given a normal
prior with variance scaled by the number of trees. This prior does not
introduce a global shrinkage parameter and does not use global–local
structure.
The dirichlet prior implements the Dirichlet–Sparse splitting rule of
Linero (2018), in which splitting probabilities follow a Dirichlet prior
whose concentration is controlled by a Beta sparsity parameter
(a_dirichlet, b_dirichlet) and an expected sparsity level
rho_dirichlet.
An S3 object of class "ShrinkageTrees" containing the following elements:
Vector of posterior mean predictions on the training data.
Vector of posterior mean predictions on the test
data (or on mean covariate vector if X_test not provided).
Vector of posterior samples of the error variance.
Average acceptance ratio across trees during sampling.
Matrix of posterior samples of training
predictions (iterations in rows, observations in columns). Present only if
store_posterior_sample = TRUE.
Matrix of posterior samples of test
predictions. Present only if store_posterior_sample = TRUE.
Vector of posterior mean probabilities on the
training data (only for outcome_type = "binary").
Vector of posterior mean probabilities on the
test data (only for outcome_type = "binary").
Matrix of posterior samples of training
probabilities (only for outcome_type = "binary" and if
store_posterior_sample = TRUE).
Matrix of posterior samples of test
probabilities (only for outcome_type = "binary" and if
store_posterior_sample = TRUE).
Jacobs, T., van Wieringen, W. N., & van der Pas, S. L. (2025). Horseshoe Forests for High-Dimensional Causal Survival Analysis. arXiv:2507.22004. https://doi.org/10.48550/arXiv.2507.22004 Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). BART: Bayesian additive regression trees. Annals of Applied Statistics.
Linero, A. R. (2018). Bayesian regression trees for high-dimensional prediction and variable selection. Journal of the American Statistical Association.
Model family: HorseTrees (horseshoe prior),
CausalHorseForest (causal inference),
CausalShrinkageForest (causal, flexible prior).
Survival wrappers: SurvivalBART, SurvivalDART.
S3 methods: print.ShrinkageTrees,
summary.ShrinkageTrees,
predict.ShrinkageTrees,
plot.ShrinkageTrees.
# Example: Continuous outcome with ShrinkageTrees, two priors n <- 50 p <- 3 X <- matrix(runif(n * p), ncol = p) X_test <- matrix(runif(n * p), ncol = p) y <- X[, 1] + rnorm(n) # Fit ShrinkageTrees with standard horseshoe prior fit_horseshoe <- ShrinkageTrees(y = y, X_train = X, X_test = X_test, outcome_type = "continuous", number_of_trees = 5, prior_type = "horseshoe", local_hp = 0.1 / sqrt(5), global_hp = 0.1 / sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE) # Fit ShrinkageTrees with half-Cauchy prior fit_halfcauchy <- ShrinkageTrees(y = y, X_train = X, X_test = X_test, outcome_type = "continuous", number_of_trees = 5, prior_type = "half-cauchy", local_hp = 1 / sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE) # Posterior mean predictions pred_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample) pred_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample) # Posteriors of the mean (global average prediction) post_mean_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample) post_mean_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample) # Posterior mean prediction averages mean_pred_horseshoe <- mean(post_mean_horseshoe) mean_pred_halfcauchy <- mean(post_mean_halfcauchy) # Example: Interval-censored survival outcome n <- 50; p <- 3 X_ic <- matrix(rnorm(n * p), ncol = p) true_t <- rexp(n, rate = exp(-X_ic[, 1])) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 15) left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 10); right_t[rc] <- Inf fit_ic <- ShrinkageTrees(left_time = left_t, right_time = right_t, X_train = X_ic, outcome_type = "interval-censored", prior_type = "horseshoe", local_hp = 0.1 / sqrt(5), global_hp = 0.1 / sqrt(5), number_of_trees = 5, N_post = 10, N_burn = 5, verbose = FALSE)# Example: Continuous outcome with ShrinkageTrees, two priors n <- 50 p <- 3 X <- matrix(runif(n * p), ncol = p) X_test <- matrix(runif(n * p), ncol = p) y <- X[, 1] + rnorm(n) # Fit ShrinkageTrees with standard horseshoe prior fit_horseshoe <- ShrinkageTrees(y = y, X_train = X, X_test = X_test, outcome_type = "continuous", number_of_trees = 5, prior_type = "horseshoe", local_hp = 0.1 / sqrt(5), global_hp = 0.1 / sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE) # Fit ShrinkageTrees with half-Cauchy prior fit_halfcauchy <- ShrinkageTrees(y = y, X_train = X, X_test = X_test, outcome_type = "continuous", number_of_trees = 5, prior_type = "half-cauchy", local_hp = 1 / sqrt(5), N_post = 10, N_burn = 5, store_posterior_sample = TRUE, verbose = FALSE) # Posterior mean predictions pred_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample) pred_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample) # Posteriors of the mean (global average prediction) post_mean_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample) post_mean_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample) # Posterior mean prediction averages mean_pred_horseshoe <- mean(post_mean_horseshoe) mean_pred_halfcauchy <- mean(post_mean_halfcauchy) # Example: Interval-censored survival outcome n <- 50; p <- 3 X_ic <- matrix(rnorm(n * p), ncol = p) true_t <- rexp(n, rate = exp(-X_ic[, 1])) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 15) left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 10); right_t[rc] <- Inf fit_ic <- ShrinkageTrees(left_time = left_t, right_time = right_t, X_train = X_ic, outcome_type = "interval-censored", prior_type = "horseshoe", local_hp = 0.1 / sqrt(5), global_hp = 0.1 / sqrt(5), number_of_trees = 5, N_post = 10, N_burn = 5, verbose = FALSE)
Returns an inspectable list with treatment effect estimates, prognostic function summaries, posterior sigma, variable importance for each forest, and MCMC diagnostics.
## S3 method for class 'CausalShrinkageForest' summary(object, bayesian_bootstrap = TRUE, ...)## S3 method for class 'CausalShrinkageForest' summary(object, bayesian_bootstrap = TRUE, ...)
object |
A fitted |
bayesian_bootstrap |
Logical; if |
... |
Currently unused. |
A summary.CausalShrinkageForest object with elements:
The original model call.
Outcome type.
Timescale for survival outcomes.
Prior specification for control and treatment forests.
MCMC settings.
Training and test data dimensions.
List with ate (posterior mean ATE),
cate_sd (SD of individual CATEs), and optionally
ate_lower, ate_upper (95 percent credible interval;
requires store_posterior_sample = TRUE) and
bayesian_bootstrap (the flag used to produce the CI).
Summary of the prognostic function (mean, SD, range).
Named vector with posterior mean, SD, and 95 percent credible interval of sigma (if estimated).
Variable importance for the control forest (if available).
Variable importance for the treatment forest (if available).
List with acceptance ratios for each forest.
print.summary.CausalShrinkageForest,
CausalHorseForest, CausalShrinkageForest
Returns distributional summaries (min, Q1, median, max) of the posterior mean predictions and credible interval bounds across all test observations, separately for the prognostic function, CATE, and total outcome.
## S3 method for class 'CausalShrinkageForestPrediction' summary(object, ...)## S3 method for class 'CausalShrinkageForestPrediction' summary(object, ...)
object |
A |
... |
Currently unused. |
A summary.CausalShrinkageForestPrediction object.
predict.CausalShrinkageForest,
print.summary.CausalShrinkageForestPrediction
Returns an inspectable list with posterior sigma summaries, prediction summaries, variable importance (posterior inclusion probabilities), and MCMC diagnostics.
## S3 method for class 'ShrinkageTrees' summary(object, ...)## S3 method for class 'ShrinkageTrees' summary(object, ...)
object |
A fitted |
... |
Currently unused. |
A summary.ShrinkageTrees object with elements:
The original model call.
Outcome type ("continuous", "binary",
"right-censored", or "interval-censored").
Timescale for survival outcomes ("time" or
"log").
Prior specification.
MCMC settings (trees, draws, burn-in).
Training and test data dimensions.
Named vector with posterior mean, SD, and 95 percent credible interval of sigma (continuous and survival outcomes only).
List with train (and optionally test)
prediction summaries (mean, SD, range).
Named vector of posterior inclusion probabilities, sorted decreasingly (if available).
MCMC acceptance ratio vector.
(When coda is installed) A list with
ess (effective sample size) and, for multi-chain fits,
rhat (Gelman–Rubin ).
print.summary.ShrinkageTrees,
as.mcmc.list.ShrinkageTrees,
HorseTrees, ShrinkageTrees
Returns distributional summaries (min, Q1, median, max) of the posterior mean predictions and credible interval bounds across all observations.
## S3 method for class 'ShrinkageTreesPrediction' summary(object, ...)## S3 method for class 'ShrinkageTreesPrediction' summary(object, ...)
object |
A |
... |
Currently unused. |
A summary.ShrinkageTreesPrediction object.
predict.ShrinkageTrees,
print.summary.ShrinkageTreesPrediction
Fits an Accelerated Failure Time (AFT) model using the classical
Bayesian Additive Regression Trees (BART) prior:
.
Supports both right-censored and interval-censored survival outcomes.
SurvivalBART( time = NULL, status = NULL, X_train, X_test = NULL, timescale = "time", number_of_trees = 200, k = 2, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE, verbose = TRUE, left_time = NULL, right_time = NULL, ... )SurvivalBART( time = NULL, status = NULL, X_train, X_test = NULL, timescale = "time", number_of_trees = 200, k = 2, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE, verbose = TRUE, left_time = NULL, right_time = NULL, ... )
time |
Outcome vector of (non-negative) survival times. Required for
right-censored outcomes; set to |
status |
Event indicator (1 = event, 0 = censored). Required for right-censored outcomes; derived automatically for interval censoring. |
X_train |
Design matrix for training data. |
X_test |
Optional test matrix. If NULL, predictions are computed at
the column means of |
timescale |
Either |
number_of_trees |
Number of trees in the ensemble. Default is 200. |
k |
Scaling constant used to calibrate the prior variance of the step heights. |
N_post |
Number of posterior samples to store. |
N_burn |
Number of burn-in iterations. |
store_posterior_sample |
Logical; if |
verbose |
Logical; print sampling progress. |
left_time |
Optional numeric vector of left (lower) time boundaries
for interval-censored data. Exact events have
|
right_time |
Optional numeric vector of right (upper) time boundaries.
Use |
... |
Additional arguments passed to |
This function provides a survival-specific interface for classical BART under an AFT formulation for right-censored or interval-censored outcomes.
For right-censored data, supply time and status.
For interval-censored data, supply left_time and right_time
instead; event indicators are derived internally following the
survival::Surv(type = "interval2") convention.
Structural regularisation is induced through the standard Gaussian leaf prior and tree depth prior of Chipman, George & McCulloch (2010).
Users requiring alternative shrinkage priors (e.g., Horseshoe or
Dirichlet splitting priors) should use ShrinkageTrees
directly.
An object of class "ShrinkageTrees" fitted under a classical
BART prior within an AFT formulation.
See ShrinkageTrees for a full description of returned components
Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). Bayesian Additive Regression Trees. Annals of Applied Statistics.
Related models: SurvivalDART (Dirichlet sparsity),
HorseTrees (horseshoe prior),
ShrinkageTrees (general shrinkage priors).
S3 methods: print.ShrinkageTrees,
summary.ShrinkageTrees,
predict.ShrinkageTrees,
plot.ShrinkageTrees.
set.seed(1) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) time <- rexp(n, rate = exp(0.5 * X[, 1])) status <- rbinom(n, 1, 0.7) fit <- SurvivalBART(time = time, status = status, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE) # S3 methods print(fit) smry <- summary(fit) # Posterior predictions on new data X_new <- matrix(rnorm(10 * p), ncol = p) pred <- predict(fit, newdata = X_new) print(pred) # Diagnostic plot (requires ggplot2) if (requireNamespace("ggplot2", quietly = TRUE)) { plot(fit, type = "trace") # Posterior survival curves for training data plot(fit, type = "survival") # Posterior predictive survival curves for new data plot(pred, type = "survival") plot(pred, type = "survival", obs = c(1, 5)) } # Interval-censored example set.seed(11) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) true_t <- rexp(n, rate = exp(0.5 * X[, 1])) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 10); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit_ic <- SurvivalBART(left_time = left_t, right_time = right_t, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE)set.seed(1) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) time <- rexp(n, rate = exp(0.5 * X[, 1])) status <- rbinom(n, 1, 0.7) fit <- SurvivalBART(time = time, status = status, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE) # S3 methods print(fit) smry <- summary(fit) # Posterior predictions on new data X_new <- matrix(rnorm(10 * p), ncol = p) pred <- predict(fit, newdata = X_new) print(pred) # Diagnostic plot (requires ggplot2) if (requireNamespace("ggplot2", quietly = TRUE)) { plot(fit, type = "trace") # Posterior survival curves for training data plot(fit, type = "survival") # Posterior predictive survival curves for new data plot(pred, type = "survival") plot(pred, type = "survival", obs = c(1, 5)) } # Interval-censored example set.seed(11) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) true_t <- rexp(n, rate = exp(0.5 * X[, 1])) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 10); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit_ic <- SurvivalBART(left_time = left_t, right_time = right_t, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE)
Fits an Accelerated Failure Time (AFT) version of Bayesian Causal Forest (BCF):
, where separate forests are used
for the prognostic (control) function and the treatment effect
function .
SurvivalBCF( time = NULL, status = NULL, X_train, treatment, timescale = "time", propensity = NULL, treatment_coding = "centered", number_of_trees_control = 200, number_of_trees_treat = 50, power_control = 2, base_control = 0.95, power_treat = 3, base_treat = 0.25, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE, verbose = TRUE, left_time = NULL, right_time = NULL, ... )SurvivalBCF( time = NULL, status = NULL, X_train, treatment, timescale = "time", propensity = NULL, treatment_coding = "centered", number_of_trees_control = 200, number_of_trees_treat = 50, power_control = 2, base_control = 0.95, power_treat = 3, base_treat = 0.25, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE, verbose = TRUE, left_time = NULL, right_time = NULL, ... )
time |
Outcome vector of (non-negative) survival times. Required for
right-censored outcomes; set to |
status |
Event indicator (1 = event, 0 = censored). Required for right-censored outcomes; derived automatically for interval censoring. |
X_train |
Design matrix for training data. |
treatment |
Treatment indicator (0/1) for training data. |
timescale |
Either |
propensity |
Optional vector of propensity scores. If provided,
it is appended to the control forest design matrix. Required when
|
treatment_coding |
Character string specifying how the treatment
indicator enters the model. One of |
number_of_trees_control |
Number of trees in the control forest. Default is 200. |
number_of_trees_treat |
Number of trees in the treatment forest. Default is 50. |
power_control, base_control
|
Tree-structure prior parameters for the control forest. |
power_treat, base_treat
|
Tree-structure prior parameters for the treatment forest. |
N_post |
Number of posterior samples to store. |
N_burn |
Number of burn-in iterations. |
store_posterior_sample |
Logical; if |
verbose |
Logical; print sampling progress. |
left_time |
Optional numeric vector of left (lower) time boundaries
for interval-censored data. Exact events have
|
right_time |
Optional numeric vector of right (upper) time boundaries.
Use |
... |
Additional arguments passed to |
This wrapper provides a survival-specific implementation using classical BART-style priors for both forests. Supports both right-censored and interval-censored survival outcomes.
This function implements a simplified AFT-BCF model for right-censored or interval-censored survival outcomes. Structural regularisation is induced through classical BART priors on the tree structure and leaf parameters.
For right-censored data, supply time and status.
For interval-censored data, supply left_time and right_time
instead; event indicators are derived internally following the
survival::Surv(type = "interval2") convention.
Users requiring alternative shrinkage priors (e.g., Horseshoe or Dirichlet
splitting priors) should use SurvivalShrinkageBCF or call
CausalShrinkageForest directly.
An object of class "CausalShrinkageForest" corresponding to a
survival BCF model under classical BART priors.
See CausalShrinkageForest for returned components.
Hahn, P. R., Murray, J. S., & Carvalho, C. M. (2020). Bayesian regression tree models for causal inference: Regularization, confounding, and heterogeneous effects. Bayesian Analysis.
Related models: SurvivalShrinkageBCF (Dirichlet sparsity),
CausalHorseForest (horseshoe prior),
CausalShrinkageForest (general shrinkage priors).
S3 methods: print.CausalShrinkageForest,
summary.CausalShrinkageForest,
predict.CausalShrinkageForest,
plot.CausalShrinkageForest.
set.seed(3) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) log_T <- X[, 1] + treatment * (-0.5) + rnorm(n) time <- exp(log_T) status <- rbinom(n, 1, 0.7) fit <- SurvivalBCF(time = time, status = status, X_train = X, treatment = treatment, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE) # S3 methods print(fit) smry <- summary(fit) # Posterior ATE cat("ATE:", round(smry$treatment_effect$ate, 3), "\n") # Diagnostic and treatment-effect plots (requires ggplot2) if (requireNamespace("ggplot2", quietly = TRUE)) { plot(fit, type = "trace") plot(fit, type = "ate") plot(fit, type = "cate") } # Interval-censored causal example set.seed(13) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) true_t <- exp(X[, 1] + treatment * (-0.5) + rnorm(n)) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 10); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit_ic <- SurvivalBCF(left_time = left_t, right_time = right_t, X_train = X, treatment = treatment, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE)set.seed(3) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) log_T <- X[, 1] + treatment * (-0.5) + rnorm(n) time <- exp(log_T) status <- rbinom(n, 1, 0.7) fit <- SurvivalBCF(time = time, status = status, X_train = X, treatment = treatment, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE) # S3 methods print(fit) smry <- summary(fit) # Posterior ATE cat("ATE:", round(smry$treatment_effect$ate, 3), "\n") # Diagnostic and treatment-effect plots (requires ggplot2) if (requireNamespace("ggplot2", quietly = TRUE)) { plot(fit, type = "trace") plot(fit, type = "ate") plot(fit, type = "cate") } # Interval-censored causal example set.seed(13) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) true_t <- exp(X[, 1] + treatment * (-0.5) + rnorm(n)) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 10); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit_ic <- SurvivalBCF(left_time = left_t, right_time = right_t, X_train = X, treatment = treatment, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE)
Fits an Accelerated Failure Time (AFT) model using the Dirichlet splitting prior (DART), which induces structural sparsity through a Beta-Dirichlet hierarchy on splitting probabilities. Supports both right-censored and interval-censored survival outcomes.
SurvivalDART( time = NULL, status = NULL, X_train, X_test = NULL, timescale = "time", number_of_trees = 200, a_dirichlet = 0.5, b_dirichlet = 1, rho_dirichlet = NULL, k = 2, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE, verbose = TRUE, left_time = NULL, right_time = NULL, ... )SurvivalDART( time = NULL, status = NULL, X_train, X_test = NULL, timescale = "time", number_of_trees = 200, a_dirichlet = 0.5, b_dirichlet = 1, rho_dirichlet = NULL, k = 2, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE, verbose = TRUE, left_time = NULL, right_time = NULL, ... )
time |
Outcome vector of (non-negative) survival times. Required for
right-censored outcomes; set to |
status |
Event indicator (1 = event, 0 = censored). Required for right-censored outcomes; derived automatically for interval censoring. |
X_train |
Design matrix for training data. |
X_test |
Optional test matrix. If NULL, predictions are computed at
the column means of |
timescale |
Either |
number_of_trees |
Number of trees in the ensemble. Default is 200. |
a_dirichlet, b_dirichlet
|
Beta hyperparameters controlling sparsity in the Dirichlet splitting rule. |
rho_dirichlet |
Expected number of active predictors. If NULL,
defaults to the number of covariates in |
k |
Scaling constant used to calibrate the prior variance of the step heights. |
N_post |
Number of posterior samples to store. |
N_burn |
Number of burn-in iterations. |
store_posterior_sample |
Logical; if |
verbose |
Logical; print sampling progress. |
left_time |
Optional numeric vector of left (lower) time boundaries
for interval-censored data. Exact events have
|
right_time |
Optional numeric vector of right (upper) time boundaries.
Use |
... |
Additional arguments passed to |
This function provides a survival-specific wrapper for DART under an AFT formulation for right-censored or interval-censored outcomes.
For right-censored data, supply time and status.
For interval-censored data, supply left_time and right_time
instead; event indicators are derived internally following the
survival::Surv(type = "interval2") convention.
Structural regularisation is induced through a Dirichlet prior on splitting probabilities, encouraging sparse feature usage in high-dimensional settings.
Users requiring alternative shrinkage priors on the leaf parameters
(e.g., Horseshoe or half-Cauchy priors) should use
ShrinkageTrees directly.
An object of class "ShrinkageTrees" fitted under a Dirichlet
splitting prior (DART) within an AFT formulation.
See ShrinkageTrees for a full description of returned components.
Related models: SurvivalBART (standard BART prior),
ShrinkageTrees (general shrinkage priors).
S3 methods: print.ShrinkageTrees,
summary.ShrinkageTrees,
predict.ShrinkageTrees,
plot.ShrinkageTrees.
set.seed(2) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) time <- rexp(n, rate = exp(0.5 * X[, 1])) status <- rbinom(n, 1, 0.7) fit <- SurvivalDART(time = time, status = status, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE) # S3 methods print(fit) smry <- summary(fit) # Posterior predictions on new data X_new <- matrix(rnorm(10 * p), ncol = p) pred <- predict(fit, newdata = X_new) print(pred) # Variable importance and survival plots (requires ggplot2) if (requireNamespace("ggplot2", quietly = TRUE)) { plot(fit, type = "vi", n_vi = 5) # Posterior survival curves for training data plot(fit, type = "survival") # Posterior predictive survival curves for new data plot(pred, type = "survival") plot(pred, type = "survival", obs = c(1, 5)) } # Interval-censored example set.seed(12) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) true_t <- rexp(n, rate = exp(0.5 * X[, 1])) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 10); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit_ic <- SurvivalDART(left_time = left_t, right_time = right_t, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE)set.seed(2) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) time <- rexp(n, rate = exp(0.5 * X[, 1])) status <- rbinom(n, 1, 0.7) fit <- SurvivalDART(time = time, status = status, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE) # S3 methods print(fit) smry <- summary(fit) # Posterior predictions on new data X_new <- matrix(rnorm(10 * p), ncol = p) pred <- predict(fit, newdata = X_new) print(pred) # Variable importance and survival plots (requires ggplot2) if (requireNamespace("ggplot2", quietly = TRUE)) { plot(fit, type = "vi", n_vi = 5) # Posterior survival curves for training data plot(fit, type = "survival") # Posterior predictive survival curves for new data plot(pred, type = "survival") plot(pred, type = "survival", obs = c(1, 5)) } # Interval-censored example set.seed(12) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) true_t <- rexp(n, rate = exp(0.5 * X[, 1])) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 10); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit_ic <- SurvivalDART(left_time = left_t, right_time = right_t, X_train = X, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE)
Fits a survival version of a Bayesian Causal Forest (BCF) under an accelerated failure time (AFT) model, combining Dirichlet splitting priors with global-local shrinkage. Supports both right-censored and interval-censored survival outcomes.
SurvivalShrinkageBCF( time = NULL, status = NULL, X_train, treatment, timescale = "time", propensity = NULL, treatment_coding = "centered", a_dir = 0.5, b_dir = 1, number_of_trees_control = 200, number_of_trees_treat = 50, power_control = 2, base_control = 0.95, power_treat = 3, base_treat = 0.25, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE, verbose = TRUE, left_time = NULL, right_time = NULL, ... )SurvivalShrinkageBCF( time = NULL, status = NULL, X_train, treatment, timescale = "time", propensity = NULL, treatment_coding = "centered", a_dir = 0.5, b_dir = 1, number_of_trees_control = 200, number_of_trees_treat = 50, power_control = 2, base_control = 0.95, power_treat = 3, base_treat = 0.25, N_post = 1000, N_burn = 1000, store_posterior_sample = TRUE, verbose = TRUE, left_time = NULL, right_time = NULL, ... )
time |
Outcome vector of (non-negative) survival times. Required for
right-censored outcomes; set to |
status |
Event indicator (1 = event, 0 = censored). Required for right-censored outcomes; derived automatically for interval censoring. |
X_train |
Design matrix for training data. |
treatment |
Treatment indicator (0/1) for training data. |
timescale |
Either |
propensity |
Optional vector of propensity scores. If provided,
it is appended to the control forest design matrix. Required when
|
treatment_coding |
Character string specifying how the treatment
indicator enters the model. One of |
a_dir |
First shape parameter of the Beta prior controlling the sparsity level in the Dirichlet splitting rule. |
b_dir |
Second shape parameter of the Beta prior controlling the sparsity level in the Dirichlet splitting rule. |
number_of_trees_control |
Number of trees in the control forest. Default is 200. |
number_of_trees_treat |
Number of trees in the treatment forest. Default is 50. |
power_control, base_control
|
Tree-structure prior parameters for the control forest. |
power_treat, base_treat
|
Tree-structure prior parameters for the treatment forest. |
N_post |
Number of posterior samples to store. |
N_burn |
Number of burn-in iterations. |
store_posterior_sample |
Logical; if |
verbose |
Logical; print sampling progress. |
left_time |
Optional numeric vector of left (lower) time boundaries
for interval-censored data. Exact events have
|
right_time |
Optional numeric vector of right (upper) time boundaries.
Use |
... |
Additional arguments passed to |
This wrapper extends SurvivalBCF by incorporating
Dirichlet sparsity in both the prognostic (control) and treatment
forests, while applying additional shrinkage to the control forest
via a half-Cauchy prior.
The SurvivalShrinkageBCF model decomposes the outcome as
where represents the prognostic (control) component and
the heterogeneous treatment effect.
In contrast to SurvivalBCF, this function:
Applies a Dirichlet splitting prior to both forests, inducing structural sparsity in variable selection.
Combines Dirichlet sparsity with additional half-Cauchy shrinkage in the control forest.
The Dirichlet prior follows the sparse splitting framework of Linero (2018),
where splitting probabilities are governed by a Beta-Dirichlet hierarchy.
The sparsity level is controlled by a_dir and b_dir.
Survival outcomes are modeled using an AFT formulation with censoring
handled via data augmentation. Both right-censored and interval-censored
data are supported. For interval-censored data, supply left_time
and right_time instead of time and status.
An object of class "CausalShrinkageForest" fitted with
Dirichlet splitting priors and additional shrinkage.
Caron, A., Baio, G., & Manolopoulou, I. (2022). Shrinkage Bayesian Causal Forests for Heterogeneous Treatment Effects Estimation. Journal of Computational and Graphical Statistics, 31(4), 1202–1214. https://doi.org/10.1080/10618600.2022.2067549
Related models: SurvivalBCF (standard BART priors),
CausalShrinkageForest (general shrinkage priors),
CausalHorseForest (horseshoe prior).
S3 methods: print.CausalShrinkageForest,
summary.CausalShrinkageForest,
predict.CausalShrinkageForest,
plot.CausalShrinkageForest.
set.seed(4) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) log_T <- X[, 1] + treatment * (-0.5) + rnorm(n) time <- exp(log_T) status <- rbinom(n, 1, 0.7) fit <- SurvivalShrinkageBCF(time = time, status = status, X_train = X, treatment = treatment, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE) # S3 methods print(fit) smry <- summary(fit) # Posterior ATE with 95% credible interval cat("ATE:", round(smry$treatment_effect$ate, 3), "\n") # Diagnostic and treatment-effect plots (requires ggplot2) if (requireNamespace("ggplot2", quietly = TRUE)) { plot(fit, type = "trace") plot(fit, type = "cate") } # Interval-censored causal example set.seed(14) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) true_t <- exp(X[, 1] + treatment * (-0.5) + rnorm(n)) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 10); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit_ic <- SurvivalShrinkageBCF(left_time = left_t, right_time = right_t, X_train = X, treatment = treatment, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE)set.seed(4) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) log_T <- X[, 1] + treatment * (-0.5) + rnorm(n) time <- exp(log_T) status <- rbinom(n, 1, 0.7) fit <- SurvivalShrinkageBCF(time = time, status = status, X_train = X, treatment = treatment, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE) # S3 methods print(fit) smry <- summary(fit) # Posterior ATE with 95% credible interval cat("ATE:", round(smry$treatment_effect$ate, 3), "\n") # Diagnostic and treatment-effect plots (requires ggplot2) if (requireNamespace("ggplot2", quietly = TRUE)) { plot(fit, type = "trace") plot(fit, type = "cate") } # Interval-censored causal example set.seed(14) n <- 30; p <- 5 X <- matrix(rnorm(n * p), ncol = p) treatment <- rbinom(n, 1, 0.5) true_t <- exp(X[, 1] + treatment * (-0.5) + rnorm(n)) left_t <- true_t * runif(n, 0.5, 1) right_t <- true_t * runif(n, 1, 1.5) exact <- sample(n, 10); left_t[exact] <- true_t[exact]; right_t[exact] <- true_t[exact] rc <- sample(setdiff(seq_len(n), exact), 5); right_t[rc] <- Inf fit_ic <- SurvivalShrinkageBCF(left_time = left_t, right_time = right_t, X_train = X, treatment = treatment, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE)