Giter VIP home page Giter VIP logo

shparkley's Introduction

Shparkley: Scaling Shapley Values with Spark

Shparkley is a PySpark implementation of Shapley values which uses a monte-carlo approximation algorithm.

Given a dataset and machine learning model, Shparkley can compute Shapley values for all features for a feature vector. Shparkley also handles training weights and is model-agnostic.

pip install shparkley

You must have Apache Spark installed on your machine/cluster.

from typing import List

from sklearn.base import ClassifierMixin

from affirm.model_interpretation.shparkley.estimator_interface import OrderedSet, ShparkleyModel
from affirm.model_interpretation.shparkley.spark_shapley import compute_shapley_for_sample


class MyShparkleyModel(ShparkleyModel):
    """
    You need to wrap your model with this interface (by subclassing ShparkleyModel)
    """
    def __init__(self, model: ClassifierMixin, required_features: OrderedSet):
        self._model = model
        self._required_features = required_features

    def predict(self, feature_matrix: List[OrderedDict]) -> List[float]:
        """
        Generates one prediction per row, taking in a list of ordered dictionaries (one per row).
        """
        pd_df = pd.DataFrame.from_dict(feature_matrix)
        preds = self._model.predict_proba(pd_df)[:, 1]
        return preds

    def _get_required_features(self) -> OrderedSet:
        """
        An ordered set of feature column names
        """
        return self._required_features

row = dataset.filter(dataset.row_id == 'xxxx').rdd.first()
shparkley_wrapped_model = MyShparkleyModel(my_model)

# You need to sample your dataset based on convergence criteria.
# More samples results in more accurate shapley values.
# Repartitioning and caching the sampled dataframe will speed up computation.
sampled_df = training_df.sample(0.1, True).repartition(75).cache()

shapley_scores_by_feature = compute_shapley_for_sample(
    df=sampled_df,
    model=shparkley_wrapped_model,
    row_to_investigate=row,
    weight_col_name='training_weight_column_name'
)

shparkley's People

Contributors

niloygupta avatar ijoseph avatar

Stargazers

 avatar  avatar Yurong Liu avatar Thaweewat avatar Tae Hyun avatar Jonathan Kim avatar Anthony Holten avatar Georvic Tur avatar Adhita Selvaraj avatar Alessandra Boccuto avatar  avatar  avatar Jiarui Xu avatar Adam avatar Zhibo avatar Sergeev Vladislav avatar Martin Fridrich avatar Maciej Olko avatar  avatar alexander erofeev avatar Nick Tran avatar  avatar  avatar Gabriel Kalejaiye avatar Jivitesh Poojary avatar Apurv Verma avatar  avatar David Dao avatar Otto von Sperling avatar Iacopo avatar Nathan Wu avatar Avinash Sooriyarachchi avatar  avatar  avatar  avatar Hadrian Paulo Lim avatar Srinath Goud Vanga avatar Jilei Yang avatar David Boy avatar Fabian Schan avatar Steven (Szu-Han) Chen avatar Arjun Srivastava avatar Peiwen Wu avatar JinjunSun avatar vfive avatar Beatriz Albiero avatar Grivan Thapar avatar João Tedeschi avatar MangelFdz avatar José Vicente avatar Kevin P avatar Tuan Nguyen avatar Tristan avatar Artem Shnayder avatar Karl Haraldsson avatar Michelangelo D'Agostino avatar Sergii Stamenov avatar Paulo Haddad avatar Ezinne Nwankwo  avatar Bernhard J. Conzelmann avatar Timothy Lau avatar Xinyu Max Liu avatar Yu Shi avatar JUNIOR ANTUNES KOCH avatar Arnab Kar avatar Kesshi Jordan avatar Pedro Larroy avatar Sankalp Gilda avatar william qian avatar Igor Vaynman avatar Michael Kodiak avatar Wojtek Swiderski avatar Rex Lin avatar

Watchers

libor michalek avatar Manuel Arias avatar Anthony K. avatar  avatar Matthew Castillon avatar James Cloos avatar Pedro Larroy avatar Yitao Wang avatar xiang huang avatar Dos avatar Patrick Li  avatar Sankalp Gilda avatar  avatar

shparkley's Issues

Can Shparkley package generate shap values for an entire validation dataset?

As observed in the simple.ipynb file, Shparkley package has generated shap values for a single datapoint, so I wanted to check whether If we input several rows to be investigated, does shparkley provides shap values for all rows?

current:
query_row = Row(fico=600, loan_amount=300, number_of_delinquencies=1, repaid_all_previous_affirm_loans=0)
shapley_values_shparkley = compute_shapley_for_sample(
df=train_spark_df,
model=model_with_shparkley_interface,
row_to_investigate=query_row,
)

Expected:
query_rows =
Row(fico=600, loan_amount=300, number_of_delinquencies=1, repaid_all_previous_affirm_loans=0);
Row(fico=700, loan_amount=350, number_of_delinquencies=0, repaid_all_previous_affirm_loans=0);
Row(fico=680, loan_amount=370, number_of_delinquencies=1, repaid_all_previous_affirm_loans=1);
shapley_values_shparkley = compute_shapley_for_sample(
df=train_spark_df,
model=model_with_shparkley_interface,
row_to_investigate=query_rows,
)

PicklingError: Could not serialize object: TypeError: can't pickle _abc_data objects

I wanted to try out this package, because this implements pyspark version of shapley value generations.
So, I just copy pasted "simple.ipynb" file into my environment to just observe everything basic is working alright or not, but able to see code is breaking at input cell [32]. Attached are the screenshots, could anyone please look into them?
image
image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.