interpretable random forests

October 30, 2019

The “variable importance” plot that shows the mean decrease in accuracy or node impurity per predictor is a classic metric used to interpret random forest models.

This metric aggregates mean decrease in accuracy across all trees in the forest. I wanted to watch a forest “grow” tree by tree, alongside the cumulative variable importance:

Animations like the one above may help visualize an important behavior of machine learning models: the stability of random forest variable importance rankings. Variable importance ranking stability is critical to scientific pursuits, which require interpretable models. Unstable models are not as interpretable as stable ones. To the best of my knowledge, variable importance ranking stability has been studied in the context of remote sensing 1 and bioinformatics 2.

In the example random forest model above, one predictor was much more predictive than others, and the variable importance ranking was relatively stable. However with different data, it’s imaginable that due to the random way in which training data and predictors are used at each node to determine splits, a few predictors might switch top rank as the most important variable. This may happen within a single forest as it grows, or between different fully-grown forests.

Below is a minimal example in R to reproduce an animation like the one above. In a next iteration, it would edifying to add the MSE/ntree plot. I expect that as the number of trees increases and the out of bag error stabilizes, the variable importance ranking also stabilizes.

# minimal example
library(randomForest)
library(ggplot2)
library(dplyr)
library(purrr)
library(colormap)
library(tree)
library(plotrix)
library(cowplot)
library(gridGraphics)
library(magick)
library(forcats)
# source repartree package for plotting individual trees generated by RF
# clone repo at https://github.com/richpauloo/reprtree
# and change the file path below to the directory containing the R files
invisible(
lapply(
list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE),
source
)
)
# example data == mtcars
df <- mtcars
# number of trees to grow in random forest
nn <- 500
# function to run nn CART models (single tree)
run_rf <- function(rand_seed){
set.seed(rand_seed)
one_tr = randomForest(mpg ~ .,
data = df,
importance = TRUE,
ntree = 1)
return(one_tr)
}
# list to store output of each model
l <- lapply(1:nn, run_rf)
# number of predictors in RF mod
npred <- length(names(l[[1]]$forest$xlevels))
# extract importance of each CART model,
impdf <- map(l, importance) %>%
map(as.data.frame) %>%
map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>%
bind_rows() %>%
mutate(tree_num = rep(1:nn, each = npred)) # add tree number
# summarised var imp
tot_mse <- group_by(impdf, var) %>%
summarise(`%IncMSE` = mean(`%IncMSE`)) %>%
arrange(-`%IncMSE`)
# ranked variables
rv <- tot_mse$var
impdf$var <- factor(impdf$var, levels = rv)
# vector of trees to plot
# here I plot every 10 trees for speed, but this can be changed
plt_vec <- c(1, seq(10, nn, 10))
# initalize lists for: varimp, trees, plot titles, and combined plots
pl <- tl <- pt <- bp <- vector("list", length = length(plt_vec))
for(i in seq_along(plt_vec)){
# cumulative variable importance with each tree's addition
pl[[i]] <- filter(impdf, tree_num %in% 1:plt_vec[i]) %>%
group_by(var) %>%
summarise(mse = mean(`%IncMSE`)) %>%
ggplot(aes(forcats::fct_rev(var), mse, fill=var)) +
geom_col() +
coord_flip(ylim = c(0, max(tot_mse$`%IncMSE`))) +
scale_fill_viridis_d() +
labs(x = "Variable", y = "Importance (% Inc MSE)", fill = "Variable",
title = paste0("Tree ", plt_vec[i])) +
theme_minimal() +
theme(legend.position = "bottom",
plot.title = element_text(size=25))
# make tree plots
plot.getTree(l[[plt_vec[i]]], k = 1, npred = npred, rv = rv)
tl[[i]] <- recordPlot()
# make plot titles
pt[[i]] <- ggdraw() +
draw_label(
paste0("Tree ", plt_vec[i]),
fontface = 'bold',
x = 0,
hjust = 0
) +
theme(
# add margin on the left of the drawing canvas,
# so title is aligned with left edge of first plot
plot.margin = margin(0, 0, 0, 7)
)
# combine all plots with title
bp[[i]] <- plot_grid(pl[[i]], tl[[i]])
}
# use magick to turn plots into a GIF.
# WARNING: magick doens't handle hundreds of plots well in my experience
# and it may be better to print them into a single PDF, then render the
# GIF elsewhere. Also beware of temporary files that magick creates...
# Also, this animation may be too large to fit in your viewer, so
# be sure to expand it!
# img <- image_graph(1000, 600, res = 96)
# for(i in seq_along(plt_vec)){ print( bp[[i]] ) }
# dev.off()
# animation <- image_animate(img, fps = 2)
# print(animation)
#
# # save to working directory
# image_write(animation, "anim.gif")
# uncomment and run to print to PDF and makethe gif elsewhere,
# like https://ezgif.com/maker
pdf("all.pdf", width = 12, height = 7)
invisible(lapply(bp, print))
dev.off()

  1. Calle, M. Luz, and Víctor Urrea. “Letter to the editor: stability of random forest importance measures.” Briefings in bioinformatics 12.1 (2010): 86-89. ↩︎

  2. Behnamian, Amir, et al. “A systematic approach for variable selection with random forests: achieving stable variable importance values.” IEEE Geoscience and Remote Sensing Letters 14.11 (2017): 1988-1992. ↩︎

Posted on:
October 30, 2019
Length:
2 minute read, 345 words
Tags:
Data science
See Also:
MTAccessibility
r4wrds
gsp dry wells .com