Giter VIP home page Giter VIP logo

Comments (5)

omoindrot avatar omoindrot commented on May 27, 2024

Estimators have a predict method that you can use to produce the embedding.

The relevant part in model_fn is here.

from tensorflow-triplet-loss.

FSet89 avatar FSet89 commented on May 27, 2024

I tried to predict like this:

   estimator = tf.estimator.Estimator(model_fn, params=params, model_dir=args.model_dir)

    img = cv2.imread(img_path)
    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)

    input_fn = tf.estimator.inputs.numpy_input_fn(
        x=img,
        y=None,
        batch_size=1,
        num_epochs=1,
        shuffle=False,
        num_threads=1)

    emb = estimator.predict(input_fn)
    next(emb)

But I get the following error when I call next(emb):

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [3,3,3,32] rhs shape= [3,3,1,32]
[[Node: save/Assign_2 = Assign[T=DT_FLOAT, _class=["loc:@model/block_1/conv2d/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](model/block_1/conv2d/kernel, save/RestoreV2/_11)]]

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

The img you pass needs to be a batch of images with shape (N, 224, 224, 3) I think:

img = np.expand_dims(img, 0)

input_fn = ...

from tensorflow-triplet-loss.

FSet89 avatar FSet89 commented on May 27, 2024

I thought about that but the error is the same. I am confused by the suggested dimensions (3, 3, 3, 32). What do they mean?

Here is the full trace:

Caused by op u'save/Assign_2', defined at:
  File "test_network.py", line 49, in <module>
    next(emb)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 499, in predict
    hooks=all_hooks) as mon_sess:
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 795, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 518, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 981, in __init__
    _WrappedSession.__init__(self, self._create_session())
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 986, in _create_session
    return self._sess_creator.create_session()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 675, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 437, in create_session
    self._scaffold.finalize()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 212, in finalize
    self._saver = training_saver._get_saver_or_default()  # pylint: disable=protected-access
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 884, in _get_saver_or_default
    saver = Saver(sharded=True, allow_empty=True)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1311, in __init__
    self.build()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1320, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1357, in _build
    build_save=build_save, build_restore=build_restore)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 803, in _build_internal
    restore_sequentially, reshape)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 501, in _AddShardedRestoreOps
    name="restore_shard"))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 470, in _AddRestoreOps
    assign_ops.append(saveable.restore(saveable_tensors, shapes))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 162, in restore
    self.op.get_shape().is_fully_defined())
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/state_ops.py", line 281, in assign
    validate_shape=validate_shape)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 61, in assign
    use_locking=use_locking, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [3,3,3,32] rhs shape= [3,3,1,32]
	 [[Node: save/Assign_2 = Assign[T=DT_FLOAT, _class=["loc:@model/block_1/conv2d/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](model/block_1/conv2d/kernel, save/RestoreV2/_11)]]

from tensorflow-triplet-loss.

FSet89 avatar FSet89 commented on May 27, 2024

I accidentally passed a wrong model folder. Sorry

from tensorflow-triplet-loss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.