Giter VIP home page Giter VIP logo

nnpulearning's Introduction

Chainer implementation of non-negative PU learning and unbiased PU learning

This is a reproducing code for non-negative PU learning [1] and unbiased PU learning [2] in the paper "Positive-Unlabeled Learning with Non-Negative Risk Estimator".

  • pu_loss.py has a chainer implementation of the risk estimator for non-negative PU (nnPU) learning and unbiased PU (uPU) learning.
  • train.py is an example code of nnPU learning and uPU learning. Dataset are MNIST [3] preprocessed in such a way that even digits form the P class and odd digits form the N class and CIFAR10 [4] preprocessed in such a way that artifacts form the P class and living things form the N class. The default setting is 100 P data and 59900 U data of MNIST, and the class prior is the ratio of P class data in U data.

Requirements

  • Python == 3.7
  • Numpy == 1.16
  • Chainer == 6.4
  • Scikit-learn == 0.21
  • Matplotlib == 3.0

Quick start

You can run an example code of MNIST for comparing the performance of nnPU learning and uPU learning on GPU.

python3 train.py -g 0

There are also preset configurations for reproducing results on [1].

  • --preset figure1: The setting of Figure 1
  • --preset exp-mnist: The setting of MNIST experiment in Experiment
  • --preset exp-cifar: The setting of CIFAR10 experiment in Experiment

You can see additional information by adding --help.

Example result

After running training_mnist.py, 2 figures and 1 log file are made in result/ by default. The errors are measured by zero-one loss.

  • Training error in result/training_error.png

training error

  • Test error in result/test_error.png

test error

Reference

[1] Ryuichi Kiryo, Gang Niu, Marthinus Christoffel du Plessis, and Masashi Sugiyama. "Positive-Unlabeled Learning with Non-Negative Risk Estimator." Advances in neural information processing systems. 2017.

[2] Marthinus Christoffel du Plessis, Gang Niu, and Masashi Sugiyama. "Convex formulation for learning from positive and unlabeled data." Proceedings of The 32nd International Conference on Machine Learning. 2015.

[3] LeCun, Yann. "The MNIST database of handwritten digits." http://yann.lecun.com/exdb/mnist/ (1998).

[4] Krizhevsky, Alex, and Geoffrey Hinton. "Learning multiple layers of features from tiny images." (2009).

nnpulearning's People

Contributors

kiryor avatar trellixvulnteam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

nnpulearning's Issues

loss return question

Hello, than you for this awesome project, which is very useful to help me understand the method in NNPU paper. But I am confused
abount the NNPU loss, In the file pu_loss.py line 51, the loss output returns gamma * negative_risk. I also saw this in NNPU paper Algorithm 1 Line 9 . In the paper, author said, the loss go along with negative_risk with a step size discounted by gamma to make mini-batch less overfitted when negative_risk < -beta. I can not understand this. Besides, I am confused why there is no positive_risk when negative_risk < -beta .

Hoping for your reply, thank you.

can't download MNIST dataset

I'm getting 'HTTP Error 504: Gateway Timeout' when downloading mnist dataset.
Seems like fetch_mldata is deprecated and the download server is no longer available.
Here are my suggested changes:

diff --git a/dataset.py b/dataset.py
index 9edd7dd..b07a9d0 100644
--- a/dataset.py
+++ b/dataset.py
@@ -3,11 +3,11 @@ import urllib.request
 import os
 import tarfile
 import pickle
-from sklearn.datasets import fetch_mldata
+from sklearn.datasets import fetch_openml
 
 
 def get_mnist():
-    mnist = fetch_mldata('MNIST original', data_home=".")
+    mnist = fetch_openml('mnist_784', data_home=".")
     x = mnist.data
     y = mnist.target
     # reshape to (#data, #channel, width, height)

can not execute trainer.run()

trainer.run()
epoch train/nnPU/error test/nnPU/error train/uPU/error test/uPU/error elapsed_time
Traceback (most recent call last):
File "", line 1, in
File "/home/sesun/anaconda3/lib/python3.6/site-packages/chainer/training/trainer.py", line 299, in run
entry.extension(self)
File "/home/sesun/anaconda3/lib/python3.6/site-packages/chainer/training/extensions/evaluator.py", line 137, in call
result = self.evaluate()
File "", line 23, in evaluate
File "", line 23, in
NameError: name 'np' is not defined

Could you please tell me what causes the problem? Thanks.

Request: Release the code in this repository under some open source license

Hey @kiryor, :)

First, kudos and thank you for your very cool research paper on Positive-Unlabeled Learning with Non-Negative Risk Estimator, and this accompanying implementation.

Second, I was wondering: Would you be willing to release the code in this repository under some open source license (may I suggest MIT or BSD 3-clause)?

