Giter VIP home page Giter VIP logo

Comments (7)

yuval-alaluf avatar yuval-alaluf commented on June 15, 2024 1

Since you converted the models to a different framework, could there be an issue in the conversion process?
There are various debug experiments you could try to make sure the conversion process worked successfully. For example, try training with a very high aging loss lambda (say 100) and see if you are able to generate images of various ages. Of course the image reconstructions won't be good but this should help you verify that your age classifier us fine.
Another experiment won't be to ovefit on say 10 images.
Small things like that can help you easily identify gaps between your model and implementation and the official implementation.
Hope this helps.

from sam.

yuval-alaluf avatar yuval-alaluf commented on June 15, 2024

Hi @S-HuaBomb ,
I feel like there is something strange going on in the early stages of your training. If everything works correctly, in the early stages of training, you should see much better reconstructions of the input images since we are using a pre-trained pSp encoder to learn the initial latent embedding of the input.
Also, your results at the later stages of training seem a bit strange to me.
Any chance you want to send the command you ran so I can check it out?

from sam.

S-HuaBomb avatar S-HuaBomb commented on June 15, 2024

Oh, I'm Sorry.

I forgot to tell you that the results I got above was not directly by running your pytorch source code and the pre-trained models. I used Baidu's deep learning framework PaddlePaddle to rewrite all networks, transfer your pre-trained models to Paddle format, and then use them to initialize network parameters. There are some differences in neural network parameters between Paddle and Pytorch.

I think this process leads to these strange problems, especially the randomness of the image generated at the early steps. But I'm surprised that now the training runs to 30000 steps, and the effect of generating faces is very good:
image
image

In particular, male face generation is much better than the earlier:
image

This is one result of running your pytorch source code to train before (initialized using your pre-trained SAM model). Indeed, the generated face at the beginning is very similar to the input, contrary to the strange phenomenon of my paddle version.
image

Btw, the result of continuing training is exciting. It doesn't take too many steps to achieve the current effect. I believe that initializing with your pre-training parameters greatly facilitates training. Some metrics of training up to now are as follows:

Metrics for train, step 28750
	loss_id_real =  0.5581539869308472
	id_improve_real =  -0.5194830099741617
	loss_l2_real =  0.10265877842903137
	loss_lpips_real =  0.3880242705345154
	loss_lpips_crop =  0.24058037996292114
	loss_l2_crop =  0.08031028509140015
	loss_w_norm_real =  12.686344146728516
	loss_aging_real =  0.0015396953094750643
	loss_real =  0.38906392455101013
	loss_id_cycle =  0.6107007265090942
	id_improve_cycle =  -0.5718449528018633
	loss_l2_cycle =  0.11893878877162933
	loss_lpips_cycle =  0.4076683223247528
	loss_w_norm_cycle =  12.188891410827637
	loss_aging_cycle =  0.0013609920861199498
	loss_cycle =  0.42397958040237427
	loss =  0.8130435049533844
Metrics for train, step 29000
	loss_id_real =  0.6209578514099121
	id_improve_real =  -0.5855490267276764
	loss_l2_real =  0.08174904435873032
	loss_lpips_real =  0.30977702140808105
	loss_lpips_crop =  0.22773930430412292
	loss_l2_crop =  0.07081114500761032
	loss_w_norm_real =  13.020597457885742
	loss_aging_real =  0.006176636088639498
	loss_real =  0.38440173864364624
	loss_id_cycle =  0.670578122138977
	id_improve_cycle =  -0.6313480387131373
	loss_l2_cycle =  0.10392758250236511
	loss_lpips_cycle =  0.334219753742218
	loss_w_norm_cycle =  12.625384330749512
	loss_aging_cycle =  0.003899508621543646
	loss_cycle =  0.41654089093208313
	loss =  0.8009426295757294
Metrics for train, step 29250
	loss_id_real =  0.0
	id_improve_real =  -0.602919747432073
	loss_l2_real =  0.08749698847532272
	loss_lpips_real =  0.345823734998703
	loss_lpips_crop =  0.24862270057201385
	loss_l2_crop =  0.07119642198085785
	loss_w_norm_real =  12.440343856811523
	loss_aging_real =  0.11909914016723633
	loss_real =  0.9098708033561707
	loss_id_cycle =  0.6522703170776367
	id_improve_cycle =  -0.6157533973455429
	loss_l2_cycle =  0.10205188393592834
	loss_lpips_cycle =  0.36336588859558105
	loss_w_norm_cycle =  11.9584379196167
	loss_aging_cycle =  0.006273109465837479
	loss_cycle =  0.4386044144630432
	loss =  1.3484752178192139
Metrics for train, step 29500
	loss_id_real =  0.0
	id_improve_real =  -0.5774414390325546
	loss_l2_real =  0.10120463371276855
	loss_lpips_real =  0.3618144690990448
	loss_lpips_crop =  0.2744097113609314
	loss_l2_crop =  0.08724726736545563
	loss_w_norm_real =  12.781248092651367
	loss_aging_real =  0.10382138937711716
	loss_real =  0.8643555045127869
	loss_id_cycle =  0.6659125685691833
	id_improve_cycle =  -0.6285987198352814
	loss_l2_cycle =  0.1224149614572525
	loss_lpips_cycle =  0.3821250796318054
	loss_w_norm_cycle =  12.295019149780273
	loss_aging_cycle =  0.004325758665800095
	loss_cycle =  0.4704045057296753
	loss =  1.3347600102424622
