Giter VIP home page Giter VIP logo

tensorflow-input-pipeline's Introduction

Tensorflow-input-pipeline

This is an input pipeline function for Tensorflow, which uses the Dataset API, and is designed for use with semantic segmentation datasets.

I have observed that generally importing your own data into tensorflow for deep learning/machine learning problems is...well...a problem, this code aims to simplify that, and get you up and running with your deep learning projects. The code is simple and readable, so you can easily edit and extend it for your own projects.

Augmentation Examples:

Following shows the same image, loaded with the pipeline, note the different augmentations (birghtness, contrast, saturation, cropping and flipping changes, and the masks are changed accordingly. The example image is taken from the PASCAL VOC dataset.

screenshot 2018-10-16 at 22 56 50

Example use:

import matplotlib.pyplot as plt
import tensorflow as tf
from dataloader import DataLoader
import numpy as np
import os

plt.ioff()

IMAGE_DIR_PATH = 'data/training/images'
MASK_DIR_PATH = 'data/training/masks'

pascal_palette = np.array([
        [  0,   0,   0],
        [128,   0,   0],
        [  0, 128,   0],
        [128, 128,   0],
        [  0,   0, 128],
        [128,   0, 128],
        [  0, 128, 128],
        [128, 128, 128],
        [ 64,   0,   0],
        [192,   0,   0],
        [ 64, 128,   0],
        [192, 128,   0],
        [ 64,   0, 128],
        [192,   0, 128],
        [ 64, 128, 128],
        [192, 128, 128],
        [  0,  64,   0],
        [128,  64,   0],
        [  0, 192,   0],
        [128, 192,   0],
        [  0,  64, 128]], dtype=np.uint8)

# create list of PATHS
image_paths = [os.path.join(IMAGE_DIR_PATH, x) for x in os.listdir(IMAGE_DIR_PATH) if x.endswith('.png')]
mask_paths = [os.path.join(MASK_DIR_PATH, x) for x in os.listdir(MASK_DIR_PATH) if x.endswith('.png')]

# Where image_paths[0] = 'data/training/images/image_0.png' 
# And mask_paths[0] = 'data/training/masks/mask_0.png'

# Initialize the dataloader object
dataset = DataLoader(image_paths=image_paths,
                     mask_paths=mask_paths,
                     image_size=[256, 256],
                     crop_percent=0.8,
                     channels=[3, 3]
                     palette=pascal_palette,
                     seed=47)

# Parse the images and masks, and return the data in batches, augmented optionally.
data, init_op = dataset.data_batch(augment=True, 
                                   shuffle=True,
                                   one_hot_encode=True,
                                   batch_size=BATCH_SIZE,
                                   num_threads=4,
                                   buffer=60)


with tf.Session() as sess:
  # Initialize the data queue
  sess.run(init_op)
  # Evaluate the tensors
  aug_image, aug_mask = sess.run(data)
                                 
  # Do whatever you want now, like creating a feed dict and train your models,
  # You can also directly feed in the tf tensors to your models to avoid using a feed dict.

Note

This code file is meant as a guide for anyone stuck at functions for loading your own data into Tensorflow, generally most problems in ML will follow the skeleton of this example, where you load image and labels (here label is just another image) -> you will preprocess this loaded data -> batch it -> return an iterator over it. This pipeline works for semantic segmentation problems, and can also handle augmentations to images.

Contributing

Ideas for extending this are welcome. If you would like to contribute:

  1. Clone the repo.
  2. Create your own branch.
  3. Make your changes.
  4. Commit and make a pull request.

tensorflow-input-pipeline's People

Contributors

hasnainraz avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.