Comments (5)
Estimators have a predict
method that you can use to produce the embedding.
The relevant part in model_fn
is here.
from tensorflow-triplet-loss.
I tried to predict like this:
estimator = tf.estimator.Estimator(model_fn, params=params, model_dir=args.model_dir)
img = cv2.imread(img_path)
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32)
input_fn = tf.estimator.inputs.numpy_input_fn(
x=img,
y=None,
batch_size=1,
num_epochs=1,
shuffle=False,
num_threads=1)
emb = estimator.predict(input_fn)
next(emb)
But I get the following error when I call next(emb)
:
InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [3,3,3,32] rhs shape= [3,3,1,32]
[[Node: save/Assign_2 = Assign[T=DT_FLOAT, _class=["loc:@model/block_1/conv2d/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](model/block_1/conv2d/kernel, save/RestoreV2/_11)]]
from tensorflow-triplet-loss.
The img
you pass needs to be a batch of images with shape (N, 224, 224, 3)
I think:
img = np.expand_dims(img, 0)
input_fn = ...
from tensorflow-triplet-loss.
I thought about that but the error is the same. I am confused by the suggested dimensions (3, 3, 3, 32). What do they mean?
Here is the full trace:
Caused by op u'save/Assign_2', defined at:
File "test_network.py", line 49, in <module>
next(emb)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 499, in predict
hooks=all_hooks) as mon_sess:
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 795, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 518, in __init__
self._sess = _RecoverableSession(self._coordinated_creator)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 981, in __init__
_WrappedSession.__init__(self, self._create_session())
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 986, in _create_session
return self._sess_creator.create_session()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 675, in create_session
self.tf_sess = self._session_creator.create_session()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 437, in create_session
self._scaffold.finalize()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 212, in finalize
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 884, in _get_saver_or_default
saver = Saver(sharded=True, allow_empty=True)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1311, in __init__
self.build()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1320, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1357, in _build
build_save=build_save, build_restore=build_restore)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 803, in _build_internal
restore_sequentially, reshape)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 501, in _AddShardedRestoreOps
name="restore_shard"))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 470, in _AddRestoreOps
assign_ops.append(saveable.restore(saveable_tensors, shapes))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 162, in restore
self.op.get_shape().is_fully_defined())
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/state_ops.py", line 281, in assign
validate_shape=validate_shape)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 61, in assign
use_locking=use_locking, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [3,3,3,32] rhs shape= [3,3,1,32]
[[Node: save/Assign_2 = Assign[T=DT_FLOAT, _class=["loc:@model/block_1/conv2d/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](model/block_1/conv2d/kernel, save/RestoreV2/_11)]]
from tensorflow-triplet-loss.
I accidentally passed a wrong model folder. Sorry
from tensorflow-triplet-loss.
Related Issues (20)
- What does 'embedding_mean_norm' mean? HOT 2
- loss=0 in step=101(after two step) HOT 1
- Saving weights of model and calculation of embeddings HOT 1
- Using batch_all_triplet_loss function HOT 2
- Embeddings Collapse very fast
- fraction_positive increasing HOT 4
- the function of _get_triplet_mask in triplet_loss.py
- base_model
- Multi domain triplet loss
- Why average loss value by `batch_size` when using `batch_hard` method?
- PyTorch Implmentation of triplet loss
- the difference between tf.contrib.losses.metric_learning.triplet_semihard_loss and batch_all_triplet_loss
- Implementation of metrics to monitor training process in tf.keras environment
- Advice on which loss to optimize HOT 2
- Performance issues in the program
- Performance issue in /model/tests (by P3) HOT 1
- Got Nan value when label's id is greater than 159 HOT 1
- hard triplet convergence!
- Some corrections in MNIST_dataset since TF 2.0
- OSS License compatibility question
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tensorflow-triplet-loss.