Giter VIP home page Giter VIP logo

facenet-keras's Introduction

Facenet:人脸识别模型在Keras当中的实现


目录

  1. 仓库更新 Top News
  2. 相关仓库 Related code
  3. 性能情况 Performance
  4. 所需环境 Environment
  5. 注意事项 Attention
  6. 文件下载 Download
  7. 预测步骤 How2predict
  8. 训练步骤 How2train
  9. 参考资料 Reference

Top News

2022-03:进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整。
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/facenet-keras/tree/bilibili

2021-02:创建仓库,支持模型训练,大量的注释,多个可调整参数,lfw数据集评估等。

相关仓库

模型 路径
facenet https://github.com/bubbliiiing/facenet-keras
arcface https://github.com/bubbliiiing/arcface-keras
retinaface https://github.com/bubbliiiing/retinaface-keras
facenet + retinaface https://github.com/bubbliiiing/facenet-retinaface-keras

性能情况

训练数据集 权值文件名称 测试数据集 输入图片大小 accuracy
CASIA-WebFace facenet_mobilenet.h5 LFW 160x160 97.86%
CASIA-WebFace facenet_inception_resnetv1.h5 LFW 160x160 99.02%

所需环境

tensorflow-gpu==1.13.1
keras==2.1.5

文件下载

已经训练好的facenet_mobilenet.h5和facenet_inception_resnetv1.h5可以在百度网盘下载。
链接: https://pan.baidu.com/s/1XzwLpU3zPFW7QK045UDgkQ 提取码: 596k

训练用的CASIA-WebFaces数据集以及评估用的LFW数据集可以在百度网盘下载。
链接: https://pan.baidu.com/s/1qMxFR8H_ih0xmY-rKgRejw 提取码: bcrq

预测步骤

a、使用预训练权重

  1. 下载完库后解压,在model_data文件夹里已经有了facenet_mobilenet.h5,可直接运行predict.py输入:
img\1_001.jpg
img\1_002.jpg
  1. 也可以在百度网盘下载facenet_inception_resnetv1.h5,放入model_data,修改facenet.py文件的model_path后,输入:
img\1_001.jpg
img\1_002.jpg

b、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在facenet.py文件里面,在如下部分修改model_path和backbone使其对应训练好的文件;model_path对应logs文件夹下面的权值文件,backbone对应主干特征提取网络
_defaults = {
    "model_path"    : "model_data/facenet_mobilenet.h5",
    "input_shape"   : [160,160,3],
    "backbone"      : "mobilenet"
}
  1. 运行predict.py,输入
img\1_001.jpg
img\1_002.jpg

训练步骤

  1. 本文使用如下格式进行训练。
|-datasets
    |-people0
        |-123.jpg
        |-234.jpg
    |-people1
        |-345.jpg
        |-456.jpg
    |-...
  1. 下载好数据集,将训练用的CASIA-WebFaces数据集以及评估用的LFW数据集,解压后放在根目录。
  2. 在训练前利用txt_annotation.py文件生成对应的cls_train.txt。
  3. 利用train.py训练facenet模型,训练前,根据自己的需要选择backbone,model_path和backbone一定要对应。
  4. 运行train.py即可开始训练。

评估步骤

  1. 下载好评估数据集,将评估用的LFW数据集,解压后放在根目录
  2. 在eval_LFW.py设置使用的主干特征提取网络和网络权值。
  3. 运行eval_LFW.py来进行模型准确率评估。

Reference

https://github.com/davidsandberg/facenet
https://github.com/timesler/facenet-pytorch

facenet-keras's People

Contributors

bubbliiiing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

facenet-keras's Issues

up太棒了!

从b站过来的,看到更新立马上Github来了哈哈,大赞大赞

怎么将.h5转换成.tflite模型?

作者您好!

​ 我目前已经训练好模型,并成功运行predict.py,但是您的代码保存的是.h5模型,我想把他转换成.tflite模型,于是我自己写了一个转换脚本(freeze.py):

import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import CustomObjectScope

def relu6(x):
  return K.relu(x, max_value=6)

with CustomObjectScope({'relu6': relu6}):
	converter = tf.lite.TFLiteConverter.from_keras_model_file("logs/ep010-loss0.387-val_loss0.758.h5")
	tflite_model = converter.convert()
	open("./facenet_keras.tflite", "wb").write(tflite_model)

