Giter VIP home page Giter VIP logo

Comments (29)

omoindrot avatar omoindrot commented on May 27, 2024 7

If you are working with images stored in jpg files for instance, you can apply tf.data.experimental.choose_from_datasets only on the filenames and labels (which should be very fast), and then load the images from these filenames.

This would be like:

num_labels = 4000
num_classes_per_batch = 4
num_images_per_class = 8

image_dirs = ["data/class_{:04d}".format(i) for i in range(num_labels)]

# Create the list of datasets creating filenames
datasets = [tf.data.Dataset.list_files("{}/*.jpg".format(image_dir) for image_dir in image_dirs]

def generator():
    while True:
        # Sample the labels that will compose the batch
        labels = np.random.choice(range(num_labels),
                                  num_classes_per_batch,
                                  replace=False)
        for label in labels:
            for _ in range(num_images_per_class):
                yield label

choice_dataset = tf.data.Dataset.from_generator(generator, tf.int64)
dataset = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)

# Now you read the image content
def load_image(filename):
    ...
    return image, label

dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

batch_size = num_classes_per_batch * num_images_per_class
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(None)

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024 6

My code above is very slow because of the dataset.filter(...) used to build the datasets.

The filter method will go through all examples until it finds one with the correct label, so if you have 1,000 labels, this will be 1,000 times slower.

The solution would be to create the datasets (one per label) in a different way.
For instance if you have filenames (containing images) and labels, you can create one list of filename per label:

num_labels = 1000
datasets = []
for label in range(num_labels):
    # Get the filenames for this label
    filenames_per_label = ...
    dataset = tf.data.Dataset.from_tensor_slices((filenames_per_label,
                                                  [label] * len(filenames_per_label)))
    datasets.append(dataset)

By the way, a better to do what I did before is to use the new tf.contrib.data.choose_from_datasets (or tf.data.experimental.choose_from_datasets since v1.12):

num_labels = 10
num_classes_per_batch = 4
num_images_per_class = 8

# Create the list of datasets as you like
datasets = ...

def generator():
    while True:
        # Sample the labels that will compose the batch
        labels = np.random.choice(range(num_labels),
                                  num_classes_per_batch,
                                  replace=False)
        for label in labels:
            for _ in range(num_images_per_class):
                yield label

choice_dataset = tf.data.Dataset.from_generator(generator, tf.int64)
dataset = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)

batch_size = num_classes_per_batch * num_images_per_class
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(None)

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024 4

@fursovia : something like that would work

from tensorflow.contrib.data.python.ops.interleave_ops import DirectedInterleaveDataset

import model.mnist_dataset as mnist_dataset


# Define the data pipeline
mnist = mnist_dataset.train(args.data_dir)

datasets = [mnist.filter(lambda img, lab: tf.equal(lab, i)) for i in range(params.num_labels)]

def generator():
    while True:
        # Sample the labels that will compose the batch
        labels = np.random.choice(range(params.num_labels),
                                  params.num_classes_per_batch,
                                  replace=False)
        for label in labels:
            for _ in range(params.num_images_per_class):
                yield label

selector = tf.data.Dataset.from_generator(generator, tf.int64)
dataset = DirectedInterleaveDataset(selector, datasets)

batch_size = params.num_classes_per_batch * params.num_images_per_class
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)

You would need the nightly build from tensorflow:

pip install tf-nightly

This will contain the DirectedInterleaveDataset. However it is not in the public interface so we still need to import it directly with from tensorflow.contrib.data.python.ops.interleave_ops import DirectedInterleaveDataset.

from tensorflow-triplet-loss.

paweller avatar paweller commented on May 27, 2024 1

Hello everyone,

first of all thank you for the initial input on balanced batches @omoindrot.

Unfortunately, as I am working with Keras and NumPy inpud data I was not able to use omoindrot's solution. So I dug down deeper into the topic and found another GitHub repository by @soroushj showing how to implement "A Keras-compatible generator for creating balanced batches". However, it does not feature any num_classes_per_batch and/or num_samples_per_class functionallity.

So I took it as a starting point and extended it by the mentioned funcitonalities. It became a Keras-compatible balanced batch generator suited for triplet loss applications. As it is built using the keras.utils.Sequence object, the generator is multiprocessing-aware and can be shuffled. It was tested on the omniglot dataset with the Vinyals spilts (according to this GitHub repository) and yielded a pretty well balanced class distribution (standard deviation of six) across the entirety of batches used during the training process (75 epochs with 42 batches per epoch). Further information and the source code can be found here. I am by no means a coding expert, so please do not hesitate to contribute.

Thank you!

from tensorflow-triplet-loss.

fursovia avatar fursovia commented on May 27, 2024

This will be so helpful, thanks! Do you have any approximate timeline for when it could happen?

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

