Giter VIP home page Giter VIP logo

facenet's Introduction

Face Recognition using Tensorflow Build Status

This is a TensorFlow implementation of the face recognizer described in the paper "FaceNet: A Unified Embedding for Face Recognition and Clustering". The project also uses ideas from the paper "A Discriminative Feature Learning Approach for Deep Face Recognition" as well as the paper "Deep Face Recognition" from the Visual Geometry Group at Oxford.

Compatibility

The code is tested using Tensorflow r1.2 under Ubuntu 14.04 with Python 2.7 and Python 3.5. The test cases can be found here and the results can be found here.

News

Date Update
2017-05-13 Removed a bunch of older non-slim models. Moved the last bottleneck layer into the respective models. Corrected normalization of Center Loss.
2017-05-06 Added code to train a classifier on your own images. Renamed facenet_train.py to train_tripletloss.py and facenet_train_classifier.py to train_softmax.py.
2017-03-02 Added pretrained models that generate 128-dimensional embeddings.
2017-02-22 Updated to Tensorflow r1.0. Added Continuous Integration using Travis-CI.
2017-02-03 Added models where only trainable variables has been stored in the checkpoint. These are therefore significantly smaller.
2017-01-27 Added a model trained on a subset of the MS-Celeb-1M dataset. The LFW accuracy of this model is around 0.994.
2017‑01‑02 Updated to code to run with Tensorflow r0.12. Not sure if it runs with older versions of Tensorflow though.

Pre-trained models

Model name LFW accuracy Training dataset Architecture
20170511-185253 0.987 CASIA-WebFace Inception ResNet v1
20170512-110547 0.992 MS-Celeb-1M Inception ResNet v1

Inspiration

The code is heavily inspired by the OpenFace implementation.

Training data

The CASIA-WebFace dataset has been used for training. This training set consists of total of 453 453 images over 10 575 identities after face detection. Some performance improvement has been seen if the dataset has been filtered before training. Some more information about how this was done will come later. The best performing model has been trained on a subset of the MS-Celeb-1M dataset. This dataset is significantly larger but also contains significantly more label noise, and therefore it is crucial to apply dataset filtering on this dataset.

Pre-processing

Face alignment using MTCNN

One problem with the above approach seems to be that the Dlib face detector misses some of the hard examples (partial occlusion, silhouettes, etc). This makes the training set to "easy" which causes the model to perform worse on other benchmarks. To solve this, other face landmark detectors has been tested. One face landmark detector that has proven to work very well in this setting is the Multi-task CNN. A Matlab/Caffe implementation can be found here and this has been used for face alignment with very good results. A Python/Tensorflow implementation of MTCNN can be found here. This implementation does not give identical results to the Matlab/Caffe implementation but the performance is very similar.

Running training

Currently, the best results are achieved by training the model as a classifier with the addition of Center loss. Details on how to train a model as a classifier can be found on the page Classifier training of Inception-ResNet-v1.

Pre-trained model

Inception-ResNet-v1 model

A couple of pretrained models are provided. They are trained using softmax loss with the Inception-Resnet-v1 model. The datasets has been aligned using MTCNN.

Performance

The accuracy on LFW for the model 20170512-110547 is 0.992+-0.003. A description of how to run the test can be found on the page Validate on LFW.

facenet's People

Contributors

davidsandberg avatar astorfi avatar rishirai10 avatar sunnylgz avatar scotthong avatar rakshaktalwar avatar theatomicoption avatar fgervais avatar cjekel avatar henrych4 avatar e271828- avatar bushibushi avatar artur-trzesiok avatar zhly0 avatar lmxhappy avatar jithinodattu avatar irmowan avatar tatsuyashirakawa avatar shaform avatar rtkaleta avatar justinshenk avatar korrawat avatar rmekdma avatar farizrahman4u avatar diegolelis avatar apollo-time avatar

Watchers

James Cloos avatar

facenet's Issues

facenet architecture

facenet hase two network architecture

  1. inception_resnet_v1

  2. inception_resnet_v2
    def iception_resnet_v1(inputs,is_training=True,
    dropout_keep_prob=0.8,
    bottleneck_layer_size=128,
    reuse=None,
    scope='InceptionResNetV1'):
    bottleneck_layer_size is the full connect output size,and is the embedding size.

    #构造计算图,prelogits为最后一层的输出
    
     prelogits, _ = network.inference(image_batch, args.keep_probability, 
         phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size,
         weight_decay=args.weight_decay)
     # 对最后的输出进行标准化,即为该图像的embedding
     embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')
     # Split embeddings into anchor, positive and negative and calculate triplet loss
     # 将输出的embeddings分为anchor,正样本, 负样本三个部分
     anchor, positive, negative = tf.unstack(tf.reshape(embeddings, [-1,3,args.embedding_size]), 3, 1)
     #根据上面三个部分计算triplet-loss
    
     triplet_loss = facenet.triplet_loss(anchor, positive, negative, args.alpha)
     #定义优化方法
     learning_rate = tf.train.exponential_decay(learning_rate_placeholder, global_step,
         args.learning_rate_decay_epochs*args.epoch_size, args.learning_rate_decay_factor, 
         staircase=True)
     tf.summary.scalar('learning_rate', learning_rate)
    
     # Calculate the total losses
     #加入正则化损失
     regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
     # 整体的损失即为triplet-loss+正则损失
     total_loss = tf.add_n([triplet_loss] + regularization_losses, name='total_loss')
    
     # Build a Graph that trains the model with one batch of examples and updates the model 
     # parameters
     # 用上述定义的优化方法和loss进行优化
     train_op = facenet.train(total_loss, global_step, args.optimizer, 
         learning_rate, args.moving_average_decay, tf.global_variables())
    

