Giter VIP home page Giter VIP logo

rsafe's Introduction

Surrogate Assisted Feature Extraction in R

CRAN_Status_Badge Build Status Coverage Status

Overview

The rSAFE package is a model agnostic tool for making an interpretable white-box model more accurate using alternative black-box model called surrogate model. Based on the complicated model, such as neural network or random forest, new features are being extracted and then used in the process of fitting a simpler interpretable model, improving its overall performance.

Installation

The package can be installed from GitHub using the code below:

install.packages("devtools")
devtools::install_github("ModelOriented/rSAFE")

Demo

In this vignette we present an example of an application of the rSAFE package in case of regression problems. It is based on apartments and apartmentsTest datasets which come from the DALEX package but are also available in the rSAFE package. We will use these artificial datasets to predict the price per square meter of an apartment based on features such as construction year, surface, floor, number of rooms and district. It should be mentioned that four of these variables are continuous while the fifth one is categorical.

library(rSAFE)
head(apartments)
#>   m2.price construction.year surface floor no.rooms    district
#> 1     5897              1953      25     3        1 Srodmiescie
#> 2     1818              1992     143     9        5     Bielany
#> 3     3643              1937      56     1        2       Praga
#> 4     3517              1995      93     7        3      Ochota
#> 5     3013              1992     144     6        5     Mokotow
#> 6     5795              1926      61     6        2 Srodmiescie

Building a black-box model

First we fit a random forest model to the original apartments dataset

  • this is our complex model that will serve us as a surrogate.
library(randomForest)
set.seed(111)
model_rf1 <- randomForest(m2.price ~ construction.year + surface + floor + no.rooms + district, data = apartments)

Creating an explainer

We also create an explainer object that will be used later to create new variables and at the end to compare models performance.

library(DALEX)
explainer_rf1 <- explain(model_rf1, data = apartmentsTest[1:3000,2:6], y = apartmentsTest[1:3000,1], label = "rf1", verbose = FALSE)
explainer_rf1
#> Model label:  rf1 
#> Model class:  randomForest.formula,randomForest 
#> Data head  :
#>      construction.year surface floor no.rooms    district
#> 1001              1976     131     3        5 Srodmiescie
#> 1002              1978     112     9        4     Mokotow

Creating a safe_extractor

Now, we create a safe_extractor object using rSAFE package and our surrogate model. Setting the argument verbose=FALSE stops progress bar from printing.

safe_extractor <- safe_extraction(explainer_rf1, penalty = 25, verbose = FALSE)

Now, let’s print summary for the new object we have just created.

print(safe_extractor)
#> Variable 'construction.year' - selected intervals:
#>  (-Inf, 1937]
#>      (1937, 1992]
#>      (1992, Inf)
#> Variable 'surface' - selected intervals:
#>  (-Inf, 47]
#>      (47, 101]
#>      (101, Inf)
#> Variable 'floor' - selected intervals:
#>  (-Inf, 5]
#>      (5, Inf)
#> Variable 'no.rooms' - selected intervals:
#>  (-Inf, 3]
#>      (3, Inf)
#> Variable 'district' - created levels:
#>  Bemowo, Bielany, Ursus, Ursynow, Praga, Wola ->  Bemowo_Bielany_Praga_Ursus_Ursynow_Wola 
#>  Zoliborz, Mokotow, Ochota ->  Mokotow_Ochota_Zoliborz 
#>  Srodmiescie ->  Srodmiescie

We can see transformation propositions for all variables in our dataset.

In the plot below we can see which points have been chosen to be the breakpoints for a particular variable:

plot(safe_extractor, variable = "construction.year")

For factor variables we can observe in which order levels have been merged and what is the optimal clustering:

plot(safe_extractor, variable = "district")

Transforming data

Now we can use our safe_extractor object to create new categorical features in the given dataset.