I'm starting a new job soon so I'm not sure how much time I'll have.
You can maybe try to build a working solution in a fork to see how that works?

from tensorflow-triplet-loss.

TengliEd avatar TengliEd commented on May 27, 2024

Nice test @omoindrot . But I still have no idea about applying DirectedInterleaveDataset on raw msceleb1m dataset. I asked and commented on your answer on stackoverflow. By the way, I don't use metric learning but arcface loss.

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

Hi @TengliEd, if you are using the arcface loss I think you don't need to have these balanced batches. Correct me if I'm wrong but you should be able to train on a normal random batch of data, like with softmax.

from tensorflow-triplet-loss.

vzxxbacq avatar vzxxbacq commented on May 27, 2024

Hello @omoindrot . Have you tested which method is faster sampling from multiple files or single file with filter?

from tensorflow-triplet-loss.

andropar avatar andropar commented on May 27, 2024

Hi @omoindrot, I implemented your proposed solution, but batch generation is extremely slow for a big number of classes. Any ideas why this is and how to circumvent it?

from tensorflow-triplet-loss.

maffos avatar maffos commented on May 27, 2024

Hey @omoindrot I have tried the solution using tf.data.experimental.choose_from_datasets. However my process gets killed when I try to train my network. I think might be because the list with all the datasets exceeds my working memory. I have ~4000 classes with ~200 instances per class in my dataset. Do you maybe know of any other way?

from tensorflow-triplet-loss.

batrlatom avatar batrlatom commented on May 27, 2024

Will this work also on batch_hard? Or do you have any suggestion how to make batch_hard work with thousands of classes?

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

@batrlatom : I would say yes, this is only the data pipeline so it should work the same for batch_hard and batch_all.

from tensorflow-triplet-loss.

christk1 avatar christk1 commented on May 27, 2024

@omoindrot if i have a list of labels eg [1, 1, 1, 4, 5, 3, 2, 2] and then use the choose_from_datasets like in your example, then it will select random images from each label?

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

It will be if the datasets are shuffled before.

from tensorflow-triplet-loss.

TengliEd avatar TengliEd commented on May 27, 2024

@omoindrot since choose_from_datasets can randomly select element from dataset in datasets, we need shuffle dataset beforehand?

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

choose_from_datasets will pick the first element of the datasets[idx] where idx is the next index returned by choice_dataset.

It's like if you had 10 pile of plates (one pile = one dataset).
Someone (choice_dataset) tells you from which pile to take a plate. But you will take the plate at the top by default, so if you want shuffled plates, you need to shuffle each dataset beforehand.

For instance:

datasets = [tf.data.Dataset.list_files("{}/*.jpg".format(image_dir) for image_dir in image_dirs]
datasets = [dataset.shuffle(buffer_size) for dataset in datasets]

from tensorflow-triplet-loss.

TengliEd avatar TengliEd commented on May 27, 2024

@omoindrot As my experimental result shows, it did not take the plate at the top but randomly take in each pile

from tensorflow-triplet-loss.

TengliEd avatar TengliEd commented on May 27, 2024

@omoindrot Your triplet preparation code worked as num_labels=20000. However, when num_labels=40000, the error below occurred. It means this method cannot make triplets for a large number of classes?
7926cf38b42f0f6a1bddd97e3

from tensorflow-triplet-loss.

sseveran avatar sseveran commented on May 27, 2024

@TengliEd I hit the same issue with ~7300 datasets. I have opened an issue to track this in tensorflow tensorflow/tensorflow#29753.

You can disable optimization for tf.data using options.experimental_optimization.apply_default_optimizations = False

from tensorflow-triplet-loss.

cyrusvahidi avatar cyrusvahidi commented on May 27, 2024

If you are working with images stored in jpg files for instance, you can apply tf.data.experimental.choose_from_datasets only on the filenames and labels (which should be very fast), and then load the images from these filenames.

This would be like:

num_labels = 4000
num_classes_per_batch = 4
num_images_per_class = 8

image_dirs = ["data/class_{:04d}".format(i) for i in range(num_labels)]

# Create the list of datasets creating filenames
datasets = [tf.data.Dataset.list_files("{}/*.jpg".format(image_dir) for image_dir in image_dirs]

def generator():
    while True:
        # Sample the labels that will compose the batch
        labels = np.random.choice(range(num_labels),
                                  num_classes_per_batch,
                                  replace=False)
        for label in labels:
            for _ in range(num_images_per_class):
                yield label

choice_dataset = tf.data.Dataset.from_generator(generator, tf.int64)
dataset = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)

# Now you read the image content
def load_image(filename):
    ...
    return image, label

dataset = dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

batch_size = num_classes_per_batch * num_images_per_class
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(None)

Would you place this code in train_input_fn? Would the selected batch be repeated throughout the epoch this way?

Interestingly, if I select one batch with this code and repeat it for the epoch my loss converges below the margin. However, if I generate balanced batches, using the whole dataset for an epoch, the loss converges at the margin.

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

Would you place this code in train_input_fn? Would the selected batch be repeated throughout the epoch this way?

Yes. The batches generated are random so there is no "repeat" needed.

Interestingly, if I select one batch with this code and repeat it for the epoch my loss converges below the margin. However, if I generate balanced batches, using the whole dataset for an epoch, the loss converges at the margin.

This is because overfitting on one batch is easy, and the loss will converge to 0.
When working on the full dataset, you may have other convergence issues that could be solved by lowering the learning rate, changing other hyperparameters or pretraining the network first on a softmax loss.

from tensorflow-triplet-loss.

connorlbark avatar connorlbark commented on May 27, 2024

So, I am currently using a method of creating balanced batches via creating a tensorflow dataset from a generator of 20 examples from a single class, shuffling with those examples, unbatching, then batching again with a batch size of 64. It's been an effective (and simple) way to creating a balanced dataset (with the rare edge case that it doesn't overlap favorably), but I still have been unable to train effectively. It will always converge to the embeddings being zero.

