Giter VIP home page Giter VIP logo

spade_colorization's People

Contributors

deepkyu avatar donghwanjang avatar mingyuliutw avatar taesungp avatar tcwang0509 avatar thisisisaac avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar

spade_colorization's Issues

malformed confidence map

1. 바둑판 conf map

epoch077_iter077_conf_map
epoch035_iter035_conf_map

Confmap 이 바둑판 처럼 나오는거


2. warped image 도 바둑판.

epoch009_iter009_warped_img

correlation matrix 연산 혹은 적용하는 부분에서 문제가 생겼을 수 있음

Attention map overlap 방식 재고

현재 grad_cam에서 사용되는 방식처럼 image위에 attention을 overlay할 때,

denormalize(image) # 0~1 로 denormalize
x=image+attention
x=x/max(x)

normalize(x)

과정으로 수행해주는데 attention_map output 값이 깨지는 것처럼 나옴
scaling의 문제인지, LAB 이미지를 RGB로 저장해서 생기는 문제인지 확인 필요

Decoder에게 refine을 맡기는것이 불가능하다

우리는 여태까지 Decoder가 refine을 알아서 잘 해줄것이라 생각했는데, Decoder는 refine을 하는게 불가능하다!
왜냐?! Decoder는 L을 모른다. warped 된 gamma/beta 값이 맞는지 아닌지 판단할 L이 없다! 못하는게 당연하다.
(한글 너무 렉걸려서 영어로 쓸게요) 알고보니 grammarly를 켜놔서 그런거였네요... 이미 쓴건 그냥 두겠습니다ㅋㅋㅜ

Even in the previous examples(Exemplar-based colorizations), all of them used target L images as an input for the refinement network!

image
image

VGG Input 및 Condition 수정 필요

2가지 부분에 대해 이야기해보려고 합니다.

  1. reference & target input에 대한 condition 혼용
    if is_ref:
    x = x[:, 0, :, :].unsqueeze(1).expand(-1, 3, -1, -1)
    vgg_feature = self.vgg_ref(x, corr_feature=True)
    else:
    if self.opt.ref_type == 'l' and x.size()[1] == 1:
    x = x.expand(-1, 3, -1, -1)
    vgg_feature = self.vgg_tar(x, corr_feature=True)

현재 master에서는 VGGFeatureExtractor 내에 condition이 위와 같이 잡혀 있습니다. is_ref=True (reference image)일 때, L map을 떼서 channel-wise tiling을 하여 vgg에 넣고, is_ref=False(target image)일 때, 그대로 vgg에 넣습니다. 각 condition에 사용되는 vgg는 알맞으나, 직전 condition이 수정되어야 할 것 같습니다.

if is_ref: 
    #if self.opt.ref_type == 'l' and x.size()[1] == 1: 
    #    x = x.expand(-1, 3, -1, -1) 
    vgg_feature = self.vgg_ref(x, corr_feature=True) 
else: 
    x = x[:, 0, :, :].unsqueeze(1).repeat(-1, 3, -1, -1)
    _x의 1, 2 channel을 0으로 채우기_
    vgg_feature = self.vgg_tar(x, corr_feature=True) 
  1. fake_image, attention, conf_map = self.netG(target_L, reference_LAB, z=z)
    assert (not compute_kld_loss) or self.opt.use_vae, \
    "You cannot compute KLD loss if opt.use_vae == False"

    netG에서 input되는 target_L과 reference_LAB가 이후에 vgg에 들어가기에 앞서, RGB format으로 변환되는 부분이 없습니다.
    def forward(self, data, mode):
    reference_LAB = data["reference_LAB"]
    target_LAB = data["target_LAB"]
    target_RGB = data["target_image"]
    target_L, target_AB = self.parse_LAB(target_LAB)
    if mode == 'generator':
    g_loss, generated, attention, conf_map, fid = self.compute_generator_loss(
    target_L, target_LAB, target_RGB, reference_LAB,
    is_reconstructing=data["is_reconstructing"], get_fid=data["get_fid"])

    현재 target_L과 reference_LAB는 LAB format으로 들어가는 것 같은데 확인 부탁드립니다.

VGG19 feature extractor 바꿔보기

Input 바꿔보기:

  • 현재: LAB pretrained VGG19 에 reference, target 이 LLL 들어감
  • LAB pretrained VGG19 에 reference 는 LAB, target 은 L00

Gradient update 하게 해주기:

  • 현재: frozen
  • unfreeze 해서 gradient 가 흐르게 하기. 이때 checkpoint 에 save / load 할때도 같이 읽어오게 하기

correspondence subnet 을 독립적으로 트레이닝

현재 판면된 바로는 SPADE + decoder 는 그나마 잘 작동을 하지만 atention 을 생성하는 non-local block 은 개똥망이다. 그리고 전혀 학습이 안되는 것 같다. 그 문제를 해결하기 위해 warped 에 reconstruction loss 를 직접적으로 주는걸 해보자 한다.

방법:

  1. target 과 ref 를 같은 사진을 사용한다. 이유는 우리가 groundtruth attention 을 알기 때문이다. 똑같은 이미지라면 모든 attention 이 같은 포지션의 픽셀을 복붙해야한다. 이 이점을 활용.
  2. target 과 ref 에 다양한 augmentation 을 진행한다. target 에 있는 모든 픽셀이 ref 의 subset 이면 문제 없을듯.

