Giter VIP home page Giter VIP logo

breakdown's People

Contributors

aleksandradabrowska avatar henningsway avatar larmarange avatar pbiecek avatar teunbrand avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

breakdown's Issues

Rounded values

Somehow now the values given by the broken function are always rounded vs exactv (tested lm model):

f.e.
contribution
(Intercept) 2300
tenure = 3 -1900
MonthlyCharges = 74.4 350
SeniorCitizen = 0 14
final_prognosis 720
baseline: 0

while: predict(model, data_in_test[analysed_user,]) gives 721.206

How to interpret / understand the results, shown in WINE plot(br) example.

Hi Biecek,

breakDown R Pkg is great
and easy to use.

I followed your Wine quality example lm() model
in:
https://pbiecek.github.io/breakDown/

But my question is
on how to interpret / understand the results,
shown in plot(br) of the WINE example.

Q1) If the var "residual.sugar" is = 1.2 (highest positive value of all vars),
does it mean, it is the most significant var
contributing?.

Q2) Why do some vars have positive values,
(ie: var "residual.sugar" is = 1.2)
and other vars in the plot show negative values?.

Q3) What is the meaning of the plot bar colors?.
Some vars in the plot(br) are yellow, some are green.

and finally,
Q4) What is the meaning of "final_prognosis" = 5.6? (gray color bar).

Just trying to learn and understand what the results mean in practical terms.
Thanks for your help Biecek!

RAY
San Francisco
using latest Rstudio, R and Ubuntu Linux

Add "top_features" argument

For models with a larger number of variables plot become unreadable, allowing user to pick top variables (by absolute value of contribution) would be a quick fix

Difficulty Understanding Boolean Predictors

I am having difficulty with interpretation in the situation where the outcome is continuous and the predictors are Boolean.

If I were to think of effects in terms of a linear model I would turn coefficients on or off depending on whether the predictors were a 1 or a 0. Breakdown does not seem to do this.

In the example below I have chosen a point that has 0s assigned to the predictors. Again, in a linear model prediction this would simply result in setting the coefficients of these predictors to 0. With breakdown I am seeing a negative effect for versicolor and setosa.

  1. How am I supposed to interpret the output in this situation?

  2. Is there a way to show that since these values are 0 for this observation that they are not contributing to the final prediction?

library(reprex)
library(tidyverse)
library(breakDown)

# data prep
iris_dummy <- iris %>% 
  mutate(setosa = ifelse(Species == "setosa",1,0),
         versicolor = ifelse(Species == "versicolor",1,0)) %>% 
  select(-Species)

# fit model
fit <- lm(Sepal.Length ~Sepal.Width + Petal.Length + setosa + versicolor, data = iris_dummy)

set.seed(42)

# pick an observation
no <- iris_dummy[sample(nrow(iris_dummy), 1), ]

# use broken
br <- broken(fit, no)

# the example `no` is not setosa or versicolor yet has breakdown effects
no
#>     Sepal.Length Sepal.Width Petal.Length Petal.Width setosa versicolor
#> 138          6.4         3.1          5.5         1.8      0          0

plot(br)

Created on 2018-10-19 by the reprex package (v0.2.0).

problem with running the break plots on sparse data

Hi, recently I tried to run the breakdown plots on random forest model (ranger from caret package) trained on sparse data (TFiDF matrix). The model is not doing really good job, but still...

When using DALEX package and after creating explainer without any problems I got for this call:

variable_attribution(rf_explainer,
                               new_observation = x_df_test[ind_to_check, ],
                               type = "break_down")

the following error:

Error in `[.data.frame`(out, , obsLevels, drop = FALSE) : 
  undefined columns selected

Then, I switched to this breakDown package. First of all, after calling it like this:

broken(rf_mod, x_df_test[ind_to_check, ])

It tells me that:

Error in "data.frame" %in% class(data) : 
  argument "data" is missing, with no default

Thus, I changed my call to:

broken(rf_mod, x_df_test[ind_to_check, ], data = x_df_test)

and this time:

Error in yhats[[which.max(yhats_diff)]] : 
  attempt to select less than one element in get1index

The whole code is here: https://github.com/CaRdiffR/tidy_thursdays/blob/master/april_30_2020/predict_gross_clf.R

Strangely, it worked well on exactly same pipeline but with a regression problem.

I use R 4.0.0 and latest version of the DALTEX, breakDown packages.

Might be related to #29 .

Random Forest Inference Predictions -- Intercept is the only contribution

Hi,

I was wondering if breakDown can be used to examine the tree votes for non-probability ranger predictions?

I'm getting weird break downs in that the contributions of the predicted value are entirely based on the intercept -- omitting the actual factors which I assumed are used in the training. I'm asking because the examples only deal with probabilities.

                                                          contribution
