Comments (31)
Yep a switch for using a DRQN architecture would be great. For now I'd go for using histLen
as the number of frames to use BPTT on for a single-frame DRQN. Would be good to base it on the rnn
library, especially since it now has the optimised SeqLSTM
.
from atari.
This is the Caffe implementation from the paper:
https://github.com/mhauskn/dqn/tree/recurrent
Altough Caffe I never looked at probably will help.
from atari.
@Kaixhin I see you started working on this, cool. I'll have some time now, so I'll look at the multigpu and async modes.
from atari.
@lake4790k Almost have something working. Disabling this line lets the DRQN train, as otherwise it crashes here, somehow propagating a batch of size 20 forward but expecting the normal batch size of 32 backwards.
I'm new to the rnn
library, so let me know if you have any ideas. Performance is considerably slower, which will be due to having to process several time steps sequentially. This is in line with Appendix B in that paper though.
from atari.
@Kaixhin Awesome! I have no experience with rnn
either, I will need to study it to have an idea. I have two 980TIs and will be able to run longer experiments to see if it goes anywhere.
from atari.
@lake4790k I'd have to delve into the original paper/code, but it looks like they train the network every step (as opposed to every 4). This seems like it'll be a problem for BPTT. In any case if you haven't used rnn
before I'll focus on this.
from atari.
@Kaixhin cool, I'll have my hands full with async for now, but in the meantime I'll be able to help with running longer rdqn experiments on my workstation when you think it's worth trying.
from atari.
Here's the result of running ./run.sh demo -recurrent true
, so I'm reasonably confident that the DRQN is capable of learning, but I'm not testing this further for now so I'm leaving this issue open. In any case, I still haven't solved this issue (which I mentioned above).
from atari.
Pinging @JoostvDoorn since he's contributed to rnn
and may have ideas about the minibatch problem/performance improvements/whether it's possible to save and restore state before and after training (and if that should be done since the parameters have changed slightly).
from atari.
@Kaixhin I will have a look later.
from atari.
@Kaixhin I'm not getting the error you mentioned when doing validation on the last batch with size 20 when running demo
. I'm using the master
code which has sequencer:remember('both')
enabled. You mention you had to disable that to not crash...? master
runs fine for me as it is.
from atari.
I think this is in the rnn
branch. This may or may not be a bug when using FastLSTM with the nngraph version. Setting nn.FastLSTM.usenngraph = false
changed the error for me, but I only got the chance to look at this for a moment.
from atari.
ok so there are two issues:
nn.FastLSTM.usenngraph = true
nngraph/gmodule.lua:335: split(4) cannot split 32 outputs
this is issue in bothrnn
andmaster
nn.FastLSTM.usenngraph = false
Wrong size for view. Input size: 20x1x3. Output size: 32x3
this is only inrnn
, because @Kaixhin fixed #16 inmaster
(but not inrnn
) that returns before doing thebackward
during validation, because it is not even needed, so maybe no issue after all?
from atari.
- With
nn.FastLSTM.usenngraph = true
, I get the same error as @lake4790k. This seems to be Element-Research/rnn#172. Which is a shame, as apparently it's significantly faster with this flag enabled (see Element-Research/rnn#182). - Yes, so if you remove the
return
on line 374 inmaster
then it fails. So I consider this a bug, albeit one that is being hidden by that return - why is this occurring even whenstates
is20x4x1x24x24
andQCurr
is20x1x3
? If the error is dependent on previous batches then the learning must be incorrect. I was wrong and removingsequencer:remember('both')
doesn't stop the crash.
from atari.
@Kaixhin re: 2. agree, this error is bad, so returning before is not a solution. I'm not sure if learning is bad with the normal batch sizes, could be only not handling a batch size change somewhere properly. I tried an isolated FastLSTM
+Sequencer
net, there switching batch sizes worked fine, weird. I'm looking adding LSTM to async, once I get that working will experiment with this further.
from atari.
@lake4790k I also tried a simple FastLSTM
+ Sequencer
net with different batch sizes - no problem. I agree with it being likely that some module is not switching its internal variables to the correct size, but finding out exactly where the problem lies is tricky. It may be that I haven't set up the recurrency correctly, but apart from this batch size issue it seems to work fine.
from atari.
@Kaixhin I need to refresh async
from master
for the recurrent, should I do a merge or rebase (I'm thinking of merge rather)? Does it even matter when merging back from async
to master
eventually?
from atari.
@lake4790k I'd go with a merge since it preserves history correctly. It's better to make sure all the changes in master
are integrated sooner rather than later.
from atari.
Done the merge and added recurrent
support for 1-step Q in async
. This is 7 minutes of training, seems to work well:
Agent sees only the latest frame per step and backpropagates with unrolling 5 steps on every step, weights are updated every 5 (or terminal) steps, no Sequencer
is needed in this algo. I used sharedRmsProp
and kept the ReLU
after the FastLSTM
to have comparable setup to my usual async
testing.
Pretty cool that is works, I'll try if it performs similar with a flickering catch as they did in the paper with the flickering pong. Also in the async paper they added a half size LSTM layer after the linear instead of replacing it, will try that as well (although the DRQN paper says replacing is the best).
Will add support for the n-step methods as well, there it's a bit trickier to get right as there are steps taken forwards and backwards to calculate n-step returns, will have to take care that forwards/backwards are correct for LSTM as well.
from atari.
Also tried replacing FastLSTM
with GRU
with everything else being the same, that did not converge after running it longer interestingly.
from atari.
@lake4790k Do you have the flickering catch version somewhere?
from atari.
@JoostvDoorn haven't got around to it since, but probably takes a few lines to add to rlenvs.
from atari.
@JoostvDoorn I can add that to rlenvs.Catch
if you want? You may also be interested in the obscured
option I set up, which blanks a strip of screen at the bottom so that the agent has to infer the motion of the ball properly. Quick enough to test by adding opt.obscured = true
in Setup.lua
.
from atari.
@JoostvDoorn Done. Just get the latest version of rlenvs
and this repo. -flickering
is a probability between 0 and 1 of the screen blanking out.
from atari.
@Kaixhin Great thanks.
Have you tried storing the state instead of calling forget for every time step? I am doing this now, however it takes longer to train but it will probably converge. I agree this has to do with the changing state distribution, but we cannot really let the agent explore without considering the history to take full advantage of the LSTM.
from atari.
@JoostvDoorn I thought that this line would actually set remember
for all internal modules, but I'm not certain? If that is not the case then yes I agree it should be set on the LSTM units themselves.
In summy, in Agent:observe
, the only place that forget
is called is at a terminal state. Of course when learning it should call forget
before passing the minibatch through, and after learning as well. This means that memSampleFreq
is the maximum amount of history the LSTMs keep during training, but they receive the entire history during validation/evaluation.
from atari.
@Kaixhin Yes that line is enough, I will change that in my pull request.
I missed memSampleFreq
, so I assumed it was calling forget every time. I guess memSampleFreq >= histLen
is a good thing here, such that training, and updating have a similar distribution. Do note though that the 5th action will update based on the 2th, 3th, 4th, and 5th state in the Q-learning update, while the policy followed will be only be based on the 5th state, right?
from atari.
@JoostvDoorn Yep memSampleFreq >= histLen
would be sensible. Sorry not sure I understand your last question though. During learning updates for recurrent networks, histLen
is used to determine the sequence length of states fed in (no concatenating frames in time as with a normal DQN). During training the hidden state will go back until the last time forget
was called (and forget
is called every memSampleFreq
).
from atari.
I guess like this; forget is called at the first time step so the LSTM will not have accumulated any information at this point, once here it will start accumulating state information (note however on torch.uniform() < epsilon
we don't accumulate info, which is a bug). Now after calling Agent:learn
we call forget again. Then once the episode continues, and reaches the point here the state information is the same as in the start of the episode, depending on the environment this is a problem.
from atari.
Thanks for spotting the bug. @lake4790k please check 626712b to make sure async agents are accounted for as well.
@JoostvDoorn If I understand correctly then the issue is that the agent can't retain information during training because observe
is interspersed with forget
calls during learn
? That's what I was wondering about above. My reasoning comes from the rnn
docs. Also, it would be prohibitive to keep old states from before learn
and pass them all through the network before starting again.
from atari.
@Kaixhin yes this is needed for async, just created #47 to do it a bit differently.
from atari.
Related Issues (20)
- Implement Memory Q-networks
- Implement Retrace(λ)
- Finish prioritised experience replay HOT 2
- Allow non-visual environments
- Can I convert rank-based prioritized experience replay to a python version HOT 2
- Async A3C Network Outputs NaN HOT 4
- Load models like environments HOT 2
- Disagreements with the async paper HOT 2
- Possible improvements on speeding up HOT 1
- problem in Agent.lua HOT 1
- gnuplots memory unreleased HOT 1
- Why is the current sharedRmsprop thread safe? HOT 2
- Implement optimality tightening HOT 8
- What is the actual performance? HOT 7
- Refactor DQN train function into separate functions
- Partition number and segments HOT 1
- How to process with the salient map? HOT 4
- actor-critic based HOT 2
- About A3C HOT 1
- Questions about training A3C HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from atari.