Giter VIP home page Giter VIP logo

Comments (5)

CloseChoice avatar CloseChoice commented on September 26, 2024 1

yes, seems like the implementation differs

from shap.

CloseChoice avatar CloseChoice commented on September 26, 2024

Hmm, the differences look pretty large. Please note that for this example, we just rely on a shap implementation of the xgboost repo and just call this under the hood, therefore it is not expected that this is exactly the same but I agree that the differences look a bit too large.

from shap.

madakkmi avatar madakkmi commented on September 26, 2024

@CloseChoice, the issue is there with catboost (version is 1.2.5) as well. Please see the code and the result.

import catboost
import shap

# get a dataset on income prediction
X, y = shap.datasets.adult()

# train an XGBoost model (but any other model type would also work)
model = catboost.CatBoostClassifier()
model.fit(X, y)


explainer = shap.explainers.Tree(model, X.iloc[:1024, :], feature_perturbation='interventional', model_output='log_loss')
shap_values_cpu = explainer(X, y)

explainer = shap.explainers.GPUTree(model, X.iloc[:1024, :], feature_perturbation='interventional', model_output='log_loss')
shap_values_gpu = explainer(X, y)

print(shap_values_cpu)
print(shap_values_gpu)


>>> print(shap_values_cpu)
.values =
array([[-2.45228340e-02, -4.55682448e-02,  1.49148056e-02, ...,
        -9.59332274e-03, -2.74721636e-02,  2.49793992e-04],
       [ 9.43002960e-02, -1.08479404e-01,  1.27293763e-01, ...,
        -1.52033860e-02, -2.89874974e-01,  3.28992919e-03],
       [-1.01427812e-02,  3.23109879e-03, -6.32459885e-02, ...,
        -1.30945537e-02, -2.66773530e-02, -1.15663806e-03],
       ...,
       [ 4.04812396e-02,  4.77366793e-03, -6.22465752e-02, ...,
        -1.11532688e-02, -2.74638473e-02, -3.50420660e-04],
       [-1.29751247e-01,  4.96752296e-04, -3.21895448e-02, ...,
        -7.16964093e-03, -1.02131182e-01, -7.98910372e-04],
       [-3.10874259e-01, -5.21471528e-02,  7.77226873e-02, ...,
         4.41351594e-04, -9.39966680e-02, -4.64903545e-03]])

.base_values =
array([<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
       ...,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>],
      dtype=object)

.data =
array([[39.,  7., 13., ...,  0., 40., 39.],
       [50.,  6., 13., ...,  0., 13., 39.],
       [38.,  4.,  9., ...,  0., 40., 39.],
       ...,
       [58.,  4.,  9., ...,  0., 40., 39.],
       [22.,  4.,  9., ...,  0., 20., 39.],
       [52.,  5.,  9., ...,  0., 40., 39.]])
>>> print(shap_values_gpu)
.values =
array([[ 2.62437046e-01, -6.89887345e-01,  5.82130373e-01, ...,
        -3.09314616e-02,  7.63727650e-02,  1.26645975e-02],
       [ 9.55442727e-01, -3.38494778e-01,  7.66370893e-01, ...,
        -2.25803982e-02, -1.10402429e+00,  1.59402527e-02],
       [ 4.88319218e-01,  1.59519743e-02, -3.86726022e-01, ...,
        -3.10675669e-02,  1.42025307e-01, -8.82095890e-04],
       ...,
       [ 1.18688166e+00,  3.16245370e-02, -3.60770345e-01, ...,
        -2.55552642e-02,  8.21248144e-02,  5.16887382e-03],
       [-1.93334663e+00,  7.00111827e-03, -3.63002360e-01, ...,
        -2.84506176e-02, -1.45065773e+00, -2.93454411e-03],
       [ 5.98957181e-01,  1.35710150e-01, -3.35983247e-01, ...,
        -1.98047664e-02,  1.61192924e-01,  1.42598096e-02]])

.base_values =
array([<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
       ...,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>],
      dtype=object)