Augmentation 방법:

  1. flip: vertical & horizontal flip.
  2. crop
  3. scaling: target 의 일부분에만 줌 인.
  4. L 에 brightness 를 높게 & 낮게
  5. target L 에 gaussian noise 추가
  6. rotation L
  7. translation

Visualization 문제

  1. tensor2im 부분에서 denormalize 하는 코드가 그냥 1더해주고 2로 나눠주는 식으로 돼있는데, 이걸 우리 쓰는 mean/std로 바꾸고, 여기에 L 이미지도 들어오고 해서 분기처리를 디버깅 해봐야할듯

  2. attention map batch 단위로 볼 수 있도록 고도화

  3. ref_LAB / ref_image(RGB) save 방식 분기화

LAB tensor range problem

LAB tensor가 0255 -> 01 로 매핑되는데, 이 때
AB를 0~1로 scale한 값을 실제로 찍어주면 아래 그림에서 새그림 부분처럼 된다.
image

즉 0.5 지점에서 discontinuity가 발생한다.
이러한 점은 loss 계산시에 큰 문제가 될 수 있다.

  1. 0과 1의 색 값은 실제로는 거의 동등하나, 수치적으로는 큰 차이가 나기 때문에, gradient 업데이트시 불안정 할 수 있다.
  2. 마찬가지로 0.5 경계에서의 값은 거의 동일하지만 인간의 눈으로 봤을 때는 정말 큰 차이가 나기 때문에, 네트워크는 로스가 충분히 작기 때문에 괜찮다고 내버려 둬도, 실제 사람이 봤을 때는 굉장히 이상한 결과물이 나올 수 있다.

이에 따라 새로운 mapping function이 필요할 것으로 보인다.

Loss lambda 찾기

video colorization 논문에선:
kind:weight (current value) -> weighted value

  • feat: 없음 (10~15) -> ?
  • perceptual: 0.001 (0.5) -> 0.0005
  • contextual: 0.2 (3~2) -> 0.5
  • smoothness: 5 (?) -> ?
  • adversarial: 0.2 (0.5~0.4) -> 0.1
  • L1 (reconstruction): 2 (0.004 ~ 0.006) -> 0.016

Batch 1 이상이면 터짐

batchSize 8 로 했을때 메세지

  File "/home/minds/isaac/SPADE_Colorization/trainers/pix2pix_trainer.py", line 44, in run_generator_one_step
    g_losses, generated, attention, conf_map = self.pix2pix_model(data, mode='generator')
  File "/home/minds/.virtualenvs/spade/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/minds/.virtualenvs/spade/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/minds/.virtualenvs/spade/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/minds/isaac/SPADE_Colorization/models/pix2pix_model.py", line 81, in forward
    target_L, target_LAB, reference_LAB, is_reconstructing=data["is_reconstructing"])
  File "/home/minds/isaac/SPADE_Colorization/models/pix2pix_model.p
wandb: Waiting for W&B process to finish, PID 31076
y", line 176, in compute_generator_loss
    fake_rgb_np = lab_deloader(fake_LAB.detach().cpu().float().numpy().squeeze(0).transpose(1, 2, 0),
ValueError: cannot select an axis to squeeze out which has size not equal to one

@DongHwanJang 형 이거 혹시 봐줄 수 있으려나?

난 지금 디버깅용 정보 뽑는거 해보는중이야

D에 들어가는 input을 lab에서 rgb로만 바꿨는데도 VGG loss가 이상하다....??

bash python train.py --use_reconstruction_loss --gpu_ids 0 --use_contextual_loss --save_epoch_freq 100 --niter 100 --niter_decay 100 --pair_file /DATA1/hksong/imagenet/pairs/single_class_bass.txt --name d_rgb_lambda_one --batchSize 10 --use_wandb --tf_log --lambda_feat 1 --lambda_vgg 1 --lambda_smooth 1 --lambda_recon 1 --lambda_context 1 --lambda_kld 1
위와 같이 lambda를 전부 1로 맞춰서 기존과 동일하게 하더라도 VGG loss의 범위가 바뀐다.... 왜지? 뭔가 잘못 흘러들어가는게 있는 듯 하다. 확인 필요

Discriminator 고도화

image
둘다 0을 찍을 수 있는 loss인데 (찍어야 하고)
D_real 조차도 수렴을 너무 못한다. 왜..?

Mid-sized dataset 구축

singe image 와 entire dataset 사이에 중간 사이즈에 dataset 을 만드는게 좋을 것 같음. entire dataset 은 제대로 트레인하는데 너무 오래걸림.

confmap gradient 유무 확인

  1. conf_map을 refinement network에 넣을 때도 gradient를 흘리게 했는지. gradient가 흐른다면 attention_map에 괜히 악영향을 주는 것 아닌가...? refinement network로부터 받는 gradient는 warping에 관한 것이 아니라, refinement network에서 결과가 더 잘나오게 해주는 gradient 일테니까

Warped image useless

Screen Shot 2020-01-31 at 18 13 43 PM

1 class (약 1200 장)에 대해 42000번 iteration 후. reconstruction 인 케이스이고 색은 잘 나오는데 warped 된 이미지가 여전히 안나오는걸 볼 수 있다. 우리가 색으로 warping 시켜서 디버깅할때 보는게 맞는지도 모르겠다. 가장 크리티컬하게 gamma 값도 0.1 이하다. warping 을 거의 안쓰고 이미지를 잘 뽑아낸다. @DongHwanJang 어떻게 생각해?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.