Comments (9)
Also you don't need this line of code:
https://github.com/philipperemy/keras-attention-mechanism/blob/master/attention_lstm.py#L25
you can pass name = 'name' in any layer
from keras-attention.
@hamelsmu yes the Reshape layer is redundant and does not add any value to the model (Everything is done by the Permute layer).
It's more to enforce the correct shape. The output of the Permute layer is (?, ?) and by adding this Reshape layer, we make it more clear about the real shapes (they are static and known at compilation time). So I wanted to reflect this idea of static shapes (vs dynamic shapes).
from keras-attention.
Thanks for your feedback! Highly appreciated!
a = Dense(TIME_STEPS, activation='softmax', name='attention_vec')(a)
if SINGLE_ATTENTION_VECTOR:
a = Lambda(lambda x: K.mean(x, axis=1), name='attention_vec')(a) # this is the attention vector!
a = RepeatVector(input_dim)(a)
Is this what you meant? Removing the Else clause and adding name='attention_vec'
before the If?
from keras-attention.
Yeah thats right
from keras-attention.
It would not work here because we define different layers with the same names twice.
RuntimeError: The name "attention_vec" is used 2 times in the model. All layer names should be unique.
from keras-attention.
@philipperemy right. However I suppose you can say that the attention layer is a_probs
because that is the layer that is being multiplied by the inputs. So you can re-factor to look like this:
a = Dense(TIME_STEPS, activation='softmax')(a)
if SINGLE_ATTENTION_VECTOR:
a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a)
a = RepeatVector(input_dim)(a, name='time_repeat')
a_probs = Permute((2, 1), name='attention_vec')(a)
output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul')
from keras-attention.
Ok seems good for me! The only thing is that attention_vec.shape
will change from (1, 2, 20)
to (1, 20, 2)
, where 20
is the number of time steps, and 2
the number of input dims. So we have to change the axis from 1 to 2 (on which we aggregate). Simply because we want to display the vector for the time axis.
attention_vector = np.mean(
get_activations(
m,
testing_inputs_1,
print_shape_only=True,
layer_name='attention_vec')[0], axis=2).squeeze()
from keras-attention.
Let me know if it seems good for you:
PR: #4
from keras-attention.
Thanks!
from keras-attention.
Related Issues (20)
- Hiddent state parameter, what really should be passed? HOT 1
- pip install and numpy, keras packages are forced to be uninstalled HOT 1
- Use this repository for CNN HOT 1
- 2D attention HOT 6
- weird attention weights when adding sequence of numbers. HOT 1
- attention when using more than one feature HOT 1
- get_config HOT 14
- Using attention with multivariate timeseries data
- Loading model problems HOT 5
- Interpreting attention weights for more than one input features. HOT 2
- Add guidance to README to use Functional API for saving models that use this layer HOT 4
- Attention Mechanism not working HOT 10
- what do the h_t mean in the Attention model? HOT 1
- Output with multiple time steps HOT 1
- Attention not working for MLP HOT 2
- TypeError: Expected `trainable` argument to be a boolean, but got: 64 HOT 3
- Please update version HOT 1
- TypeError: __call__() takes 2 positional arguments but 3 were given HOT 2
- Number of parameters in Attention layer HOT 2
- Does it support causal mask? HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras-attention.