Giter VIP home page Giter VIP logo

discrete-key-value-bottleneck-pytorch's Introduction

Discrete Key / Value Bottleneck - Pytorch

Implementation of Discrete Key / Value Bottleneck, in Pytorch.

Install

$ pip install discrete-key-value-bottleneck-pytorch

Usage

import torch
from discrete_key_value_bottleneck_pytorch import DiscreteKeyValueBottleneck

key_value_bottleneck = DiscreteKeyValueBottleneck(
    dim = 512,                  # input dimension
    dim_memory = 512,           # output dimension - or dimension of each memories for all heads (defaults to same as input)
    num_memory_codebooks = 2,   # number of memory codebook, embedding is split into 2 pieces of 256, 256, quantized, outputs 256, 256, flattened together to 512
    num_memories = 256,         # number of memories
    decay = 0.9,                # the exponential moving average decay, lower means the keys will change faster
)

embeds = torch.randn(1, 1024, 512)  # from pretrained encoder

memories = key_value_bottleneck(embeds)

memories.shape # (1, 1024, 512)  # (batch, seq, memory / values dimension)

# now you can use the memories for the downstream decoder

You can also pass the pretrained encoder to the bottleneck and it will automatically invoke it. Example with vit-pytorch library

$ pip install vit-pytorch

Then

import torch

# import vision transformer

from vit_pytorch import SimpleViT
from vit_pytorch.extractor import Extractor

vit = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# train vit, or load pretrained

vit = Extractor(vit, return_embeddings_only = True)

# then

from discrete_key_value_bottleneck_pytorch import DiscreteKeyValueBottleneck

enc_with_bottleneck = DiscreteKeyValueBottleneck(
    encoder = vit,         # pass the frozen encoder into the bottleneck
    dim = 512,             # input dimension
    num_memories = 256,    # number of memories
    dim_memory = 2048,     # dimension of the output memories
    decay = 0.9,           # the exponential moving average decay, lower means the keys will change faster
)

images = torch.randn(1, 3, 256, 256)  # input to encoder

memories = enc_with_bottleneck(images) # (1, 64, 2048)   # (64 patches)

Todo

  • work off multiple encoder's embedding spaces, and allow for shared or separate memory spaces, to aid exploration in this research

Citations

@inproceedings{Trauble2022DiscreteKB,
    title   = {Discrete Key-Value Bottleneck},
    author  = {Frederik Trauble and Anirudh Goyal and Nasim Rahaman and Michael Curtis Mozer and Kenji Kawaguchi and Yoshua Bengio and Bernhard Scholkopf},
    year    = {2022}
}

discrete-key-value-bottleneck-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

discrete-key-value-bottleneck-pytorch's Issues

Update in the way they propose to extract the codebooks' embeddings

In the last version of the paper they propose to apply a random projection to the encoder embedding to extract the C embedding heads used to find the keys of each codebook, instead of taking a portion of the actual output of the encoder embedding. This changes a little the implementation, just for everybody to be aware.
image

How to initialize the `enc_with_bottleneck.vq` as paper said?

Hi,

  1. when i use the next code snippet to load a pretain model as paper's setting, the output dim is (batch_size, 2048), so how can i input it to enc_with_bottleneck?
resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
  1. just as the above setting, how can i use that pertrained model to initialize my enc_with_bottleneck.vq? especially how to initialize the codebook's keys

  2. i try to slove my problem with the next code snippet, but my evluation on testset show accuracy is 0.1(just like random choice)

from discrete_key_value_bottleneck_pytorch import DiscreteKeyValueBottleneck
enc_with_bottleneck = DiscreteKeyValueBottleneck(
    dim = 384,             
    num_memory_codebooks=64, 
    num_memories = 128,     
    dim_memory = 32,     
    decay = 0.8,            # the exponential moving average decay, lower means the keys will change faster
)
enc_with_bottleneck.to(device)
optimizer_kvib = optim.SGD(enc_with_bottleneck.parameters(), lr=0.001,momentum=0.8)


for x in enc_with_bottleneck.vq.parameters(): 
    x.requires_grad = True
optimizer_kvib_vq = optim.SGD(enc_with_bottleneck.vq.parameters(), lr=0.001,momentum=0.8)
for e in range(5):
    for i, data in tqdm(enumerate(trainloader, 0)):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        x = resnet50(inputs)
        print(x.shape)
        x = torch.unsqueeze(x,dim=1)
        optimizer_kvib_vq.zero_grad()
        _, _, loss = enc_with_bottleneck.vq(x)
        loss.backward()
        optimizer_kvib_vq.step()

for x in enc_with_bottleneck.vq.parameters(): 
    x.requires_grad = False

I look forward to your reply
zhang ruiyuan

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.