Giter VIP home page Giter VIP logo

retinaface-pytorch's Introduction

Retinaface:人脸检测模型在Pytorch当中的实现


目录

  1. 仓库更新 Top News
  2. 性能情况 Performance
  3. 所需环境 Environment
  4. 文件下载 Download
  5. 预测步骤 How2predict
  6. 训练步骤 How2train
  7. 评估步骤 Eval
  8. 参考资料 Reference

Top News

2022-03:进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整。
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/retinaface-pytorch/tree/bilibili

2020-09:仓库创建,支持模型训练,大量的注释,多个主干的选择,多个可调整参数。

性能情况

训练数据集 权值文件名称 测试数据集 输入图片大小 Easy Medium Hard
Widerface-Train Retinaface_mobilenet0.25.pth Widerface-Val 1280x1280 89.76% 86.96% 74.69%
Widerface-Train Retinaface_resnet50.pth Widerface-Val 1280x1280 94.72% 93.13% 84.48%

所需环境

pytorch==1.2.0

文件下载

训练所需的Retinaface_resnet50.pth等文件可以在百度云下载。
链接: https://pan.baidu.com/s/1Jt9Bo2UVP03bmEMuUpk_9Q 提取码: qknw

数据集可以在如下连接里下载。
链接: https://pan.baidu.com/s/1bsgay9iMihPlAKE49aWNTA 提取码: bhee

预测步骤

a、使用预训练权重

  1. 下载完库后解压,运行predict.py,输入
img/timg.jpg
  1. 在predict.py里面进行设置可以进行fps测试和video视频检测。

b、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在retinaface.py文件里面,在如下部分修改model_path和backbone使其对应训练好的文件。
_defaults = {
    "model_path"        : 'model_data/Retinaface_mobilenet0.25.pth',
    "backbone"          : 'mobilenet',
    "confidence"        : 0.5,
    "nms_iou"           : 0.45,
    "cuda"              : True,
    #----------------------------------------------------------------------#
    #   是否需要进行图像大小限制。
    #   开启后,会将输入图像的大小限制为input_shape。否则使用原图进行预测。
    #   可根据输入图像的大小自行调整input_shape,注意为32的倍数,如[640, 640, 3]
    #----------------------------------------------------------------------#
    "input_shape"       : [1280, 1280, 3],
    "letterbox_image"   : True
}
  1. 运行predict.py,输入
img/timg.jpg
  1. 在predict.py里面进行设置可以进行fps测试和video视频检测。

训练步骤

  1. 本文使用widerface数据集进行训练。
  2. 可通过上述百度网盘下载widerface数据集。
  3. 覆盖根目录下的data文件夹。
  4. 根据自己需要选择从头开始训练还是在已经训练好的权重下训练,需要修改train.py文件下的代码,在训练时需要注意backbone和权重文件的对应。 使用mobilenet为主干特征提取网络的示例如下:
    从头开始训练需要将pretrained设置为True,并且注释train.py里面的权值载入部分:
backbone = "mobilenet"
#-------------------------------#
#   是否使用主干特征提取网络
#   的预训练权重
#-------------------------------#
pretrained = True
model = RetinaFace(cfg=cfg, pretrained = pretrained).train()

在已经训练好的权重下训练:

backbone = "mobilenet"
#-------------------------------------------#
#   权值文件的下载请看README
#   权值和主干特征提取网络一定要对应
#-------------------------------------------#
model = RetinaFace(cfg=cfg, pretrained = pretrained).train()
model_path = "model_data/Retinaface_mobilenet0.25.pth"
# 加快模型训练的效率
print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
print('Finished!')
  1. 可以在logs文件夹里面获得训练好的权值文件。

评估步骤

  1. 在retinaface.py文件里面,在如下部分修改model_path和backbone使其对应训练好的文件。
 = {
    "model_path"        : 'model_data/Retinaface_mobilenet0.25.pth',
    "backbone"          : 'mobilenet',
    "confidence"        : 0.5,
    "nms_iou"           : 0.45,
    "cuda"              : True,
    #----------------------------------------------------------------------#
    #   是否需要进行图像大小限制。
    #   开启后,会将输入图像的大小限制为input_shape。否则使用原图进行预测。
    #   可根据输入图像的大小自行调整input_shape,注意为32的倍数,如[640, 640, 3]
    #----------------------------------------------------------------------#
    "input_shape"       : [1280, 1280, 3],
    "letterbox_image"   : True
}
  1. 下载好百度网盘上上传的数据集,其中包括了验证集,解压在根目录下。
  2. 运行evaluation.py即可开始评估。

Reference

https://github.com/biubug6/Pytorch_Retinaface

retinaface-pytorch's People

Contributors

bubbliiiing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

retinaface-pytorch's Issues

训练数据集压缩包文件损坏

训练数据集压缩包文件下载后无法解压,提示压缩文件损坏了,强行解压之后发现里面的图片是不全的,请问可以重新再上传一下数据集文件吗?谢谢您!

数据集的大小

你好,请问如果用自己的数据大约需要多少图片能达到挺好的效果,loss最低能降到多少?

640x640输入的准确率有下降

  1. 使用1280x1280的输入 + cpu, 基本能复现大佬的数据
    ==================== Results ====================
    Easy Val AP: 0.897754843526703
    Medium Val AP: 0.8698380489389713
    Hard Val AP: 0.747567637370379
    =================================================

  2. 修改为640x640输入+ cpu后, 准确率下降了不少
    ==================== Results ====================
    Easy Val AP: 0.8683975540701973
    Medium Val AP: 0.7995705363978524
    Hard Val AP: 0.5136680165318652
    =================================================

批量图片检测人脸

@bubbliiiing ,您好,代码考虑支持批量图片检测人脸吗?如果支持批量图片检测人脸,retinaface.py下的detect_img该怎么改动?

随着训练,显存占用增大?

哈喽,老师~
请问您有遇到过随着训练epoch的增加,显存逐渐增大的问题嘛?
我更换了backbone后,会不断出现这样的情况~~

关于从训练

问题描述:

不使用预训练模型,从头开始训练,请问要达到您提供的精度需要调整什么参数,我目前没有达到项目给出的预训练精度。

期待你的回复。

祝好!

人脸库质量

请问现在我有同一个人的十张人脸照片,我想找出哪张人脸是质量最好的、最适合放进人脸库的照片。如何实现呢?我的想法是通过五个关键点的位置来对比,有没有其他更好的方法呢

从头训练

up您好,您从头训练mobilenet-retinaface的时候是直接使用wider-face训练的吗,有没有先训练backbone的预训练模型?我从头训练发现模型达不到您所列的精度,求指教

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.