Giter VIP home page Giter VIP logo

Comments (6)

Hyungjun-K1m avatar Hyungjun-K1m commented on July 24, 2024 1

Okay, We'll try to find best hyper-parameter settings to achieve 65.04% val. accuracy on R2B model.
If there's any improvement or news, we'll let you know.

Thanks for your help!

from zoo.

leonoverweel avatar leonoverweel commented on July 24, 2024

Hi, thanks for raising this. The training script on Zoo is indeed not what we used to train the results we reported internally - that code is quite tightly coupled to our training infrastructure.

Specifically for the FP baseline (I think it makes sense to get that training properly before looking at the rest), here are a few things I noticed looking at the logs of our internal run vs the default training script on Zoo:

  • We used weight decay instead of L2 regularization (LR 0.001, weight decay constant 1e-5)
  • We used efficientnet preprocessing - this is implemented in Zoo (here), but I'm not sure if they're used by default in these training scripts.
  • We used input shape 224 x 224 x 3, which may be the default but just in case.

from zoo.

Hyungjun-K1m avatar Hyungjun-K1m commented on July 24, 2024

Thanks for the quick reply!

I've checked what you advised and found that the second and the third points were already used by default setting.
Regarding the first point (weight decay and L2 regularization), what do you mean by 'using weight decay instead of L2 regularization'?
Do you mean that you used SGDW or ADAMW optimizer proposed by this paper?
Also, it seems like you suggest to use LR=0.001 instead of 0.1 which is used in the default setting.
Does that imply that you used ADAM optimizer with LR=0.001?

Thanks,
Best regards,
Hyungjun

from zoo.

leonoverweel avatar leonoverweel commented on July 24, 2024

No problem!

We used regular Adam as optimizer. For weight decay we added the following kernel_constraint to Larq layers:

import tensorflow as tf
from larq import utils as lq_utils

@lq_utils.register_keras_custom_object
class WeightDecay(tf.keras.constraints.Constraint):
    def __init__(
        self,
        max_learning_rate: float,
        weight_decay_constant: float,
        optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
    ) -> None:
        """Weight decay constraint which can be passed in as a `kernel_constraint`.
        This allows for applying weight decay correctly with any optimizer. This is not
        the case when using l2 regularization, which is not equivalent to using weight
        decay for anything other than SGD without momentum.
        When using this class, make sure to pass in the `learning_rate_variable` that is
        updated during training.
        :param max_learning_rate: maximum learning rate, used to normalize the current
            learning rate.
        :param optimizer: keras optimizer that has a lr variable (which can optionally be a schedule).
        :param weight_decay_constant: strength of the weight decay.
        """
        self.optimizer = optimizer if optimizer is not None else max_learning_rate
        self.max_learning_rate = max_learning_rate
        self.weight_decay_constant = weight_decay_constant

        if self.max_learning_rate <= 0:
            warnings.warn(
                "WeightDecay: no weight decay will be applied as the received learning rate is 0."
            )
            self.multiplier = 0
        else:
            self.multiplier = self.weight_decay_constant / self.max_learning_rate

    def __call__(self, x):
        if isinstance(
            self.optimizer.lr, tf.keras.optimizers.schedules.LearningRateSchedule,
        ):
            lr = self.optimizer.lr(self.optimizer.iterations)
        else:
            lr = self.optimizer.lr
        return (1.0 - lr * self.multiplier) * x

    def get_config(self):
        return {
            "max_learning_rate": self.max_learning_rate,
            "weight_decay_constant": self.weight_decay_constant,
        }

With max_learning_rate=0.001, weight_decay_constant=1e-5, and optimizer= our (Adam) optimizer instance.

You may want to try both 0.1 and 0.001 as learning rate (and max_learning_rate); I can't see which of those we used for the FP ResNet.

from zoo.

Hyungjun-K1m avatar Hyungjun-K1m commented on July 24, 2024

Based on your comments, we tried to use adam optimizer for resnet18 FP model.
First, we tried to train the model with regular Adam optimizer with L2 regularization. We didn't add the WeightDecay class you provided and just ran the training with LR=0.001, WD=1e-5. That results in 70.01% val. accuracy which is much higher than before.
Then, we tried to use weight decay instead of L2 regularization as you suggested. Actually, it seems like your team indeed followed this paper to use Adam optimizer with fixed weight decay (instead of L2 regularization). Can you confirm this?

Anyway, we added the code you provided in the real_to_bin_nets.py. And also modified the ResNet18FPFactory like this.

@factory
class ResNet18FPFactory(ResNet18Factory):
    model_name = Field("resnet_fp")
    input_quantizer = None
    kernel_quantizer = None
    optimizer = lambda self: tf.keras.optimizers.Adam(
        CosineDecayWithWarmup(
            max_learning_rate=self.learning_rate,
            warmup_steps=self.warmup_duration * self.steps_per_epoch,
            decay_steps=(self.epochs - self.warmup_duration) * self.steps_per_epoch,
        )
    )
    kernel_constraint = WeightDecay(max_learning_rate=1e-3, weight_decay_constant=1e-5, optimizer=optimizer)                                                                                                                                                                                                                 

Since we were not clear about how to pass the Adam optimizer instance declared here to the WeightDecay class argument, we defined the optimizer right before the kernel_constraint again.
We are not sure if this approach is correct way to use the weight decay with Adam optimizer.
With this configuration, we achieved 67.85% val. accuracy which is lower than the case without weight decay.
Please let me know if we've done anything wrong.

from zoo.

leonoverweel avatar leonoverweel commented on July 24, 2024

First, we tried to train the model with regular Adam optimizer with L2 regularization. We didn't add the WeightDecay class you provided and just ran the training with LR=0.001, WD=1e-5. That results in 70.01% val. accuracy which is much higher than before.

Ah, nice - since this is quite close to the expected 70.32%, I'd recommend just using this setup (Adam + L2) then. Our use of the WeightDecay class instead of L2 actually does not follow the paper, so since that didn't work for you anyway it might be better to just stick with L2.


As a note: we provided the classes in multi_stage_experiments.py primarily as an example of how to use Zoo's multi-stage infrastructure to implement a set of training steps - it is not the exact setup we used to train the pretrained weights available for our implementation of Real-to-Binary nets. Therefore, these exact steps are not expected to reproduce our weights or reported accuracies. (I've added this note to the code in #234.)

Because of some internal (cluster) infrastructure changes between when we ran those old experiments and now, it's quite difficult to find these exact settings from the experiment logs that are still available (which is not all). Helping you more with reproducing R2B's exact training results would require us to redo these runs internally, for which I'm afraid we don't currently have the resources available. When the authors share their code (brais-martinez/real2binary), hopefully you'll be to get some insights from there.

Anyhow, best of luck with your reproduction! If you do figure it out and notice any obvious mistakes or easy fixes in our example training code, we'd very happy to review and merge any PRs. :)

from zoo.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.