It would enable people to use it in open source projects (which is what I hope to do).

As you perhaps know, as long as you do not explicitly release code under an open source license the rights are implicitly reserved to the writer (in this case, obviously, just for the implementation, not the algorithm) without permitting any kind of use.

Releasing it explicitly, however, you may reserve rights to it while explicitly allowing almost all kinds of use with it, but releasing yourself from any warranty or liability.

If I've convinced you, then it's as easy as creating a LICENSE file and pasting the short text of the license inside, putting in your name and the year:

Cheers,
Shay

Numpy ImportError in train.py

It seems that the last commit to move from surrogate loss to 0-1 loss broke train.py.

summary = {key: np.zeros(4) for key in targets}

To reproduce the NIPS results can we just use the commit from April last year, fix bug of CIFAR10 (baef0a3), instead ?

Where are the fitted model stored and how to use it?

Could you please give more detailed codes for below tasks after execution of train.py?

  1. where and how to load the fitted model.
  2. how to use it with new data for prediction.
  3. output files for result of the prediction.
    Thanks

Decreasing train / validation loss with decreasing metrics like accuracy , f1-score etc.

Hi, first of all, thank you for this amazing research and results.
I'm doing some research using nnPU with binary classification (with pytorch applications).

During my experiments, there is a problem that train loss & validation loss are decreasing just like the way loss decreases in the paper of nnPU,
but printed confusion matrix every 5 epochs are odd (scoring extremely low metrics suchs as precison, recall, f1-score etc)

I'm expecting the reason of this problem is accurred when each batch output of the model are mapped to the predicted class wrong because all same settings but different loss with cross entropy learns the data very well.

I'm mapping the output of the model to the calss by pred = torch.where(out[:batch_size] < 0, torch.tensor(-1), torch.tensor(1)), which seems nothing wrong about the code. Is there any opinion or information to deal with this situation?

Can not get negative risk

Hi,

I tried to implement this algorithm in pytorch, but found the risk can not go negative. Due to this, the nnPU and uPU has the same training loss history, can you help me figure out what's wrong with my implementation?

Here is my pytorch code:

class PULoss(nn.Module):
    """Loss function for PU learning."""
    def __init__(self, prior, loss=(lambda x: torch.sigmoid(-x)), gamma=1, beta=0, nnPU=True):
        super(PULoss, self).__init__()
        if not 0 < prior < 1:
            raise NotImplementedError("The class prior should be in (0, 1)")
        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.loss_func = loss
        self.nnPU = nnPU
        self.positive = 1
        self.unlabeled = -1

    def forward(self, x, t):
        t = t[:, None]
        positive, unlabeled = (t==self.positive).float(), (t==self.unlabeled).float()
        n_positive, n_unlabeled = max(1., positive.sum().item()), max(1., unlabeled.sum().item())
        y_positive = self.loss_func(x)
        y_unlabeled = self.loss_func(x)
        positive_risk = torch.sum(self.prior * positive / n_positive * y_positive)
        negative_risk = torch.sum((unlabeled / n_unlabeled - self.prior * positive / n_positive) * y_unlabeled)
        objective = positive_risk + negative_risk
        if self.nnPU:
            if negative_risk.item() < -self.beta:  # No negative occasion in my experiment here !!!
                objective = positive_risk -self.beta
                self.x_out = -self.gamma * negative_risk
            else:
                self.x_out = objective
        else:
            self.x_out = objective
        self.loss = objective
        return self.loss, self.out

train_loss_history

I don't know if it is correct to use auto-gradient function here, I can sure the forward pass is the same as pu_loss.py in this repository.

Thanks a lot!

Can't reproduce with preset

Hi!
I'm struggling to reproduce your results with the code in the repo.

Even though I'm running this on python2, not 3 like the readme instructs, I find it puzzling.
Are there any known problems with incompatible dependencies known?


finn@9090e0a1791c:/mnt/personal/experiments/nnPUlearning$ python train.py -g 0 --preset exp-mnist
(61000, 1, 28, 28)
training:(61000, 1, 28, 28)
test:(10000, 1, 28, 28)
prior: 0.491533333333
loss: sigmoid
batchsize: 30000
model: <class 'model.MultiLayerPerceptron'>
beta: 0.0
gamma: 1.0

