donghwanjang / spade_colorization Goto Github PK
View Code? Open in Web Editor NEWLicense: Other
License: Other
각 이미지 및 attention_map
4748768 에 적용된 lambda 를 backward 직전에 계산해주는 변화 후 RGB D 에서 학습이 엄청 안됬음. 아마 gradient 곱셈 자체가 잘못 됬을거라고 @DongHwanJang 이 추정함.
RGB grayscale 에 만 트레인된 VGG 하나를 사용.
현재 grad_cam에서 사용되는 방식처럼 image위에 attention을 overlay할 때,
denormalize(image) # 0~1 로 denormalize
x=image+attention
x=x/max(x)normalize(x)
과정으로 수행해주는데 attention_map output 값이 깨지는 것처럼 나옴
scaling의 문제인지, LAB 이미지를 RGB로 저장해서 생기는 문제인지 확인 필요
그 직후 self.conv_concate(x) 한 모습
SPADE_Colorization/models/networks/architecture.py
Lines 317 to 318 in 8ad7a5d
학습할 수록 극단적인 형태로 진행된다.
self.conv_concate은 kernel_size=1 인데 왜 격자무늬가 나타나는 거지...?? 어떻게 나타나는 거지?
우리는 여태까지 Decoder가 refine을 알아서 잘 해줄것이라 생각했는데, Decoder는 refine을 하는게 불가능하다!
왜냐?! Decoder는 L을 모른다. warped 된 gamma/beta 값이 맞는지 아닌지 판단할 L이 없다! 못하는게 당연하다.
(한글 너무 렉걸려서 영어로 쓸게요) 알고보니 grammarly를 켜놔서 그런거였네요... 이미 쓴건 그냥 두겠습니다ㅋㅋㅜ
Even in the previous examples(Exemplar-based colorizations), all of them used target L images as an input for the refinement network!
2가지 부분에 대해 이야기해보려고 합니다.
reference & target input
에 대한 condition 혼용SPADE_Colorization/models/networks/architecture.py
Lines 335 to 341 in fa09727
현재 master에서는 VGGFeatureExtractor 내에 condition이 위와 같이 잡혀 있습니다. is_ref=True
(reference image)일 때, L map을 떼서 channel-wise tiling을 하여 vgg에 넣고, is_ref=False
(target image)일 때, 그대로 vgg에 넣습니다. 각 condition에 사용되는 vgg는 알맞으나, 직전 condition이 수정되어야 할 것 같습니다.
if is_ref:
#if self.opt.ref_type == 'l' and x.size()[1] == 1:
# x = x.expand(-1, 3, -1, -1)
vgg_feature = self.vgg_ref(x, corr_feature=True)
else:
x = x[:, 0, :, :].unsqueeze(1).repeat(-1, 3, -1, -1)
_x의 1, 2 channel을 0으로 채우기_
vgg_feature = self.vgg_tar(x, corr_feature=True)
SPADE_Colorization/models/pix2pix_model.py
Lines 266 to 268 in fa09727
SPADE_Colorization/models/pix2pix_model.py
Lines 80 to 89 in fa09727
Input 바꿔보기:
Gradient update 하게 해주기:
현재 판면된 바로는 SPADE + decoder 는 그나마 잘 작동을 하지만 atention 을 생성하는 non-local block 은 개똥망이다. 그리고 전혀 학습이 안되는 것 같다. 그 문제를 해결하기 위해 warped 에 reconstruction loss 를 직접적으로 주는걸 해보자 한다.
방법:
Augmentation 방법:
tensor2im 부분에서 denormalize 하는 코드가 그냥 1더해주고 2로 나눠주는 식으로 돼있는데, 이걸 우리 쓰는 mean/std로 바꾸고, 여기에 L 이미지도 들어오고 해서 분기처리를 디버깅 해봐야할듯
attention map batch 단위로 볼 수 있도록 고도화
ref_LAB / ref_image(RGB) save 방식 분기화
디버깅 목적으로 weights, bias, gamma, beta를 tensorfoard에 Histogram 형식으로 띄우게 하기.
requires_grad 로 인해서 weight가 과연 다른게 업데이트 되는가?
LAB tensor가 0255 -> 01 로 매핑되는데, 이 때
AB를 0~1로 scale한 값을 실제로 찍어주면 아래 그림에서 새그림 부분처럼 된다.
즉 0.5 지점에서 discontinuity가 발생한다.
이러한 점은 loss 계산시에 큰 문제가 될 수 있다.
이에 따라 새로운 mapping function이 필요할 것으로 보인다.
현재 perceptual loss 에 RGB pretrained VGG 사용하지만 인풋을 줄때 RGB 로 바꾸지 않음. 여기서 @DongHwanJang 이 새로 작성한 LAB -> RGB 함수를 사용해서 synthesized LAB 를 RGB 로 바꾸고 perceptual loss 태워주기
해당 attention이 어느 포인트의 것을 보고 있는 것인지 판별하기 어렵다
video colorization 논문에선:
kind
:weight (current value
) -> weighted value
feat
: 없음 (10~15
) -> ?
perceptual
: 0.001 (0.5
) -> 0.0005
contextual
: 0.2 (3~2
) -> 0.5
smoothness
: 5 (?
) -> ?
adversarial
: 0.2 (0.5~0.4
) -> 0.1
L1 (reconstruction)
: 2 (0.004 ~ 0.006
) -> 0.016
batchSize 8 로 했을때 메세지
File "/home/minds/isaac/SPADE_Colorization/trainers/pix2pix_trainer.py", line 44, in run_generator_one_step
g_losses, generated, attention, conf_map = self.pix2pix_model(data, mode='generator')
File "/home/minds/.virtualenvs/spade/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "/home/minds/.virtualenvs/spade/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
return self.module(*inputs[0], **kwargs[0])
File "/home/minds/.virtualenvs/spade/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
result = self.forward(*input, **kwargs)
File "/home/minds/isaac/SPADE_Colorization/models/pix2pix_model.py", line 81, in forward
target_L, target_LAB, reference_LAB, is_reconstructing=data["is_reconstructing"])
File "/home/minds/isaac/SPADE_Colorization/models/pix2pix_model.p
wandb: Waiting for W&B process to finish, PID 31076
y", line 176, in compute_generator_loss
fake_rgb_np = lab_deloader(fake_LAB.detach().cpu().float().numpy().squeeze(0).transpose(1, 2, 0),
ValueError: cannot select an axis to squeeze out which has size not equal to one
@DongHwanJang 형 이거 혹시 봐줄 수 있으려나?
난 지금 디버깅용 정보 뽑는거 해보는중이야
bash python train.py --use_reconstruction_loss --gpu_ids 0 --use_contextual_loss --save_epoch_freq 100 --niter 100 --niter_decay 100 --pair_file /DATA1/hksong/imagenet/pairs/single_class_bass.txt --name d_rgb_lambda_one --batchSize 10 --use_wandb --tf_log --lambda_feat 1 --lambda_vgg 1 --lambda_smooth 1 --lambda_recon 1 --lambda_context 1 --lambda_kld 1
위와 같이 lambda를 전부 1로 맞춰서 기존과 동일하게 하더라도 VGG loss의 범위가 바뀐다.... 왜지? 뭔가 잘못 흘러들어가는게 있는 듯 하다. 확인 필요
singe image 와 entire dataset 사이에 중간 사이즈에 dataset 을 만드는게 좋을 것 같음. entire dataset 은 제대로 트레인하는데 너무 오래걸림.
Why does attention maps have checker board pattern..?
1 class (약 1200 장)에 대해 42000번 iteration 후. reconstruction 인 케이스이고 색은 잘 나오는데 warped 된 이미지가 여전히 안나오는걸 볼 수 있다. 우리가 색으로 warping 시켜서 디버깅할때 보는게 맞는지도 모르겠다. 가장 크리티컬하게 gamma 값도 0.1 이하다. warping 을 거의 안쓰고 이미지를 잘 뽑아낸다. @DongHwanJang 어떻게 생각해?
model save/load 가능하게 바꾸자
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.