Giter VIP home page Giter VIP logo

contrastive-predictive-coding-images's Introduction

Representation Learning with Contrastive Predictive Coding for images

This repository contains a Keras implementation of contrastive-predictive-coding for images, an algorithm fully described in:

The goal of unsupervised representation learning is to capture semantic information about the world, recognizing regular patterns in the data without using annotations. This paper presents a new method called Contrastive Predictive Coding (CPC) that can do so across multiple applications. This implementation covers the case of CPC applied to images (vision).

In a nutshell, an input image is divided into a grid of overlapping patches, and each patch is embedded using an encoding network. From these embeddings, the model computes a context vector at each position using masked convolutions (do not have access to future pixels). These context vectors propagate and integrate spatial information. Given a row in the grid of vectors, the model predicts entire rows below at different offsets (2 rows below, 3 rows below, etc.). These predictions should match the embedded vectors computed at the very beginning. The model is optimized to find the correct embedding when vectors from other patches are considered.

My code is optimized for readability and it is meant to be used as a resource to understand how CPC works. Therefore, I would like to explain a few concepts that I found challenging to understand in the papers.

CPC algorithm - context

In this figure, horizontal lines in the left represent embedded input image rows (7 embedding vectors), and triangles correspond to masked convolutions. We can see how all the orange rows (0 to 4) contribute to the context vector of row 3 (center). Although masked convolutions prevent information from lower rows to flow to the context, notice how input row 4 makes its way to the end. This is because patches are extracted with overlapping, that is, patches encoded in row 3 (left) contain pixels from row 4 below (in yellow). For this reason, we should never optimize the CPC model to predict just one row below, since the information would flow directly from the input preventing it from learning anything useful.

Once the context vectors are computed, we can proceed with the actual row predictions. We cannot predict the row below, but we can predict the rest. Depending on how far we would like to predict (offset), we will use different prediction networks. In total, there are 5 prediction networks.

CPC algorithm - offset

Given a set of context vectors, we apply each prediction network to all rows, however, not all predictions will be used. These predictions are aligned with the embedded input image rows taking into account the row offset used during the prediction (see bottom of the figure for a detailed mapping). This operation is performed in cpc_model.py > CPCLayer > align_embeddings().

To train the CPC algorithm, I have created a toy dataset based on 64x64 color MNIST.

CPC algorithm - samples

Disclaimer: this code is provided as is, if you encounter a bug please report it as an issue. Your help will be much welcomed!

Usage

  • Execute python train_cpc.py to train the CPC model.
  • Execute python train_classifier.py to train a classifier on top of the CPC encoder.

Requisites

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.