Giter VIP home page Giter VIP logo

Comments (9)

hamelsmu avatar hamelsmu commented on July 24, 2024

@philipperemy

Also you don't need this line of code:
https://github.com/philipperemy/keras-attention-mechanism/blob/master/attention_lstm.py#L25

you can pass name = 'name' in any layer

from keras-attention.

philipperemy avatar philipperemy commented on July 24, 2024

@hamelsmu yes the Reshape layer is redundant and does not add any value to the model (Everything is done by the Permute layer).

It's more to enforce the correct shape. The output of the Permute layer is (?, ?) and by adding this Reshape layer, we make it more clear about the real shapes (they are static and known at compilation time). So I wanted to reflect this idea of static shapes (vs dynamic shapes).

from keras-attention.

philipperemy avatar philipperemy commented on July 24, 2024

Thanks for your feedback! Highly appreciated!

a = Dense(TIME_STEPS, activation='softmax', name='attention_vec')(a)
if SINGLE_ATTENTION_VECTOR:
    a = Lambda(lambda x: K.mean(x, axis=1), name='attention_vec')(a)  # this is the attention vector!
    a = RepeatVector(input_dim)(a)

Is this what you meant? Removing the Else clause and adding name='attention_vec' before the If?

from keras-attention.

hamelsmu avatar hamelsmu commented on July 24, 2024

Yeah thats right

from keras-attention.

philipperemy avatar philipperemy commented on July 24, 2024

It would not work here because we define different layers with the same names twice.

RuntimeError: The name "attention_vec" is used 2 times in the model. All layer names should be unique.

from keras-attention.

hamelsmu avatar hamelsmu commented on July 24, 2024

@philipperemy right. However I suppose you can say that the attention layer is a_probs because that is the layer that is being multiplied by the inputs. So you can re-factor to look like this:

    a = Dense(TIME_STEPS, activation='softmax')(a)
    if SINGLE_ATTENTION_VECTOR:
        a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a) 
        a = RepeatVector(input_dim)(a, name='time_repeat')

    a_probs = Permute((2, 1), name='attention_vec')(a)
    output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul')

from keras-attention.

philipperemy avatar philipperemy commented on July 24, 2024

Ok seems good for me! The only thing is that attention_vec.shape will change from (1, 2, 20) to (1, 20, 2), where 20 is the number of time steps, and 2 the number of input dims. So we have to change the axis from 1 to 2 (on which we aggregate). Simply because we want to display the vector for the time axis.

attention_vector = np.mean(
      get_activations(
       m,
       testing_inputs_1,
       print_shape_only=True,
       layer_name='attention_vec')[0], axis=2).squeeze()

from keras-attention.

philipperemy avatar philipperemy commented on July 24, 2024

Let me know if it seems good for you:

PR: #4

from keras-attention.

hamelsmu avatar hamelsmu commented on July 24, 2024

Thanks!

from keras-attention.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.