yiyixuxu / denoising-diffusion-flax Goto Github PK
View Code? Open in Web Editor NEWImplementing the Denoising Diffusion Probabilistic Model in Flax
License: Apache License 2.0
Implementing the Denoising Diffusion Probabilistic Model in Flax
License: Apache License 2.0
Hi, thanks for the great work!
There is an assertion error when checking the dataset, which is confusing because as far as I understand it should fail for anyone.
Possibly a version issue (maybe some version of jax recognises tf types as jnp?).
AssertionError Traceback (most recent call last)
/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb Cell 5 in <cell line: 2>()
[1](vscode-notebook-cell://ssh-remote%2Btitan08.compute.dtu.dk/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb#Y131sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) work_dir = './fashion_mnist'
----> [2](vscode-notebook-cell://ssh-remote%2Btitan08.compute.dtu.dk/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb#Y131sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) state = train.train(my_config, work_dir)
File ~/projects/denoising-diffusion-flax/denoising_diffusion_flax/train.py:436, in train(config, workdir, wandb_artifact)
434 rng, *train_step_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
435 train_step_rng = jnp.asarray(train_step_rng)
--> 436 state, metrics = p_train_step(train_step_rng, state, batch)
437 for h in hooks:
438 h(step)
[... skipping hidden 17 frame]
File ~/projects/denoising-diffusion-flax/denoising_diffusion_flax/train.py:252, in p_loss(rng, state, batch, ddpm_params, loss_fn, self_condition, is_pred_x0, pmap_axis)
248 def p_loss(rng, state, batch, ddpm_params, loss_fn, self_condition=False, is_pred_x0=False, pmap_axis='batch'):
249
250 # run the forward diffusion process to generate noisy image x_t at timestep t
251 x = batch['image']
--> 252 assert x.dtype in [jnp.float32, jnp.float64]
254 # create batched timesteps: t with shape (B,)
255 B, H, W, C = x.shape
AssertionError:
get_dataset shown below with fixing lines commented out
def get_dataset(rng, config):
if config.data.batch_size % jax.device_count() > 0:
raise ValueError('Batch size must be divisible by the number of devices')
batch_size = config.data.batch_size //jax.process_count()
platform = jax.local_devices()[0].platform
if config.training.half_precision:
if platform == 'tpu':
# input_dtype = tf.bfloat16
input_dtype = jnp.bfloat16
else:
# input_dtype = tf.float16
input_dtype = jnp.float16
else:
input_dtype = tf.float32
# input_dtype = jnp.float32
For anyone reading I'm using 0.3.21 CUDA (not TPU).
why set lr = 2e-4
for oxford102 flowers dataset? I've tried on denoising-diffusion-pytorch and my implementation denoising-diffusion-mindspore, the loss waves around 0.4 and the sampled image are always noisy.
Is the weight initialization method not the same between Pytorch and Jax? I use the training config below which can sample a better image:
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8)
)
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000, # number of steps
sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
loss_type = 'l1' # L1 or L2
)
trainer = Trainer(
diffusion,
path,
train_batch_size = 16,
train_lr = 8e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp_level = 'O1', # turn on mixed precision
)
trainer.train()
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.