Comments (4)
@lucidrains I've just finished looking over the EMA library. So, if I got it right, this is what happens:
- During
Trainer.__init__
, you initialize theEMA
. This creates a deep copy of the model which is then passed to theFidEvaluator
- At testing time, you copy the parameters to the model stored in the
EMA
- You call
.eval()
for theFidEvaluator
model (which is actually the EMA model), which means that the original model isn't affected and remains in training mode
If this is correct, I'll go ahead and close the issue, If there isn't any downside, I'll also probably edit the script so that it calls .eval()
and .train()
manually (I need to compute some other metrics besides the FID).
Out of curiosity, what is the rationale for keeping a deep-copied EMA model and manually updating the parameters? Is it for the sake of thread safety / accelerator black magic?
from denoising-diffusion-pytorch.
@samuelemarro hey Samuele, thanks for raising this issue
i think this is taken care of here? but i just realized it may not set it back to train mode, which may be necessary given the dropouts just added
let me quickly do that
from denoising-diffusion-pytorch.
@samuelemarro oh actually, it is fine, since the FidEvaluator
is working off the EMA
model
from denoising-diffusion-pytorch.
@samuelemarro i'm not sure where the practice of keeping an EMA of generator came from, but i first encountered it in StyleGAN2 from Tero Karras years ago. ever since, many papers use this technique and it has become mainstream
the rationale is that the smoothed parameter updates yield a model with better quality
from denoising-diffusion-pytorch.
Related Issues (20)
- Unable to train HOT 1
- Failed to load image Python extension: '[WinError 127] 找不到指定的程序 HOT 1
- Any implements on classify free guidance?
- How could load gpu to train?
- Question about the normalization of the input data for ddpm.
- Question about how to use elucidated_diffusion HOT 1
- Fast attention in Windows possible?
- No available kernel HOT 1
- change of beta_schedule leads to significantly worse results
- Loss on Unet1D
- scale up UNet with different resolution
- Why 1D diffusion is so extremely slow?? HOT 1
- RePaint Improvements HOT 1
- Bug in RePaint implementation: p_sample input args and resample loop HOT 1
- A question related to batch size and training speed HOT 2
- The ./results folder is empty
- training on CIFAR-10 performs not well, whether L2 loss or FID
- Enable flash attention for compute capability >= 8.0, not == 8.0 HOT 6
- bug of UNet final step dim HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from denoising-diffusion-pytorch.