Giter VIP home page Giter VIP logo

Comments (3)

connortann avatar connortann commented on May 28, 2024

Thank you for the bug report. The example given above is not reproducible; to allow us to investigate further, please would you provide a minimal reproducible example according to this guide?

from shap.

goslak avatar goslak commented on May 28, 2024

from shap.

Cemlyn avatar Cemlyn commented on May 28, 2024

I've seen a similar error and can reproduce it using the code below.

Changing the check_additivity arg to False stop the assertion error coming up however I'm not sure if this is good/correct.


import shap
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.nn.functional import binary_cross_entropy
from sklearn.datasets import make_classification

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

N_EPOCHS = 10
LEARNING_RATE = 0.1

class LogisticRegression(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.linear = nn.Linear(num_features, 1)

    def forward(self, vec):
        return F.sigmoid(self.linear(vec))

    def save_model(self, loc):
        with open(loc, "w", encoding="utf-8") as file:
            model_params = ""
            for n, param in enumerate(self.parameters()):
                model_params += f"{n},{param.data}\n"
            file.write(model_params)


model = LogisticRegression(num_features=6).to(device)

optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

X,y = make_classification(n_samples=1000,n_features=6,n_informative=4)
X = torch.as_tensor(X, dtype=torch.float32).to(device)
y = torch.as_tensor(y, dtype=torch.float32).unsqueeze(dim=1).to(device)

for i in range(N_EPOCHS):
    model.zero_grad()
    y_pred = model(X)
    loss = binary_cross_entropy(y_pred, y)
    loss.backward()
    optimizer.step()

e = shap.DeepExplainer(model, X)

shap_values = e.shap_values(X, check_additivity=True)```

from shap.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.