My implementation of "Attention is All You Need" by Vaswani et al. [1] using the JAX Framework.
Please see model code here.
My model code defines a number of generator functions that take as input a config dictionary and return two functions:
- Function to compute (i.e. feedforward network, dropout layer, encoder, etc)
- gen_params() function to generate initial weights for the compute function
[1] Vaswani et al. 2017 Attention Is All You Need. arXiv.