Comments (5)
yes, seems like the implementation differs
from shap.
Hmm, the differences look pretty large. Please note that for this example, we just rely on a shap implementation of the xgboost repo and just call this under the hood, therefore it is not expected that this is exactly the same but I agree that the differences look a bit too large.
from shap.
@CloseChoice, the issue is there with catboost (version is 1.2.5) as well. Please see the code and the result.
import catboost
import shap
# get a dataset on income prediction
X, y = shap.datasets.adult()
# train an XGBoost model (but any other model type would also work)
model = catboost.CatBoostClassifier()
model.fit(X, y)
explainer = shap.explainers.Tree(model, X.iloc[:1024, :], feature_perturbation='interventional', model_output='log_loss')
shap_values_cpu = explainer(X, y)
explainer = shap.explainers.GPUTree(model, X.iloc[:1024, :], feature_perturbation='interventional', model_output='log_loss')
shap_values_gpu = explainer(X, y)
print(shap_values_cpu)
print(shap_values_gpu)
>>> print(shap_values_cpu)
.values =
array([[-2.45228340e-02, -4.55682448e-02, 1.49148056e-02, ...,
-9.59332274e-03, -2.74721636e-02, 2.49793992e-04],
[ 9.43002960e-02, -1.08479404e-01, 1.27293763e-01, ...,
-1.52033860e-02, -2.89874974e-01, 3.28992919e-03],
[-1.01427812e-02, 3.23109879e-03, -6.32459885e-02, ...,
-1.30945537e-02, -2.66773530e-02, -1.15663806e-03],
...,
[ 4.04812396e-02, 4.77366793e-03, -6.22465752e-02, ...,
-1.11532688e-02, -2.74638473e-02, -3.50420660e-04],
[-1.29751247e-01, 4.96752296e-04, -3.21895448e-02, ...,
-7.16964093e-03, -1.02131182e-01, -7.98910372e-04],
[-3.10874259e-01, -5.21471528e-02, 7.77226873e-02, ...,
4.41351594e-04, -9.39966680e-02, -4.64903545e-03]])
.base_values =
array([<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
...,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7fde95c70390>>],
dtype=object)
.data =
array([[39., 7., 13., ..., 0., 40., 39.],
[50., 6., 13., ..., 0., 13., 39.],
[38., 4., 9., ..., 0., 40., 39.],
...,
[58., 4., 9., ..., 0., 40., 39.],
[22., 4., 9., ..., 0., 20., 39.],
[52., 5., 9., ..., 0., 40., 39.]])
>>> print(shap_values_gpu)
.values =
array([[ 2.62437046e-01, -6.89887345e-01, 5.82130373e-01, ...,
-3.09314616e-02, 7.63727650e-02, 1.26645975e-02],
[ 9.55442727e-01, -3.38494778e-01, 7.66370893e-01, ...,
-2.25803982e-02, -1.10402429e+00, 1.59402527e-02],
[ 4.88319218e-01, 1.59519743e-02, -3.86726022e-01, ...,
-3.10675669e-02, 1.42025307e-01, -8.82095890e-04],
...,
[ 1.18688166e+00, 3.16245370e-02, -3.60770345e-01, ...,
-2.55552642e-02, 8.21248144e-02, 5.16887382e-03],
[-1.93334663e+00, 7.00111827e-03, -3.63002360e-01, ...,
-2.84506176e-02, -1.45065773e+00, -2.93454411e-03],
[ 5.98957181e-01, 1.35710150e-01, -3.35983247e-01, ...,
-1.98047664e-02, 1.61192924e-01, 1.42598096e-02]])
.base_values =
array([<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
...,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7fde99e71a50>>],
dtype=object)
.data =
array([[39., 7., 13., ..., 0., 40., 39.],
[50., 6., 13., ..., 0., 13., 39.],
[38., 4., 9., ..., 0., 40., 39.],
...,
[58., 4., 9., ..., 0., 40., 39.],
[22., 4., 9., ..., 0., 20., 39.],
[52., 5., 9., ..., 0., 40., 39.]])
>>>
from shap.
@CloseChoice , the issue is there with LightGBM (version 4.3.0) as well. Please see the below.
import lightgbm
import shap
# get a dataset on income prediction
X, y = shap.datasets.adult()
model = lightgbm.LGBMClassifier()
model.fit(X, y)
explainer = shap.explainers.Tree(model, X.iloc[:1024, :], feature_perturbation='interventional', model_output='log_loss')
shap_values_cpu = explainer(X, y)
explainer = shap.explainers.GPUTree(model, X.iloc[:1024, :], feature_perturbation='interventional', model_output='log_loss')
shap_values_gpu = explainer(X, y)
print(shap_values_cpu)
print(shap_values_gpu)
>>> print(shap_values_cpu)
.values =
array([[-4.91080390e-03, -1.69228050e-02, 2.96835418e-02, ...,
-9.80729848e-03, -3.05952129e-02, 1.58141686e-04],
[ 9.99060617e-02, -8.55006509e-02, 1.56129534e-01, ...,
-1.04579078e-02, -2.75666446e-01, 1.38203819e-03],
[-4.93458344e-03, 3.79995951e-03, -5.60787516e-02, ...,
-1.15025515e-02, -2.48505877e-02, -3.01114370e-04],
...,
[ 4.21504108e-02, 6.12675405e-03, -4.99365184e-02, ...,
-9.77849832e-03, -2.94363954e-02, 2.71079930e-05],
[-1.47657198e-01, 1.88917643e-03, -2.48646541e-02, ...,
-6.77186991e-03, -8.68297360e-02, 5.81489082e-05],
[-3.92795784e-01, -2.82715164e-02, 2.97571946e-02, ...,
1.81641138e-04, -8.22349042e-02, -1.36063861e-03]])
.base_values =
array([<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
...,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._tree.TreeExplainer object at 0x7f2368a6a0d0>>],
dtype=object)
.data =
array([[39., 7., 13., ..., 0., 40., 39.],
[50., 6., 13., ..., 0., 13., 39.],
[38., 4., 9., ..., 0., 40., 39.],
...,
[58., 4., 9., ..., 0., 40., 39.],
[22., 4., 9., ..., 0., 20., 39.],
[52., 5., 9., ..., 0., 40., 39.]])
>>> print(shap_values_gpu)
.values =
array([[ 0.4804818 , -0.15829426, 0.58872831, ..., -0.02613338,
-0.05505598, 0.00279381],
[ 0.89829117, -0.28824031, 0.82394063, ..., -0.01517473,
-0.9674257 , 0.00623654],
[ 0.41495359, 0.03114166, -0.35944128, ..., -0.02870829,
-0.00318135, -0.00527164],
...,
[ 0.91733944, 0.03161558, -0.24758764, ..., -0.02169394,
-0.03218083, 0.00519772],
[-1.38789368, 0.02243478, -0.17830381, ..., -0.02230209,
-0.84264034, -0.01070848],
[ 0.88696593, 0.09897842, -0.16394429, ..., -0.0109373 ,
0.19775532, 0.00537567]])
.base_values =
array([<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
...,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>,
<bound method TreeExplainer.__dynamic_expected_value of <shap.explainers._gpu_tree.GPUTreeExplainer object at 0x7f2368a6bf50>>],
dtype=object)
.data =
array([[39., 7., 13., ..., 0., 40., 39.],
[50., 6., 13., ..., 0., 13., 39.],
[38., 4., 9., ..., 0., 40., 39.],
...,
[58., 4., 9., ..., 0., 40., 39.],
[22., 4., 9., ..., 0., 20., 39.],
[52., 5., 9., ..., 0., 40., 39.]])
>>>
from shap.
@CloseChoice, please let me know if you could fix this bug or to guide and advise me on how I could solve this issue. Thank you.
from shap.
Related Issues (20)
- Why do some values of categorical features have a low positive rate but get high SHAP scores in binary classification?
- Deprecated numpy functions in vendored colorconv code
- Plot overhaul: beeswarm
- Plot overhaul: benchmark
- Plot overhaul: decision
- Plot overhaul: embedding
- Plot overhaul: force
- Plot overhaul: initjs
- Plot overhaul: group_difference
- Plot overhaul: heatmap
- Plot overhaul: image
- Plot overhaul: image_to_text
- Plot overhaul: monitoring
- Plot overhaul: scatter
- Plot overhaul: partial_dependence
- Plot overhaul: text
- Plot overhaul: violin
- Plot overhaul: waterfall
- Plot overhaul: bar
- BUG: Native support for (XGBoost) Categoricals HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from shap.