Giter VIP home page Giter VIP logo

jax-dataloader's Introduction

Dataloader for JAX

Python CI status Docs pypi GitHub License Downloads

Overview

jax_dataloader brings pytorch-like dataloader API to jax. It supports

A minimum jax-dataloader example:

import jax_dataloader as jdl

dataloader = jdl.DataLoader(
    dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
    backend='jax', # Use 'jax' backend for loading data
    batch_size=32, # Batch size 
    shuffle=True, # Shuffle the dataloader every iteration or not
    drop_last=False, # Drop the last batch or not
)

batch = next(iter(dataloader)) # iterate next batch

Installation

The latest jax-dataloader release can directly be installed from PyPI:

pip install jax-dataloader

or install directly from the repository:

pip install git+https://github.com/BirkhoffG/jax-dataloader.git

Note

We keep jax-dataloader’s dependencies minimum, which only install jax and plum-dispatch (for backend dispatching) when installing. If you wish to use integration of pytorch, huggingface datasets, or tensorflow, we highly recommend manually install those dependencies.

You can also run pip install jax-dataloader[all] to install everything (not recommended).

Usage

jax_dataloader.core.DataLoader follows similar API as the pytorch dataloader.

  • The dataset should be an object of the subclass of jax_dataloader.core.Dataset or torch.utils.data.Dataset or (the huggingface) datasets.Dataset or tf.data.Dataset.
  • The backend should be one of "jax" or "pytorch" or "tensorflow". This argument specifies which backend dataloader to load batches.

Note that not every dataset is compatible with every backend. See the compatibility table below:

jdl.Dataset torch_data.Dataset tf.data.Dataset datasets.Dataset
"jax"
"pytorch"
"tensorflow"

The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple jax.numpy.array into one Dataset. For example, we can create an ArrayDataset as follows:

# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

This arr_ds can be loaded by every backends.

# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)

Using Huggingface Datasets

The huggingface datasets is a morden library for downloading, pre-processing, and sharing datasets. jax_dataloader supports directly passing the huggingface datasets.

from datasets import load_dataset

For example, We load the "squad" dataset from datasets:

hf_ds = load_dataset("squad")

Then, we can use jax_dataloader to load batches of hf_ds.

# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)

Using Pytorch Datasets

The pytorch Dataset and its ecosystems (e.g., torchvision, torchtext, torchaudio) supports many built-in datasets. jax_dataloader supports directly passing the pytorch Dataset.

Note

Unfortuantely, the pytorch Dataset can only work with backend=pytorch. See the belowing example.

from torchvision.datasets import MNIST
import numpy as np

We load the MNIST dataset from torchvision. The ToNumpy object transforms images to numpy.array.

pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)

This pt_ds can only be loaded via "pytorch" dataloaders.

dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)

Using Tensowflow Datasets

jax_dataloader supports directly passing the tensorflow datasets.

import tensorflow_datasets as tfds
import tensorflow as tf

For instance, we can load the MNIST dataset from tensorflow_datasets

tf_ds = tfds.load('mnist', split='test', as_supervised=True)

and use jax_dataloader for iterating the dataset.

dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)

jax-dataloader's People

Contributors

birkhoffg avatar devan-kerman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

jax-dataloader's Issues

Error if torch is not installed

if isinstance(dataset, torch_data.Dataset) and backend != "pytorch":

The following line fails with an AttributeError if torch isn't installed. I am using the huggingface dataloader

  File "/root/.virtualenvs/mnist/lib/python3.8/site-packages/jax_dataloader/core.py", line 367, in _dispatch_dataset_and_backend
    if isinstance(dataset, torch_data.Dataset) and backend != "pytorch":
AttributeError: 'NoneType' object has no attribute 'Dataset'

JaxDataLoader is not a valid jax type

Example

import jax_dataloader.core as jdl
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST

train_ds = MNIST('/tmp/mnist/', download=True, transform=ToTensor())
X, Y = train_ds.data.numpy(), train_ds.targets.numpy()
ds = jdl.Dataset(X, Y)
train_loader = jdl.DataLoaderJax(ds, batch_size=5, shuffle=False, drop_last=False)

@jax.jit
def foo(train_loader):
    train_iter = iter(train_loader)
    Xtr, Ytr = next(train_iter)
    print(Xtr.shape, Ytr.shape, Xtr[0].shape, Ytr[0].shape) # 5, 28, 28) (5,) (28, 28) ()

foo(train_loader) # fails

produces

Argument '<jax_dataloader.core.DataLoaderJax object at 0x7f891a1fcb20>' of type <class 'jax_dataloader.core.DataLoaderJax'> is not a valid JAX type.

Slow Loading speed

