Giter VIP home page Giter VIP logo

simsiam-tf's Introduction

SimSiam-TF

Minimal implementation of SimSiam (Exploring Simple Siamese Representation Learning by Xinlei Chen & Kaiming He) in TensorFlow 2.

The purpose of this repository is to demonstrate the workflow of SimSiam and NOT to implement it note to note and at the same time I will try not to miss out on the major bits discussed in the paper. For that matter, I'll be using the Flowers dataset.

Following depicts the workflow of SimSiam (taken from the paper) -

The authors have also provided a PyTorch-like psuedocode in the paper (how cool!) -

# f: backbone + projection mlp
# h: prediction mlp
for x in loader: # load a minibatch x with n samples
    x1, x2 = aug(x), aug(x) # random augmentation
    z1, z2 = f(x1), f(x2) # projections, n-by-d
    p1, p2 = h(z1), h(z2) # predictions, n-by-d
    L = D(p1, z2)/2 + D(p2, z1)/2 # loss
    L.backward() # back-propagate
    update(f, h) # SGD update

def D(p, z): # negative cosine similarity
    z = z.detach() # stop gradient
    p = normalize(p, dim=1) # l2-normalize
    z = normalize(z, dim=1) # l2-normalize
    return -(p*z).sum(dim=1).mean()

The authors emphasize the stop_gradient operation that helps the network to avoid collapsing solutions. Further details about this are available in the paper. SimSiam eliminates the need for using large batch sizes, momentum encoders, memory banks, negative samples, etc. that are important components of the modern self-supervised learning frameworks for visual recognition. This makes SimSiam an easily approachable framework for practical problems.

About the notebooks

  • SimSiam_Pre_training.ipynb: Pre-trains a ResNet50 using SimSiam.
  • SimSiam_Evaluation.ipynb: Evaluates (linear evaluation) ResNet50 as pre-trained in SimSiam_Pre_training.ipynb.

Results

Pre-training Schedule Validation Accuracy (Linear Evaluation)
50 epochs 45.64%
75 epochs 44.91%

I think with further hyperparameter-tuning and regularization these scores can be improved.

Supervised training (results are taken from here and here):

Training Type Validation Accuracy (Linear Evaluation)
Supervised ImageNet-trained ResNet50 Features 48.36%
From Scratch Training with ResNet50 63.64%

Observations

The figure below shows the training loss plots from two different pre-training schedules (50 epochs and 75 epochs) -

We see that the loss gets plateaued after 35 epochs. We can experiment with the following components to further improve this -

  • data augmentation pipeline
  • architectures of the two MLP heads
  • learning rate schedule used during pre-training

and so on.

Pre-trained weights

Acknowledgements

Thanks to Connor Shorten's video on the paper that helped in understanding the paper briefly. Thanks to the ML-GDE program for providing GCP Credits that helped in preparing the experiments.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.