Giter VIP home page Giter VIP logo

softbart's Introduction

Soft Bayesian Sum of Trees Models

The SoftBart package

This package implements the methodology described in the paper

  • Linero, A.R. and Yang, Y. (2017). Bayesian tree ensembles that adapt to smoothness and sparsity. (In preapration)

Installation

The package can be installed with the devtools package:

library(devtools)
install_github("theodds/SoftBART")

Note that if you are on OSX, you may need to first run the following commands from the terminal:

curl -O http://r.research.att.com/libs/gfortran-4.8.2-darwin13.tar.bz2
sudo tar fvxz gfortran-4.8.2-darwin13.tar.bz2 -C /

Usage

The package is designed to mirror the functionality of the BayesTree package. The function softbart is the primary function, and is used in essentially the same manner as the bart function in BayesTree.

The following is a minimal example on "Friedman's example":

## Load library
library(SoftBart)

## Functions used to generate fake data
set.seed(1234)
f_fried <- function(x) 10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 + 
                      10 * x[,4] + 5 * x[,5]
    
gen_data <- function(n_train, n_test, P, sigma) {
    X <- matrix(runif(n_train * P), nrow = n_train)
    mu <- f_fried(X)
    X_test <- matrix(runif(n_test * P), nrow = n_test)
    mu_test <- f_fried(X_test)
    Y <- mu + sigma * rnorm(n_train)
    Y_test <- mu_test + sigma * rnorm(n_test)
        
    return(list(X = X, Y = Y, mu = mu, X_test = X_test, Y_test = Y_test, mu_test = mu_test))
}

## Simiulate dataset
sim_data <- gen_data(250, 100, 50, 1)
    
## Fit the model
fit <- softbart(X = sim_data$X, Y = sim_data$Y, X_test = sim_data$X_test, 
                hypers = Hypers(sim_data$X, sim_data$Y, num_tree = 50, temperature = 1),
                opts = Opts(num_burn = 5000, num_save = 5000, update_tau = TRUE))
    
## Plot the fit (note: interval estimates are not prediction intervals, 
## so they do not cover the predictions at the nominal rate)
plot(fit)

## Look at posterior model inclusion probabilities for each predictor. 

posterior_probs <- function(fit) colMeans(fit$var_counts > 0)
plot(posterior_probs(fit), 
     col = ifelse(posterior_probs(fit) > 0.5, muted("blue"), muted("green")), 
     pch = 20)

rmse <- function(x,y) sqrt(mean((x-y)^2))

rmse(fit$y_hat_test_mean, sim_data$mu_test)
rmse(fit$y_hat_train_mean, sim_data$mu)

Accessing the model from R

In more complex settings, one may wish to incorporate the SoftBART model as a component within a larger model. In this case, it is possible to construct a SoftBART object within R and do a single Gibbs sampling update.

hypers <- Hypers(sim_data$X, sim_data$Y)
opts <- Opts()

forest <- MakeForest(hypers, opts)
mu_hat <- forest$do_gibbs(sim_data$X, sim_data$Y, sim_data$X_test, opts$num_burn)
mu_hat <- forest$do_gibbs(sim_data$X, sim_data$Y, sim_data$X_test, opts$num_save)

The do_gibbs function takes as input the data used to do the update, an additional set of points at which to predict, and the number of iterations to run the sampler. By default, the probability vector s will not be updated until at least num_burn / 2 iterations have been run. This can be checked by calling forest$num_gibbs. The s vector itself can be obtained by calling forest$get_s(), and in the future we will add features to deal with other components as well. The code above first burns in and then samples from the posterior, and is essentially equivalent to using softbart. WARNING: if you are going to do this, you need to preprocess X and X_test by hand, so that all values lie in [0,1]. The softbart function, in addition to doing the sampling, also preprocesses X and X_test by applying a quantile transformation.

softbart's People

Contributors

theodds avatar

Watchers

Xuhui Fan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.