Giter VIP home page Giter VIP logo

Comments (10)

hbaniecki avatar hbaniecki commented on May 13, 2024

Hi @agilebean! Can you provide the model please? It might be the problem with predict_function.
Your code works for me DALEX v0.9.4 && iBreakDown v0.9.9 :

library(dplyr)
random.case <- structure(list(anger = 0.166666666666667, anticipation = 0, disgust = 0.166666666666667, 
                              fear = 0.166666666666667, joy = 0, negative = 0.25, positive = 0.0833333333333333, 
                              sadness = 0.0833333333333333, surprise = 0.0833333333333333, 
                              trust = 0), class = "data.frame", row.names = c(NA, -1L))

training.set <- structure(list(.outcome = structure(c(3L, 4L, 5L, 4L, 4L, 5L, 
                                                      5L, 4L, 3L, 3L, 3L, 5L, 4L, 3L, 3L, 1L, 4L, 3L, 4L, 5L, 3L, 2L, 
                                                      5L, 5L, 5L), .Label = c("1", "2", "3", "4", "5"), class = "factor"), 
                               anger = c(0, 0.0434782608695652, 0, 0, 0, 0.1, 0, 0.037037037037037, 
                                         0.0192307692307692, 0, 0, 0, 0, 0.0673076923076923, 0.181818181818182, 
                                         0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
                                         0.0952380952380952, 0, 0.0441176470588235, 0), anticipation = c(0.333333333333333, 
                                                                                                         0.217391304347826, 0.125, 0.15, 0.2, 0.2, 0.217391304347826, 
                                                                                                         0.111111111111111, 0.173076923076923, 0.166666666666667, 
                                                                                                         0.111111111111111, 0.157894736842105, 0.214285714285714, 
                                                                                                         0.115384615384615, 0.0909090909090909, 0.0408163265306122, 
                                                                                                         0, 0.166666666666667, 0, 0.114285714285714, 0.184210526315789, 
                                                                                                         0.0476190476190476, 0.133333333333333, 0.102941176470588, 
                                                                                                         0.176470588235294), disgust = c(0, 0, 0, 0, 0, 0, 0, 0.0185185185185185, 
                                                                                                                                         0.0192307692307692, 0.0833333333333333, 0.0740740740740741, 
                                                                                                                                         0, 0, 0.0288461538461538, 0, 0.0204081632653061, 0, 0, 0.111111111111111, 
                                                                                                                                         0, 0, 0.0952380952380952, 0, 0.0294117647058824, 0), fear = c(0, 
                                                                                                                                                                                                       0.0434782608695652, 0, 0.05, 0, 0, 0, 0.0185185185185185, 
                                                                                                                                                                                                       0, 0, 0, 0, 0, 0.0673076923076923, 0, 0.0408163265306122, 
                                                                                                                                                                                                       0, 0.0833333333333333, 0.111111111111111, 0, 0.0263157894736842, 
                                                                                                                                                                                                       0.0952380952380952, 0, 0.0294117647058824, 0), joy = c(0, 
                                                                                                                                                                                                                                                              0.130434782608696, 0.166666666666667, 0.15, 0.233333333333333, 
                                                                                                                                                                                                                                                              0.2, 0.173913043478261, 0.166666666666667, 0.0961538461538462, 
                                                                                                                                                                                                                                                              0.166666666666667, 0.037037037037037, 0.210526315789474, 
                                                                                                                                                                                                                                                              0.214285714285714, 0.0961538461538462, 0.181818181818182, 
                                                                                                                                                                                                                                                              0.0204081632653061, 0.333333333333333, 0.0833333333333333, 
                                                                                                                                                                                                                                                       0.222222222222222, 0.2, 0.105263157894737, 0.0952380952380952, 
                                                                                                                                                                                                                                                              0.2, 0.147058823529412, 0.176470588235294), negative = c(0, 
                                                                                                                                                                                                                                                                                                                       0.0869565217391304, 0.0833333333333333, 0.1, 0, 0, 0, 0.0555555555555556, 
                                                                                                                                                                                                                                                                                                                       0.0769230769230769, 0.166666666666667, 0.0740740740740741, 
                                                                                                                                                                                                                                                                                                                       0.0526315789473684, 0.0714285714285714, 0.105769230769231, 
                                                                                                                                                                                                                                                                                                                       0.181818181818182, 0.204081632653061, 0, 0.166666666666667, 
                                                                                                                                                                                                                                                                                                                       0.222222222222222, 0.0285714285714286, 0.105263157894737, 
                                                                                                                                                                                                                                                                                                                       0.19047619047619, 0, 0.102941176470588, 0.0294117647058824
                                                                                                                                                                                                                                                              ), positive = c(0.333333333333333, 0.217391304347826, 0.291666666666667, 
                                                                                                                                                                                                                                                                              0.4, 0.3, 0.3, 0.347826086956522, 0.333333333333333, 0.326923076923077, 
                                                                                                                                                                                                                                                                              0.25, 0.259259259259259, 0.315789473684211, 0.285714285714286, 
                                                                                                                                                                                                                                                                              0.240384615384615, 0.181818181818182, 0.244897959183673, 
                                                                                                                                                                                                                                                                              0.333333333333333, 0.25, 0.222222222222222, 0.4, 0.342105263157895, 
                                                                                                                                                                                                                                                                              0.238095238095238, 0.4, 0.235294117647059, 0.352941176470588
                                                                                                                                                                                                                                                              ), sadness = c(0.333333333333333, 0.0434782608695652, 0.0416666666666667, 
                                                                                                                                                                                                                                                                             0, 0, 0, 0, 0.0185185185185185, 0.0576923076923077, 0, 0.0740740740740741, 
                                                                                                                                                                                                                                                                             0, 0, 0.0480769230769231, 0.0909090909090909, 0.142857142857143, 
                                                                                                                                                                                                                                                                             0, 0, 0.111111111111111, 0, 0.0526315789473684, 0.0952380952380952, 
                                                                                                                                                                                                                                                                             0, 0.0441176470588235, 0.0294117647058824), surprise = c(0, 
                                                                                                                                                                                                                                                                                                                                      0.0434782608695652, 0.0833333333333333, 0.05, 0.0666666666666667, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0434782608695652, 0.037037037037037, 0.0192307692307692, 
                                                                                                                                                                                                                                                                                                                                      0, 0.111111111111111, 0.0526315789473684, 0, 0.0865384615384615, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0666666666666667, 0.0735294117647059, 0.0294117647058824
                                                                                                                                                                                                                                                                             ), trust = c(0, 0.173913043478261, 0.208333333333333, 0.1, 
                                                                                                                                                                                                                                                                                          0.2, 0.2, 0.217391304347826, 0.203703703703704, 0.211538461538462, 
                                                                                                                                                                                                                                                                                          0.166666666666667, 0.259259259259259, 0.210526315789474, 
                                                                                                                                                                                                                                                                                          0.214285714285714, 0.144230769230769, 0.0909090909090909, 
                                                                                                                                                                                                                                                                                          0.204081632653061, 0.333333333333333, 0.25, 0, 0.2, 0.0789473684210526, 
                                                                                                                                                                                                                                                                                          0.0476190476190476, 0.2, 0.191176470588235, 0.205882352941176
                                                                                                                                                                                                                                                                             )), row.names = c(NA, 25L), class = "data.frame")
