interpretable random forests

Random forests are powerful but often opaque. The standard “variable importance” plot — showing mean decrease in accuracy or node impurity per predictor — summarizes the entire forest at once. But what happens as the forest grows?

I wanted to see variable importance evolve tree by tree:

View post on X

This kind of animation reveals something important: how quickly variable importance rankings stabilize. If rankings settle early, the model is interpretable and robust. If they keep shifting, the model may be unreliable for drawing scientific conclusions.

In the example above, one predictor clearly dominates, so the ranking stabilizes fast. But with more evenly matched predictors, the randomness of bagging and feature selection at each split could cause the top-ranked variable to fluctuate — both as a single forest grows and across independently trained forests.

Ranking stability has been studied in bioinformatics 1 and remote sensing 2, but it deserves wider attention anywhere random forests are used for inference rather than pure prediction.

Below is a minimal R example to reproduce an animation like this:

# minimal example
library(randomForest)
library(ggplot2)
library(dplyr)
library(purrr)
library(colormap)
library(tree)
library(plotrix)
library(cowplot)
library(gridGraphics)
library(magick)
library(forcats)
# source repartree package for plotting individual trees generated by RF
# clone repo at https://github.com/richpauloo/reprtree
# and change the file path below to the directory containing the R files
invisible(
lapply(
list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE),
source
)
)
# example data == mtcars
df <- mtcars
# number of trees to grow in random forest
nn <- 500
# function to run nn CART models (single tree)
run_rf <- function(rand_seed){
set.seed(rand_seed)
one_tr = randomForest(mpg ~ .,
data = df,
importance = TRUE,
ntree = 1)
return(one_tr)
}
# list to store output of each model
l <- lapply(1:nn, run_rf)
# number of predictors in RF mod
npred <- length(names(l[[1]]$forest$xlevels))
# extract importance of each CART model,
impdf <- map(l, importance) %>%
map(as.data.frame) %>%
map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>%
bind_rows() %>%
mutate(tree_num = rep(1:nn, each = npred)) # add tree number
# summarised var imp
tot_mse <- group_by(impdf, var) %>%
summarise(`%IncMSE` = mean(`%IncMSE`)) %>%
arrange(-`%IncMSE`)
# ranked variables
rv <- tot_mse$var
impdf$var <- factor(impdf$var, levels = rv)
# vector of trees to plot
# here I plot every 10 trees for speed, but this can be changed
plt_vec <- c(1, seq(10, nn, 10))
# initalize lists for: varimp, trees, plot titles, and combined plots
pl <- tl <- pt <- bp <- vector("list", length = length(plt_vec))
for(i in seq_along(plt_vec)){
# cumulative variable importance with each tree's addition
pl[[i]] <- filter(impdf, tree_num %in% 1:plt_vec[i]) %>%
group_by(var) %>%
summarise(mse = mean(`%IncMSE`)) %>%
ggplot(aes(forcats::fct_rev(var), mse, fill=var)) +
geom_col() +
coord_flip(ylim = c(0, max(tot_mse$`%IncMSE`))) +
scale_fill_viridis_d() +
labs(x = "Variable", y = "Importance (% Inc MSE)", fill = "Variable",
title = paste0("Tree ", plt_vec[i])) +
theme_minimal() +
theme(legend.position = "bottom",
plot.title = element_text(size=25))
# make tree plots
plot.getTree(l[[plt_vec[i]]], k = 1, npred = npred, rv = rv)
tl[[i]] <- recordPlot()
# make plot titles
pt[[i]] <- ggdraw() +
draw_label(
paste0("Tree ", plt_vec[i]),
fontface = 'bold',
x = 0,
hjust = 0
) +
theme(
# add margin on the left of the drawing canvas,
# so title is aligned with left edge of first plot
plot.margin = margin(0, 0, 0, 7)
)
# combine all plots with title
bp[[i]] <- plot_grid(pl[[i]], tl[[i]])
}
# use magick to turn plots into a GIF.
# WARNING: magick doens't handle hundreds of plots well in my experience
# and it may be better to print them into a single PDF, then render the
# GIF elsewhere. Also beware of temporary files that magick creates...
# Also, this animation may be too large to fit in your viewer, so
# be sure to expand it!
# img <- image_graph(1000, 600, res = 96)
# for(i in seq_along(plt_vec)){ print( bp[[i]] ) }
# dev.off()
# animation <- image_animate(img, fps = 2)
# print(animation)
#
# # save to working directory
# image_write(animation, "anim.gif")
# uncomment and run to print to PDF and makethe gif elsewhere,
# like https://ezgif.com/maker
pdf("all.pdf", width = 12, height = 7)
invisible(lapply(bp, print))
dev.off()

A natural extension would be to plot out-of-bag error alongside importance as trees accumulate. I’d expect both to stabilize together.