Giter VIP home page Giter VIP logo

12wang3 / mllp Goto Github PK

View Code? Open in Web Editor NEW
19.0 4.0 6.0 3.78 MB

The code of AAAI 2020 paper "Transparent Classification with Multilayer Logical Perceptrons and Random Binarization".

License: MIT License

Python 60.42% Jupyter Notebook 39.58%
transparency interpretability transparent-ml interpretable-ai explainable-ai xai interpretable-ml rule-based rule-sets aaai machine-learning-interpretability interpretable-machine-learning machine-learning explainable-ml iml explainability interpretml

mllp's Introduction

Our new work

For better scalability and classification performance, please refer to our new work:

Multilayer Logical Perceptrons

This is a PyTorch implementation of Multilayer Logical Perceptrons (MLLP) and Random Binarization (RB) method to learn Concept Rule Sets (CRS) for transparent classification tasks, as described in our paper: Transparent Classification with Multilayer Logical Perceptrons and Random Binarization.

drawing

If you want a model with a transparent inner structure (good interpretability) and high classification performance, our code may be useful for you. CRS is a hierarchical rule set model, which is transparent and interpretable. We can use gradient descent to learn the discrete CRS via continuous MLLP and the RB method efficiently.

Installation

Clone the repository and run:

python3 setup.py install

Requirements

  • torch>=1.0.1
  • torchvision>=0.2.2
  • sklearn>=0.21.2
  • numpy>=1.16.3
  • pandas>=0.24.2
  • matplotlib>=3.0.0
  • CUDA (optional, for running on GPU)

Run the demo

UCI data sets

We put 12 UCI data sets in the dataset folder. The description of these data sets are listed in DataSetDesc.

You can specify one data set in the dataset folder and train the model as follows:

# tic-tac-toe data set
python3 experiments.py -d tic-tac-toe &

The demo will read the data set and data set information first, then discrete and binarize the data.

After data preprocessing, the demo will train the MLLP on the training set. The training log file (log.txt) can be found in the log_folder. During the training, you can check the training loss and the evaluation result on the validation set (or training set) by:

tail -f log_folder/tic-tac-toe_k5_ki0_useValidationSetFalse_e401_bs64_lr0.01_lrdr0.75_lrde100_wd0.0_p0.0_useNOTFalse_L64/log.txt

After training, the evaluation result on the test set is shown at the end of log.txt:

[INFO] - ============================================================
[INFO] - Test:
	Accuracy of MLLP Model: 0.9895833333333334
	Accuracy of CRS  Model: 1.0
[INFO] - Test:
	F1 Score of MLLP Model: 0.989158667419537
	F1 Score of CRS  Model: 1.0
[INFO] - ============================================================

The figure of training loss is shown in plot_file.pdf. plot

Moreover, the trained MLLP model is save in model.pth, and the extracted CRS is printed in crs.txt:

class_negative:
       r1,6:	 [' 2_o', ' 5_o', ' 8_o']
      r1,16:	 [' 7_o', ' 8_o', ' 9_o']
      r1,20:	 [' 1_o', ' 5_o', ' 9_o']
      r1,24:	 [' 1_x', ' 2_o', ' 3_x', ' 6_x', ' 7_o', ' 9_o']
      r1,27:	 [' 1_o', ' 4_o', ' 7_o']
      r1,39:	 [' 3_x', ' 4_x', ' 6_o', ' 7_o', ' 8_x', ' 9_x']
      r1,40:	 [' 2_x', ' 3_o', ' 5_o', ' 6_x', ' 8_o', ' 9_x']
      r1,48:	 [' 4_o', ' 5_o', ' 6_o']
      r1,50:	 [' 3_o', ' 6_o', ' 9_o']
      r1,55:	 [' 1_x', ' 4_o', ' 6_x', ' 7_x', ' 8_x', ' 9_o']
      r1,58:	 [' 1_o', ' 2_x', ' 4_x', ' 6_o', ' 8_o', ' 9_x']
      r1,60:	 [' 3_o', ' 5_o', ' 7_o']
      r1,62:	 [' 1_o', ' 2_o', ' 3_o']
class_positive:
       r1,3:	 [' 3_x', ' 5_x', ' 7_x']
       r1,5:	 [' 2_x', ' 5_x', ' 8_x']
      r1,26:	 [' 1_x', ' 2_x', ' 3_x']
      r1,29:	 [' 7_x', ' 8_x', ' 9_x']
      r1,37:	 [' 3_x', ' 6_x', ' 9_x']
      r1,38:	 [' 4_x', ' 5_x', ' 6_x']
      r1,51:	 [' 1_x', ' 4_x', ' 7_x']
      r1,52:	 [' 1_x', ' 5_x', ' 9_x']

The 2_o denotes the second square is o. The r1,6: [' 2_o', ' 5_o', ' 8_o'] denotes the rule drawing.
The rule set drawing and drawing are used for label prediction.

Try another data set with more specified arguments:

# adult data set
# You'd better run it on GPU for training large network on CPU may cost lots of time.
python3 experiments.py -d adult -e 800 -bs 64 -lr 0.005 -p 0.9 --use_not --use_validation_set -s 256_256_64 &

