Giter VIP home page Giter VIP logo

n-grammer-flax's Introduction

n-grammer-flax

Implementation of N-Grammer: Augmenting Transformers with latent n-grams in Flax

Usage

from n_grammer_flax.n_grammer_flax import PQNgrammer
import jax

key0, key1, key2 = jax.random.split(random.PRNGKey(0), 3)

init_rngs = {'params': key1, 
             'batch_stats': key2}

x = jax.random.normal(key0, shape=(1, 1024, 32 * 16)) 

pq_ngram = PQNgrammer(
    num_clusters = 1024, # number of clusters
    num_heads = 16, # number of attention heads
    dim_per_head = 32, # dimensions of each attention head
    ngram_vocab_size = 768 * 256, #ngram vocab size 
    ngram_emb_dim= 16, # ngram embedding 
    decay = 0.99)

init_variables  = pq_ngram.init(init_rngs, x)
out,mutated_variables  =pq_ngram.apply(init_variables,x, mutable=['batch_stats'])

print('mutated variables.shape:\n', jax.tree_map(lambda x: x.shape, mutated_variables))
print('output.shape:\n', out.shape)
mutated variables.shape:
 FrozenDict({
    batch_stats: {
        ProductQuantization_0: {
            means: (32, 1024, 16),
        },
    },
})
output.shape:
 (1, 1024, 512)

Acknowledgement

This Project is enabled by TRC program. Thank you google!

Reference

Thanks for lucidrains's concise implementation of N-Grammer in pytorch. n-grammer-flax is inspired by and tested against (see n_grammer_flax_test.py for more details ) his project.

also inspired by the official jax implementation: https://github.com/tensorflow/lingvo/tree/master/lingvo/jax

Citations

@inproceedings{thai2020using,
    title   = {N-grammer: Augmenting Transformers with latent n-grams},
    author  = {Aurko Roy and Rohan Anil and Guangda Lai and Benjamin Lee and Jeffrey Zhao and Shuyuan Zhang and Shibo Wang and Ye Zhang and Shen Wu and Rigel Swavely and Tao (Alex)Yu and Phuong Dao and Christopher Fifty and Zhifeng Chen and Yonghui Wu},
    year    = {2021},
    url     = {https://arxiv.org/abs/2207.06366}
}

n-grammer-flax's People

Contributors

yiyixuxu avatar

Stargazers

Jeff Carpenter avatar Shintaro Harada avatar Bharath avatar Qin Lin avatar Doron Adler avatar Roman Hossain Shaon avatar 爱可可-爱生活 avatar Guocong Song avatar Evgenii Varseev avatar Robin avatar Shitty Girl avatar Pablo Duque avatar Josh Mize avatar Olivier  avatar AJEET SINGH avatar Boris Dayma avatar

Watchers

 avatar

Forkers

techthiyanes

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.