Comments (7)
Since you converted the models to a different framework, could there be an issue in the conversion process?
There are various debug experiments you could try to make sure the conversion process worked successfully. For example, try training with a very high aging loss lambda (say 100) and see if you are able to generate images of various ages. Of course the image reconstructions won't be good but this should help you verify that your age classifier us fine.
Another experiment won't be to ovefit on say 10 images.
Small things like that can help you easily identify gaps between your model and implementation and the official implementation.
Hope this helps.
from sam.
Hi @S-HuaBomb ,
I feel like there is something strange going on in the early stages of your training. If everything works correctly, in the early stages of training, you should see much better reconstructions of the input images since we are using a pre-trained pSp encoder to learn the initial latent embedding of the input.
Also, your results at the later stages of training seem a bit strange to me.
Any chance you want to send the command you ran so I can check it out?
from sam.
Oh, I'm Sorry.
I forgot to tell you that the results I got above was not directly by running your pytorch source code and the pre-trained models. I used Baidu's deep learning framework PaddlePaddle to rewrite all networks, transfer your pre-trained models to Paddle format, and then use them to initialize network parameters. There are some differences in neural network parameters between Paddle and Pytorch.
I think this process leads to these strange problems, especially the randomness of the image generated at the early steps. But I'm surprised that now the training runs to 30000 steps, and the effect of generating faces is very good:
In particular, male face generation is much better than the earlier:
This is one result of running your pytorch source code to train before (initialized using your pre-trained SAM model). Indeed, the generated face at the beginning is very similar to the input, contrary to the strange phenomenon of my paddle version.
Btw, the result of continuing training is exciting. It doesn't take too many steps to achieve the current effect. I believe that initializing with your pre-training parameters greatly facilitates training. Some metrics of training up to now are as follows:
Metrics for train, step 28750
loss_id_real = 0.5581539869308472
id_improve_real = -0.5194830099741617
loss_l2_real = 0.10265877842903137
loss_lpips_real = 0.3880242705345154
loss_lpips_crop = 0.24058037996292114
loss_l2_crop = 0.08031028509140015
loss_w_norm_real = 12.686344146728516
loss_aging_real = 0.0015396953094750643
loss_real = 0.38906392455101013
loss_id_cycle = 0.6107007265090942
id_improve_cycle = -0.5718449528018633
loss_l2_cycle = 0.11893878877162933
loss_lpips_cycle = 0.4076683223247528
loss_w_norm_cycle = 12.188891410827637
loss_aging_cycle = 0.0013609920861199498
loss_cycle = 0.42397958040237427
loss = 0.8130435049533844
Metrics for train, step 29000
loss_id_real = 0.6209578514099121
id_improve_real = -0.5855490267276764
loss_l2_real = 0.08174904435873032
loss_lpips_real = 0.30977702140808105
loss_lpips_crop = 0.22773930430412292
loss_l2_crop = 0.07081114500761032
loss_w_norm_real = 13.020597457885742
loss_aging_real = 0.006176636088639498
loss_real = 0.38440173864364624
loss_id_cycle = 0.670578122138977
id_improve_cycle = -0.6313480387131373
loss_l2_cycle = 0.10392758250236511
loss_lpips_cycle = 0.334219753742218
loss_w_norm_cycle = 12.625384330749512
loss_aging_cycle = 0.003899508621543646
loss_cycle = 0.41654089093208313
loss = 0.8009426295757294
Metrics for train, step 29250
loss_id_real = 0.0
id_improve_real = -0.602919747432073
loss_l2_real = 0.08749698847532272
loss_lpips_real = 0.345823734998703
loss_lpips_crop = 0.24862270057201385
loss_l2_crop = 0.07119642198085785
loss_w_norm_real = 12.440343856811523
loss_aging_real = 0.11909914016723633
loss_real = 0.9098708033561707
loss_id_cycle = 0.6522703170776367
id_improve_cycle = -0.6157533973455429
loss_l2_cycle = 0.10205188393592834
loss_lpips_cycle = 0.36336588859558105
loss_w_norm_cycle = 11.9584379196167
loss_aging_cycle = 0.006273109465837479
loss_cycle = 0.4386044144630432
loss = 1.3484752178192139
Metrics for train, step 29500
loss_id_real = 0.0
id_improve_real = -0.5774414390325546
loss_l2_real = 0.10120463371276855
loss_lpips_real = 0.3618144690990448
loss_lpips_crop = 0.2744097113609314
loss_l2_crop = 0.08724726736545563
loss_w_norm_real = 12.781248092651367
loss_aging_real = 0.10382138937711716
loss_real = 0.8643555045127869
loss_id_cycle = 0.6659125685691833
id_improve_cycle = -0.6285987198352814
loss_l2_cycle = 0.1224149614572525
loss_lpips_cycle = 0.3821250796318054
loss_w_norm_cycle = 12.295019149780273
loss_aging_cycle = 0.004325758665800095
loss_cycle = 0.4704045057296753
loss = 1.3347600102424622
Metrics for train, step 29750
loss_id_real = 0.0
id_improve_real = -0.6456222732861837
loss_l2_real = 0.08446688950061798
loss_lpips_real = 0.34283968806266785
loss_lpips_crop = 0.21868062019348145
loss_l2_crop = 0.06268683075904846
loss_w_norm_real = 11.760087966918945
loss_aging_real = 0.11295045912265778
loss_real = 0.8554226756095886
loss_id_cycle = 0.7188147902488708
id_improve_cycle = -0.6828048142294089
loss_l2_cycle = 0.09350767731666565
loss_lpips_cycle = 0.3628307282924652
loss_w_norm_cycle = 11.585875511169434
loss_aging_cycle = 0.003500542603433132
loss_cycle = 0.4008687734603882
loss = 1.2562914490699768
Train options are set as follows:
{
"aging_lambda": 5,
"batch_size": 6,
"board_interval": 250,
"cycle_lambda": 1,
"dataset_type": "ffhq_aging",
"exp_dir": "code_source/SAM_P/exp_dir",
"id_lambda": 0.1,
"image_interval": 1000,
"input_nc": 4,
"l2_lambda": 0.25,
"l2_lambda_aging": 0.25,
"l2_lambda_crop": 1,
"label_nc": 0,
"latent_avg": "code_source/SAM_P/pretrained_models/latent_avg.pdparams",
"learning_rate": 0.0001,
"lpips_lambda": 0.1,
"lpips_lambda_aging": 0.1,
"lpips_lambda_crop": 0.6,
"max_steps": 500000,
"optim_name": "adam",
"output_size": 1024,
"pretrained_psp_path": "code_source/SAM_P/pretrained_models/psp_ffhq_encoder.pdparams",
"psp_encoder": "code_source/SAM_P/pretrained_models/psp_encoder.pdparams",
"psp_ffhq_encoder": "code_source/SAM_P/pretrained_models/psp_ffhq_encoder.pdparams",
"sam_decoder": "code_source/SAM_P/pretrained_models/sam_decoder.pdparams",
"sam_encoder": "code_source/SAM_P/pretrained_models/sam_encoder.pdparams",
"sam_psp3_encoder": "code_source/SAM_P/pretrained_models/sam_psp3_encoder.pdparams",
"save_interval": 5000,
"start_from_encoded_w_plus": true,
"start_from_latent_avg": true,
"stylegan_decoder": "code_source/SAM_P/pretrained_models/stylegan2.pdparams",
"target_age": "uniform_random",
"test_batch_size": 4,
"test_workers": 2,
"train_decoder": false,
"train_from_sam_ckpt": true,
"use_weighted_id_loss": true,
"val_interval": 2500,
"w_norm_lambda": 0.005,
"workers": 2
}
Sorry to confuse you. Based on the above results, do you think I need to train longer? looking forward to your reply. π
from sam.
It looks like your results are improving but you should continue training. As I mentioned we trained for about 60,000 steps and you're only at 30,000 if I understood correctly.
One thing I noticed is that your ID similarities are quite low. The results you posted using the pre-trained SAM model reach similarities of 0.6 - 0.7 while you are reaching similarities in the 0.2 range.
Since you changed frameworks it is hard for me to pinpoint what could be the problem, but I would try to verify that the pSp encoder is transferred correctly and is used correctly during training.
from sam.
I will upload the code of the paddle version to GitHub in these two days. Now ID
similarities are rising, about 0.4.
But now there is a problem: No matter what the target ages are in the training process, the age of the generated face is close to the input age, not the target ageοΌIn fact, most of the generated faces are young peopleοΌ. Is the optimization priority of aging-loss lower?
from sam.
Thank you very much for your help. We have found the problem that without age transformation effect:
The VGG network defined in dex_vgg.py
uses a dict to save the output of each layer during forward, this will result in loss that cannot be propagated back in Paddle version, therefore, I only need to replace out[]
with x
. And now it work!
Here the repository of Paddle version: paddle-SAM.
We were surprised to find that the SAM reproduced with Paddle achieved almost the same effect as your source code! π π π π
from sam.
Really cool to see that you were able to reproduce the results using a different framework. Awesome work!
from sam.
Related Issues (20)
- [Error] [Win] INFO: Could not find files for the given pattern(s). HOT 4
- [ERROR] c++: error: c10_cuda.lib: No such file or directory HOT 1
- Tflite HOT 1
- No module named 'models.fused_act' HOT 2
- What is the difference between source and target? HOT 2
- How the loss lambdas are set? HOT 3
- Could you please share more details about PCA? HOT 2
- inference error HOT 2
- custom dataset
- Sorry for disturbance ,Environment/sam_env.yaml have some problem HOT 3
- how to finetune the vgg age (DEX) ? HOT 1
- How can we use this model on CPU. HOT 2
- How can I run this on CPU ? HOT 2
- ImportError: No module named 'fused' HOT 3
- Getting Current Age of person HOT 1
- urllib.error.URLError: <urlopen error [Errno 111] Connection refused> HOT 4
- Gray image test results HOT 2
- time issue
- Error When Specifying Multiple Ages for --target_age Parameter In Replicate
- lpips_lambda_aging and l2_lambda_aging
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from sam.