Giter VIP home page Giter VIP logo

Comments (8)

FSet89 avatar FSet89 commented on May 27, 2024 2

Just in case someone needs it, I solved by replacing
images = features
with
images = features['input']
in model_fn.py

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

I'm not sure what the issue is, I don't think it is related to this project in particular.

You should ask on stackoverflow.

from tensorflow-triplet-loss.

batrlatom avatar batrlatom commented on May 27, 2024

@FSetragno Could you post here some more details and exact code snippets with solution?
I am trying to export it too, but without any success. I changed the build_model code in model_fn to
use mobilenet as cnn

def build_model(is_training, images, params):
import tensorflow_hub as hub
module = hub.Module("https://tfhub.dev/google/imagenet/mobilenet_v2_035_224/feature_vector/2")
tf_model = module(images)
with tf.variable_scope('fc_1'):
tf_model = tf.layers.dense(tf_model, params.embedding_size)
return tf_model

And tried to make own input_fn functions

def serving_input_receiver_fn():
feature_spec = {}
feature_spec['features'] = tf.placeholder(tf.float32, shape=[224,224,3], name='features')
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_spec)
return serving_input_fn

from tensorflow-triplet-loss.

FSet89 avatar FSet89 commented on May 27, 2024

This is the code I use to export the model. Please note that you have to modify model_fn.py as mentioned above.

estimator = tf.estimator.Estimator(model_fn, params=params, model_dir=args.model_dir)
features = {'input': tf.placeholder(tf.float32, shape=(1, 224, 224, 3), name="input")}
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(features, 1)
exported_model_path = estimator.export_savedmodel(args.model_dir, input_fn)

from tensorflow-triplet-loss.

batrlatom avatar batrlatom commented on May 27, 2024

I tried to modify the model.fn code by placing this code into

images = features['input']

but I am getting an error:

File "/home/tom/Devel/AI/rclvenv/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 489, in _slice_helper
end.append(s + 1)
TypeError: must be str, not int

The code of model.fn itself is:


def model_fn(features, labels, mode, params):
    """Model function for tf.estimator

    Args:
        features: input batch of images
        labels: labels of the images
        mode: can be one of tf.estimator.ModeKeys.{TRAIN, EVAL, PREDICT}
        params: contains hyperparameters of the model (ex: `params.learning_rate`)

    Returns:
        model_spec: tf.estimator.EstimatorSpec object
    """
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    images = features['input']

    images = tf.reshape(images, [-1, params.image_size, params.image_size, 3])
    assert images.shape[1:] == [params.image_size, params.image_size, 3], "{}".format(images.shape)

    # -----------------------------------------------------------
    # MODEL: define the layers of the model
    with tf.variable_scope('model'):
        # Compute the embeddings with the model
        embeddings = build_model(is_training, images, params)
        embeddings = tf.nn.l2_normalize(embeddings, axis=1)
    embedding_mean_norm = tf.reduce_mean(tf.norm(embeddings, axis=1))
    tf.summary.scalar("embedding_mean_norm", embedding_mean_norm)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {'embeddings': embeddings}
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    labels = tf.cast(labels, tf.int64)

    # Define triplet loss
    if params.triplet_strategy == "batch_all":
        loss, fraction = batch_all_triplet_loss(labels, embeddings, margin=params.margin,
                                                squared=params.squared)
    elif params.triplet_strategy == "batch_hard":
        loss = batch_hard_triplet_loss(labels, embeddings, margin=params.margin,
                                       squared=params.squared)
    else:
        raise ValueError("Triplet strategy not recognized: {}".format(params.triplet_strategy))

    # -----------------------------------------------------------
    # METRICS AND SUMMARIES
    # Metrics for evaluation using tf.metrics (average over whole dataset)
    # TODO: some other metrics like rank-1 accuracy?
    with tf.variable_scope("metrics"):
        eval_metric_ops = {"embedding_mean_norm": tf.metrics.mean(embedding_mean_norm)}

        if params.triplet_strategy == "batch_all":
            eval_metric_ops['fraction_positive_triplets'] = tf.metrics.mean(fraction)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=eval_metric_ops)


    # Summaries for training
    tf.summary.scalar('loss', loss)
    if params.triplet_strategy == "batch_all":
        tf.summary.scalar('fraction_positive_triplets', fraction)

    tf.summary.image('train_image', images, max_outputs=1)

    # Define training step that minimizes the loss with the Adam optimizer
    optimizer = tf.train.AdamOptimizer(params.learning_rate)
    global_step = tf.train.get_global_step()
    if params.use_batch_norm:
        # Add a dependency to update the moving mean and variance for batch normalization
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            train_op = optimizer.minimize(loss, global_step=global_step)
    else:
        train_op = optimizer.minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)



from tensorflow-triplet-loss.

batrlatom avatar batrlatom commented on May 27, 2024

I also tried to change
images = features['input']

with

images = features.get('input')

which looks like it works, but introduce a roblem with export outputs.

ValueError: export_outputs must be a dict and not<class 'NoneType'>

Did you encounter this problem?

from tensorflow-triplet-loss.

batrlatom avatar batrlatom commented on May 27, 2024

could you please also post the full model.fn here? I really need to solve this

from tensorflow-triplet-loss.

batrlatom avatar batrlatom commented on May 27, 2024

Ok so I was able to manage it in this way.. Noting that your solution is for python2:


if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {'embeddings': embeddings}
        tf.estimator.export.PredictOutput(predictions['embeddings'])
        export_outputs={'embeddings': tf.estimator.export.PredictOutput((predictions['embeddings']))}

        for op in tf.get_default_graph().get_operations():
            print(str(op.name))

        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs = export_outputs)


from tensorflow-triplet-loss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.