Giter VIP home page Giter VIP logo

pytorch_adain's Introduction

Pytorch_Adain_from_scratch

Unofficial Pytorch implementation of Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017]

Original torch implementation from the author can be found here.

Other implementations such as Pytorch_implementation_using_pretrained_torch_model or Chainer_implementation are also available.I have learned a lot from them and try the pure Pytorch implementation from scratch in this repository.This repository provides a pre-trained model for you to generate your own image given content image and style image. Also, you can download the training dataset or prepare your own dataset to train the model from scratch.

I give a brief qiita blog and you can check it from here.

If you have any question, please feel free to contact me. (Language in English/Japanese/Chinese will be ok!)

Notice

I propose a structure-emphasized multimodal style transfer(SEMST), feel free to use it here.


Requirements

  • Python 3.7
  • PyTorch 1.0+
  • TorchVision
  • Pillow
  • Skimage
  • tqdm

Anaconda environment recommended here!

(optional)

  • GPU environment for training

Usage


test

  1. Clone this repository

    git clone https://github.com/irasin/Pytorch_Adain_from_scratch
    cd Pytorch_Adain_from_scratch
  2. Prepare your content image and style image. I provide some in the content and style and you can try to use them easily.

  3. Download the pretrained model here

  4. Generate the output image. A transferred output image and a content_output_pair image and a NST_demo_like image will be generated.

    python test -c content_image_path -s style_image_path
    usage: test.py [-h] 
                   [--content CONTENT] 
                   [--style STYLE]
                   [--output_name OUTPUT_NAME] 
                   [--alpha ALPHA] 
                   [--gpu GPU]
                   [--model_state_path MODEL_STATE_PATH]
    
    
    

    If output_name is not given, it will use the combination of content image name and style image name.


train

  1. Download COCO (as content dataset)and Wikiart (as style dataset) and unzip them, rename them as content and style respectively (recommended).

  2. Modify the argument in the train.py such as the path of directory, epoch, learning_rate or you can add your own training code.

  3. Train the model using gpu.

  4. python train.py
    usage: train.py [-h] 
                    [--batch_size BATCH_SIZE] 
                    [--epoch EPOCH]
                    [--gpu GPU]
                    [--learning_rate LEARNING_RATE]
                    [--snapshot_interval SNAPSHOT_INTERVAL]
                    [--train_content_dir TRAIN_CONTENT_DIR]
                    [--train_style_dir TRAIN_STYLE_DIR]
                    [--test_content_dir TEST_CONTENT_DIR]
                    [--test_style_dir TEST_STYLE_DIR] 
                    [--save_dir SAVE_DIR]
                    [--reuse REUSE]
    

Result

Some results will be shown here.

image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image image

References

pytorch_adain's People

Contributors

irasin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch_adain's Issues

test

hello,thanks for your codes, can you shre your trained model to test my own content images? thank you

Do you know the effect of scales of inputs?

Hi!

I am really interested in this concept, and I think you are really somebody, I mean expert🕶️.

As you can see some small textures in the dog, and I think bigger textures are greater(look at the style image, cat)
WechatIMG25_oilpainting_style_transfer_demo

Do you have tried about the scales that affect the final generated images?
image
Like this ⬆️

Thank you so much! I am trying this as well but it may take longer.

There is an error when slicing the model for Encoder

There is an error when the model is being sliced for the encoder.

/tmp/ipykernel_10779/972606360.py in fit(epochs, optimizer, model_state_dir, iters, learning_rate_decay)
7 for (i, (content,style)) in progress_bar(enumerate(zip(content_dl,style_dl), 1)):
8 #style = next(style_iter)
----> 9 loss = model(content, style)
10 loss_list.append(loss.item())
11 #adjust_learning_rate(optim, iters, learning_rate_decay)

~/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_10779/887166169.py in forward(self, content_images, style_images, alpha, lam)
27
28 def forward(self, content_images, style_images, alpha=1.0, lam=10):
---> 29 content_features = self.vgg_encoder(content_images, output_last_feature=True)
30 style_features = self.vgg_encoder(style_images, output_last_feature=True)
31 t = adain(content_features, style_features)

~/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_10779/2727537595.py in forward(self, images, output_last_feature)
11
12 def forward(self, images, output_last_feature=False):
---> 13 h1 = self.slice1(images)
14 h2 = self.slice2(h1)
15 h3 = self.slice3(h2)

~/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/container.py in forward(self, input)
202 def forward(self, input):
203 for module in self:
--> 204 input = module(input)
205 return input
206

~/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/conv.py in forward(self, input)
461
462 def forward(self, input: Tensor) -> Tensor:
--> 463 return self._conv_forward(input, self.weight, self.bias)
464
465 class Conv3d(_ConvNd):

~/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
457 weight, bias, self.stride,
458 _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
460 self.padding, self.dilation, self.groups)
461

TypeError: conv2d() received an invalid combination of arguments - got (list, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:

  • (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
    didn't match because some of the arguments have invalid types: (!list!, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)
  • (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
    didn't match because some of the arguments have invalid types: (!list!, !Parameter!, !Parameter!, !tuple!, !tuple!, !tuple!, int)

Runtimeerror

F:\arbitrary\Pytorch_Adain_from_scratch-master\Pytorch_Adain_from_scratch-master>python test.py -c f:\arbitrary\Pytorch_Adain_from_scratch-master\Pytorch_Adain_from_scratch-master\content\tree.jpg -s f:\arbitrary\Pytorch_Adain_from_scratch-master\Pytorch_Adain_from_scratch-master\style\news1.jpg
Traceback (most recent call last):
File "test.py", line 84, in
main()
File "test.py", line 64, in main
res = torch.cat([c_denorm, out_denorm], dim=0)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 459 and 456 in dimension 2 at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensor.cpp:689

content图片
tree
style图片
1

Model isn't learning

I trained model for 20 epochs. It seem doesn't learn anything. Can you help me to fix it?
1_epoch_2000_iteration

About unpaired image-to-image translation using AdaIN

Hi @irasin , I am very interested in this image-to-image translation translation without Cycle Consistency Loss. I am confused about wether this method can be used in unpaired image-to-image? I think maybe can achieve it by changing datasets.py.

About the denorm when save the images

Hello , irasin. Thank you very much for your excellent project with pytorch . I am very interested in this project. However , I am puzzled about the denorm operation in training ,
with torch.no_grad():
out = model.generate(content, style)
content = denorm(content, device)
style = denorm(style, device)
out = denorm(out, device)

res = torch.cat([content, style, out], dim=0)
res = res.to('cpu')
save_image(res, f'{image_dir}/{e}epoch{i}_iteration.png', nrow=args.batch_size)
I found that the out is not satisfactory without denorm operation. But I can't understand why use the denorm here ? when we should add this operation? Thank you for your project and looking for your help . Good weekend for you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.