Hey, just want to say that I have been a huge fan of your work. I just jax_dataloader quite often in my workflow, but I have noticed that the dataloader slows down my training loop by 5 seconds every 10 epochs (compared to tensorflow's numpy iterator). I am not using prefetch on the numpy iterator, so do let me know if you have anticipated this before.

P.S. My training loop is quite ordinary (but my models are rn being developed by an internal library, so I can't disclose the code), so I know that the dataloading speed is the problem.

Typing issues when importing library

Hey! I was trying out this package on a new, minimal python environment and got some typing related error when importing the package.

Environment:

  • Python 3.11.8
    packages:
beartype==0.18.0
jax==0.4.25
jax-dataloader==0.1.0
jaxlib==0.4.25
markdown-it-py==3.0.0
mdurl==0.1.2
ml-dtypes==0.4.0
numpy==1.26.4
opt-einsum==3.3.0
plum-dispatch==2.3.3
Pygments==2.17.2
rich==13.7.1
scipy==1.13.

Steps to reproduce:

import jax_dataloader

Stack trace:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/jax_dataloader/__init__.py", line 4, in <module>
    from .core import *
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/jax_dataloader/core.py", line 8, in <module>
    from .loaders import *
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/jax_dataloader/loaders/__init__.py", line 2, in <module>
    from .jax import *
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/jax_dataloader/loaders/jax.py", line 88, in <module>
    class DataLoaderJAX(BaseDataLoader):
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/jax_dataloader/loaders/jax.py", line 90, in DataLoaderJAX
    @typecheck
     ^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_decor/decorcache.py", line 77, in beartype
    return beartype_object(obj, conf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_decor/decorcore.py", line 87, in beartype_object
    _beartype_object_fatal(obj, conf=conf, **kwargs)
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_decor/decorcore.py", line 136, in _beartype_object_fatal
    beartype_nontype(obj, **kwargs)  # type: ignore[return-value]
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_decor/_decornontype.py", line 174, in beartype_nontype
    return beartype_func(obj, **kwargs)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_decor/_decornontype.py", line 239, in beartype_func
    func_wrapper_code = generate_code(bear_call)
                        ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_decor/wrap/wrapmain.py", line 118, in generate_code
    code_check_params = _code_check_args(bear_call)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_decor/wrap/_wrapargs.py", line 309, in code_check_args
    reraise_exception_placeholder(
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_util/error/utilerrraise.py", line 138, in reraise_exception_placeholder
    raise exception.with_traceback(exception.__traceback__)
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_decor/wrap/_wrapargs.py", line 262, in code_check_args
    ) = make_code_raiser_func_pith_check(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py", line 250, in _callable_cached
    raise exception
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py", line 242, in _callable_cached
    return_value = args_flat_to_return_value[args_flat] = func(
                                                          ^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_check/checkmake.py", line 311, in make_code_raiser_func_pith_check
    ) = make_check_expr(hint, conf, cls_stack)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py", line 250, in _callable_cached
    raise exception
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_util/cache/utilcachecall.py", line 242, in _callable_cached
    return_value = args_flat_to_return_value[args_flat] = func(
                                                          ^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_check/code/codemake.py", line 1752, in make_check_expr
    hint_child_placeholder=_enqueue_hint_child(
                           ^^^^^^^^^^^^^^^^^^^^
  File "/home/elias/.virtualenvs/jax/lib/python3.11/site-packages/beartype/_check/code/codemake.py", line 526, in _enqueue_hint_child
    pith_child_expr is pith_curr_assign_expr
AssertionError: Method jax_dataloader.loaders.jax.DataLoaderJAX.__init__() parameter "dataset" type hint typing.Annotated[NoneType, Is[lambda _: hf_datasets is not None]] child pith expression '__beartype_pith_0' duplicates current pith assignment expression '__beartype_pith_0'.

Custom Jax dataset does not work

Hi,

I'm trying to create a base Dataset (in the pytorch style, but from the default jax_dataloader) but the dataloader crashes if torch is not installed.
This is due to the function _dispatch_dataset_and_backend that has the following condition:
if isinstance(dataset, torch_data.Dataset)... but as torch_data = None, this crashes.

does not seem to convert pytorch datasets correctly

Consider this example

import jax_dataloader.core as jdl
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
train_ds = MNIST('/tmp/mnist/', download=True, transform=ToTensor())

train_loader = jdl.DataLoader(train_ds, 'pytorch', batch_size=5, shuffle=False, drop_last=False)
train_iter = iter(train_loader)
nbatches = len(train_iter)
for b in range(1): #nbatches):
    Xtr, Ytr = next(train_iter)
    print(b, Xtr.shape, Ytr.shape, type(Xtr), type(Ytr))
    print(type(Xtr[0]), Xtr[0].shape, Ytr[0].shape)

This returns

0 (5,) (5,) <class 'numpy.ndarray'> <class 'numpy.ndarray'>
<class 'torch.Tensor'> torch.Size([1, 28, 28]) ()

so the elements within the batch are still pytorch tensors, not numpy arrays.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.