data1 <- safely_transform_data(safe_extractor, apartmentsTest[3001:6000,], verbose = FALSE)
district m2.price construction.year surface floor no.rooms construction.year_new surface_new floor_new no.rooms_new district_new
Bielany 3542 1979 21 6 1 (1937, 1992] (-Inf, 47] (5, Inf) (-Inf, 3] Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
Srodmiescie 5631 1997 107 2 4 (1992, Inf) (101, Inf) (-Inf, 5] (3, Inf) Srodmiescie
Bielany 2989 1994 41 9 2 (1992, Inf) (-Inf, 47] (5, Inf) (-Inf, 3] Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
Ursynow 3822 1968 28 2 2 (1937, 1992] (-Inf, 47] (-Inf, 5] (-Inf, 3] Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
Ursynow 2337 1971 146 3 6 (1937, 1992] (101, Inf) (-Inf, 5] (3, Inf) Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
Ochota 3381 1956 97 8 3 (1937, 1992] (47, 101] (5, Inf) (-Inf, 3] Mokotow_Ochota_Zoliborz

We can also perform feature selection if we wish. For each original feature it keeps exactly one of their forms - original one or transformed one.

vars <- safely_select_variables(safe_extractor, data1, which_y = "m2.price", verbose = FALSE)
data1 <- data1[,c("m2.price", vars)]
print(vars)
#> [1] "surface"               "floor"                 "no.rooms"             
#> [4] "construction.year_new" "district_new"

It can be observed that for some features the original form was preferred and for others the transformed one.

Here are the first few rows for our data after feature selection:

m2.price surface floor no.rooms construction.year_new district_new
3542 21 6 1 (1937, 1992] Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
5631 107 2 4 (1992, Inf) Srodmiescie
2989 41 9 2 (1992, Inf) Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
3822 28 2 2 (1937, 1992] Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
2337 146 3 6 (1937, 1992] Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
3381 97 8 3 (1937, 1992] Mokotow_Ochota_Zoliborz

Now, we perform transformations on another data that will be used later in explainers:

data2 <- safely_transform_data(safe_extractor, apartmentsTest[6001:9000,], verbose = FALSE)[,c("m2.price", vars)]

Creating white-box models on original and transformed datasets

Let’s fit the models to data containing newly created columns. We consider a linear model as a white-box model.

model_lm2 <- lm(m2.price ~ ., data = data1)
explainer_lm2 <- explain(model_lm2, data = data2, y = apartmentsTest[6001:9000,1], label = "lm2", verbose = FALSE)
set.seed(111)
model_rf2 <- randomForest(m2.price ~ ., data = data1)
explainer_rf2 <- explain(model_rf2, data2, apartmentsTest[6001:9000,1], label = "rf2", verbose = FALSE)

Moreover, we create a linear model based on original apartments dataset and its corresponding explainer in order to check if our methodology improves results.

model_lm1 <- lm(m2.price ~ ., data = apartments)
explainer_lm1 <- explain(model_lm1, data = apartmentsTest[1:3000,2:6], y = apartmentsTest[1:3000,1], label = "lm1", verbose = FALSE)

Comparing models performance

Final step is the comparison of all the models we have created.

mp_lm1 <- model_performance(explainer_lm1)
mp_rf1 <- model_performance(explainer_rf1)
mp_lm2 <- model_performance(explainer_lm2)
mp_rf2 <- model_performance(explainer_rf2)
plot(mp_lm1, mp_rf1, mp_lm2, mp_rf2, geom = "boxplot")

In the plot above we can see that the linear model based on transformed features has generally more accurate predictions that the one fitted to the original dataset.

References

The package was created as a part of master’s diploma thesis at Warsaw University of Technology at Faculty of Mathematics and Information Science by Anna Gierlak.

rsafe's People

Contributors

agosiewska avatar annagierlak avatar hbaniecki avatar maksymiuks avatar michbur avatar pbiecek avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rsafe's Issues

Variable roles in tidymodels recipe and workflow... are they respected by rSAFE?

