Giter VIP home page Giter VIP logo

triplet-loss's Introduction

Triplet-Loss

Fine tune and train a CNN pre-trained on Imagenet dataset, using Triplet Loss

Implementation Details

The code is implemented using the Keras library with TensorFlow backend in a python environment. The data_format is changed to 'channels_first' by default in the keras.json file.

Files:

init-
    --Start of the code.
    --Download and split dataset (Cifar 10 in the current example)
    --Fit model on the training and validation data.
    --Predict classes of the test data.
    
fine_tune_model-
    --Contains two functions; fine_tuned_models and the wrapper function, data_generator.
    --The wrapper function is not used in the final run of the implementation.
   
triplet_loss_functions-
    --Two different implementations of triplet loss.

Background

Triplet loss was first implemented by Florian Schroff and Dmitry Kalenichenko in the paper that introduced FaceNet. It is computed using three images;

a. anchor image     -   The user defined anchor.
b. positive image   -   Image of the same class as the anchor.
c. negative image   -   Image of a different class.

The premise of triplet loss is to separate the embeddings of a positive pair from a negative pair by a margin distance m / alpha. The positive pair is the anchor and the positive image whereas the negative pair is the anchor and the negative image.

The mathematical fuction for triplet loss is as follows:

equation

Triplet Loss can be implemented directly as a loss function in the compile method, or it can be implemented as a merge mode with the anchor, positive and negative embeddings of three individual images as the three branches of the merge function.

Problem

The code trains and fine-tunes a CNN model (ResNet50), pre-trained on the Imagenet dataset, by replacing the classifier of the CNN and using triplet loss. The First 15 layers of ResNet50 have been frozen to reduce the affect of overfitting to the new dataset.

Note

Dataset used is Cifar 10. However the images in Cifar 10 are of dimensions 32x32 while CNN models like ResNet50 / AlexNet / VGG16 / GoogleNet etc. require image dimensions to be at least 197x197. Cifar 10 images can either be scaled up or padded neither of which are a good solution. It is better to use a dataset similar to Imagenet, with images of comparable size and information. The code contains Cifar 10 dataset though, because it is already available with the Keras library. Add your own database in its place to run the code.

triplet-loss's People

Contributors

shivsondhi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

triplet-loss's Issues

Issue with input shape

Uncommenting the line to make input work with resNet is throwing an error in reshaping the array

Triplets generation

Where are the triplets being generated in the code and how is the triplet_loss_functions being called ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.