interpretable random forests
October 30, 2019
The “variable importance” plot that shows the mean decrease in accuracy or node impurity per predictor is a classic metric used to interpret random forest models.
This metric aggregates mean decrease in accuracy across all trees in the forest. I wanted to watch a forest “grow” tree by tree, alongside the cumulative variable importance:
What does an interpretable random forest (RF) 🌲🌳 #datavis look like? Out-of-the-box 📦RF in #rstats and #python3 computes variable importance over *all* trees, but how do we get there? Here's a RF of 300 trees, tree-by-tree, showing cumulative variable importance. #DataScience pic.twitter.com/ODyrYRfUya
— Rich Pauloo (@RichPauloo) May 4, 2019
Animations like the one above may help visualize an important behavior of machine learning models: the stability of random forest variable importance rankings. Variable importance ranking stability is critical to scientific pursuits, which require interpretable models. Unstable models are not as interpretable as stable ones. To the best of my knowledge, variable importance ranking stability has been studied in the context of remote sensing 1 and bioinformatics 2.
In the example random forest model above, one predictor was much more predictive than others, and the variable importance ranking was relatively stable. However with different data, it’s imaginable that due to the random way in which training data and predictors are used at each node to determine splits, a few predictors might switch top rank as the most important variable. This may happen within a single forest as it grows, or between different fully-grown forests.
Below is a minimal example in R
to reproduce an animation like the one above. In a next iteration, it would edifying to add the MSE/ntree plot. I expect that as the number of trees increases and the out of bag error stabilizes, the variable importance ranking also stabilizes.
# minimal example | |
library(randomForest) | |
library(ggplot2) | |
library(dplyr) | |
library(purrr) | |
library(colormap) | |
library(tree) | |
library(plotrix) | |
library(cowplot) | |
library(gridGraphics) | |
library(magick) | |
library(forcats) | |
# source repartree package for plotting individual trees generated by RF | |
# clone repo at https://github.com/richpauloo/reprtree | |
# and change the file path below to the directory containing the R files | |
invisible( | |
lapply( | |
list.files("/Users/richpauloo/GitHub/reprtree/R", full.names = TRUE), | |
source | |
) | |
) | |
# example data == mtcars | |
df <- mtcars | |
# number of trees to grow in random forest | |
nn <- 500 | |
# function to run nn CART models (single tree) | |
run_rf <- function(rand_seed){ | |
set.seed(rand_seed) | |
one_tr = randomForest(mpg ~ ., | |
data = df, | |
importance = TRUE, | |
ntree = 1) | |
return(one_tr) | |
} | |
# list to store output of each model | |
l <- lapply(1:nn, run_rf) | |
# number of predictors in RF mod | |
npred <- length(names(l[[1]]$forest$xlevels)) | |
# extract importance of each CART model, | |
impdf <- map(l, importance) %>% | |
map(as.data.frame) %>% | |
map( ~ { .$var = rownames(.); rownames(.) <- NULL; return(.) } ) %>% | |
bind_rows() %>% | |
mutate(tree_num = rep(1:nn, each = npred)) # add tree number | |
# summarised var imp | |
tot_mse <- group_by(impdf, var) %>% | |
summarise(`%IncMSE` = mean(`%IncMSE`)) %>% | |
arrange(-`%IncMSE`) | |
# ranked variables | |
rv <- tot_mse$var | |
impdf$var <- factor(impdf$var, levels = rv) | |
# vector of trees to plot | |
# here I plot every 10 trees for speed, but this can be changed | |
plt_vec <- c(1, seq(10, nn, 10)) | |
# initalize lists for: varimp, trees, plot titles, and combined plots | |
pl <- tl <- pt <- bp <- vector("list", length = length(plt_vec)) | |
for(i in seq_along(plt_vec)){ | |
# cumulative variable importance with each tree's addition | |
pl[[i]] <- filter(impdf, tree_num %in% 1:plt_vec[i]) %>% | |
group_by(var) %>% | |
summarise(mse = mean(`%IncMSE`)) %>% | |
ggplot(aes(forcats::fct_rev(var), mse, fill=var)) + | |
geom_col() + | |
coord_flip(ylim = c(0, max(tot_mse$`%IncMSE`))) + | |
scale_fill_viridis_d() + | |
labs(x = "Variable", y = "Importance (% Inc MSE)", fill = "Variable", | |
title = paste0("Tree ", plt_vec[i])) + | |
theme_minimal() + | |
theme(legend.position = "bottom", | |
plot.title = element_text(size=25)) | |
# make tree plots | |
plot.getTree(l[[plt_vec[i]]], k = 1, npred = npred, rv = rv) | |
tl[[i]] <- recordPlot() | |
# make plot titles | |
pt[[i]] <- ggdraw() + | |
draw_label( | |
paste0("Tree ", plt_vec[i]), | |
fontface = 'bold', | |
x = 0, | |
hjust = 0 | |
) + | |
theme( | |
# add margin on the left of the drawing canvas, | |
# so title is aligned with left edge of first plot | |
plot.margin = margin(0, 0, 0, 7) | |
) | |
# combine all plots with title | |
bp[[i]] <- plot_grid(pl[[i]], tl[[i]]) | |
} | |
# use magick to turn plots into a GIF. | |
# WARNING: magick doens't handle hundreds of plots well in my experience | |
# and it may be better to print them into a single PDF, then render the | |
# GIF elsewhere. Also beware of temporary files that magick creates... | |
# Also, this animation may be too large to fit in your viewer, so | |
# be sure to expand it! | |
# img <- image_graph(1000, 600, res = 96) | |
# for(i in seq_along(plt_vec)){ print( bp[[i]] ) } | |
# dev.off() | |
# animation <- image_animate(img, fps = 2) | |
# print(animation) | |
# | |
# # save to working directory | |
# image_write(animation, "anim.gif") | |
# uncomment and run to print to PDF and makethe gif elsewhere, | |
# like https://ezgif.com/maker | |
pdf("all.pdf", width = 12, height = 7) | |
invisible(lapply(bp, print)) | |
dev.off() |
-
Calle, M. Luz, and Víctor Urrea. “Letter to the editor: stability of random forest importance measures.” Briefings in bioinformatics 12.1 (2010): 86-89. ↩︎
-
Behnamian, Amir, et al. “A systematic approach for variable selection with random forests: achieving stable variable importance values.” IEEE Geoscience and Remote Sensing Letters 14.11 (2017): 1988-1992. ↩︎
- Posted on:
- October 30, 2019
- Length:
- 2 minute read, 345 words
- Tags:
- Data science
- See Also:
- MTAccessibility
- r4wrds
- gsp dry wells .com