I have tried changing my embedding size to equal the number of classes in my dataset and pre-train the model on softmax, which I can get to ~85% accuracy easily. This will reduce the starting triplet loss significantly, but it will still ultimately fail to converge.

I've tried many different hyperparameters, including extremely small learning rates (1e-7), but that will just make it collapse slower. Perhaps I should try even lower?

Any ideas? I'm at a loss (no pun intended)

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

@porgull : you can try to overfit a very small dataset (one triplet to begin with, then a bit more) and make sure that the loss converges to 0 on the training set.
This should help catch some bugs.

Otherwise I would check the data and make sure that the generated triplets look correct.

from tensorflow-triplet-loss.

saravanabalagi avatar saravanabalagi commented on May 27, 2024

Yes. The batches generated are random so there is no "repeat" needed.

The batches are generated at random and the generator will keep giving random labels infinitely. However the individual datasets (dataset created per class) won't; they yield images sequentially and since there's no repeat for them, they will eventually stop after the last image is yielded. This will generate data exactly and only for one epoch, until all images in each of these datasets are consumed. It won't run any further. Playground here

Will have to do this for it repeat forever. But there's no guarantee that each epoch will not have a particular image more than once.

# Create the list of datasets creating filenames
datasets = [tf.data.Dataset.list_files("{}/*.jpg".format(image_dir).repeat() for ....]

from tensorflow-triplet-loss.

omoindrot avatar omoindrot commented on May 27, 2024

@saravanabalagi : good point, the original datasets need to yield samples infinitely.

It's then up to you how you want to control the amount of data coming from each dataset, and whether you want to oversample some datasets.

from tensorflow-triplet-loss.

kasri-mids avatar kasri-mids commented on May 27, 2024

So, I am currently using a method of creating balanced batches via creating a tensorflow dataset from a generator of 20 examples from a single class, shuffling with those examples, unbatching, then batching again with a batch size of 64. It's been an effective (and simple) way to creating a balanced dataset (with the rare edge case that it doesn't overlap favorably), but I still have been unable to train effectively. It will always converge to the embeddings being zero.

I have tried changing my embedding size to equal the number of classes in my dataset and pre-train the model on softmax, which I can get to ~85% accuracy easily. This will reduce the starting triplet loss significantly, but it will still ultimately fail to converge.

I've tried many different hyperparameters, including extremely small learning rates (1e-7), but that will just make it collapse slower. Perhaps I should try even lower?

Any ideas? I'm at a loss (no pun intended)

@porgull Did you get this resolved? I am facing the same issue...Thanks!

from tensorflow-triplet-loss.

majdirabia avatar majdirabia commented on May 27, 2024

Hi,
Anyone tried this and faced an issue when loading file from filename ?
I have .npy files and get this error :

TypeError: expected str, bytes or os.PathLike object, not Tensor

Quite lost here as I tried to solve creating a wrapper around tf.py_func.

Code :

    def get_data_from_filename(filename):
        npdata = np.load(filename)
        return npdata, int(filename.split('_')[1])

    def get_data_wrapper(filename):
        features, labels_in = tf.py_function(
            get_data_from_filename, [filename], (tf.float32, tf.int32))
        return tf.data.Dataset.from_tensor_slices((features, labels_in))

from tensorflow-triplet-loss.

majdirabia avatar majdirabia commented on May 27, 2024

Hi,

Could anyone help me ? Still stuck with file loading, as it expects paths str rather than Tensors.
If someone can show me how they implemented their load_image() function, it will give me a better idea on how to adapt it in my use case of .npy files.

Cheers,
Majdi

from tensorflow-triplet-loss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.