If GPU is available, the demo will run on GPU automatically.

Your own data sets

You can use the demo to train MLLP and CRS on your own data set by putting the data and data information files in the dataset folder. Please read DataSetDesc for a more specific guideline.

Available arguments

List all the available arguments and their default values by:

$ python3 experiments.py --help
usage: experiments.py [-h] [-d DATA_SET] [-k KFOLD] [-ki ITH_KFOLD]
                      [--use_validation_set] [-e EPOCH] [-bs BATCH_SIZE]
                      [-lr LEARNING_RATE] [-lrdr LR_DECAY_RATE]
                      [-lrde LR_DECAY_EPOCH] [-wd WEIGHT_DECAY]
                      [-p RANDOM_BINARIZATION_RATE] [--use_not] [-s STRUCTURE]

optional arguments:
  -h, --help            show this help message and exit
  -d DATA_SET, --data_set DATA_SET
                        Set the data set for training. All the data sets in
                        the dataset folder are available. (default: tic-tac-
                        toe)
  -k KFOLD, --kfold KFOLD
                        Set the k of K-Folds cross-validation. (default: 5)
  -ki ITH_KFOLD, --ith_kfold ITH_KFOLD
                        Do the i-th validation, 0 <= ki < k. (default: 0)
  --use_validation_set  Use the validation set for parameters tuning.
                        (default: False)
  -e EPOCH, --epoch EPOCH
                        Set the total epoch. (default: 401)
  -bs BATCH_SIZE, --batch_size BATCH_SIZE
                        Set the batch size. (default: 64)
  -lr LEARNING_RATE, --learning_rate LEARNING_RATE
                        Set the initial learning rate. (default: 0.01)
  -lrdr LR_DECAY_RATE, --lr_decay_rate LR_DECAY_RATE
                        Set the learning rate decay rate. (default: 0.75)
  -lrde LR_DECAY_EPOCH, --lr_decay_epoch LR_DECAY_EPOCH
                        Set the learning rate decay epoch. (default: 100)
  -wd WEIGHT_DECAY, --weight_decay WEIGHT_DECAY
                        Set the weight decay (L2 penalty). (default: 0.0)
  -p RANDOM_BINARIZATION_RATE, --random_binarization_rate RANDOM_BINARIZATION_RATE
                        Set the rate of random binarization. It is important
                        for CRS extractions from deep MLLPs. (default: 0.0)
  --use_not             Use the NOT (~) operator in logical rules. It will
                        enhance model capability but make the CRS more
                        complex. (default: False)
  -s STRUCTURE, --structure STRUCTURE
                        Set the structure of network. Only the number of nodes
                        in middle layers are needed. E.g., 64, 64_32_16. The
                        total number of middle layers should be odd. (default: 64)

Tutorial

You can use the mllp package in your code easily after installation.

The tutorial is shown in the jupyter notebook tutorial.ipynb.

Citation

If our work is helpful to you, please kindly cite our paper as:

@inproceedings{wang2020transparent,
  title={Transparent classification with multilayer logical perceptrons and random binarization},
  author={Wang, Zhuo and Zhang, Wei and Ning, LIU and Wang, Jianyong},
  booktitle={Proceedings of the AAAI conference on artificial intelligence},
  volume={34},
  number={04},
  pages={6331--6339},
  year={2020}
}

License

MIT license

mllp's People

Contributors

12wang3 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

mllp's Issues

倒数第二层(2L-1)中某个node(合取规则)如果与最后一层(2L)的class node有多条连边(weight=1)时的处理

(1)你好,文中描述的这个情况:
image
我理解是否就是这样,就是在权重离散化后正向传播过程中倒数第二层的合取规则与最后一层的两个节点(如果二分类)都有连边,即权重都为1,那就选择第一个class node作为该合取规则的分类标签,但这样的话是否有些不太科学或者是降低了性能(这样可能就等于是随机分类了?),在RRL中,最后一层的weight不会被离散化,保持在[-1,1]之间,但是这样得到的解释集就没那么直观了,相当于只能得到每条规则重要程度。
image
(2)另外一个小问题就是对于网络结构的初始化,如果参数 -s 64_32_16,那么得到的网络的中间层是否应该是 [64, 64, 32, 32, 16]这样的结构(最后一层以合取层16个节点结束),但是实际建立的网络好像并不是这样的
image
image
不知道是我理解还有问题吗,谢谢啦。

合取层的疑问待解决

作者你好,非常喜欢您的工作。
对于合取层,我有一个小问题,对于连续的特征来说,先对特征进行离散化然后进行编码one-hot,比如A特征,假如随机离散化划分阈值为10 20 30,那可能得到的one-hot特征就为(-inf,10)、[10,20),[30,+inf),这样训练完后,如果在合取层上的一个节点与输入的多个节点有连边,那就进行合取操作,但是这样问题是比如:(-inf,10)节点与[10,20)节点都与合取层的某一个节点有连边,那就意味着这条规则是 (-inf,10)&[10,20),这样的规则明显是不可能出现的,就意味着这条规则永远不会被激活。
想知道作者对于这种情况是如何处理考虑的呢?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.