target <- training.set$.outcome
features <- training.set %>% select(-.outcome)

TARGET.VALUE <- "1"

colnames(training.set)[1] <- "outcome"
model_object <- lm(outcome==TARGET.VALUE~., data = training.set)

DALEX.explainer <- DALEX::explain(
  model = model_object,
  data = features,
  y = training.set$outcome == TARGET.VALUE,
  label = paste(model_object$method, " model"),
  colorize = TRUE
)

DALEX.attribution <- DALEX.explainer %>% iBreakDown::local_attributions(random.case) 
DALEX.attribution

from ibreakdown.

agilebean avatar agilebean commented on May 13, 2024

Ha, crossing thoughts - I just included the model in the description!

from ibreakdown.

agilebean avatar agilebean commented on May 13, 2024

I just verified I had iBreakDown_0.9.9 and only 1 subrelease number below for DALEX, i.e. DALEX_0.9.3

from ibreakdown.

agilebean avatar agilebean commented on May 13, 2024

I just ran it again with the same error. However, when I run the same analysis - but with the model trained as regression instead of classification, it WORKS! Double checked just now.

from ibreakdown.

agilebean avatar agilebean commented on May 13, 2024

@hbaniecki Did you try it with the model I specified in the description above?

from ibreakdown.

hbaniecki avatar hbaniecki commented on May 13, 2024

Yes, it is a weird problem with data.frame/matrix behavior. I believe that this part

cummulative <- do.call(rbind, c(list(baseline_yhat), yhats_mean, list(target_yhat)))
contribution <- rbind(0,apply(cummulative, 2, diff))
contribution[1,] <- cummulative[1,]
contribution[nrow(contribution),] <- cummulative[nrow(contribution),]
can be handled better (to fix).

While running your example, there is a red warning (in the explainer output) saying that predict_function returns probabilities for multiple classes. For now, if you want to use local_attributions for one class (e.g. target = "1"), you can use a custom predict_function and pass it to the explainer.

custom_predict_caret_oneclass <- function(model, data, target = "1") {
  return(predict(model, data, type = "prob")[, target])
}

DALEX.explainer <- DALEX::explain(
  model = model.rf,
  data = features,
  y = target == TARGET.VALUE,
  predict_function = custom_predict_caret_oneclass,
  label = paste(model.rf$method, " model"),
  colorize = TRUE
)

DALEX.attribution <- DALEX.explainer %>%
  iBreakDown::local_attributions(random.case)

DALEX.attribution

from ibreakdown.

agilebean avatar agilebean commented on May 13, 2024

Great analysis.
Thanks for the oneclass predict_function code!
But that's a bummer, I need it for a publication.
Speaking of which, this issue on numbers on plots is extremely important for publications.

from ibreakdown.

pbiecek avatar pbiecek commented on May 13, 2024

Thanks, there was a problem in the predict returns data.frame instead of matrix.
It is solved in the latest DALEX in the ema branch (will be on master on the beginning of the week and on CRAN in a week).

In the meantime you can use user defined predict_function

DALEX.explainer <- DALEX::explain(
  model = model.rf,
  data = features,
  y = target == TARGET.VALUE,
  predict_function = function(m,x) as.matrix(predict(m, newdata = x, type = "prob")),
  label = paste(model_object$method, " model"),
  colorize = TRUE
)

from ibreakdown.

pbiecek avatar pbiecek commented on May 13, 2024

this is now fixed with the latest DALEX starting with 0.9.8 as in
https://github.com/ModelOriented/DALEX/tree/DALEX_1.0_ema_version

from ibreakdown.

agilebean avatar agilebean commented on May 13, 2024

I can confirm it works now - just ran local_attributions() on a classification dataset.
Wonderful.
Returns this plot:
image

from ibreakdown.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.