facenet train

facenet use triplet-loss to train net,minimizes the distance between anchor and a positive,maximizes the distance between anchor and a negative of a different identity.
while the network is training,the embedding of anchor and positive,negative sample is
change too,there is two solutions to generate triplets,the first is generate triplets offline every n steps,using the most recent network checkpoint and computing the argmin and argmax on a subset of the data,second is generate triplet online,this can be done bye seletcting the hard positive/negative exemplars from within a mini-batch.facenet choice second.every minibatch,according to the current embedding,generate triplets,calculate triplet-loss,update embedding,

  1. minibatch begin,facenet select a group from data set.

    #从数据集中进行抽样图片,参数为训练数据集,每一个batch抽样多少人,每个人抽样多少张
    def sample_people(dataset, people_per_batch, images_per_person):
    #总共应该抽样多少张
    nrof_images = people_per_batch * images_per_person
    #数据集中一共有多少人的图像
    nrof_classes = len(dataset)
    #每个人的索引
    class_indices = np.arange(nrof_classes)
    #随机打乱一下
    np.random.shuffle(class_indices)
    i = 0
    #保存抽样出来的图像的路径
    image_paths = []
    #抽样的样本是属于哪一个人的,作为label
    num_per_class = []
    sampled_class_indices = []
    #Sample images from these classes until we have enough
    #不断抽样直到达到指定数量
    while len(image_paths)<nrof_images:
    #从第i个人开始抽样
    class_index = class_indices[i]
    #第i个人有多少张图片
    nrof_images_in_class = len(dataset[class_index])
    #这些图片的索引
    image_indices = np.arange(nrof_images_in_class)
    np.random.shuffle(image_indices)
    #从第i个人中抽样的图片数量
    nrof_images_from_class = min(nrof_images_in_class, images_per_person, nrof_images-len(image_paths))
    idx = image_indices[0:nrof_images_from_class]
    #抽样出来的人的路径
    image_paths_for_class = [dataset[class_index].image_paths[j] for j in idx]
    #图片的label
    sampled_class_indices += [class_index]*nrof_images_from_class
    image_paths += image_paths_for_class
    #第i个人抽样了多少张
    num_per_class.append(nrof_images_from_class)
    i+=1
    return image_paths, num_per_class

  2. calculate embedding, save to emb_array,according to the array,calculate triplet

    #多少人,alpha参数
    def select_triplets(embeddings, nrof_images_per_class, image_paths, people_per_batch, alpha):
    """ Select the triplets for training
    """
    trip_idx = 0
    #某个人的图片的embedding在emb_arr中的开始的索引
    emb_start_idx = 0
    num_trips = 0
    triplets = []

    #VGG Face: Choosing good triplets is crucial and should strike a balance between
    #selecting informative (i.e. challenging) examples and swamping training with examples that
    #are too hard. This is achieve by extending each pair (a, p) to a triplet (a, p, n) by sampling
    #the image n at random, but only between the ones that violate the triplet loss margin. The
    #latter is a form of hard-negative mining, but it is not as aggressive (and much cheaper) than
    #choosing the maximally violating example, as often done in structured output learning.
    #遍历每一个人
    for i in xrange(people_per_batch):
    #这个人有多少张图片
    nrof_images = int(nrof_images_per_class[i])
    #遍历第i个人的所有图片
    for j in xrange(1,nrof_images):
    #第j张图的embedding在emb_arr 中的位置
    a_idx = emb_start_idx + j - 1
    #第j张图跟其他所有图片的欧氏距离
    neg_dists_sqr = np.sum(np.square(embeddings[a_idx] - embeddings), 1)
    #遍历每一对可能的(anchor,postive)图片,记为(a,p)吧
    for pair in xrange(j, nrof_images): # For every possible positive pair.
    #第p张图片在emb_arr中的位置
    p_idx = emb_start_idx + pair
    #(a,p)之前的欧式距离
    pos_dist_sqr = np.sum(np.square(embeddings[a_idx]-embeddings[p_idx]))
    #同一个人的图片不作为negative,所以将距离设为无穷大
    neg_dists_sqr[emb_start_idx:emb_start_idx+nrof_images] = np.NaN
    #all_neg = np.where(np.logical_and(neg_dists_sqr-pos_dist_sqr<alpha, pos_dist_sqr<neg_dists_sqr))[0] # FaceNet selection
    #其他人的图片中有哪些图片与a之间的距离-p与a之间的距离小于alpha的
    all_neg = np.where(neg_dists_sqr-pos_dist_sqr<alpha)[0] # VGG Face selecction
    #所有可能的negative
    nrof_random_negs = all_neg.shape[0]
    #如果有满足条件的negative
    if nrof_random_negs>0:
    #从中随机选取一个作为n
    rnd_idx = np.random.randint(nrof_random_negs)
    n_idx = all_neg[rnd_idx]
    # 选到(a,p,n)作为三元组
    triplets.append((image_paths[a_idx], image_paths[p_idx], image_paths[n_idx]))
    #print('Triplet %d: (%d, %d, %d), pos_dist=%2.6f, neg_dist=%2.6f (%d, %d, %d, %d, %d)' %
    # (trip_idx, a_idx, p_idx, n_idx, pos_dist_sqr, neg_dists_sqr[n_idx], nrof_random_negs, rnd_idx, i, j, emb_start_idx))
    trip_idx += 1
    num_trips += 1
    emb_start_idx += nrof_images
    np.random.shuffle(triplets)
    return triplets, num_trips, len(triplets)

  3. calculate triplet loss,update network,update embedding...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.