Giter VIP home page Giter VIP logo

tpn's Introduction

Transductive Propagation Network

Code for ICLR19 paper: Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning. pdf

Pytorch Version

https://github.com/csyanbin/TPN-pytorch

Requirements

  • Python 3.5
  • Tensorflow 1.3+
  • tqdm

Data Download (miniImagenet and tieredImagenet)

Please download the compressed tar files from: https://github.com/renmengye/few-shot-ssl-public

mkdir -p data/miniImagenet/data
tar -zxvf mini-imagenet.tar.gz
mv *.pkl data/miniImagenet/data

mkdir -p data/tieredImagenet/data
tar -xvf tiered-imagenet.tar
mv *.pkl data/tieredImagenet/data

TPN mini-5way1shot

python train.py --gpu=0 --n_way=5 --n_shot=1 --n_test_way=5 --n_test_shot=1 --lr=0.001 --step_size=10000 --dataset=mini --exp_name=mini_TPN_5w1s_5tw1ts_rn300_k20 --rn=300 --alpha=0.99 --k=20
python test.py --gpu=0 --n_way=5 --n_shot=1 --n_test_way=5 --n_test_shot=1 --lr=0.001 --step_size=10000 --dataset=mini --exp_name=mini_TPN_5w1s_5tw1ts_rn300_k20 --rn=300 --alpha=0.99 --k=20 --iters=81500

TPN mini-5way5shot

python train.py --gpu=0 --n_way=5 --n_shot=5 --n_test_way=5 --n_test_shot=5 --lr=0.001 --step_size=10000 --dataset=mini --exp_name=mini_TPN_5w5s_5tw5ts_rn300_k20 --rn=300 --alpha=0.99 --k=20
python test.py --gpu=0 --n_way=5 --n_shot=5 --n_test_way=5 --n_test_shot=5 --lr=0.001 --step_size=10000 --dataset=mini --exp_name=mini_TPN_5w5s_5tw5ts_rn300_k20 --rn=300 --alpha=0.99 --k=20 --iters=50100

TPN tiered-5way1shot

python train.py --gpu=0 --n_way=5 --n_shot=1 --n_test_way=5 --n_test_shot=1 --lr=0.001 --step_size=25000 --dataset=tiered --exp_name=tiered_TPN_5w1s_5tw1ts_rn300_k20 --rn=300 --alpha=0.99 --k=20

TPN tiered-5way5shot

python train.py --gpu=0 --n_way=5 --n_shot=5 --n_test_way=5 --n_test_shot=5 --lr=0.001 --step_size=25000 --dataset=tiered --exp_name=tiered_TPN_5w5s_5tw5ts_rn300_k20 --rn=300 --alpha=0.99 --k=20

Citation

If you use our code, please consider to cite the following paper:

  • Yanbin Liu, Juho Lee, Minseop Park, Saehoon Kim, Eunho Yang, Sungju Hwang, Yi Yang. Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning. In Proceedings of 7th International Conference on Learning Representations (ICLR), 2019.

@inproceedings{liu2019fewTPN,
	title={Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning},
	author={Yanbin Liu and 
		Juho Lee and 
		Minseop Park and 
		Saehoon Kim and 
		Eunho Yang and 
		Sungju Hwang and 
		Yi Yang},
booktitle={International Conference on Learning Representations},
year={2019},
}

tpn's People

Contributors

csyanbin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

tpn's Issues

MatrixInverse Error

Have you ever met this error. It appears when I run the code.
train epoc:0: 0%| | 0/100 [00:00<?, ?it/s]2019-04-08 19:45:32.987229: I tensorflow/core/kernels/cuda_solvers.cc:159] Creating CudaSolver handles for stream 0x1564c380
2019-04-08 19:45:33.117242: W tensorflow/core/framework/op_kernel.cc:1318] OP_REQUIRES failed at matrix_inverse_op.cc:223 : Internal: tensorflow/core/kernels/cuda_solvers.cc:408: cuSolverDN call failed with status =6

Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1322, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1307, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InternalError: tensorflow/core/kernels/cuda_solvers.cc:408: cuSolverDN call failed with status =6
[[Node: MatrixInverse = MatrixInverseT=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"]]
[[Node: Mean_2/_51 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_2112_Mean_2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/media/hlf/51aa617b-dbbc-4ad1-af81-45cf8dfce172/hlf/code/TPN-master/train.py", line 213, in
_, summaries, step, ls, ac = sess.run([train_op, train_summary_op, global_step, ce_loss, acc], feed_dict={m.x: support, m.ys:s_labels, m.q: query, m.y:q_labels, m.phase:1})
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1135, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1316, in _do_run
run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1335, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InternalError: tensorflow/core/kernels/cuda_solvers.cc:408: cuSolverDN call failed with status =6
[[Node: MatrixInverse = MatrixInverseT=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"]]
[[Node: Mean_2/_51 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_2112_Mean_2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

Caused by op 'MatrixInverse', defined at:
File "/media/hlf/51aa617b-dbbc-4ad1-af81-45cf8dfce172/hlf/code/TPN-master/train.py", line 148, in
ce_loss,acc,sigma_value = m.construct()
File "/media/hlf/51aa617b-dbbc-4ad1-af81-45cf8dfce172/hlf/code/TPN-master/models.py", line 88, in construct
ce_loss, acc, sigma_value = self.label_prop(emb_x, emb_q, ys_one_hot)
File "/media/hlf/51aa617b-dbbc-4ad1-af81-45cf8dfce172/hlf/code/TPN-master/models.py", line 139, in label_prop
F = tf.matrix_inverse(tf.cast(F, tf.float32))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_linalg_ops.py", line 1049, in matrix_inverse
"MatrixInverse", input=input, adjoint=adjoint, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1718, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InternalError (see above for traceback): tensorflow/core/kernels/cuda_solvers.cc:408: cuSolverDN call failed with status =6
[[Node: MatrixInverse = MatrixInverseT=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"]]
[[Node: Mean_2/_51 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_2112_Mean_2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

pkl.load()

Traceback (most recent call last):
File "test.py", line 122, in
loader_test.load_data_pkl()
File "/home/x_y/lyf/TPN/dataset_mini.py", line 86, in load_data_pkl
data = pkl.load(f)
_pickle.UnpicklingError: invalid load key, '*'.

hello,could you help me?

Semi-TPN

Hi!
I cannot understand the sentence, 'Transductive methods directly use test set as unlabeled data while semi-supervised learning usually has an extra unlabeled set.', in Sec. 4.4.
Can I obtain your code for Semi-TPN.
I am looking forward to your reply!
Thank you very much!

Code for MAML, Reptile, and Relation Net

Hello,

Very interesting paper. Thank you for uploading the code. In the paper the results of applying transductive idea on MAML, Reptile and RelationNet algorithms are reported in table1 and table2. Is there a way to run MAML with transductive in this repository?

Question about query sample numbers

Hello, it's really a nice work. And i found when 5*1 test samples, the model has a good performance. But i wonder if there is only one query sample, does the model has a recognize ability and how to test it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.