epoch       nnPU/loss   test/nnPU/error  uPU/loss    test/uPU/error  elapsed_time
1           0.353452    0                -0.102437   0               1.67173
2           0.446729    0                -0.357114   0               2.59044
3           0.46871     0                -0.429388   0               3.45704
4           0.47941     0                -0.456596   0               4.3503
5           0.483416    0                -0.469491   0               5.20476
6           0.486629    0                -0.480274   0               6.08255
7           0.487832    0                -0.480892   0               6.94002
8           0.488755    0                -0.484769   0               7.85751
9           0.489408    0                -0.485066   0               8.72103
10          0.489744    0                -0.486602   0               9.60042
11          0.489895    0                -0.487998   0               10.4424
12          0.490064    0                -0.488483   0               11.2919
13          0.490143    0                -0.487867   0               12.1934
14          0.490136    0                -0.489183   0               13.0329
15          0.49013     0                -0.489783   0               13.8755
16          0.490192    0                -0.489936   0               14.7377
17          0.49013     0                -0.489952   0               15.605
18          0.490032    0                -0.48995    0               16.475
19          0.490035    0                -0.490022   0               17.3875
20          0.490001    0                -0.49       0               18.2469
21          0.489953    0                -0.490023   0               19.1238
22          0.490019    0                -0.490018   0               20.0064
23          0.490058    0                -0.489983   0               20.8473
24          0.490139    0                -0.489932   0               21.7471
25          0.490332    0                -0.490016   0               22.5976
26          0.490442    0                -0.490011   0               23.4543
27          0.490574    0                -0.49003    0               24.3041
28          0.490695    0                -0.490027   0               25.1836
29          0.4908      0                -0.490078   0               26.0536
30          0.490891    0                -0.490141   0               26.9988
31          0.490989    0                -0.490243   0               28.0987
32          0.491036    0                -0.490307   0               28.9555
33          0.491104    0                -0.490457   0               29.8123
34          0.49111     0                -0.490471   0               30.6697
35          0.491158    0                -0.490597   0               31.5942
36          0.491184    0                -0.490678   0               32.4575
37          0.491206    0                -0.490761   0               33.3116
38          0.491227    0                -0.490813   0               34.1664
39          0.491258    0                -0.490875   0               35.0255
40          0.491291    0                -0.490973   0               35.8946
41          0.491308    0                -0.490987   0               36.8193
42          0.491337    0                -0.491048   0               37.687
43          0.49135     0                -0.491067   0               38.5148
44          0.491369    0                -0.491119   0               39.3524
45          0.491382    0                -0.491137   0               40.1832
46          0.491402    0                -0.491174   0               41.075
47          0.491414    0                -0.491208   0               41.9089
48          0.491417    0                -0.491227   0               42.739
49          0.491398    0                -0.49126    0               43.5793
50          0.491398    0                -0.491259   0               44.4244
51          0.491377    0                -0.491295   0               45.2701
52          0.491329    0                -0.491319   0               46.1681
53          0.490777    0                -0.491328   0               47.0041
54          0.462806    0                -0.49135    0               47.8437
55          0.478266    0                -0.491363   0               48.6857
56          0.485979    0                -0.491378   0               49.5294
57          0.488192    0                -0.491389   0               50.433
58          0.489556    0                -0.491402   0               51.2755
59          0.490065    0                -0.491418   0               52.1367
60          0.490497    0                -0.491428   0               53.0053
61          0.490727    0                -0.491438   0               54.091
62          0.490808    0                -0.491447   0               54.9617
63          0.490878    0                -0.49145    0               55.8698
64          0.491001    0                -0.491458   0               56.7245
65          0.491036    0                -0.491455   0               57.5813
66          0.491094    0                -0.491458   0               58.4284
67          0.491102    0                -0.491463   0               59.2827
68          0.491122    0                -0.490919   0               60.2063
69          0.491154    0                -0.337696   0               61.0661
70          0.491159    0                -0.460063   0               61.9256
71          0.491166    0                -0.480137   0               62.7957
72          0.491178    0                -0.482818   0               63.6729
73          0.491199    0                -0.486562   0               64.54
74          0.49119     0                -0.4858     0               65.4714
75          0.491203    0                -0.488717   0               66.3372
76          0.491196    0                -0.489188   0               67.2088
77          0.491221    0                -0.489562   0               68.072
78          0.491232    0                -0.489815   0               68.95
79          0.49124     0                -0.489987   0               69.879
80          0.491253    0                -0.490068   0               70.7484
81          0.491272    0                -0.490254   0               71.6196
82          0.491277    0                -0.490278   0               72.5025
83          0.491309    0                -0.490461   0               73.3859
84          0.49132     0                -0.490497   0               74.2446
85          0.491338    0                -0.49053    0               75.1539
86          0.491351    0                -0.490575   0               76.0052
87          0.491364    0                -0.490628   0               76.8445
88          0.491384    0                -0.490645   0               77.6819
89          0.491396    0                -0.490712   0               78.5155
90          0.491406    0                -0.490723   0               79.4058
91          0.491422    0                -0.490764   0               80.4602
92          0.491431    0                -0.490784   0               81.2949
93          0.491434    0                -0.490806   0               82.1361
94          0.491446    0                -0.490856   0               82.9859
95          0.49145     0                -0.490853   0               83.8477
96          0.491441    0                -0.490897   0               84.7633
97          0.491444    0                -0.490908   0               85.6327
98          0.491357    0                -0.490955   0               86.5019
99          0.481058    0                -0.490958   0               87.3789
100         0.460474    0                -0.490989   0               88.2398

