Giter VIP home page Giter VIP logo

Comments (6)

CyberZHG avatar CyberZHG commented on August 21, 2024 2

#7 Sentence Embedding

GlobalMaxPool1D doesn't support masking. Following is a modification that suits this case:

class MaskedGlobalMaxPool1D(keras.layers.Layer):
def __init__(self, **kwargs):
super(MaskedGlobalMaxPool1D, self).__init__(**kwargs)
self.supports_masking = True
def compute_mask(self, inputs, mask=None):
return None
def compute_output_shape(self, input_shape):
return input_shape[:-2] + (input_shape[-1],)
def call(self, inputs, mask=None):
if mask is not None:
mask = K.cast(mask, K.floatx())
inputs -= K.expand_dims((1.0 - mask) * 1e6, axis=-1)
return K.max(inputs, axis=-2)

I've added a demo for sentence embedding with pooling:

model = load_trained_model_from_checkpoint(config_path, checkpoint_path)
pool_layer = MaskedGlobalMaxPool1D(name='Pooling')(model.output)
model = keras.models.Model(inputs=model.inputs, outputs=pool_layer)
model.summary(line_length=120)
tokens = ['[CLS]', '语', '言', '模', '型', '[SEP]']
token_dict = {}
with codecs.open(dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
token_input = np.asarray([[token_dict[token] for token in tokens] + [0] * (512 - len(tokens))])
seg_input = np.asarray([[0] * len(tokens) + [0] * (512 - len(tokens))])
print('Inputs:', token_input[0][:len(tokens)])
predicts = model.predict([token_input, seg_input])[0]
print('Pooled:', predicts.tolist()[:5])

from keras-bert.

CyberZHG avatar CyberZHG commented on August 21, 2024 2

I forgot to return a None mask in MaskedGlobalMaxPool1D. I've fixed it and made a release.

from keras-bert.

BerenLuthien avatar BerenLuthien commented on August 21, 2024

It looks that the author provided a demo:
inputs, output_layer = get_model( #output_layeris the last feature extraction layer (the last transformer) ... training=False, # The input layers and output layer will be returned iftrainingisFalse)
Then Any classifier can be added on top of this `output_layer' (which is embeddings), such as LSTM or Logistic Regression.
Make sure "training=False"
You may have to rewrite the token dictionary since your dataset may not be exactly like MRPC.

from keras-bert.

njordsir2 avatar njordsir2 commented on August 21, 2024

@CyberZHG Oh sweet! Will check this out. Going through the paper for BERT cleared my masking queries.
I meanwhile got the job done with the official tensorflow-hub module.

from keras-bert.

BerenLuthien avatar BerenLuthien commented on August 21, 2024

Thanks. The MaskedGlobalMaxPool1D itself works well in the demo you gave, but it looks it does not fit if we add a classification layer (such as Dense) on top of it:
This code gives error:

model = load_trained_model_from_checkpoint('bert_config.json', 'bert_model.ckpt') 
def get_custermized_model(model):
    pool_layer = MaskedGlobalMaxPool1D(name='Pooling')(model.output) 
    x = pool_layer
    x = Dense(units=1, activation='sigmoid')(x)
    print(model.inputs[0])    
    custermized_model = Model(inputs=model.inputs,  outputs=x)
    custermized_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc'])
    return custermized_model
custermized_model = get_custermized_model(model=model)
history = custermized_model.fit(x=X_train, y=train_labels, epochs=1, validation_split=0.3)

InvalidArgumentError: Incompatible shapes: [32] vs. [32,512]
[[{{node metrics_4/acc/mul}} = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](metrics_4/acc/Mean, metrics_4/acc/Cast_1)]]
[[{{node metrics_4/acc/Mean_2/_1745}} = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4979_metrics_4/acc/Mean_2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

Would you help

  • add a Dense layer on top of "MaskedGlobalMaxPool1D" and try how it works ? Thanks

PS:
I noticed that the token_input in your demo has shape (1, 512) , and the above error happens when I feed a batch of data with shape (64, 512) where 64 is batch size.
However, if I add dimension and feed (64, 1, 512) to the model, it complains of input shape errors.

from keras-bert.

weizhenzhao avatar weizhenzhao commented on August 21, 2024

@CyberZHG
Hi cyberzhg

I notice in the code above
`
(1) model = load_trained_model_from_checkpoint(config_path, checkpoint_path)

(2) model = load_trained_model_from_checkpoint(config_path, checkpoint_path, training=True, seq_len=seq_len)

`
if I build an classifier with bilstm on top of that ,
which means finetune?

Thanks
weiizhen

from keras-bert.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.