Example (I am playing with bicycle demand data from Kaggle

bike_recipe <- recipe(count ~ . , data = bike_training) %>%
  step_date(datetime, features = c("doy", "dow", "month", "year"), abbr = TRUE) %>%
   update_role("datetime", new_role = "id_variable") %>%
    step_rm("atemp")

will create time features out of the datetime index and then datetime will not take part in modelling.
I also removed "atemp" variable altogether (temp and atemp were strongly correlated). It is not taking part in the modelling either.

Next I run the explainer:

explainer <- explain_tidymodels(bike_final_fit, data = bike_all %>% select(-count), y = bike_all$count)
safe_extractor <- safe_extraction(explainer)

Safe extractor seems to ignore the lack of datetime and atemp in modelling process and proposes:

 Variable 'datetime' - selected intervals:
	(-Inf, 2011-02-16 23:00:00]
 	(2011-02-16 23:00:00, 2011-06-17 23:00:00]
 	(2011-06-17 23:00:00, 2012-04-15 23:00:00]
 	(2012-04-15 23:00:00, 2012-07-08 23:00:00]
 	(2012-07-08 23:00:00, Inf)
Variable 'season' - selected intervals:
	(-Inf, 3]
 	(3, Inf)
Variable 'holiday' - no transformation suggested.
Variable 'workingday' - no transformation suggested.
Variable 'weather' - selected intervals:
	(-Inf, 1]
 	(1, Inf)
Variable 'temp' - selected intervals:
	(-Inf, 12.3]
 	(12.3, 22.96]
 	(22.96, Inf)
Variable 'atemp' - selected intervals:
	(-Inf, 24.24]
 	(24.24, Inf)
Variable 'humidity' - selected intervals:
	(-Inf, 30]
 	(30, 48]
 	(48, 67]
 	(67, 84]
 	(84, Inf)
Variable 'windspeed' - selected intervals:
	(-Inf, 7.0015]
 	(7.0015, Inf)

How to tell rSAFE these two vars (one is time index another has been removed in the bake) are not taking part?
I am attaching my quick and dirty workflow:

timeseries_modelling_xgboost_short.zip
@agosiewska

for multiclass task safely_select_variable() gives error

hi, for my multiclass task safely_select_variables() gives following error:
Error in [.data.frame(data, , var_best) : undefined columns selected

Following is a dummy code (isomorphic to my original problem), would you please check my last 5 lines. I think I have messed them up. Thanks

library(tidyverse)
library(mlr3verse)
library(DALEX)
library(DALEXtra)
library(rSAFE)

df=data.frame(v=c(3.4,5.6,1.3,9.8,7.3, 4.6,5.5,2.3,8.9,7.1, 4.9,6.5,2.3,4.1,3.37, 3.4,6.0,2.3,7.8,3.7),
w=c(34,65,23,78,37, 34,65,23,78,37, 34,65,23,78,37, 34,65,23,78,37),
x=c('a','b','a','c','c', 'a','b','a','c','c', 'a','b','a','c','c', 'a','b','a','c','c'),
y=c(TRUE,FALSE,TRUE,TRUE,FALSE, TRUE,FALSE,TRUE,TRUE,FALSE, TRUE,FALSE,TRUE,TRUE,FALSE, TRUE,FALSE,TRUE,TRUE,FALSE),
z=c('alpha','alpha','delta','delta','phi', 'alpha','alpha','delta','delta','phi', 'alpha','alpha','delta','delta','phi', 'alpha','alpha','delta','delta','phi')
)

df_task <- TaskClassif$new(id = "my_df", backend = df, target = "z")
lrn_rf <- GraphLearner$new(po('encode') %>>% lrn("classif.ranger", predict_type = "prob"))
lrn_rf$train(df_task)

lrn_rf_exp <- explain_mlr3(lrn_rf,
data = df,
y = df$z,
label = "rf_exp")
safe_extractor <- safe_extraction(lrn_rf_exp, penalty = 25, verbose = FALSE)
sf_trafo_data <- safely_transform_data(safe_extractor, df, verbose = FALSE)
vars <- safely_select_variables(safe_extractor, sf_trafo_data, which_y = "z", class_pred = 'alpha', verbose = FALSE)

data2 <- safely_transform_data(safe_extractor, df, verbose = FALSE)[,c("z", vars)]

model_lm2 <- lm(z ~ ., data = data2)

missing trans function

I've got an error after following lines:

library(rSAFE)
library(randomForest)
library(DALEX)
set.seed(111)
model_rf1 <- randomForest(survived ~ ., data = titanic_imputed)
explainer_rf1 <- explain(model_rf1, data = titanic_imputed, y = titanic_imputed$survived == "yes", label = "rf1")
safe_extractor <- safe_extraction(explainer_rf1, penalty = 25, verbose = TRUE)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.