fine-tune-models's Introduction
fine-tune-models's People
Forkers
peternara slives-lab sheldonchiu matianchi heyalexchoi yanhuifair kgonia zhuliyi0 cv-nlp ciarastrawberry jireh-father thomasedingfine-tune-models's Issues
Is there any "Artifact" issues during finetuning?
I've implemented PatchGAN and StyleGAN Loss for decoder only finetuned, but I got artifacts reconstructions even after 50k steps.
My Loss function setting is: 0.1 LPIPS + 0.1 PatchGAN (with adaptive weight) / 0.1 StyleGAN (with 1e9 gradient panelty) + MSE Loss
The learning part is post_quant_conv
, decoder
and GAN
part.
The images are reconstructed without EMA. Here are some failure cases
Could not find parameter named "kernel" in scope "/encoder/mid_block/attentions_0/query"
Hi, upon loading this up with any dataset and the 1.4 vae converted to jax, i get this warning on loading:
The checkpoint D:\gitprojects\finetuningvae\backup\stable-diffusion-v1-5-jax\vae\ is missing required keys: {('decoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'key', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'key', 'kernel')}. Make sure to call model.init_weights to initialize the missing weights.
(it's the 1.4 vae i just copied it over)
which i presume causes the training to fail with this error
y = fun(self, *args, **kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\stable_diffusion_jax\modeling_vae.py", line 287, in call
hidden_states = attn(hidden_states)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\stable_diffusion_jax\modeling_vae.py", line 150, in call
query = self.query(hidden_states)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 418, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 854, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\linear.py", line 196, in call
kernel = self.param('kernel',
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\linen\module.py", line 1263, in param
v = self.scope.param(name, init_fn, *init_args, unbox=unbox)
File "C:\Users\crowl\AppData\Local\Programs\Python\Python310\lib\site-packages\flax\core\scope.py", line 842, in param
raise errors.ScopeParamNotFoundError(name, self.path_text)
jax._src.traceback_util.UnfilteredStackTrace: flax.errors.ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/encoder/mid_block/attentions_0/query". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamNotFoundError)
This happens whatever i start up the training with, at a bit of a loss really.
Converting from flax back to pt?
Got the training working and it works really well for fine tuning out some common "glitches" and mis-reconstructions from photos and textures. However, i'd really like to convert the weights back into pytorch-usable format.
I've looked at the convert_diffusers_to_jax script from patil suraj's repo, but it doesn't look like reverse engineering it into an inverse function is feasible. Some changes it applies to the model appear to be irreversible without reconstructing the model key structure from scratch.
Diffusers now has its own pt->flax and flax->pt conversion, but i couldn't get its flax model to work with the script. There seem to be too many discrepancies in both the methods and the model keys and my knowledge is too limited to figure out how to adapt the fine tuning script to it. Diffuser's AutoEncoderKL also returns its output as yet another encapsulated proprietary class, complicating things even further.
So... bottom line - is there something i'm missing?
Is there some simple solution i've overlooked?
about dataset format
Hi, I see your code of run_finetune_vqgan.py and Kaggle dataset, but I can't understand the dataset format of "danbooru_image_paths_ds.json". Can you give me its dataset format?
vae finetuning
does this work with normal pytorch sd and not just stablediffusion-jax?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.