List of dependencies installed:

finn@9090e0a1791c:/mnt/personal/experiments/nnPUlearning$ pip2 list
annoy (1.9.1)
asn1crypto (0.23.0)
backports-abc (0.5)
backports.shutil-get-terminal-size (1.0.0)
backports.weakref (1.0rc1)
bcolz (1.1.2)
beautifulsoup4 (4.6.0)
bleach (1.5.0)
boto (2.48.0)
brewer2mpl (1.4.1)
bs4 (0.0.1)
bz2file (0.98)
cachetools (2.0.1)
certifi (2017.4.17)
cffi (1.11.2)
chainer (3.2.0)
chardet (3.0.4)
configparser (3.5.0)
cryptography (2.1.3)
cupy (2.2.0)
cycler (0.10.0)
decorator (4.1.2)
dill (0.2.7.1)
entrypoints (0.2.3)
enum34 (1.1.6)
fastrlock (0.3)
filelock (2.0.13)
funcsigs (1.0.2)
functools32 (3.2.3.post2)
future (0.16.0)
futures (3.1.1)
gapic-google-cloud-pubsub-v1 (0.14.1)
gensim (2.3.0)
ggplot (0.11.3)
google-auth (1.2.0)
google-auth-httplib2 (0.0.2)
google-cloud-core (0.22.1)
google-cloud-pubsub (0.22.0)
google-gax (0.15.15)
googleapis-common-protos (1.5.3)
grpc-google-cloud-pubsub-v1 (0.14.0)
grpc-google-iam-v1 (0.11.4)
grpcio (1.7.0)
h5py (2.7.0)
html5lib (0.9999999)
httplib2 (0.10.3)
hyperopt (0.1)
idna (2.6)
ipaddress (1.0.18)
ipykernel (4.6.1)
ipython (5.4.1)
ipython-genutils (0.2.0)
ipywidgets (6.0.0)
Jinja2 (2.9.6)
joblib (0.11)
jsonschema (2.6.0)
jupyter (1.0.0)
jupyter-client (5.0.1)
jupyter-console (5.1.0)
jupyter-core (4.3.0)
jupyterlab (0.28.12)
jupyterlab-launcher (0.5.5)
Keras (2.0.4)
Markdown (2.2.0)
MarkupSafe (1.0)
matplotlib (2.0.2)
mistune (0.7.4)
mock (2.0.0)
nbconvert (5.2.1)
nbformat (4.3.0)
networkx (2.0)
nltk (3.2.2)
nose (1.3.7)
notebook (5.0.0)
numpy (1.13.1)
oauth2client (3.0.0)
olefile (0.44)
pandas (0.18.1)
pandocfilters (1.4.1)
pathlib2 (2.3.0)
patsy (0.4.1)
pbr (3.0.1)
pexpect (4.2.1)
pickleshare (0.7.4)
Pillow (4.1.1)
pip (9.0.1)
ply (3.8)
prompt-toolkit (1.0.14)
protobuf (3.3.0)
psycopg2 (2.7.3.1)
ptyprocess (0.5.1)
pyasn1 (0.3.7)
pyasn1-modules (0.1.5)
pycparser (2.18)
Pygments (2.2.0)
pymongo (3.5.1)
pyOpenSSL (17.3.0)
pyparsing (2.2.0)
Pyste (0.9.10)
python-dateutil (2.6.0)
python-snappy (0.5)
pytz (2017.2)
PyYAML (3.12)
pyzmq (16.0.2)
qtconsole (4.3.0)
requests (2.18.4)
rsa (3.4.2)
scandir (1.5)
scikit-learn (0.18)
scipy (0.19.0)
setuptools (36.0.1)
simplegeneric (0.8.1)
singledispatch (3.4.0.3)
six (1.10.0)
skflow (0.1.0)
sklearn (0.0)
smart-open (1.5.3)
statsmodels (0.8.0)
subprocess32 (3.2.7)
tensorflow-gpu (1.2.0)
terminado (0.6)
testpath (0.3.1)
tflearn (0.2.1)
Theano (0.9.0)
torch (0.2.0.post1)
torchvision (0.2.0)
tornado (4.5.1)
traitlets (4.3.2)
urllib3 (1.22)
wcwidth (0.1.7)
webencodings (0.5.1)
Werkzeug (0.12.2)
wheel (0.29.0)
widgetsnbextension (2.0.0)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.