Giter VIP home page Giter VIP logo

soft-decision-tree's Introduction

Soft-Decision-Tree

Soft-Decision-Tree is the pytorch implementation of Distilling a Neural Network Into a Soft Decision Tree, paper recently published on Arxiv about adopting decision tree algorithm into neural network. "If we could take the knowledge acquired by the neural net and express the same knowledge in a model that relies on hierarchical decisions instead, explaining a particular decision would be much easier."

Requirements

Result

I achieved 92.95% of test dataset accuracy on MNISTafter 40 epoches, without exploring enough of hyper-parameters (The paper achieved 94.45%). Higher accuracy might be achievable with searching hyper-parameters, or training longer epoches (if you can, please let me know :) )

Usage

$ python main.py

soft-decision-tree's People

Contributors

kimhc6028 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

soft-decision-tree's Issues

I can't solve it

(python352) C:\Users\ZQ>python
Python 3.5.2 |Continuum Analytics, Inc.| (default, Jul 5 2016, 11:41:13) [MSC v.1900 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.

import torch
print(torch.version)
0.3.1.post2
quit()

(python352) C:\Users\ZQ>python D:\soft_decision_tree\main.py
directory ./data already exists
Traceback (most recent call last):
File "D:\soft_decision_tree\main.py", line 53, in
transforms.Normalize((0.1307,), (0.3081,))
File "D:\Anaconda\envs\python352\lib\site-packages\torchvision\datasets\mnist.py", line 54, in init
os.path.join(self.root, self.processed_folder, self.training_file))
File "D:\Anaconda\envs\python352\lib\site-packages\torch\serialization.py", line 267, in load
return _load(f, map_location, pickle_module)
File "D:\Anaconda\envs\python352\lib\site-packages\torch\serialization.py", line 420, in _load
result = unpickler.load()
AttributeError: Can't get attribute '_rebuild_tensor_v2' on <module 'torch._utils' from 'D:\Anaconda\envs\python352\lib\site-packages\torch\_utils.py'>

when I try to run this demo, I met a mistake. I could not figure out why?

Traceback (most recent call last):
File "C:/Users/woai_fish/Desktop/android/soft-decision-tree-master/main.py", line 100, in
model.train_(train_loader, epoch)
File "C:\Users\woai_fish\Desktop\android\soft-decision-tree-master\model.py", line 193, in train_
loss, output = self.cal_loss(data, self.target_onehot)
File "C:\Users\woai_fish\Desktop\android\soft-decision-tree-master\model.py", line 133, in cal_loss
leaf_accumulator = self.root.cal_prob(x, self.path_prob_init)
File "C:\Users\woai_fish\Desktop\android\soft-decision-tree-master\model.py", line 58, in cal_prob
self.prob = self.forward(x) # probability of selecting right node
File "C:\Users\woai_fish\Desktop\android\soft-decision-tree-master\model.py", line 48, in forward
return (F.sigmoid(self.beta * self.fc(x)))
File "E:\Conda\envs\android\lib\site-packages\torch\autograd\variable.py", line 757, in mul
return self.mul(other)
File "E:\Conda\envs\android\lib\site-packages\torch\autograd\variable.py", line 301, in mul
return Mul.apply(self, other)
File "E:\Conda\envs\android\lib\site-packages\torch\autograd_functions\basic_ops.py", line 50, in forward
return a.mul(b)
RuntimeError: inconsistent tensor size at d:\downloads\pytorch-master-1\torch\lib\th\generic/THTensorMath.c:847

where is the 0.5 test ?

Thank you for sharing this code I was wondering why did you comment that forward section in soft decision tree and why you are not using the 0.5 test for left or right nodes ??

just another question output dimension in args is the number of classes right?

best regards

exponential increase in the temporal scale

Really great implementation!
I have a question about the implementation. In the last paragraph of section of regularizers, the authors mention ' exponential increase in the temporal scale of the window used to compute the running average'. Is this feature implemented in this codebase? I didn't find it.
Thanks:)

What to change to use Cross entropy

Thank you for your code I was wondering what I should change to use the nn.CrossEntropy loss instead of the loss stated in the paper ?

Thank you

hello,about bigger sizes input?

First of all, thank you for such a good code. I want to ask, when I input a larger size, such as 224 * 224 * 3, I find that the training has no effect, is it necessary to change some parts of the code?

Feature importance

Is that possible to retrieve feature importance, as original decision trees

loss is nan

I try to use my dataset, my data is a table with many discrete data, such as 0,1,2. I found the loss is nan

Train Epoch: 1 [0/49626 (0%)]	Loss: nan, Accuracy: 28/64 (43.0000%)
Train Epoch: 1 [640/49626 (1%)]	Loss: nan, Accuracy: 30/64 (46.0000%)
Train Epoch: 1 [1280/49626 (3%)]	Loss: nan, Accuracy: 36/64 (56.0000%)
Train Epoch: 1 [1920/49626 (4%)]	Loss: nan, Accuracy: 35/64 (54.0000%)
Train Epoch: 1 [2560/49626 (5%)]	Loss: nan, Accuracy: 40/64 (62.0000%)
Train Epoch: 1 [3200/49626 (6%)]	Loss: nan, Accuracy: 29/64 (45.0000%)
Train Epoch: 1 [3840/49626 (8%)]	Loss: nan, Accuracy: 33/64 (51.0000%)
Train Epoch: 1 [4480/49626 (9%)]	Loss: nan, Accuracy: 28/64 (43.0000%)
Train Epoch: 1 [5120/49626 (10%)]	Loss: nan, Accuracy: 30/64 (46.0000%)
Train Epoch: 1 [5760/49626 (12%)]	Loss: nan, Accuracy: 28/64 (43.0000%)
Train Epoch: 1 [6400/49626 (13%)]	Loss: nan, Accuracy: 38/64 (59.0000%)
Train Epoch: 1 [7040/49626 (14%)]	Loss: nan, Accuracy: 36/64 (56.0000%)
Train Epoch: 1 [7680/49626 (15%)]	Loss: nan, Accuracy: 27/64 (42.0000%)
Train Epoch: 1 [8320/49626 (17%)]	Loss: nan, Accuracy: 29/64 (45.0000%)
Train Epoch: 1 [8960/49626 (18%)]	Loss: nan, Accuracy: 31/64 (48.0000%)
Train Epoch: 1 [9600/49626 (19%)]	Loss: nan, Accuracy: 32/64 (50.0000%)
Train Epoch: 1 [10240/49626 (21%)]	Loss: nan, Accuracy: 34/64 (53.0000%)
Train Epoch: 1 [10880/49626 (22%)]	Loss: nan, Accuracy: 34/64 (53.0000%)
Train Epoch: 1 [11520/49626 (23%)]	Loss: nan, Accuracy: 29/64 (45.0000%)
Train Epoch: 1 [12160/49626 (24%)]	Loss: nan, Accuracy: 34/64 (53.0000%)
Train Epoch: 1 [12800/49626 (26%)]	Loss: nan, Accuracy: 39/64 (60.0000%)
Train Epoch: 1 [13440/49626 (27%)]	Loss: nan, Accuracy: 34/64 (53.0000%)

Error when learning rate is big

python3 main.py --max-depth 4 --lr 1
...
Train Epoch: 3 [17280/60000 (29%)] Loss: 0.522095, Accuracy: 54/64 (84.0000%)
Traceback (most recent call last):
File "main.py", line 79, in
model.train_(train_loader, epoch)
File "/content/model.py", line 190, in train_
correct += pred.eq(target.data).cpu().sum()
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'other'

This is so stange and I can't fix it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.