Giter VIP home page Giter VIP logo

neural-decision-forests's Introduction

Neural-Decision-Forests

An implementation of the Deep Neural Decision Forests(dNDF) in PyTorch.

Features

  • Two stage optimization as in the original paper Deep Neural Decision Forests (fix the neural network and optimize $\pi$ and then optimize $\Theta$ with the class probability distribution in each leaf node fixed )
  • Jointly training $\pi$ and $\Theta$ proposed by chrischoy in his work Fully Differentiable Deep Neural Decision Forest
  • Shallow Neural Decision Forest (sNDF)
  • Deep Neural Decision Forest (dNDF)

Datasets

MNIST, UCI_Adult, UCI_Letter and UCI_Yeast datasets are available. For datasets other than MNIST, you need to go to corresponding directory and run the get_data.sh script.

Requirements

  • Python 3.x
  • PyTorch >= 1.0.0
  • numpy
  • sklearn

Usage

python train.py --ARG=VALUE

in the case of training the sNDF on MNIST with alternating optimization, the command is like

python train.py -dataset mnist -n_class 10 -gpuid 0 -n_tree 80 -tree_depth 10 -batch_size 1000 -epochs 100

Results

Not spending much time on picking hyperparameters and without bells and whistles, I got the accuracy results(obtained by training $\pi$ and $\Theta$ seperately) as follows:

Dataset sNDF dNDF
MNIST 0.9794 0.9963
UCI_Adult 0.8558 NA
UCI_Letter 0.9507 NA
UCI_Yeast 0.6031 NA

By adding the nonlinearity in the routing function, the accuraries can reach 0.6502 and 0.9753 respectively on the UCI_Yeast and UCI_Letter.

Note

Some people may experience the 'loss is NaN' situation which could be caused by the output probability being zero. Please make sure you have normalized your data and used a large enough tree size and depth. In the case that you want to stick with your tree setting, a workaround could be to clamp the output value.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.