(Intercept)                                                      6.154
- run = -1.68749498470884                                        0.000
- kernel_time = 393.481201                                       0.000
- total_time = 445.312100278895                                  0.000
- device = xeon_es-2697v2                                        0.000
- invocation = 0                                                 0.000
- branch_entropy_average_linear = 0.0625153                      0.000
- branch_entropy_yokota = -0.164396323864394                     0.000
- ninety_percent_branch_instructions = -0.457287868935426        0.000
- total_unique_branch_instructions = -0.279132672222695          0.000
- local_memory_address_entropy_10 = 4.85307                      0.000
- local_memory_address_entropy_9 = 5.78479                       0.000
- local_memory_address_entropy_8 = 0.286085532384769             0.000
- local_memory_address_entropy_7 = 0.355682140016823             0.000
- local_memory_address_entropy_6 = 0.404053324038373             0.000
- local_memory_address_entropy_5 = 0.372824371077543             0.000
- local_memory_address_entropy_4 = 0.391075370713963             0.000
- local_memory_address_entropy_3 = 0.408194624910907             0.000
- local_memory_address_entropy_2 = 0.430309441064799             0.000
- local_memory_address_entropy_1 = 0.425171485537762             0.000
- global_memory_address_entropy = 0.442245377596217              0.000
- ninety_percent_memory_footprint = -0.189804453106315           0.000
- total_memory_footprint = -0.246916191110329                    0.000
- stddev_simd_width = 0                                          0.000
- mean_simd_width = 1                                            0.000
- max_simd_width = -0.284600989261938                            0.000
- median_instructions_to_barrier = -0.109768646102735            0.000
- max_instructions_to_barrier = 435                              0.000
- min_instructions_to_barrier = 435                              0.000
- total_barriers_hit = 0                                         0.000
- operand_sum = 111360                                           0.000
- workitems = 256                                                0.000
- total_instruction_count = 111360                               0.000
- instructions_per_operand = 1                                   0.000
- barriers_per_instruction = 8.97989e-06                         0.000
- granularity = -0.483677232475196                               0.000
- opcode = -0.714987553238371                                    0.000
- kernel = invert_mapping                                        0.000
- size = tiny                                                    0.000
- application = kmeans                                           0.000
final_prognosis                                                  6.154
baseline:  0

This was generated with the following:

library(breakDown)
library(ranger)
library(ggplot2)

load("./train_dat.Rdf")
load("./test_dat.Rdf")

#build the model
rgd.aiwc <- ranger(log(kernel_time)~.,
                   data = train_dat,
                   num.trees = 505,
                   mtry = 30,
                   min.node.size = 9,
                   importance = "impurity",
                   splitrule = 'variance',
                   respect.unordered.factors = 'order')

#make predictions over all devices we've used for training
predict.function <- function(model, new_observation) predict(model,data=new_observation,type='response')$predictions[1]

#remove the measured time from the data to break down the random forest model
test_data <- subset(test_dat, select = -kernel_time)

explain <- broken(rgd.aiwc,
                  test_dat[1,], #selected element to understand which factors drive its prediction
                  data = test_dat,
                  predict.function = predict.function,
                  direction = "down"
                  )

print(explain)

p <- plot(explain) + ggtitle("Breakdown of AIWC contributions")

pdf("contributions.pdf")
print(p)
dev.off()

Your suggestions would be greatly appreciated!

submit to CRAN

any new features before the package will be submitted?

Arguments in broken()

In broken() function documentation there isn't any information about the direction of the exploration strategy (that information is only in the broken.default()).
Also in DALEX prediction_breakdown() function we don't have direction argument.

typo in the word cumulative

Hi

There is a misspelling of the word cumulative -- it appears as "cummulative".

That leads to a ggplot with a typo on one of the axis.

Is it possible to fix it in the next version?

breakdown to slow for xgboost

Hi Biecek,

The package is great; but I got a non-sense issue. I have a xgboost binary classification model (objective = "binary:logistic", metric = "auc) with roughly 1000 trees and 200 variables. When I am trying to breakdowm the model for a new observation using "broken", it takes 1.5 hour per each observation!! Is it normal or something is wrong?

Should I specify the prediction function? DMatrix or model.matrix objects could be source of the problem?
In this implementation, can I use DMatrix object without using "model.matrix"? If so how?

Thanks,

Amir

Help: Error in yhats[[which.min(yhats_diff)]]

Hi !

Can I have help to fix the error below?

Error in yhats[[which.min(yhats_diff)]] :
attempt to select less than one element in get1index

I only code this:

contribution <- broken(model = model, new_observation = df[1,features], data = train[,features])

and my model is a randomForest.

It seems that is trying to access to the zero position, and it doesn't exist.

Thanks, Margarida

support for interaction terms

# works
lmModel = lm(len ~ supp+dose, data = ToothGrowth)

# does not work but would be great
lmModel = lm(len ~ supp*dose, data = ToothGrowth)
broken(lmModel, ToothGrowth[1,])

Something off with coefitients?

Hi, I don't understand how the broken function calculates the coefficients? (or something is off?)

In the lm function this is my test result:

summary(model)

Call:
lm(formula = TotalCharges ~ ., data = data_in_test)

Residuals:
Min 1Q Median 3Q Max
-1943.33 -453.71 -94.64 490.26 1887.26

Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -2162.4583 21.9717 -98.420 < 2e-16 ***
MonthlyCharges 36.1234 0.3080 117.301 < 2e-16 ***
tenure 65.3606 0.3683 177.476 < 2e-16 ***
SeniorCitizen -86.7050 24.3449 -3.562 0.000371 ***

Test user:
-2162.4583 + (data_in_test[analysed_user,]$MonthlyCharges * 36.1234) +
data_in_test[analysed_user,]$tenure65.3606 +
data_in_test[analysed_user,]$SeniorCitizen
(-86.7050)

[1] 721.2045

While you get: (u can see that the intercept is different)

lm_br
contribution
(Intercept) 2283.300
tenure = 3 -1923.025
MonthlyCharges = 74.4 346.850
SeniorCitizen = 0 14.081
final_prognosis 721.206
baseline: 0

  • strangely the final prognosis is now the same for both lm and broken but broken does not have the same coefficients as the summary(model) when doing calculations

Obviously one would expect that contributions of a waterfall plot would be simply Y=intercept + beta*value ... etc. from the summary output?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.