Comments (4)
Thanks for reaching out. I haven't observed this and I'm wondering whether you can provide a simple setup to reproduce this phenomenon.
BTW, there is a known issue that can be fixed by setting degenerated_to_sgd=False
(more discussions can be found at: #54)
from radam.
I have run into the same issue, trying to implement RAdam. Here's a pure (num)python implementation:
class RADAM:
def __init__(self, fg, x0, alpha, beta1=0.9, beta2=0.999):
"""Create a new RADAM optimizer.
Parameters
----------
fg : callable
a function which returns (f, g) where f is the scalar cost, and
g is the vector gradient.
x0 : callable
the parameter vector immediately prior to optimization
alpha : float
the step size
beta1 : float
the decay rate of the first moment (mean of gradient)
beta2 : float
the decay rate of the second moment (uncentered variance)
"""
self.fg = fg
self.x0 = x0
self.alpha = alpha
self.beta1 = beta1
self.beta2 = beta2
self.x = x0.copy()
self.m = np.zeros_like(x0)
self.v = np.zeros_like(x0)
self.eps = np.finfo(x0.dtype).eps
self.rhoinf = 2 / (1-beta2) - 1
self.iter = 0
def step(self):
"""Perform one iteration of optimization."""
self.iter += 1
k = self.iter
beta1 = self.beta1
beta2 = self.beta2
beta2k = beta2**k
f, g = self.fg(self.x)
# update momentum estimates
self.m = beta1*self.m + (1-beta1) * g
self.v = beta2*self.v + (1-beta2) * (g*g)
# torch exp_avg_sq.mul_(beta2).addcmul_(grad,grad,value=1-beta2)
# == v
mhat = self.m / (1 - beta1**k)
# going to use this many times, local lookup is cheaper
rhoinf = self.rhoinf
rho = rhoinf - (2*k*beta2k)/(1-beta2k)
x = self.x
if rho >= 5: # 5 was 4 in the paper, but PyTorch uses 5, most others too
# l = np.sqrt((1-beta2k)/self.v) # NOQA
# commented out l exactly as in paper
# seems to blow up all the time, must be a typo; missing sqrt(v)
# torch computes vhat same as ADAM, assume that's the typo
l = np.sqrt(1 - beta2k) / (np.sqrt(self.v)+self.eps) # NOQA
num = (rho - 4) * (rho - 2) * rhoinf
den = (rhoinf - 4) * (rhoinf - 2) * rho
r = np.sqrt(num/den)
self.x = x - self.alpha * r * mhat * l
else:
self.x = x - self.alpha * mhat
return x, f, g
def runN(optimizer, N):
for _ in range(N):
yield optimizer.step()
A minimum working example that blows up,
import numpy as np
from scipy.optimize import rosen, rosen_der
def fg(x):
f = rosen(x)
g = rosen_der(x)
return f,g
x0 = np.zeros(2)
x0[0]=-2
x0[1]=2
opt = RADAM(fg, x0, 1e-2)
hist = []
xh = []
for xk, fk, gk in runN(opt,1000):
hist.append(float(fk))
xh.append(xk.copy())
I do not observe this behavior with vanilla Adam, Yogi, Adagrad, RMSprop, or other optimizers. Any thoughts? @LiyuanLucasLiu
from radam.
@brandondube thanks for providing the example.
I believe this is a known issue and can be fixed by setting degenerated_to_sgd=False
(in your case, you can simply delete the else: self.x = x - self.alpha * mhat
part).
More discussions can be found at: #54 (comment).
from radam.
Thanks, that was it. II made a different choice, detuning g by its norm. This increases the range of stable learning rates, although not all that much.
invgnorm = 1 / np.sqrt(gsq.sum())
self.x = x - self.alpha * invgnorm * g
from radam.
Related Issues (20)
- What's the difference between RAdam and PlainRAdam? HOT 1
- Overload of addcmul_ is deprecated: HOT 2
- Cannot reproduce the PPL on One Billion Words HOT 1
- Hi HOT 1
- RAdam Instability vs AdamW / Adam HOT 8
- Algorithm 2 Arxiv paper 1/beta2 typo? HOT 2
- Why there are 10 slots in the buffer? HOT 1
- Any concern for using `math.sqrt` instead of `torch.sqrt` HOT 2
- Deprecated Warning in `RAdam` with torch==1.7.1 HOT 2
- Will radam be affacted by weight decay?
- simplify add_ HOT 1
- RAdam for pytorch official HOT 6
- Is RAdam needed when fitting perfectly a small batch e.g. 500 examples? HOT 3
- Question of RAdam's dependence on the number of examples HOT 1
- Should one be using RAdam or PlainRadam? HOT 1
- How to choose decay rate? (No success with RAdam - does one need a decay scheduler or gradient clipping) HOT 5
- Are the plots you have wrt epochs or iterations? HOT 1
- Does RAdam usually need an annealing and warm up scheduler? HOT 2
- Question regarding 2nd Moment Update HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from radam.