Giter VIP home page Giter VIP logo

influence_boosting's Introduction

Finding Influential Training Samples for Gradient Boosted Decision Trees

This repository implements the LeafRefit and LeafInfluence methods described in the paper Finding Influential Training Samples for Gradient Boosted Decision Trees.

The paper deals with the problem of finding infuential training samples using the Infuence Functions framework from classical statistics recently revisited in the paper "Understanding Black-box Predictions via Influence Functions" (code). The classical approach, however, is only applicable to smooth parametric models. In our paper, we introduce LeafRefit and LeafInfuence, methods for extending the Infuence Functions framework to non-parametric Gradient Boosted Decision Trees ensembles.

Requirements

We recommend using the Anaconda Python distribution for easy installation.

Python packages

The following Python 2.7 packages are required:

Note: versions of the packages specified below are the versions with which the experiments reported in the paper were tested.

  • numpy==1.14.0
  • scipy==0.19.1
  • pandas==0.20.3
  • scikit-learn==0.19.0
  • matplotlib==2.0.2
  • tensorflow==1.6.0rc0
  • tqdm==4.19.5
  • ipywidgets>=7.0.0 (for Jupyter Notebook rendering)

The create_influence_boosting_env.sh script creates the influence_boosting Conda environment with the required packages installed. You can run the script by running the following in the influence_boosting directory:

bash create_influence_boosting_env.sh

CatBoost

The code in this repository uses CatBoost for an implementation of GBDT. We tested our package with CatBoost version 0.6 built from GitHub. Installation instructions are available in the documentation.

Note: if you are using the influence_boosting environment described above, make sure to install CatBoost specifically for this environment.

export_catboost

Since CatBoost is written in C++, in order to use CatBoost models with our Python package, we also include export_catboost, a binary that exports a saved CatBoost model to a human-readable JSON.

This repository assumes that a program named export_catboost is available in the shell. To ensure that, you can do the following:

  • Select one of the two binaries, export_catboost_macosx or export_catboost_linux, depending on your OS.
  • Copy it to export_catboost in the root repository directory.
  • Add the path to the root repository directory to the PATH environment variable.

Note: since CatBoost's treatment of categorical features can be fairly complicated, export_catboost currently supports numerical features only.

Example

An example experiment showing the API and a use-case of Influence Functions can be found in the influence_for_error_fixing.ipynb notebook.

Note: in this notebook, CatBoost parameters are loaded from the catboost_params.json file. In particular, the task_type parameter is set to CPU by default. If you have a GPU with CUDA available on your machine and compiled CatBoost with GPU support, you can change this parameter to GPU in order to train CatBoost faster on GPU. The majority of the experiments in the paper were conducted using the GPU mode.

influence_boosting's People

Contributors

bsharchilev avatar

Watchers

James Cloos avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.