.data =
array([[39.,  7., 13., ...,  0., 40., 39.],
       [50.,  6., 13., ...,  0., 13., 39.],
       [38.,  4.,  9., ...,  0., 40., 39.],
       ...,
       [58.,  4.,  9., ...,  0., 40., 39.],
       [22.,  4.,  9., ...,  0., 20., 39.],
       [52.,  5.,  9., ...,  0., 40., 39.]])
>>>

from shap.

madakkmi avatar madakkmi commented on September 26, 2024

@CloseChoice , the issue is there with LightGBM (version 4.3.0) as well. Please see the below.

import lightgbm
import shap

# get a dataset on income prediction
X, y = shap.datasets.adult()

model = lightgbm.LGBMClassifier()
model.fit(X, y)


explainer = shap.explainers.Tree(model, X.iloc[:1024, :], feature_perturbation='interventional', model_output='log_loss')
shap_values_cpu = explainer(X, y)

explainer = shap.explainers.GPUTree(model, X.iloc[:1024, :], feature_perturbation='interventional', model_output='log_loss')
shap_values_gpu = explainer(X, y)

print(shap_values_cpu)
print(shap_values_gpu)

>>> print(shap_values_cpu)
.values =
array([[-4.91080390e-03, -1.69228050e-02,  2.96835418e-02, ...,
        -9.80729848e-03, -3.05952129e-02,  1.58141686e-04],
       [ 9.99060617e-02, -8.55006509e-02,  1.56129534e-01, ...,
        -1.04579078e-02, -2.75666446e-01,  1.38203819e-03],
       [-4.93458344e-03,  3.79995951e-03, -5.60787516e-02, ...,
        -1.15025515e-02, -2.48505877e-02, -3.01114370e-04],
       ...,
       [ 4.21504108e-02,  6.12675405e-03, -4.99365184e-02, ...,
        -9.77849832e-03, -2.94363954e-02,  2.71079930e-05],
       [-1.47657198e-01,  1.88917643e-03, -2.48646541e-02, ...,
        -6.77186991e-03, -8.68297360e-02,  5.81489082e-05],
       [-3.92795784e-01, -2.82715164e-02,  2.97571946e-02, ...,
         1.81641138e-04, -8.22349042e-02, -1.36063861e-03]])

.base_values =
array([<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
       ...,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>],
      dtype=object)

.data =
array([[39.,  7., 13., ...,  0., 40., 39.],
       [50.,  6., 13., ...,  0., 13., 39.],
       [38.,  4.,  9., ...,  0., 40., 39.],
       ...,
       [58.,  4.,  9., ...,  0., 40., 39.],
       [22.,  4.,  9., ...,  0., 20., 39.],
       [52.,  5.,  9., ...,  0., 40., 39.]])
>>> print(shap_values_gpu)
.values =
array([[ 0.4804818 , -0.15829426,  0.58872831, ..., -0.02613338,
        -0.05505598,  0.00279381],
       [ 0.89829117, -0.28824031,  0.82394063, ..., -0.01517473,
        -0.9674257 ,  0.00623654],
       [ 0.41495359,  0.03114166, -0.35944128, ..., -0.02870829,
        -0.00318135, -0.00527164],
       ...,
       [ 0.91733944,  0.03161558, -0.24758764, ..., -0.02169394,
        -0.03218083,  0.00519772],
       [-1.38789368,  0.02243478, -0.17830381, ..., -0.02230209,
        -0.84264034, -0.01070848],
       [ 0.88696593,  0.09897842, -0.16394429, ..., -0.0109373 ,
         0.19775532,  0.00537567]])

.base_values =
array([<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
       ...,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
       <bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>],
      dtype=object)

.data =
array([[39.,  7., 13., ...,  0., 40., 39.],
       [50.,  6., 13., ...,  0., 13., 39.],
       [38.,  4.,  9., ...,  0., 40., 39.],
       ...,
       [58.,  4.,  9., ...,  0., 40., 39.],
       [22.,  4.,  9., ...,  0., 20., 39.],
       [52.,  5.,  9., ...,  0., 40., 39.]])
>>>

from shap.

madakkmi avatar madakkmi commented on September 26, 2024

@CloseChoice, please let me know if you could fix this bug or to guide and advise me on how I could solve this issue. Thank you.

from shap.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.