Giter VIP home page Giter VIP logo

barbar's Introduction

Barbar💈

Progress bar for deep learning training iterations.

screenshot

Quick glance

from barbar import Bar
import torch
from torch.utils.data import DataLoader
from torchvision import datasets

mnist_train = datasets.MNIST(root=root,
                             download=True,
                             train=True)
train_dataloader = DataLoader(mnist_train,
                              batch_size=100,
                              shuffle=True)

model = MLP().to(device)

for epoch in range(epochs):
    print('Epoch: {}'.format(epoch+1))

    for idx, (x, t) in enumerate(Bar(train_dataloader)):
        x, t = x.to(device), t.to(device)
        loss, preds = train_step(x, t)
Epoch: 1
60000/60000: [===============================>] - ETA 0.0s
Epoch: 2
28100/60000: [==============>.................] - ETA 4.1s

Barbar works best with PyTorch DataLoader, but it also works with custom DataLoader. Minimal DataLoader example can be written as follows:

class CustomDataLoader(object):
    def __init__(self, dataset,
                 batch_size=100,
                 shuffle=False,
                 random_state=None):
        self.dataset = list(zip(dataset[0], dataset[1]))
        self.batch_size = batch_size
        self.shuffle = shuffle
        if random_state is None:
            random_state = np.random.RandomState(1234)
        self.random_state = random_state
        self._idx = 0
        self._reset()

    def __len__(self):
        N = len(self.dataset)
        b = self.batch_size
        return N // b + bool(N % b)

    def __iter__(self):
        return self

    def __next__(self):
        if self._idx >= len(self.dataset):
            self._reset()
            raise StopIteration()

        x, y = \
            zip(*self.dataset[self._idx:(self._idx + self.batch_size)])

        # x = torch.Tensor(x)
        # y = torch.LongTensor(y)

        self._idx += self.batch_size

        return x, y

    def _reset(self):
        if self.shuffle:
            self.dataset = shuffle(self.dataset,
                                   random_state=self.random_state)
        self._idx = 0

mnist = datasets.fetch_openml('mnist_784', version=1,)
x, y = mnist.data.astype(np.float32), mnist.target.astype(np.int32)
x = x / 255.
x_train = x[:60000]
y_train = y[:60000]

train_dataloader = CustomDataLoader((x_train, y_train),
                                    batch_size=100,
                                    shuffle=True)

Installation

  • Install Barbar from PyPI (recommended):
pip install barbar
  • Alternatively: install Barbar from the GitHub source:

First, clone Barbar using git:

git clone https://github.com/yusugomori/barbar.git

Then, cd to the Barbar folder and run the install command:

cd barbar
sudo python setup.py install

barbar's People

Contributors

yusugomori avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.