Giter VIP home page Giter VIP logo

gqnlib's Introduction

gqnlib (Work in progress)

Generative Query Network by PyTorch.

Requirements

  • Python == 3.7
  • PyTorch == 1.5.0

Requirements for example code

  • torchvision == 0.6.0
  • tqdm == 4.46.0
  • tensorflow == 2.2.0
  • tensorboardX == 2.0
  • matplotlib == 3.2.1

How to use

Set up environments

Clone repository.

git clone https://github.com/rnagumo/gqnlib.git
cd gqnlib

Install the package in virtual env.

python3 -m venv .venv
source .venv/bin/activate
pip3 install --upgrade pip
pip3 install .

Or use Docker and NVIDIA Container Toolkit. You can run container with GPUs by Docker 19.03+.

docker build -t gqnlib .
docker run --gpus all -it gqnlib bash

Install other requirements for sample code.

pip3 install tqdm==4.46.0 tensorflow==2.2.0 tensorboardX==2.0 matplotlib==3.2.1 torchvision==0.6.0

Prepare dataset

Dataset is provided by DeepMind as GQN dataset and SLIM dataset.

The following command will download the specified dataset and convert tfrecords into torch gziped files. This shell script uses gsutil command, which should be installed in advance (read here).

Caution: This process takes a very long time. For example, shepard_metzler_5_parts dataset which is the smallest one takes 2~3 hours on my PC with 32 GB memory.

Caution: This process creates very large size files. For example, original shepard_metzler_5_parts dataset contains 900 files (17 GB) for train and 100 files (5 GB) for test, and converted dataset contains 2,100 files (47 GB) for train and 400 files (12 GB) for test.

bash bin/download_scene.sh shepard_metzler_5_parts

Run experiment

Run training. bin/train.sh contains the necessary settings. This takes a very long time, 10~30 hours.

bash bin/train.sh

Example

Training

import pathlib
import torch
import gqnlib

# Prepare dataset and model
root = "./data/shepard_metzler_5_parts_torch/train/"
dataset = gqnlib.SceneDataset(root, 20)
model = gqnlib.GenerativeQueryNetwork()
optimizer = torch.optim.Adam(model.parameters())

model.train()
for batch in dataset:
    for data in batch:
        # Partition data into context and query
        data = gqnlib.partition_scene(*data)

        # Inference
        optimizer.zero_grad()
        loss_dict = model(*data)

        # Backward
        loss = loss_dict["loss"].mean()
        loss.backward()
        optimizer.step()

# Save checkpoints
p = pathlib.Path("./logs/tmp")
p.mkdir(exist_ok=True)

cp = {"model_state_dict": model.state_dict(),
      "optimizer_state_dict": optimizer.state_dict()}
torch.save(cp, p / "example.pt")

Use pre-trained model

import torch
import gqnlib

# Load pre-trained model
model = gqnlib.GenerativeQueryNetwork()
cp = torch.load("./logs/tmp/example.pt")
model.load_state_dict(cp["model_state_dict"])

# Data
root = "./data/shepard_metzler_5_parts_torch/train/"
dataset = gqnlib.SceneDataset(root, 20)
images, viewpoints = dataset[0][0]
x_c, v_c, x_q, v_q = gqnlib.partition_scene(images, viewpoints)

# Reconstruct and sample
with torch.no_grad():
    recon = model.reconstruct(x_c, v_c, x_q, v_q)
    sample = model.sample(x_c, v_c, v_q)

print(recon.size())  # -> torch.Size([20, 1, 3, 64, 64])
print(sample.size())  # -> torch.Size([20, 1, 3, 64, 64])

Reference

Original papers

Datasets

  • Datasets by DeepMind for GQN. GitHub
  • Datasetf by DeepMind for SLIM. GitHub

Codes

  • mushoku, chainer-gqn. GitHub
  • iShohei220, torch-gqn. GitHub
  • wohlert, generative-query-network-pytorch. GitHub
  • l3robot, gqn_datasets_translator. GitHub

gqnlib's People

Contributors

rnagumo avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.