nnaisense / bayesian-flow-networks Goto Github PK
View Code? Open in Web Editor NEWThis is the official code release for Bayesian Flow Networks.
License: Apache License 2.0
This is the official code release for Bayesian Flow Networks.
License: Apache License 2.0
I wanna know how to generate bfn.gif, would you like to share the code which implement this?
When I try to run
python test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000
to test the pre-trained model, I got this error info says
ImportError: cannot import name 'get_generator' from 'utils_train'
(I've already run git clone [email protected]:rupspace/pretrained-BFNs
successfully.)
I checked utils_train.py and found that there is no get_generator
. However, I see function get_generator
in its history commit 834d896:
def get_generator(seed: int):
g = torch.Generator()
g.manual_seed(seed)
return g
After adding this function to utils_train.py
, the error info changed to:
UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution.Do it only if
you get the file from a trusted source. WeightsUnpickler error: Unsupported operand 118
I tried to change weights_only
from True to False, but it doesn't work.
Then I changed the model to my own checkpoint at ./checkpoints/BFN/best/ema_model.pt
(trained with your code, of course), with weights_only
as True, problem solved.
Therefore, there might be some code to fix and models to update. :-)
Hi,
The discrete BFN presented in the paper has demonstrated competitive performance on the text8 dataset. However, the vocabulary size of text8, which stands at a mere 27, is considerably limited for most NLP tasks. I am curious to know if you have experimented with training discrete BFN models on datasets with a larger vocabulary. If that is the case, could you provide some insights into the model's architecture, settings of hyper parameters, and the performance achieved?
Thanks!
Thanks for this excellent work! I really the code implement of BFN is very beautiful, both the code structure and style.
But when I take a close look at the training part, I found an error(maybe) in the calculation of the best validation loss:
best_val_loss = validate(
cfg=cfg,
model=model,
ema_model=ema_model,
val_dataloader=dataloaders["val"],
step=step,
run=run,
pbar=pbar,
best_val_loss=best_val_loss,
checkpoint_root_dir=checkpoint_root_dir,
accelerator=accelerator,
)
because validate()
always return the current validation loss, I think we should change some way as below:
best_val_loss = validate(
cfg=cfg,
model=model,
ema_model=ema_model,
val_dataloader=dataloaders["val"],
step=step,
run=run,
pbar=pbar,
best_val_loss=best_val_loss,
checkpoint_root_dir=checkpoint_root_dir,
accelerator=accelerator,
)
best_val_loss = min(val_loss, best_val_loss)
Am I right? waiting for your reply, thx!
Is there a pre-trained model here?
Hi, it's me again! I think there maybe a problem with dataloader reseeding workers in multi-gpus training, workers with the same worker_id
in different gpus will get the same randomness if we use the way as below(as repo):
bayesian-flow-networks/utils_train.py
Line 60 in 896ea20
def worker_init_function(worker_id: int) -> None:
"""https://pytorch.org/docs/stable/notes/randomness.html#dataloader"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
bayesian-flow-networks/utils_train.py
Line 67 in 896ea20
def get_generator(seed: int):
g = torch.Generator()
g.manual_seed(seed)
return g
One way to avoid this problem is to seed generator by the specified seed
and the rank
, and this may look like:
def get_generator(seed: int):
import torch.distributed as dist
rank = dist.get_rank()
seed += rank
g = torch.Generator()
g.manual_seed(seed)
return g
Following this way, we don't even have to set worker_init_fn
in dataloader, and different gpus will have different _base_seed
in their dataloaders, finally making them(each worker in each gpu) own their unique randomness.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.