Metrics for train, step 29750
	loss_id_real =  0.0
	id_improve_real =  -0.6456222732861837
	loss_l2_real =  0.08446688950061798
	loss_lpips_real =  0.34283968806266785
	loss_lpips_crop =  0.21868062019348145
	loss_l2_crop =  0.06268683075904846
	loss_w_norm_real =  11.760087966918945
	loss_aging_real =  0.11295045912265778
	loss_real =  0.8554226756095886
	loss_id_cycle =  0.7188147902488708
	id_improve_cycle =  -0.6828048142294089
	loss_l2_cycle =  0.09350767731666565
	loss_lpips_cycle =  0.3628307282924652
	loss_w_norm_cycle =  11.585875511169434
	loss_aging_cycle =  0.003500542603433132
	loss_cycle =  0.4008687734603882
	loss =  1.2562914490699768

Train options are set as follows:

{
    "aging_lambda": 5,
    "batch_size": 6,
    "board_interval": 250,
    "cycle_lambda": 1,
    "dataset_type": "ffhq_aging",
    "exp_dir": "code_source/SAM_P/exp_dir",
    "id_lambda": 0.1,
    "image_interval": 1000,
    "input_nc": 4,
    "l2_lambda": 0.25,
    "l2_lambda_aging": 0.25,
    "l2_lambda_crop": 1,
    "label_nc": 0,
    "latent_avg": "code_source/SAM_P/pretrained_models/latent_avg.pdparams",
    "learning_rate": 0.0001,
    "lpips_lambda": 0.1,
    "lpips_lambda_aging": 0.1,
    "lpips_lambda_crop": 0.6,
    "max_steps": 500000,
    "optim_name": "adam",
    "output_size": 1024,
    "pretrained_psp_path": "code_source/SAM_P/pretrained_models/psp_ffhq_encoder.pdparams",
    "psp_encoder": "code_source/SAM_P/pretrained_models/psp_encoder.pdparams",
    "psp_ffhq_encoder": "code_source/SAM_P/pretrained_models/psp_ffhq_encoder.pdparams",
    "sam_decoder": "code_source/SAM_P/pretrained_models/sam_decoder.pdparams",
    "sam_encoder": "code_source/SAM_P/pretrained_models/sam_encoder.pdparams",
    "sam_psp3_encoder": "code_source/SAM_P/pretrained_models/sam_psp3_encoder.pdparams",
    "save_interval": 5000,
    "start_from_encoded_w_plus": true,
    "start_from_latent_avg": true,
    "stylegan_decoder": "code_source/SAM_P/pretrained_models/stylegan2.pdparams",
    "target_age": "uniform_random",
    "test_batch_size": 4,
    "test_workers": 2,
    "train_decoder": false,
    "train_from_sam_ckpt": true,
    "use_weighted_id_loss": true,
    "val_interval": 2500,
    "w_norm_lambda": 0.005,
    "workers": 2
}

Sorry to confuse you. Based on the above results, do you think I need to train longer? looking forward to your reply. πŸ˜„

from sam.

yuval-alaluf avatar yuval-alaluf commented on June 15, 2024

It looks like your results are improving but you should continue training. As I mentioned we trained for about 60,000 steps and you're only at 30,000 if I understood correctly.
One thing I noticed is that your ID similarities are quite low. The results you posted using the pre-trained SAM model reach similarities of 0.6 - 0.7 while you are reaching similarities in the 0.2 range.
Since you changed frameworks it is hard for me to pinpoint what could be the problem, but I would try to verify that the pSp encoder is transferred correctly and is used correctly during training.

from sam.

S-HuaBomb avatar S-HuaBomb commented on June 15, 2024

I will upload the code of the paddle version to GitHub in these two days. Now ID
similarities are rising, about 0.4.

But now there is a problem: No matter what the target ages are in the training process, the age of the generated face is close to the input age, not the target age(In fact, most of the generated faces are young peopleοΌ‰. Is the optimization priority of aging-loss lower?

image

from sam.

S-HuaBomb avatar S-HuaBomb commented on June 15, 2024

Thank you very much for your help. We have found the problem that without age transformation effect:

The VGG network defined in dex_vgg.py uses a dict to save the output of each layer during forward, this will result in loss that cannot be propagated back in Paddle version, therefore, I only need to replace out[] with x. And now it work!

Here the repository of Paddle version: paddle-SAM.

We were surprised to find that the SAM reproduced with Paddle achieved almost the same effect as your source code! 😏 πŸ‘ πŸ‘ πŸ‘

from sam.

yuval-alaluf avatar yuval-alaluf commented on June 15, 2024

Really cool to see that you were able to reproduce the results using a different framework. Awesome work!

from sam.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.