Giter VIP home page Giter VIP logo

Comments (4)

realzza avatar realzza commented on June 9, 2024

hell,in the file of transformer-multibranch-v2,the class of TransformerEncoderLayer--the code are as follow:
if args.encoder_branch_type is None:#default=None????
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout, self_attention=True,
)
else:
layers = []
embed_dims = []
heads = []
num_types = len(args.

I just wonder that do the args.encoder_branch_type equalstrue???

Hi, args.encoder_branch_type is a list containing the encoder branch type defined in your training yml file.
In my case, I set the encoder_branch_type in the training yml as encoder-branch-type: [attn:1:32:4, dynamic:default:32:4], where 32 represents the embedding dimension, and 4 stands for the attention head numbers.
Hope this helps!

from lite-transformer.

sanwei111 avatar sanwei111 commented on June 9, 2024

hell,in the file of transformer-multibranch-v2,the class of TransformerEncoderLayer--the code are as follow:
if args.encoder_branch_type is None:#default=None????
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout, self_attention=True,
)
else:
layers = []
embed_dims = []
heads = []
num_types = len(args.
I just wonder that do the args.encoder_branch_type equalstrue???

Hi, args.encoder_branch_type is a list containing the encoder branch type defined in your training yml file.
In my case, I set the encoder_branch_type in the training yml as encoder-branch-type: [attn:1:32:4, dynamic:default:32:4], where 32 represents the embedding dimension, and 4 stands for the attention head numbers.
Hope this helps!

thx,what'S the meaning of [attn:1:32:4, dynamic:default:32:4]?could you show some details about the list

from lite-transformer.

realzza avatar realzza commented on June 9, 2024

thx,what'S the meaning of [attn:1:32:4, dynamic:default:32:4]?could you show some details about the list

As I mentioned in my last reply, args.encoder_branch_type should not be a boolean value, instead it should be a list recording the branch type of your encoder. As for 32 and 4, they represent params embed_dim and num_head when initializing MultiheadAttention and DynamicconvLayer modules.

encoder-branch-type: [attn:1:248:4, dynamic:default:248:4]

You can find more details on these two params at the get_layer method in TransformerEncoderLayer module.
def get_layer(self, args, index, out_dim, num_heads, layer_type):
kernel_size = layer_type.split(':')[1]
if kernel_size == 'default':
kernel_size = args.encoder_kernel_size_list[index]
else:
kernel_size = int(kernel_size)
padding_l = kernel_size // 2 if kernel_size % 2 == 1 else ((kernel_size - 1) // 2, kernel_size // 2)
if 'lightweight' in layer_type:
layer = LightweightConv(
out_dim, kernel_size, padding_l=padding_l, weight_softmax=args.weight_softmax,
num_heads=num_heads, weight_dropout=args.weight_dropout,
with_linear=args.conv_linear,
)
elif 'dynamic' in layer_type:
layer = DynamicConv(
out_dim, kernel_size, padding_l=padding_l,
weight_softmax=args.weight_softmax, num_heads=num_heads,
weight_dropout=args.weight_dropout, with_linear=args.conv_linear,
glu=args.encoder_glu,
)
elif 'attn' in layer_type:
layer = MultiheadAttention(
out_dim, num_heads,
dropout=args.attention_dropout, self_attention=True,
)
else:
raise NotImplementedError
return layer

Find more details about MultiheadAttention module at
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim

from lite-transformer.

sanwei111 avatar sanwei111 commented on June 9, 2024

thx,what'S the meaning of [attn:1:32:4, dynamic:default:32:4]?could you show some details about the list

As I mentioned in my last reply, args.encoder_branch_type should not be a boolean value, instead it should be a list recording the branch type of your encoder. As for 32 and 4, they represent params embed_dim and num_head when initializing MultiheadAttention and DynamicconvLayer modules.

encoder-branch-type: [attn:1:248:4, dynamic:default:248:4]

You can find more details on these two params at the get_layer method in TransformerEncoderLayer module.

def get_layer(self, args, index, out_dim, num_heads, layer_type):
kernel_size = layer_type.split(':')[1]
if kernel_size == 'default':
kernel_size = args.encoder_kernel_size_list[index]
else:
kernel_size = int(kernel_size)
padding_l = kernel_size // 2 if kernel_size % 2 == 1 else ((kernel_size - 1) // 2, kernel_size // 2)
if 'lightweight' in layer_type:
layer = LightweightConv(
out_dim, kernel_size, padding_l=padding_l, weight_softmax=args.weight_softmax,
num_heads=num_heads, weight_dropout=args.weight_dropout,
with_linear=args.conv_linear,
)
elif 'dynamic' in layer_type:
layer = DynamicConv(
out_dim, kernel_size, padding_l=padding_l,
weight_softmax=args.weight_softmax, num_heads=num_heads,
weight_dropout=args.weight_dropout, with_linear=args.conv_linear,
glu=args.encoder_glu,
)
elif 'attn' in layer_type:
layer = MultiheadAttention(
out_dim, num_heads,
dropout=args.attention_dropout, self_attention=True,
)
else:
raise NotImplementedError
return layer

Find more details about MultiheadAttention module at

class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim

thx a lot!!!
one more,as shown below,

for layer_type in args.encoder_branch_type:
embed_dims.append(int(layer_type.split(':')[2]))
heads.append(int(layer_type.split(':')[3]))
layers.append(self.get_layer(args, index, embed_dims[-1], heads[-1], layer_type))
self.self_attn = MultiBranch(layers, embed_dims)

the above code appear in the encoderlayer class,as you said,args.encoder_branch_type ==[attn:1:160:4, lightweight:default:160:4],but it lead to some errors,how to comprehend it????

from lite-transformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.