运行此脚本报错,报错信息如下:

Traceback (most recent call last):
  File "freeze.py", line 42, in <module>
    converter = tf.lite.TFLiteConverter.from_keras_model_file("logs/ep010-loss0.387-val_loss0.758.h5")
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\lite\python\lite.py", line 370, in from_keras_model_file
    keras_model = _keras.models.load_model(model_file)
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 232, in load_model
    raise ValueError('No model found in config file.')
ValueError: No model found in config file.

经过我一番百度之后,发现在代码里边,是只保存了模型的权重,原来代码如下:

checkpoint = ModelCheckpoint(os.path.join(save_dir, "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5"), monitor = 'val_loss', save_weights_only = True, save_best_only = False, period = save_period)

from_keras_model_file这个函数会调用load_model这个函数,所以载入只有权重的模型会报错,于是我作出了以下修改,将 save_weights_only = True改为 save_weights_only = False

checkpoint = ModelCheckpoint(os.path.join(save_dir, "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5"), monitor = 'val_loss', save_weights_only = False, save_best_only = False, period = save_period)

重新训练了后,得到了若干.h5模型,我再次运行此脚本,得到的报错如下:

Traceback (most recent call last):
  File "freeze.py", line 42, in <module>
    converter = tf.lite.TFLiteConverter.from_keras_model_file("logs/ep009-loss0.231-val_loss0.473.h5")
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\lite\python\lite.py", line 370, in from_keras_model_file
    keras_model = _keras.models.load_model(model_file)
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\python\keras\engine\saving.py", line 266, in load_model
    sample_weight_mode=sample_weight_mode)
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\python\training\checkpointable\base.py", line 442, in _method_wrapper
    method(self, *args, **kwargs)
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\python\keras\engine\training.py", line 273, in compile
    loss_functions.append(training_utils.get_loss_function(loss.get(name)))
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 873, in get_loss_function
    return losses.get(loss)
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\python\keras\losses.py", line 594, in get
    return deserialize(identifier)
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\python\keras\losses.py", line 585, in deserialize
    printable_module_name='loss function')
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 212, in deserialize_keras_object
    function_name)
ValueError: Unknown loss function:_triplet_loss

这个报错说是找到损失函数。于是我尝试在脚本里边先load_model,再将这个model传给from_keras_model_file,代码如下:

import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import CustomObjectScope
from keras.models import load_model

from nets.facenet_training import triplet_loss

def relu6(x):
  return K.relu(x, max_value=6)

with CustomObjectScope({'relu6': relu6}):
    model = load_model("logs/ep009-loss0.231-val_loss0.473.h5",custom_objects={'relu6': relu6,'triplet_loss':triplet_loss})
    converter = tf.lite.TFLiteConverter.from_keras_model_file(model)
    tflite_model = converter.convert()
    open("./facenet_keras.tflite", "wb").write(tflite_model)

得到的报错依然是这个:

Traceback (most recent call last):
  File "freeze.py", line 45, in <module>
    model = load_model("logs/ep009-loss0.231-val_loss0.473.h5",custom_objects={'relu6': relu6,'triplet_loss':triplet_loss})
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\keras\models.py", line 274, in load_model
    sample_weight_mode=sample_weight_mode)
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\keras\engine\training.py", line 626, in compile
    loss_functions.append(losses.get(loss.get(name)))
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\keras\losses.py", line 122, in get
    return deserialize(identifier)
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\keras\losses.py", line 114, in deserialize
    printable_module_name='loss function')
  File "E:\Anaconda3\envs\facenet-keras\lib\site-packages\keras\utils\generic_utils.py", line 164, in deserialize_keras_object
    ':' + function_name)
ValueError: Unknown loss function:_triplet_loss

希望作者能指点一下迷津,谢谢!

能否使用自己搭建的backbone网络训练,能否使用自己的数据集

我用自己搭建的mobilenet-v2跑训练,报错Graph disconnected: cannot obtain value for tensor Tensor("input_2:0", shape=(?, 224, 224, 3), dtype=float32) at layer "input_2". The following previous layers were accessed without issue: []

另外我想训练自己的数据集,可以直接替换database文件,并按up主的格式整理数据,然后就可以训练了吗

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.