Giter VIP home page Giter VIP logo

cawb's Introduction

余弦退火从启动学习率机制

【导语】主要介绍 ** 在pytorch 中实现了余弦退火从启动学习率机制,支持 warmup 和 resume 训练。并且支持自定义下降函数,实现多种重启动机制。 觉得好用记得点个 star 哈...

代码: https://github.com/Huangdebo/CAWB

1. 多 step 重启动

设定 cawb_steps 之后,便可实现多步长余弦退火重启动学习率机制;调整epoch_scale可以实现更复杂的变化机制 设定 cawb_steps 之后,便可实现多步长余弦退火重启动学习率机制。每次重启动时,开始学习率会乘上一个比例因子 step_scale。调整 step_scale 和 epoch_scale 等参数,可以实现学习率跳变的时候是上升还是下降。也可以调整中间的 step 不用走完一个退火过程,保持较高的学习率,实现更复杂的学习率变化机制。

2. 正常余弦退火机制

如果 cawb_steps 为 [], 则会实现正常的余弦退火机制 如果 cawb_steps 为 [], 则会实现正常的余弦退火机制,在整个 epochs 中按设定的 lf 机制一直下降

3. warmup

设定 warmup_epoch 之后便可实现学习率的 warmup 机制 设定 warmup_epoch 之后便可实现学习率的 warmup 机制。warmup_epoch 结束后则按设定的 cawb_steps 实现重启动退火机制。

4. resume

设定 last_epoch 便可实现 resume 训练 设定 last_epoch 便可实现 resume 训练,接上之前中断的训练中的学习率。

5. 自定义下降函数

自定义下降函数 可通过自定义下降函数,实现多种重启动机制

# lf = lambda x, y=opt.epochs: (((1 + math.cos(x * math.pi / y)) / 2) ** 1.0) * 0.9 + 0.1  
lf = lambda x, y=opt.epochs: (1.0 - (x / y)) * 0.8 + 0.2 
scheduler = CosineAnnealingWarmbootingLR(optimizer, epochs=opt.epochs, step_scale=0.7, 
                                         steps=opt.cawb_steps, lf=lf, batchs=len(data), warmup_epoch=0)

6. 实验结果

本实验是在 COCO2017中随机选出 10000 图像和 1000 张图像分别作为训练集和验证集。检测网络使用 yolov5s,学习率调整机制分别原版的 cos 和 本文实现的 CAWB。

6.1 学习率:

# yolov5
lf = lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

# 本文
lf = lambda x, y=opt.epochs: (((1 + math.cos(x * math.pi / y)) / 2) ** 1.0) * 0.65 + 0.35 
scheduler = CosineAnnealingWarmbootingLR(optimizer, epochs=opt.epochs, steps=opt.cawb_steps, step_scale=0.7,
                                         lf=lf, batchs=len(train_loader), warmup_epoch=3, epoch_scale=4.0)

cos 和 cawb 学习率

6.2 map:

6.2.1 cos:

mAP_0.5 = 0.294;  mAP_0.5:0.95 = 0.161

mAP_0.5 = 0.294; mAP_0.5:0.95 = 0.161

6.2.2 CAWB :

mAP_0.5 = 0.302; mAP_0.5:0.95 = 0.165

mAP_0.5 = 0.302; mAP_0.5:0.95 = 0.165

7. 结论

在实验中使用了 CAWB 学习率机制时候,mAP_0.5 和 mAP_0.5:0.95 都提升了一丢丢,而且上升趋势更加明显,增加 epochs 可能提升更大。

改变 CAWB 的参数可以实现更多形式的学习率变化机制。增加学习率突变就是想增加网络跳出局部最优的概率,所以不同数据集可能合适不同的变化机制。小伙伴们在其他数据集上尝试之后,记得来提个 issue 哈...

代码: https://github.com/Huangdebo/CAWB

cawb's People

Contributors

huangdebo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

cawb's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.