Giter VIP home page Giter VIP logo

Comments (4)

LiyuanLucasLiu avatar LiyuanLucasLiu commented on August 20, 2024

Thanks for reaching out. I haven't observed this and I'm wondering whether you can provide a simple setup to reproduce this phenomenon.

BTW, there is a known issue that can be fixed by setting degenerated_to_sgd=False (more discussions can be found at: #54)

from radam.

brandondube avatar brandondube commented on August 20, 2024

I have run into the same issue, trying to implement RAdam. Here's a pure (num)python implementation:

class RADAM:
    def __init__(self, fg, x0, alpha, beta1=0.9, beta2=0.999):
        """Create a new RADAM optimizer.

        Parameters
        ----------
        fg : callable
            a function which returns (f, g) where f is the scalar cost, and
            g is the vector gradient.
        x0 : callable
            the parameter vector immediately prior to optimization
        alpha : float
            the step size
        beta1 : float
            the decay rate of the first moment (mean of gradient)
        beta2 : float
            the decay rate of the second moment (uncentered variance)

        """
        self.fg = fg
        self.x0 = x0
        self.alpha = alpha
        self.beta1 = beta1
        self.beta2 = beta2
        self.x = x0.copy()
        self.m = np.zeros_like(x0)
        self.v = np.zeros_like(x0)
        self.eps = np.finfo(x0.dtype).eps
        self.rhoinf = 2 / (1-beta2) - 1
        self.iter = 0

    def step(self):
        """Perform one iteration of optimization."""
        self.iter += 1
        k = self.iter
        beta1 = self.beta1
        beta2 = self.beta2
        beta2k = beta2**k

        f, g = self.fg(self.x)
        # update momentum estimates
        self.m = beta1*self.m + (1-beta1) * g
        self.v = beta2*self.v + (1-beta2) * (g*g)
        # torch exp_avg_sq.mul_(beta2).addcmul_(grad,grad,value=1-beta2)
        # == v

        mhat = self.m / (1 - beta1**k)

        # going to use this many times, local lookup is cheaper
        rhoinf = self.rhoinf
        rho = rhoinf - (2*k*beta2k)/(1-beta2k)
        x = self.x
        if rho >= 5:  # 5 was 4 in the paper, but PyTorch uses 5, most others too
            # l = np.sqrt((1-beta2k)/self.v)  # NOQA
            # commented out l exactly as in paper
            # seems to blow up all the time, must be a typo; missing sqrt(v)
            # torch computes vhat same as ADAM, assume that's the typo
            l = np.sqrt(1 - beta2k) / (np.sqrt(self.v)+self.eps)  # NOQA
            num = (rho - 4) * (rho - 2) * rhoinf
            den = (rhoinf - 4) * (rhoinf - 2) * rho
            r = np.sqrt(num/den)
            self.x = x - self.alpha * r * mhat * l
        else:
            self.x = x - self.alpha * mhat
        return x, f, g

def runN(optimizer, N):
    for _ in range(N):
        yield optimizer.step()

A minimum working example that blows up,

import numpy as np
from scipy.optimize import rosen, rosen_der
def fg(x):
    f = rosen(x)
    g = rosen_der(x)
    return f,g

x0 = np.zeros(2)
x0[0]=-2
x0[1]=2

opt = RADAM(fg, x0, 1e-2)
hist = []
xh = []
for xk, fk, gk in runN(opt,1000):
    hist.append(float(fk))
    xh.append(xk.copy())

I do not observe this behavior with vanilla Adam, Yogi, Adagrad, RMSprop, or other optimizers. Any thoughts? @LiyuanLucasLiu

from radam.

LiyuanLucasLiu avatar LiyuanLucasLiu commented on August 20, 2024

@brandondube thanks for providing the example.

I believe this is a known issue and can be fixed by setting degenerated_to_sgd=False (in your case, you can simply delete the else: self.x = x - self.alpha * mhat part).

More discussions can be found at: #54 (comment).

from radam.

brandondube avatar brandondube commented on August 20, 2024

Thanks, that was it. II made a different choice, detuning g by its norm. This increases the range of stable learning rates, although not all that much.

            invgnorm = 1 / np.sqrt(gsq.sum())
            self.x = x - self.alpha * invgnorm * g

from radam.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.