xujinfan / reinforced-continual-learning Goto Github PK
View Code? Open in Web Editor NEWcode for paper "Ju Xu, Zhanxing Zhu. Reinforced Continual Learning. NIPS 2018."
code for paper "Ju Xu, Zhanxing Zhu. Reinforced Continual Learning. NIPS 2018."
Dear @xujinfan
I want to reproduce The search space is 5 in convolutional layers, 25 in fully-connected layers for RCL,DEN and PGN.
plz comment for define target variable for evaluate function.
TODO, I made basic 5 CNN model with 2 CNN(5x5,32), (5x5,64) + 2 Maxpool + 1 FCN(1024) as below :
https://github.com/yhgon/Reinforced-Continual-Learning/blob/master/RCL_CNN.py#L53
with tf.Graph().as_default() as g:
with tf.name_scope("before"):
inputs = tf.placeholder(shape=(None, 784), dtype=tf.float32)
y = tf.placeholder(shape=(None, 10), dtype=tf.float32)
w1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1))
b1 = tf.Variable(tf.constant(0.1, shape=(32,)))
w2 = tf.Variable(tf.truncated_normal([5,5,32,64], stddev=0.1))
b2 = tf.Variable(tf.constant(0.1, shape=(64,)))
w3 = tf.Variable(tf.truncated_normal([2*2*64,1024], stddev=0.1))
b3 = tf.Variable(tf.constant(0.1, shape=(1024, )))
w4 = tf.Variable(tf.truncated_normal([1024,10], stddev=0.1))
b4 = tf.Variable(tf.constant(0.1, shape=(10,)))
## model
inputs_shape= inputs.get_shape().as_list()
print("DEBUG input_shape before:",inputs_shape)
inputx=tf.reshape(inputs, shape=[-1,28,28,1]) # 28x28
inputs_shape= inputx.get_shape().as_list()
print("DEBUG input_shape after :",inputs_shape)
conv1 = tf.nn.relu(tf.nn.conv2d(inputx, w1, strides=[1,2,2,1], padding='SAME') + b1)
conv1 = tf.nn.max_pool(conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
conv2 = tf.nn.relu(tf.nn.conv2d(conv1, w2, strides=[1,2,2,1], padding='SAME') + b2)
conv2 = tf.nn.max_pool(conv2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
conv2_shape = conv2.get_shape().as_list()
print("DEBUG conv2_shape before :",conv2_shape)
conv2 = tf.reshape(conv2, [-1, conv2_shape[1] * conv2_shape[2] * conv2_shape[3]])
conv2_shape = conv2.get_shape().as_list()
print("DEBUG conv2_shape after :",conv2_shape)
fcn= tf.nn.relu(tf.matmul(conv2, w3) + b3)
output3=tf.matmul(fcn, w4) + b4
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output3)) + \
0.0001*(tf.nn.l2_loss(w1) + tf.nn.l2_loss(w2) + tf.nn.l2_loss(w3) + tf.nn.l2_loss(w4))
task 0 for initial training works well as below :
--batch_size 1024 --n_epochs 30 --n_tasks 10 --n_layers 2 --hidden_size 10 --num_layers 2 --max_trials 40 --actions_num 2
DEBUG : task0/10 IF
DEBUG input_shape before: [None, 784]
DEBUG input_shape after : [None, 28, 28, 1]
DEBUG conv2_shape before : [None, 2, 2, 64]
DEBUG conv2_shape after : [None, 256]
task 0/10 epoch 29 train_step 53248/55000
task 0/10 epoch 29 train_step 54272/55000
task 0/10 epoch 29 train_step 55296/55000
task 0/10 test accuracy: 0.9629999995231628 IF
task 1/10 trial 0/40 *********actions for [13, 14] ELSE
DEBUG : 0 (5, 1, 32)
DEBUG : 1 ()
DEBUG : 2 (5, 32, 64)
DEBUG : 3 ()
DEBUG : 4 (1024,)
DEBUG : 5 ()
DEBUG : TODO
I wonder how to set action for 2d conv and range for var_list.
current actions for [13, 14] use (5,1,32) instead of (5x5,1,3) of (5x5,1,32)
IMHO, right behavior for taget would be control 32.
task 1/10 trial 0/40 *********actions for [13, 14] ELSE
DEBUG : 0 (5, 1, 32)
DEBUG : 1 ()
DEBUG : 2 (5, 32, 64)
DEBUG : 3 ()
DEBUG : 4 (1024,)
DEBUG : 5 ()
DEBUG : TODO
in FCN task1, var_list is below :
#DEBUG : 0 (312,)
#DEBUG : 1 ()
#DEBUG : 2 (128,)
#DEBUG : 3 ()
#DEBUG : 4 (10,)
#DEBUG : 5 ()
how to control the target?
in your paper, do you also control conv filter size such as 3x3, 5x5, 1x1 ?
Hi, @xujinfan , thank you for your implemenation first. After checking the code, I find there are some questions in the implementation of the policy generator. Such as in the line 47 in policy_gradient.py,
self.loss = -tf.log(picked_action_prob)*self.target
the loss function of the RNN is the multiplication of the advantage and the probabilities of all the selected actions.
baseline_value = self.value_estimator.predict(self.state, self.sess)
advantage = reward - baseline_value
Intuitively, the advantage is going to decrease in order to make the loss funciton decrease, but we eventually want the reward to become much larger. Am I misunderstanding your idea in the paper, could you please explain this more detailedly?
Thanks,
Dear @xujinfan
thanks for your sharing the great work.
When I try to reproduce your result, I found at the start point of RCL.py#L89 in Task_id=3, have error as below.
ValueError: Variable policy_estimator/LSTM/multi_rnn_cell/cell_0/basic_lstm_cell/kernel already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:
File "/mnt/dataset/lsun/Reinforced-Continual-Learning/policy_gradient.py", line 38, in __init__
(cell_output, hidden_state) = cell(inputs, hidden_state)
detail log are below. I just add print for debugging purpose. https://github.com/yhgon/Reinforced-Continual-Learning/blob/master/RCL.py
task 0/4 epoch 3 train_step 54272/55000
task 0/4 epoch 3 train_step 54784/55000
task 0/4 epoch 3 train_step 55296/55000
task 0/4 test accuracy: 0.9463000297546387 IF
DEBUG : task 1/4 ELSE
WARNING:tensorflow:From /mnt/dataset/lsun/Reinforced-Continual-Learning/policy_gradient.py:32: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').
WARNING:tensorflow:From /mnt/dataset/lsun/Reinforced-Continual-Learning/policy_gradient.py:51: get_global_step (from tensorflow.contrib.framework.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Please switch to tf.train.get_global_step
DEBUG : taks1/4 ELSE finish Controller
DEBUG : taks1/4 ELSE result Controller
DEBUG : taks1/4 ELSE start trial loop
task 1/4 trial 0/2 *********actions for [2, 18] ELSE
task:1, epoch 0/4 55296/55000 test accuracy:0.15389999747276306 for evaluate action
task:1, epoch 3/4 55296/55000 test accuracy:0.5547999739646912 for evaluate action
task 1/4 trial 0/2, test accuracy: 0.5547999739646912 ELSE
reward: 0.5562000017166138 ELSE
DEBUG : task 1/4 trial 0/2 ELSE done best_reward
DEBUG : task 1/4 trial 0/2 ELSE end trial internal loop for train_control
task 1/4 trial 1/2 *********actions for [3, 0] ELSE
task:1, epoch 0/4 55296/55000 test accuracy:0.1509000062942505 for evaluate action
task:1, epoch 3/4 55296/55000 test accuracy:0.6255000233650208 for evaluate action
task 1/4 trial 1/2, test accuracy: 0.6255000233650208 ELSE
reward: 0.6373000046730042 ELSE
DEBUG : task 1/4 trial 1/2 ELSE done best_reward
DEBUG : task 1/4 trial 1/2 ELSE end trial internal loop for train_control
DEBUG : task 1/4 ELSE end trial loop
DEBUG : task 1/4 ELSE end controller session
DEBUG : task 1/4 ELSE end result append
DEBUG : task 1/4 ELSE end self vars
DEBUG : task 1/4 ELSE end loop
DEBUG : task 2/4 ELSE start
Traceback (most recent call last):
Traceback (most recent call last):
File "RCL.py", line 170, in <module>
jason = RCL(args)
File "RCL.py", line 32, in __init__
self.train()
File "RCL.py", line 93, in train
controller = Controller(self.args)
File "/mnt/dataset/lsun/Reinforced-Continual-Learning/policy_gradient.py", line 104, in __init__
self.policy_estimator = PolicyEstimator(args)
File "/mnt/dataset/lsun/Reinforced-Continual-Learning/policy_gradient.py", line 38, in __init__
(cell_output, hidden_state) = cell(inputs, hidden_state)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 233, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 374, in __call__
outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 757, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1486, in call
cur_inp, new_state = cell(cur_inp, cur_state)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 370, in __call__
*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 374, in __call__
outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 746, in __call__
self.build(input_shapes)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/utils/tf_utils.py", line 149, in wrapper
output_shape = fn(instance, input_shape)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 716, in build
shape=[input_depth + h_depth, 4 * self._num_units])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 495, in add_variable
return self.add_weight(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/layers/base.py", line 288, in add_weight
getter=vs.get_variable)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 609, in add_weight
aggregation=aggregation)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/checkpointable/base.py", line 639, in _add_variable_with_custom_getter
**kwargs_for_getter)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 1487, in get_variable
aggregation=aggregation)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 1237, in get_variable
aggregation=aggregation)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 523, in get_variable
return custom_getter(**custom_getter_kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 236, in _rnn_get_variable
variable = getter(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 492, in _true_getter
aggregation=aggregation)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/variable_scope.py", line 861, in _get_single_variable
name, "".join(traceback.format_list(tb))))
ValueError: Variable policy_estimator/LSTM/multi_rnn_cell/cell_0/basic_lstm_cell/kernel already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:
File "/mnt/dataset/lsun/Reinforced-Continual-Learning/policy_gradient.py", line 38, in __init__
(cell_output, hidden_state) = cell(inputs, hidden_state)
File "/mnt/dataset/lsun/Reinforced-Continual-Learning/policy_gradient.py", line 104, in __init__
self.policy_estimator = PolicyEstimator(args)
File "RCL.py", line 93, in train
controller = Controller(self.args)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.