Comments (5)
I'll close this issue for now. Thank you guys for your help!
from keras-attention-mechanism.
I've met the same problem, when using the defaut format.
Traceback (most recent call last):
File "/home/cedar/kzheng/MAP569_AMF_Challenge/resume_training.py", line 97, in <module>
model = keras.models.load_model(MODEL_PATH)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 182, in load_model
return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py", line 177, in load_model_from_hdf5
model = model_config_lib.model_from_config(model_config,
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/model_config.py", line 55, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py", line 171, in deserialize
return generic_utils.deserialize_keras_object(
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py", line 354, in deserialize_keras_object
return cls.from_config(
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py", line 488, in from_config
model.add(layer)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/training/tracking/base.py", line 457, in _method_wrapper
result = method(self, *args, **kwargs)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py", line 221, in add
output_tensor = layer(self.outputs[0])
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 925, in __call__
return self._functional_construction_call(inputs, args, kwargs,
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1098, in _functional_construction_call
self._maybe_build(inputs)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 2643, in _maybe_build
self.build(input_shapes) # pylint:disable=not-callable
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 323, in wrapper
output_shape = fn(instance, input_shape)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/layers/merge.py", line 663, in build
raise ValueError('A `Dot` layer should be called '
When I switched to the format .h5
to load the model, this is the error message.
Traceback (most recent call last):
File "/home/cedar/kzheng/MAP569_AMF_Challenge/resume_training.py", line 97, in <module>
model = keras.models.load_model(MODEL_PATH)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 187, in load_model
return saved_model_load.load(filepath, compile, options)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 120, in load
model = tf_load.load_internal(
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 632, in load_internal
loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 194, in __init__
super(KerasObjectLoader, self).__init__(*args, **kwargs)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 130, in __init__
self._load_all()
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 215, in _load_all
self._layer_nodes = self._load_layers()
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 318, in _load_layers
layers[node_id] = self._load_layer(proto.user_object, node_id)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 343, in _load_layer
obj, setter = revive_custom_object(proto.identifier, metadata)
File "/home/cedar/kzheng/anaconda3/envs/gt/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py", line 842, in revive_custom_object
raise ValueError('Unable to restore custom object of type {} currently. '
ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements `get_config`and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`
from keras-attention-mechanism.
Yes use the functional API and it will work. There are some problems with Sequential.
from keras-attention-mechanism.
model_input = Input(shape=(seq_length, 1))
x = LSTM(100, input_shape=(seq_length, 1), return_sequences=True)(model_input)
x = Attention(name='attention_weight')(x)
x = Dropout(0.2)(x)
x = Dense(1, activation='linear')(x)
model = Model(model_input, x)
model.compile(loss='mse', optimizer='adam')
# test save/reload model.
pred1 = model.predict(x_val)
model.save('test_model.h5')
model_h5 = load_model('test_model.h5')
pred2 = model_h5.predict(x_val)
np.testing.assert_almost_equal(pred1, pred2)
from keras-attention-mechanism.
A walk around of using the Sequential model is that instead of saving the whole model, we can just call save_weights
on the model with .h5
format. And rebuild the model then call the load_weights
.
This works for me. Otherwise, the functional API works as well.
from keras-attention-mechanism.
Related Issues (20)
- pip install and numpy, keras packages are forced to be uninstalled HOT 1
- Use this repository for CNN HOT 1
- 2D attention HOT 6
- weird attention weights when adding sequence of numbers. HOT 1
- attention when using more than one feature HOT 1
- get_config HOT 14
- Using attention with multivariate timeseries data
- Interpreting attention weights for more than one input features. HOT 2
- Add guidance to README to use Functional API for saving models that use this layer HOT 4
- Attention Mechanism not working HOT 10
- what do the h_t mean in the Attention model? HOT 1
- Output with multiple time steps HOT 1
- Attention not working for MLP HOT 2
- TypeError: Expected `trainable` argument to be a boolean, but got: 64 HOT 3
- Please update version HOT 1
- TypeError: __call__() takes 2 positional arguments but 3 were given HOT 2
- Number of parameters in Attention layer HOT 2
- Does it support causal mask? HOT 2
- Value Error occurs when I exercise your demo code
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras-attention-mechanism.