Comments (6)
Okay, We'll try to find best hyper-parameter settings to achieve 65.04% val. accuracy on R2B model.
If there's any improvement or news, we'll let you know.
Thanks for your help!
from zoo.
Hi, thanks for raising this. The training script on Zoo is indeed not what we used to train the results we reported internally - that code is quite tightly coupled to our training infrastructure.
Specifically for the FP baseline (I think it makes sense to get that training properly before looking at the rest), here are a few things I noticed looking at the logs of our internal run vs the default training script on Zoo:
- We used weight decay instead of L2 regularization (LR 0.001, weight decay constant 1e-5)
- We used efficientnet preprocessing - this is implemented in Zoo (here), but I'm not sure if they're used by default in these training scripts.
- We used input shape 224 x 224 x 3, which may be the default but just in case.
from zoo.
Thanks for the quick reply!
I've checked what you advised and found that the second and the third points were already used by default setting.
Regarding the first point (weight decay and L2 regularization), what do you mean by 'using weight decay instead of L2 regularization'?
Do you mean that you used SGDW or ADAMW optimizer proposed by this paper?
Also, it seems like you suggest to use LR=0.001 instead of 0.1 which is used in the default setting.
Does that imply that you used ADAM optimizer with LR=0.001?
Thanks,
Best regards,
Hyungjun
from zoo.
No problem!
We used regular Adam as optimizer. For weight decay we added the following kernel_constraint
to Larq layers:
import tensorflow as tf
from larq import utils as lq_utils
@lq_utils.register_keras_custom_object
class WeightDecay(tf.keras.constraints.Constraint):
def __init__(
self,
max_learning_rate: float,
weight_decay_constant: float,
optimizer: Optional[tf.keras.optimizers.Optimizer] = None,
) -> None:
"""Weight decay constraint which can be passed in as a `kernel_constraint`.
This allows for applying weight decay correctly with any optimizer. This is not
the case when using l2 regularization, which is not equivalent to using weight
decay for anything other than SGD without momentum.
When using this class, make sure to pass in the `learning_rate_variable` that is
updated during training.
:param max_learning_rate: maximum learning rate, used to normalize the current
learning rate.
:param optimizer: keras optimizer that has a lr variable (which can optionally be a schedule).
:param weight_decay_constant: strength of the weight decay.
"""
self.optimizer = optimizer if optimizer is not None else max_learning_rate
self.max_learning_rate = max_learning_rate
self.weight_decay_constant = weight_decay_constant
if self.max_learning_rate <= 0:
warnings.warn(
"WeightDecay: no weight decay will be applied as the received learning rate is 0."
)
self.multiplier = 0
else:
self.multiplier = self.weight_decay_constant / self.max_learning_rate
def __call__(self, x):
if isinstance(
self.optimizer.lr, tf.keras.optimizers.schedules.LearningRateSchedule,
):
lr = self.optimizer.lr(self.optimizer.iterations)
else:
lr = self.optimizer.lr
return (1.0 - lr * self.multiplier) * x
def get_config(self):
return {
"max_learning_rate": self.max_learning_rate,
"weight_decay_constant": self.weight_decay_constant,
}
With max_learning_rate=0.001
, weight_decay_constant=1e-5
, and optimizer=
our (Adam) optimizer instance.
You may want to try both 0.1 and 0.001 as learning rate (and max_learning_rate
); I can't see which of those we used for the FP ResNet.
from zoo.
Based on your comments, we tried to use adam optimizer for resnet18 FP model.
First, we tried to train the model with regular Adam optimizer with L2 regularization. We didn't add the WeightDecay class you provided and just ran the training with LR=0.001, WD=1e-5. That results in 70.01% val. accuracy which is much higher than before.
Then, we tried to use weight decay instead of L2 regularization as you suggested. Actually, it seems like your team indeed followed this paper to use Adam optimizer with fixed weight decay (instead of L2 regularization). Can you confirm this?
Anyway, we added the code you provided in the real_to_bin_nets.py. And also modified the ResNet18FPFactory like this.
@factory
class ResNet18FPFactory(ResNet18Factory):
model_name = Field("resnet_fp")
input_quantizer = None
kernel_quantizer = None
optimizer = lambda self: tf.keras.optimizers.Adam(
CosineDecayWithWarmup(
max_learning_rate=self.learning_rate,
warmup_steps=self.warmup_duration * self.steps_per_epoch,
decay_steps=(self.epochs - self.warmup_duration) * self.steps_per_epoch,
)
)
kernel_constraint = WeightDecay(max_learning_rate=1e-3, weight_decay_constant=1e-5, optimizer=optimizer)
Since we were not clear about how to pass the Adam optimizer instance declared here to the WeightDecay class argument, we defined the optimizer right before the kernel_constraint again.
We are not sure if this approach is correct way to use the weight decay with Adam optimizer.
With this configuration, we achieved 67.85% val. accuracy which is lower than the case without weight decay.
Please let me know if we've done anything wrong.
from zoo.
First, we tried to train the model with regular Adam optimizer with L2 regularization. We didn't add the WeightDecay class you provided and just ran the training with LR=0.001, WD=1e-5. That results in 70.01% val. accuracy which is much higher than before.
Ah, nice - since this is quite close to the expected 70.32%, I'd recommend just using this setup (Adam + L2) then. Our use of the WeightDecay class instead of L2 actually does not follow the paper, so since that didn't work for you anyway it might be better to just stick with L2.
As a note: we provided the classes in multi_stage_experiments.py
primarily as an example of how to use Zoo's multi-stage infrastructure to implement a set of training steps - it is not the exact setup we used to train the pretrained weights available for our implementation of Real-to-Binary nets. Therefore, these exact steps are not expected to reproduce our weights or reported accuracies. (I've added this note to the code in #234.)
Because of some internal (cluster) infrastructure changes between when we ran those old experiments and now, it's quite difficult to find these exact settings from the experiment logs that are still available (which is not all). Helping you more with reproducing R2B's exact training results would require us to redo these runs internally, for which I'm afraid we don't currently have the resources available. When the authors share their code (brais-martinez/real2binary), hopefully you'll be to get some insights from there.
Anyhow, best of luck with your reproduction! If you do figure it out and notice any obvious mistakes or easy fixes in our example training code, we'd very happy to review and merge any PRs. :)
from zoo.
Related Issues (20)
- Unexpected behavior of the "include_top" argument HOT 3
- Unexpected behavior of the "preprocess_input" function HOT 1
- Make ordering of docstring constistent
- Snapshot tests of model summaries HOT 2
- RFC: structure change HOT 3
- Add model accuracies to docstrings
- QuickNet(Large) models don't match released h5 files
- Support TensorFlow 2.2 HOT 3
- QuickNet model and flip_ratio metric do not work together HOT 3
- Update weights and parameters in docstrings
- QuickNet no-top models pretrained weights are not working as expected HOT 1
- No 'sota' module HOT 2
- Speech Models HOT 2
- About RealToBinaryNet model HOT 14
- Intermediate results of training R2B model HOT 5
- Data directory HOT 1
- Help, no logs are printed! HOT 2
- The usage of data.cache() causes the run out of memory. HOT 5
- Drop Python 3.6 support
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from zoo.