kiryor / nnpulearning Goto Github PK
View Code? Open in Web Editor NEWNon-negative Positive-Unlabeled (nnPU) and unbiased Positive-Unlabeled (uPU) learning reproductive code on MNIST and CIFAR10
License: Other
Non-negative Positive-Unlabeled (nnPU) and unbiased Positive-Unlabeled (uPU) learning reproductive code on MNIST and CIFAR10
License: Other
I'm getting 'HTTP Error 504: Gateway Timeout' when downloading mnist dataset.
Seems like fetch_mldata is deprecated and the download server is no longer available.
Here are my suggested changes:
diff --git a/dataset.py b/dataset.py
index 9edd7dd..b07a9d0 100644
--- a/dataset.py
+++ b/dataset.py
@@ -3,11 +3,11 @@ import urllib.request
import os
import tarfile
import pickle
-from sklearn.datasets import fetch_mldata
+from sklearn.datasets import fetch_openml
def get_mnist():
- mnist = fetch_mldata('MNIST original', data_home=".")
+ mnist = fetch_openml('mnist_784', data_home=".")
x = mnist.data
y = mnist.target
# reshape to (#data, #channel, width, height)
It seems that the last commit to move from surrogate loss to 0-1 loss broke train.py.
Line 196 in fe443ef
To reproduce the NIPS results can we just use the commit from April last year, fix bug of CIFAR10 (baef0a3)
, instead ?
Hi,
I tried to implement this algorithm in pytorch
, but found the risk can not go negative. Due to this, the nnPU and uPU has the same training loss history, can you help me figure out what's wrong with my implementation?
Here is my pytorch code:
class PULoss(nn.Module):
"""Loss function for PU learning."""
def __init__(self, prior, loss=(lambda x: torch.sigmoid(-x)), gamma=1, beta=0, nnPU=True):
super(PULoss, self).__init__()
if not 0 < prior < 1:
raise NotImplementedError("The class prior should be in (0, 1)")
self.prior = prior
self.gamma = gamma
self.beta = beta
self.loss_func = loss
self.nnPU = nnPU
self.positive = 1
self.unlabeled = -1
def forward(self, x, t):
t = t[:, None]
positive, unlabeled = (t==self.positive).float(), (t==self.unlabeled).float()
n_positive, n_unlabeled = max(1., positive.sum().item()), max(1., unlabeled.sum().item())
y_positive = self.loss_func(x)
y_unlabeled = self.loss_func(x)
positive_risk = torch.sum(self.prior * positive / n_positive * y_positive)
negative_risk = torch.sum((unlabeled / n_unlabeled - self.prior * positive / n_positive) * y_unlabeled)
objective = positive_risk + negative_risk
if self.nnPU:
if negative_risk.item() < -self.beta: # No negative occasion in my experiment here !!!
objective = positive_risk -self.beta
self.x_out = -self.gamma * negative_risk
else:
self.x_out = objective
else:
self.x_out = objective
self.loss = objective
return self.loss, self.out
I don't know if it is correct to use auto-gradient function here, I can sure the forward pass is the same as pu_loss.py
in this repository.
Thanks a lot!
Hello, than you for this awesome project, which is very useful to help me understand the method in NNPU paper. But I am confused
abount the NNPU loss, In the file pu_loss.py
line 51, the loss output returns gamma * negative_risk. I also saw this in NNPU paper Algorithm 1 Line 9
. In the paper, author said, the loss go along with negative_risk with a step size discounted by gamma to make mini-batch less overfitted when negative_risk < -beta
. I can not understand this. Besides, I am confused why there is no positive_risk when negative_risk < -beta
.
Hoping for your reply, thank you.
Hey @kiryor, :)
First, kudos and thank you for your very cool research paper on Positive-Unlabeled Learning with Non-Negative Risk Estimator, and this accompanying implementation.
Second, I was wondering: Would you be willing to release the code in this repository under some open source license (may I suggest MIT or BSD 3-clause)?
It would enable people to use it in open source projects (which is what I hope to do).
As you perhaps know, as long as you do not explicitly release code under an open source license the rights are implicitly reserved to the writer (in this case, obviously, just for the implementation, not the algorithm) without permitting any kind of use.
Releasing it explicitly, however, you may reserve rights to it while explicitly allowing almost all kinds of use with it, but releasing yourself from any warranty or liability.
If I've convinced you, then it's as easy as creating a LICENSE file and pasting the short text of the license inside, putting in your name and the year:
Cheers,
Shay
Hi, first of all, thank you for this amazing research and results.
I'm doing some research using nnPU with binary classification (with pytorch applications).
During my experiments, there is a problem that train loss & validation loss are decreasing just like the way loss decreases in the paper of nnPU,
but printed confusion matrix every 5 epochs are odd (scoring extremely low metrics suchs as precison, recall, f1-score etc)
I'm expecting the reason of this problem is accurred when each batch output of the model are mapped to the predicted class wrong because all same settings but different loss with cross entropy learns the data very well.
I'm mapping the output of the model to the calss by pred = torch.where(out[:batch_size] < 0, torch.tensor(-1), torch.tensor(1)), which seems nothing wrong about the code. Is there any opinion or information to deal with this situation?
Hi!
I'm struggling to reproduce your results with the code in the repo.
Even though I'm running this on python2, not 3 like the readme instructs, I find it puzzling.
Are there any known problems with incompatible dependencies known?
finn@9090e0a1791c:/mnt/personal/experiments/nnPUlearning$ python train.py -g 0 --preset exp-mnist
(61000, 1, 28, 28)
training:(61000, 1, 28, 28)
test:(10000, 1, 28, 28)
prior: 0.491533333333
loss: sigmoid
batchsize: 30000
model: <class 'model.MultiLayerPerceptron'>
beta: 0.0
gamma: 1.0
epoch nnPU/loss test/nnPU/error uPU/loss test/uPU/error elapsed_time
1 0.353452 0 -0.102437 0 1.67173
2 0.446729 0 -0.357114 0 2.59044
3 0.46871 0 -0.429388 0 3.45704
4 0.47941 0 -0.456596 0 4.3503
5 0.483416 0 -0.469491 0 5.20476
6 0.486629 0 -0.480274 0 6.08255
7 0.487832 0 -0.480892 0 6.94002
8 0.488755 0 -0.484769 0 7.85751
9 0.489408 0 -0.485066 0 8.72103
10 0.489744 0 -0.486602 0 9.60042
11 0.489895 0 -0.487998 0 10.4424
12 0.490064 0 -0.488483 0 11.2919
13 0.490143 0 -0.487867 0 12.1934
14 0.490136 0 -0.489183 0 13.0329
15 0.49013 0 -0.489783 0 13.8755
16 0.490192 0 -0.489936 0 14.7377
17 0.49013 0 -0.489952 0 15.605
18 0.490032 0 -0.48995 0 16.475
19 0.490035 0 -0.490022 0 17.3875
20 0.490001 0 -0.49 0 18.2469
21 0.489953 0 -0.490023 0 19.1238
22 0.490019 0 -0.490018 0 20.0064
23 0.490058 0 -0.489983 0 20.8473
24 0.490139 0 -0.489932 0 21.7471
25 0.490332 0 -0.490016 0 22.5976
26 0.490442 0 -0.490011 0 23.4543
27 0.490574 0 -0.49003 0 24.3041
28 0.490695 0 -0.490027 0 25.1836
29 0.4908 0 -0.490078 0 26.0536
30 0.490891 0 -0.490141 0 26.9988
31 0.490989 0 -0.490243 0 28.0987
32 0.491036 0 -0.490307 0 28.9555
33 0.491104 0 -0.490457 0 29.8123
34 0.49111 0 -0.490471 0 30.6697
35 0.491158 0 -0.490597 0 31.5942
36 0.491184 0 -0.490678 0 32.4575
37 0.491206 0 -0.490761 0 33.3116
38 0.491227 0 -0.490813 0 34.1664
39 0.491258 0 -0.490875 0 35.0255
40 0.491291 0 -0.490973 0 35.8946
41 0.491308 0 -0.490987 0 36.8193
42 0.491337 0 -0.491048 0 37.687
43 0.49135 0 -0.491067 0 38.5148
44 0.491369 0 -0.491119 0 39.3524
45 0.491382 0 -0.491137 0 40.1832
46 0.491402 0 -0.491174 0 41.075
47 0.491414 0 -0.491208 0 41.9089
48 0.491417 0 -0.491227 0 42.739
49 0.491398 0 -0.49126 0 43.5793
50 0.491398 0 -0.491259 0 44.4244
51 0.491377 0 -0.491295 0 45.2701
52 0.491329 0 -0.491319 0 46.1681
53 0.490777 0 -0.491328 0 47.0041
54 0.462806 0 -0.49135 0 47.8437
55 0.478266 0 -0.491363 0 48.6857
56 0.485979 0 -0.491378 0 49.5294
57 0.488192 0 -0.491389 0 50.433
58 0.489556 0 -0.491402 0 51.2755
59 0.490065 0 -0.491418 0 52.1367
60 0.490497 0 -0.491428 0 53.0053
61 0.490727 0 -0.491438 0 54.091
62 0.490808 0 -0.491447 0 54.9617
63 0.490878 0 -0.49145 0 55.8698
64 0.491001 0 -0.491458 0 56.7245
65 0.491036 0 -0.491455 0 57.5813
66 0.491094 0 -0.491458 0 58.4284
67 0.491102 0 -0.491463 0 59.2827
68 0.491122 0 -0.490919 0 60.2063
69 0.491154 0 -0.337696 0 61.0661
70 0.491159 0 -0.460063 0 61.9256
71 0.491166 0 -0.480137 0 62.7957
72 0.491178 0 -0.482818 0 63.6729
73 0.491199 0 -0.486562 0 64.54
74 0.49119 0 -0.4858 0 65.4714
75 0.491203 0 -0.488717 0 66.3372
76 0.491196 0 -0.489188 0 67.2088
77 0.491221 0 -0.489562 0 68.072
78 0.491232 0 -0.489815 0 68.95
79 0.49124 0 -0.489987 0 69.879
80 0.491253 0 -0.490068 0 70.7484
81 0.491272 0 -0.490254 0 71.6196
82 0.491277 0 -0.490278 0 72.5025
83 0.491309 0 -0.490461 0 73.3859
84 0.49132 0 -0.490497 0 74.2446
85 0.491338 0 -0.49053 0 75.1539
86 0.491351 0 -0.490575 0 76.0052
87 0.491364 0 -0.490628 0 76.8445
88 0.491384 0 -0.490645 0 77.6819
89 0.491396 0 -0.490712 0 78.5155
90 0.491406 0 -0.490723 0 79.4058
91 0.491422 0 -0.490764 0 80.4602
92 0.491431 0 -0.490784 0 81.2949
93 0.491434 0 -0.490806 0 82.1361
94 0.491446 0 -0.490856 0 82.9859
95 0.49145 0 -0.490853 0 83.8477
96 0.491441 0 -0.490897 0 84.7633
97 0.491444 0 -0.490908 0 85.6327
98 0.491357 0 -0.490955 0 86.5019
99 0.481058 0 -0.490958 0 87.3789
100 0.460474 0 -0.490989 0 88.2398
List of dependencies installed:
finn@9090e0a1791c:/mnt/personal/experiments/nnPUlearning$ pip2 list
annoy (1.9.1)
asn1crypto (0.23.0)
backports-abc (0.5)
backports.shutil-get-terminal-size (1.0.0)
backports.weakref (1.0rc1)
bcolz (1.1.2)
beautifulsoup4 (4.6.0)
bleach (1.5.0)
boto (2.48.0)
brewer2mpl (1.4.1)
bs4 (0.0.1)
bz2file (0.98)
cachetools (2.0.1)
certifi (2017.4.17)
cffi (1.11.2)
chainer (3.2.0)
chardet (3.0.4)
configparser (3.5.0)
cryptography (2.1.3)
cupy (2.2.0)
cycler (0.10.0)
decorator (4.1.2)
dill (0.2.7.1)
entrypoints (0.2.3)
enum34 (1.1.6)
fastrlock (0.3)
filelock (2.0.13)
funcsigs (1.0.2)
functools32 (3.2.3.post2)
future (0.16.0)
futures (3.1.1)
gapic-google-cloud-pubsub-v1 (0.14.1)
gensim (2.3.0)
ggplot (0.11.3)
google-auth (1.2.0)
google-auth-httplib2 (0.0.2)
google-cloud-core (0.22.1)
google-cloud-pubsub (0.22.0)
google-gax (0.15.15)
googleapis-common-protos (1.5.3)
grpc-google-cloud-pubsub-v1 (0.14.0)
grpc-google-iam-v1 (0.11.4)
grpcio (1.7.0)
h5py (2.7.0)
html5lib (0.9999999)
httplib2 (0.10.3)
hyperopt (0.1)
idna (2.6)
ipaddress (1.0.18)
ipykernel (4.6.1)
ipython (5.4.1)
ipython-genutils (0.2.0)
ipywidgets (6.0.0)
Jinja2 (2.9.6)
joblib (0.11)
jsonschema (2.6.0)
jupyter (1.0.0)
jupyter-client (5.0.1)
jupyter-console (5.1.0)
jupyter-core (4.3.0)
jupyterlab (0.28.12)
jupyterlab-launcher (0.5.5)
Keras (2.0.4)
Markdown (2.2.0)
MarkupSafe (1.0)
matplotlib (2.0.2)
mistune (0.7.4)
mock (2.0.0)
nbconvert (5.2.1)
nbformat (4.3.0)
networkx (2.0)
nltk (3.2.2)
nose (1.3.7)
notebook (5.0.0)
numpy (1.13.1)
oauth2client (3.0.0)
olefile (0.44)
pandas (0.18.1)
pandocfilters (1.4.1)
pathlib2 (2.3.0)
patsy (0.4.1)
pbr (3.0.1)
pexpect (4.2.1)
pickleshare (0.7.4)
Pillow (4.1.1)
pip (9.0.1)
ply (3.8)
prompt-toolkit (1.0.14)
protobuf (3.3.0)
psycopg2 (2.7.3.1)
ptyprocess (0.5.1)
pyasn1 (0.3.7)
pyasn1-modules (0.1.5)
pycparser (2.18)
Pygments (2.2.0)
pymongo (3.5.1)
pyOpenSSL (17.3.0)
pyparsing (2.2.0)
Pyste (0.9.10)
python-dateutil (2.6.0)
python-snappy (0.5)
pytz (2017.2)
PyYAML (3.12)
pyzmq (16.0.2)
qtconsole (4.3.0)
requests (2.18.4)
rsa (3.4.2)
scandir (1.5)
scikit-learn (0.18)
scipy (0.19.0)
setuptools (36.0.1)
simplegeneric (0.8.1)
singledispatch (3.4.0.3)
six (1.10.0)
skflow (0.1.0)
sklearn (0.0)
smart-open (1.5.3)
statsmodels (0.8.0)
subprocess32 (3.2.7)
tensorflow-gpu (1.2.0)
terminado (0.6)
testpath (0.3.1)
tflearn (0.2.1)
Theano (0.9.0)
torch (0.2.0.post1)
torchvision (0.2.0)
tornado (4.5.1)
traitlets (4.3.2)
urllib3 (1.22)
wcwidth (0.1.7)
webencodings (0.5.1)
Werkzeug (0.12.2)
wheel (0.29.0)
widgetsnbextension (2.0.0)
trainer.run()
epoch train/nnPU/error test/nnPU/error train/uPU/error test/uPU/error elapsed_time
Traceback (most recent call last):
File "", line 1, in
File "/home/sesun/anaconda3/lib/python3.6/site-packages/chainer/training/trainer.py", line 299, in run
entry.extension(self)
File "/home/sesun/anaconda3/lib/python3.6/site-packages/chainer/training/extensions/evaluator.py", line 137, in call
result = self.evaluate()
File "", line 23, in evaluate
File "", line 23, in
NameError: name 'np' is not defined
Could you please tell me what causes the problem? Thanks.
Could you please give more detailed codes for below tasks after execution of train.py?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.