benjilu / foresterror Goto Github PK
View Code? Open in Web Editor NEWA Unified Framework for Random Forest Prediction Error Estimation
Home Page: https://www.jmlr.org/papers/v22/18-558.html
A Unified Framework for Random Forest Prediction Error Estimation
Home Page: https://www.jmlr.org/papers/v22/18-558.html
I am trying to run this with a fairly large dataset, but my computer (64GB of RAM) always runs out of memory. Is it possible to reduce the memory consumption of quantForestError somehow?
This is one of the most exciting RF papers I've read in a while. I've homebrewed a lot of conditional bias estimators in the past for industry applications, but nothing with a real theoretical foundation, and in early trials this is radically improving on my prior work. Thank you!
A practical issue I'm running into: this method is quite memory-intensive (by necessity) when run for large X.test
samples and/or when forest
has many unique terminal nodes (e.g. from a high number of X.train
samples or many trees or both). Calls to quantForestError
that are too large result in Error: vector memory exhausted (limit reached?)
.
I'm concluding from your paper and some experiments with the software that quantForestError
results are deterministic once the input forest
is trained, such that it is safe to split large dataframes of X.test
into pieces to prevent overloading memory. It also seems to run in linear time with number of samples in X.test
. This should allow for convenient parallel execution of smaller calls by batch on X.test
. However, ranger::predict by default uses all cores available (and I believe randomForestSRC
does the same), which leads to problems in parallel execution of the quantForestError
.
A simple solution would be to enable a user input for nCores to tune this setting manually, defaulting to the underlying model class's default behavior. I wanted to raise the issue more generally, though, because the vector memory exhausted
errors (or the candidate solution to parcel out the X.test
workload) aren't necessarily obvious and might merit some documentation. I'd be happy to work on a vignette or something if useful.
The actual limits will depend on hardware (I'm running this on an iMac with 64GB RAM) but below is a code example. On my machine it doesn't fail until something between 1k and 50k test rows; I didn't get an exact breakdown point. It should be possible to find a heuristic for when X.test needs to be broken up given knowledge of available RAM, but that's hard in the wild.
library(tidyverse)
library(ranger)
library(forestError)
library(tictoc)
train_rows <- 100000
test_rows <- 50000
set.seed(1234)
fake_train <- dplyr::tibble(
a = rpois(train_rows, 500),
b = a + rnorm(train_rows, 50, 50),
c = b + rnorm(train_rows, 50, 50),
d = c + rnorm(train_rows, 50, 50),
e = d + rnorm(train_rows, 50, 50),
f = e + rnorm(train_rows, 50, 50),
g = f + rnorm(train_rows, 50, 50),
h = g + rnorm(train_rows, 50, 50),
i = h + rnorm(train_rows, 50, 50)
)
fake_test <- dplyr::tibble(
a = rpois(test_rows, 500),
b = a + rnorm(test_rows, 50, 50),
c = b + rnorm(test_rows, 50, 50),
d = c + rnorm(test_rows, 50, 50),
e = d + rnorm(test_rows, 50, 50),
f = e + rnorm(test_rows, 50, 50),
g = f + rnorm(test_rows, 50, 50),
h = g + rnorm(test_rows, 50, 50),
i = h + rnorm(test_rows, 50, 50)
)
cor(fake_train)
cor(fake_test)
rf1 <- ranger::ranger(a ~ .,
data = fake_train,
num.trees = 100,
seed = 1234,
keep.inbag = TRUE)
tic()
condbias_100rows <- forestError::quantForestError(
forest = rf1,
X.train = dplyr::select(fake_train, -a),
Y.train = fake_train$a,
X.test = dplyr::select(fake_test, -a)[1:100, ]
)
toc()
tic()
condbias_1000rows <- forestError::quantForestError(
forest = rf1,
X.train = dplyr::select(fake_train, -a),
Y.train = fake_train$a,
X.test = dplyr::select(fake_test, -a)[1:1000, ]
)
toc()
tic()
condbias_50000rows <- forestError::quantForestError(
forest = rf1,
X.train = dplyr::select(fake_train, -a),
Y.train = fake_train$a,
X.test = dplyr::select(fake_test, -a)
)
toc()
# Confirming that running chunks of X.test together or in pieces
# leads to the same result
all.equal(condbias_100rows$estimates,
condbias_1000rows$estimates[1:100, ])
all.equal(condbias_100rows$qerror,
condbias_1000rows$qerror)
all.equal(condbias_100rows$perror,
condbias_1000rows$perror)
I ran across this very nice paper recently: https://besjournals.onlinelibrary.wiley.com/doi/full/10.1111/2041-210X.13650
It proposes a very intuitive metric for measuring distance from a test observation's X values to the predictor space of the training set and calls it a Dissimilarity Index.
This specific application is for spatial models with a corresponding block-bootstrap design, but the same calculation could be done with in-bag/out-of-bag splits using the data ingested by forestError.
Thought you might find this interesting -- I would consider it very germane to the same tasks where I'm using forestError and might be a worthwhile addition to the exported data.
@benjilu hope all's well since we last caught up on the package -- still using it all the time.
Under the current design, forestError calls scale reasonably well out to 50k or 100k observations in the test set for forests with moderate (100-200) tree counts. Beyond that point, either in test set observations or tree count, the multiplication in row counts that occurs in the edgelist join of train_node to test_node starts to break memory limits.
It's still possible to take a large test set and iterate, doing e.g.
huge_test_df %>%
split(.$batch_id) %>%
purrr::map(
.f = ~quantForestError(
forest = forest,
X.train = trainset[c("x1", "x2", "x3")],
X.test = .x[c("x1", "x2", "x3")],
Y.train = trainset$y
)
... but that recomputes everything to do with the training set in every iteration over batch ID.
So, here's a proposal for a moderately large refactoring of the quantForestError function into independently reusable components. The main objective would be to separate the two costly parts of the computation: (1) turn a forest into the trainset tree/node OOB error data structure we use internally, and return it; (2) take in the training error data and return the computed test set statistics.
This can easily be wrapped in a single function (like now) so existing code doesn't break.
quantForestError
would gain two optional parameters (e.g. use_train_nodes = NULL
and return_train_nodes = FALSE
)return_train_nodes
is TRUE, return the final computed form of long_train_nodes
use_train_nodes
is not null, must be a data object long_train_nodes
returned from a prior computation, so we can skip the steps required to create it on a second pass.Internally, we'd want to change a few things:
long_train_nodes
should happen in a (probably also exported) function called within quantForestError;estimates
dataframe for a specific test set should get broken outA few other benefits of doing this, besides scalability, would be
long_train_nodes
separately allows for experimentation on other ways of summarizing the forest errors, beyond just bias and quantile statistics; here I'm thinking of e.g. clustering and second-stage bias correction models beyond a node-wise mean. Those wouldn't fit naturally inside a single function but could be other bolt-ons later.I'm pretty happy to do the work and send a PR over the next couple of weeks, but this is a very large set of changes and wanted to check first. Happy to fork/modify on my own outside of the main/CRAN version if you prefer that too.
Thanks for the package! Do you think it could be extended to multiclass probability predictions?
We are predicting probabilities for each class and want to have prediction intervals for these probabilities. Example:
train_idx <- sample(nrow(iris), 2/3 * nrow(iris))
dat_train <- iris[train_idx, ]
dat_test <- iris[-train_idx, ]
rf <- ranger(Species ~ ., dat_train, probability = TRUE)
predict(rf, dat_test)$predictions
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.