Giter VIP home page Giter VIP logo

bayesian-flow-networks's Introduction

Bayesian Flow Networks

This is the official code release for Bayesian Flow Networks by Alex Graves, Rupesh Kumar Srivastava, Timothy Atkinson and Faustino Gomez.

Overview of BFN process

Reading Guide

  • model.py contains all the main contributions of the paper. These include definitions, for both continuous and discrete data, of Bayesian Flows as well as loss functions for both continuous-time and discrete-time. See comments in the base classes in that file for details.
  • probability.py defines the probability distributions used by the models.
  • train.py, test.py and sample.py are scripts for training, testing and sampling (see below for usage).
  • data.py contains utilities related to data loading and processing.
  • networks/ contains implementations of the network architectures used by the models.

Setup

# Create a new conda env with all dependencies including pytorch and CUDA
conda env create -f env.yml
conda activate bfn

# Or, install additional dependencies into an existing pytorch env
pip install accelerate==0.19.0 matplotlib omegaconf rich

# Optional, if you want to enable logging to neptune.ai
pip install neptune 

Training

The models in the paper can be trained using the configs provided in the configs dir as follows:

# mnist experiment on 1 GPU
accelerate launch train.py config_file=configs/mnist_discrete.yaml
# cifar10 experiment on 1 GPU (A100)
accelerate launch train.py config_file=configs/cifar10_discretized_256bins.yaml
# text8 experiment on 8 GPUs (A100)
accelerate launch --multi_gpu --num_processes=8 --num_machines=1 --dynamo_backend=no --mixed_precision=fp16 train.py config_file=configs/text8_discrete.yaml 

Testing

Note

Depending on your GPU, you may wish to adjust the batch size used for testing in test.py.

# Optional: Download pretrained checkpoints (make sure you have git-lfs installed: https://git-lfs.com/)
git clone [email protected]:rupspace/pretrained-BFNs
# Compute 784-step loss on MNIST
python test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000
# Compute 10-step loss on CIFAR-10
python test.py seed=1 config_file=./configs/cifar10_discretized_256bins.yaml load_model=./pretrained-BFNs/cifar10_256d_ema.pt n_steps=10 n_repeats=100
# Compute continuous-time loss on text8
python test.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt n_steps=0 n_repeats=1

Important

All computed results will be in nats-per-data-dimension. To convert to bits, divide by ln(2).

Sampling

You can sample from a pre-trained model as follows (change options as desired):

# Sample 4 binarized MNIST images using 100 steps
python sample.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt samples_shape="[4, 28, 28, 1]" n_steps=100 save_file=./samples_mnist.pt
# Sample 4 CIFAR-10 16-bit images modeled as discretized data using 1000 steps
python sample.py seed=1 config_file=./configs/cifar10_discretized_16bins.yaml load_model=./pretrained-BFNs/cifar10_16d_ema.pt samples_shape="[4, 32, 32, 3]" n_steps=1000 save_file=./samples_cifar.pt
# Sample 2 text8 sequences of length 256 using 100 steps
python sample.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt samples_shape="[2, 256]" n_steps=100 save_file=./samples_text8.pt

The samples are stored as PyTorch tensors in the save_file, and can be visualized by loading them and then using the utilities batch_to_images and batch_to_str in data.py. For example:

# batch_to_images returns a matplotlib Figure object
python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_mnist.pt')).savefig('mnist.png')"
python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_cifar.pt')).savefig('cifar.png')"
# batch_to_str returns a list of str
python -c "import torch; from data import batch_to_str; print(batch_to_str(torch.load('./samples_text8.pt')))"

Reproducibility

If a high degree of reproducibility is desired (e.g. during sampling), set the following:

torch.set_float32_matmul_precision("highest")
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False

Acknowledgements

We are grateful to @Higgcz for generous support with the experiment infrastructure and code release.

bayesian-flow-networks's People

Contributors

flukeskywalker avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bayesian-flow-networks's Issues

Dataloader workers in different gpus may get the same randomness when multi-processes training

Hi, it's me again! I think there maybe a problem with dataloader reseeding workers in multi-gpus training, workers with the same worker_id in different gpus will get the same randomness if we use the way as below(as repo):

def worker_init_function(worker_id: int) -> None:

def worker_init_function(worker_id: int) -> None:
    """https://pytorch.org/docs/stable/notes/randomness.html#dataloader"""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def get_generator(seed: int):

def get_generator(seed: int):
    g = torch.Generator()
    g.manual_seed(seed)
    return g

One way to avoid this problem is to seed generator by the specified seed and the rank, and this may look like:

def get_generator(seed: int):
    import torch.distributed as dist
    
    rank = dist.get_rank()
    seed += rank

    g = torch.Generator()
    g.manual_seed(seed)
    
    return g

Following this way, we don't even have to set worker_init_fn in dataloader, and different gpus will have different _base_seed in their dataloaders, finally making them(each worker in each gpu) own their unique randomness.

Errors while running test.py

When I try to run

python test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000

to test the pre-trained model, I got this error info says
ImportError: cannot import name 'get_generator' from 'utils_train'
(I've already run git clone [email protected]:rupspace/pretrained-BFNs successfully.)

I checked utils_train.py and found that there is no get_generator. However, I see function get_generator in its history commit 834d896:

def get_generator(seed: int):
    g = torch.Generator()
    g.manual_seed(seed)
    return g

After adding this function to utils_train.py, the error info changed to:

UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution.Do it only if 
you get the file from a trusted source. WeightsUnpickler error: Unsupported operand 118

I tried to change weights_only from True to False, but it doesn't work.

Then I changed the model to my own checkpoint at ./checkpoints/BFN/best/ema_model.pt (trained with your code, of course), with weights_only as True, problem solved.

Therefore, there might be some code to fix and models to update. :-)

There maybe an error in the calculation of the best validation loss in train.py

Thanks for this excellent work! I really the code implement of BFN is very beautiful, both the code structure and style.
But when I take a close look at the training part, I found an error(maybe) in the calculation of the best validation loss:

best_val_loss = validate(
      cfg=cfg,
      model=model,
      ema_model=ema_model,
      val_dataloader=dataloaders["val"],
      step=step,
      run=run,
      pbar=pbar,
      best_val_loss=best_val_loss,
      checkpoint_root_dir=checkpoint_root_dir,
      accelerator=accelerator,
)

because validate() always return the current validation loss, I think we should change some way as below:

best_val_loss = validate(
      cfg=cfg,
      model=model,
      ema_model=ema_model,
      val_dataloader=dataloaders["val"],
      step=step,
      run=run,
      pbar=pbar,
      best_val_loss=best_val_loss,
      checkpoint_root_dir=checkpoint_root_dir,
      accelerator=accelerator,
)
best_val_loss = min(val_loss, best_val_loss)

Am I right? waiting for your reply, thx!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.