Giter VIP home page Giter VIP logo

sketchknitter's Introduction

SketchKnitter

Python 3.6 Pytorch 1.6 MIT License open issues

In this repository, you can find the PyTorch implementation of SketchKnitter: Vectorized Sketch Generation with Diffusion Models, ICLR 2023, Spotlight.

Authors: Qiang Wang, Haoge Deng, Yonggang Qi, Da Li, Yi-Zhe Song. Beijing University of Posts and Telecommunications, Samsung AI Centre Cambridge, University of Surrey.

Abstract: We show vectorized sketch generation can be identified as a reversal of the stroke deformation process. This relationship was established by means of a diffusion model that learns data distributions over the stroke-point locations and pen states of real human sketches. Given randomly scattered stroke-points, sketch generation becomes a process of deformation-based denoising, where the generator rectifies positions of stroke points at each timestep to converge at a recognizable sketch. A key innovation was to embed recognizability into the reverse time diffusion process. It was observed that the estimated noise during the reversal process is strongly correlated with sketch classification accuracy. An auxiliary recurrent neural network (RNN) was consequently used to quantify recognizability during data sampling. It follows that, based on the recognizability scores, a sampling shortcut function can also be devised that renders better quality sketches with fewer sampling steps. Finally it is shown that the model can be easily extended to a conditional generation framework, where given incomplete and unfaithful sketches, it yields one that is more visually appealing and with higher recognizability.

Fig.1

Datasets

Please go to the QuickDraw official website to download the datasets. The class list used in the paper: moon, airplane, fish, umbrella, train, spider, shoe, apple, lion, bus, you can also replace it with any other category. Each category class is stored in its own file, and contains training/validation/test set sizes of 70000/2500/2500 examples.

In addition to the QuickDraw dataset, you can train the model on any dataset, but please pay attention to organizing the dataset into vector format and packaging it into .npz file. In the case of less data sets, please pay attention to over-fitting. If you want to create your own dataset, you can follow the official tutorial of SketchRNN.

Installation

The requirements of this repo can be found in requirements.txt.

conda create -n sketchknitter python=3.7
conda activate sketchknitter
pip install -r requirements.txt

Train and Inference

Haperparameters

Here is a list of full options for the model:

lr,                   # learning rate.
log_dir,              # save log path.
dropout,              # dropout rate.
use_fp16,             # whether to use mixed precision training.
ema_rate,             # comma-separated list of EMA values
category,             # list of category name to be trained.
data_dir,             # the data sets path.
use_ddim,             # choose whether to use DDIM or DDPM
save_path,            # path to save vector results.
pen_break,            # determines the experience value of stroke break.
image_size,           # the max numbers of datasets.
model_path,           # path to save the trained model checkpoint.
class_cond,           # whether to use guidance technology.
batch_size,           # batch size of training.
emb_channels,         # Unet embedding channel numbers.
num_channels,         # the numbers of channels in Unet backbone.
out_channels,         # output channels in Unet. 
save_interval,        # saving models interval.
noise_schedule,       # the method of adding noise is linear by default.
num_res_blocks,       # numbers of resnet blocks in Unet backbone.
diffusion_steps,      # diffusion steps in the forward process.
schedule_sampler,     # the schedule of sampler.
fp16_scale_growth,    # the mixed precision scale growth.
use_scale_shift_norm, # whether to use scale shift norm. 

Train Example Usage:

bash train.sh

Inference Example Usage:

bash sample.sh

Visualization and Evaluation

sample.py obtained results are stored in .npz file format, which can be directly used to calculate quantitative indicators such as FID, IS, GS, etc. You can also visualize the results for qualitative experiments, refer to the results of Google Brain for the visualization method.

To calculate of FID and IS, refer to the official code. The calculation of Geometry Score(GS) can directly use data in vector format, please go to the official website for instructions.

Results

Simple FID↓ GS↓ Prec↑ Rec↑
SketchPix2seq 13.3 7.0 0.40 0.79
SketchHealer 10.3 5.9 0.45 0.81
SketchRNN 10.8 5.4 0.44 0.82
Diff-HW 13.3 6.8 0.42 0.81
SketchODE 11.5 9.4 0.48 0.74
Ours (full 1000 steps) 6.9 3.4 0.52 0.88
Ours (r-Shortcut, S=30) 7.4 3.9 0.47 0.87
Ours (Linear-DDIMs, S=30) 11.9 6.4 0.38 0.81
Ours (Quadratic-DDIMs, S=30) 12.3 6.6 0.41 0.79
Ours (Abs) 20.7 12.1 0.18 0.55
Ours (Point-Shuffle) 9.5 5.3 0.35 0.72
Ours (Stroke-Shuffle) 8.2 3.8 0.36 0.74

Fig 4

Only part of the results are listed here. For more detailed results, please see our paper and supplementary materials.

License

This project is released under the MIT License.

Citation

If you find this repository useful for your research, please use the following.

@inproceedings{wangsketchknitter,
  title={SketchKnitter: Vectorized Sketch Generation with Diffusion Models},
  author={Wang, Qiang and Deng, Haoge and Qi, Yonggang and Li, Da and Song, Yi-Zhe},
  booktitle={The Eleventh International Conference on Learning Representations}
}

Acknowledgements

Contact

If you have any questions about the code, please contact [email protected]

sketchknitter's People

Contributors

lmagoncalo avatar wangqiang9 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

sketchknitter's Issues

I cant find .losses code

When I checked code, I couldn't find .losses.
Please let me know detail.
gaussian_diffusion.py line 8 :
from .losses import normal_kl, discretized_gaussian_log_likelihood ()

Some questions about samples

Thank you very much for the code, I have some questions I would like to ask.
The sample.py does not save any samples anywhere.

sample_all = bin_pen(sample_all, args.pen_break)
print(f"sample all {sample_all} is saved!")

This code seems to be printing the whole vector ?

Reproducing paper results

I'm trying to reproduce the results presented in the paper, however using the configurations provided in the scripts (train.sh and sample.sh) I cannot archive the results. Here is an example of the results I obtained training using a single dataset:

generated_2

They appear to be composed of dashes instead of a continuous lines.
Is it possible to prove the configurations used to train and sample to obtain the results from the paper?

Question about draw_sketch.py

Hi there!
Firstly, thank you for your advice about sample_all in sample.py! However, I noticed that in your dataset such as moon.npz, the size of data.train (as shown in the 80th line in draw_sketch.py) is 70000×3. While the size I get in sample_all.npz (from the 73rd line in sample.py) is 16×96×4.
So could you please explain how you save sample_all into the true npz file? I've tried
np.savez(args.save_path + "sample0.npz", train=sample_all) but it didn't work when draw_sketch.
Thank you for your help agian!

Seeking Guidance for Reproducing SketchKnitter Results on QuickDraw Airplane Class

Thank you for sharing your compelling and meticulously constructed work. I attempted to reproduce the results of SketchKnitter using only the airplane class from the QuickDraw dataset. However, the quality of the generated airplane sketches did not align with my expectations, and the quantitative results were not as favorable as I had anticipated. I employed the code and dataset from your GitHub repository, executing the following commands to train SketchKnitter:

python train.py --data_dir [/path/to/datasets] \
                --lr 1e-4 \
                --batch_size 512 \
                --use_fp16 False \
                --log_dir [/path/to/log] \
                --diffusion_steps 1000 \
                --noise_schedule linear \
                --image_size 96 \
                --num_channels 96 \
                --num_res_blocks 3

I also sampled the sketches and evaluated the quality using your code. The results are as follows:

  • Inception Score: 2.460366725921631
  • FID: 61.48920809233371
  • sFID: 49.6528623255079
  • Precision: 0.5416
  • Recall: 0.37

I suspect there may be some issues with my approach. However, I haven't made any modifications to your original code. Could you please provide any guidance or suggestions on potential adjustments I could make to improve the performance? I appreciate your time and help in this matter. Have a wonderful day, and I look forward to your response.

Question about sample.py

Thank you so much for your excellent work!
I can't find how to save "sample_all" in .npz format and draw sketch from them. In sample.py, the main function stops at these two statements:
sample_all = th.cat((sample, pen_state), 2).cpu()
sample_all = bin_pen(sample_all, args.pen_break)
So could you please explain how to save "sample_all" in .npz format and get a sketch result ?

Question of N

When I set N to a smaller number, I ran into the following problem

ValueError: could not broadcast input array from shape (59,3) into shape (24,3)’

I tried to use my modified code and the result was very bad
len_seq = min(self.Nmax, self.sketches_normed[idx].shape[0])
step = self.sketches_normed[idx].shape[0] // len_seq + 1
sampledata = self.sketches_normed[idx][::step]
sketch[:len(sampledata), :] = sampledata
I see that your paper mentions data about N=24, how did you achieve it?

custom dataset

I want to train on my own sketch RGB dataset, but I do not think the sketch-rnn repo has the guidance of how to do that.
Is there a way to transform RGB dataset into your dataset format?

Code for rectifying bad sketches.

I'm trying to use SketchKnitter to rectify bad sketches, but I can't find the corresponding code (including how to use the NLR module). Can you please upload the code for conditional generation?

cant not use ddim sample loop function

TypeError: ddim_sample_loop() missing 3 required positional arguments: 'data', 'raster', and 'loss'
I used below samping command.
python sample.py --model_path ./log/ema_0.9999_050000.pt --pen_break 0.1 --save_path ./ --use_ddim True --log_dir ./log/ --diffusion_steps 100 --noise_schedule linear --image_size 96 --num_channels 96 --num_res_blocks 3

Pretrained model.

Can I please get the model you're using in sample.py?
i'm getting this error
File "sample.py", line 45, in main
dist_util.load_state_dict(args.model_path, map_location="cpu")
File "/Users/filzahamjad/anaconda3/envs/sketchknitter/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1672, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for UNetModel:
Missing key(s) in state_dict: "time_embed.0.weight", "time_embed.0.bias", "time_embed.2.weight", "time_embed.2.bias", "input_blocks.0.0.weight", "input_blocks.0.0.bias", "input_blocks.1.0.in_layers.0.weight", "input_blocks.1.0.in_layers.0.bias", "input_blocks.1.0.in_layers.2.weight", "input_blocks.1.0.in_layers.2.bias", "input_blocks.1.0.emb_layers.1.weight", "input_blocks.1.0.emb_layers.1.bias", "input_blocks.1.0.out_layers.0.weight", "input_blocks.1.0.out_layers.0.bias", "input_blocks.1.0.out_layers.3.weight", "input_blocks.1.0.out_layers.3.bias", "input_blocks.2.0.in_layers.0.weight", "input_blocks.2.0.in_layers.0.bias", "input_blocks.2.0.in_layers.2.weight", "input_blocks.2.0.in_layers.2.bias", "input_blocks.2.0.emb_layers.1.weight", "input_blocks.2.0.emb_layers.1.bias", "input_blocks.2.0.out_layers.0.weight", "input_blocks.2.0.out_layers.0.bias", "input_blocks.2.0.out_layers.3.weight", "input_blocks.2.0.out_layers.3.bias", "input_blocks.3.0.in_layers.0.weight", "input_blocks.3.0.in_layers.0.bias", "input_blocks.3.0.in_layers.2.weight", "input_blocks.3.0.in_layers.2.bias", "input_blocks.3.0.emb_layers.1.weight", "input_blocks.3.0.emb_layers.1.bias", "input_blocks.3.0.out_layers.0.weight", "input_blocks.3.0.out_layers.0.bias", "input_blocks.3.0.out_layers.3.weight", "input_blocks.3.0.out_layers.3.bias", "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.5.0.in_layers.0.weight", "input_blocks.5.0.in_layers.0.bias", "input_blocks.5.0.in_layers.2.weight", "input_blocks.5.0.in_layers.2.bias", "input_blocks.5.0.emb_layers.1.weight", "input_blocks.5.0.emb_layers.1.bias", "input_blocks.5.0.out_layers.0.weight", "input_blocks.5.0.out_layers.0.bias", "input_blocks.5.0.out_layers.3.weight", "input_blocks.5.0.out_layers.3.bias", "input_blocks.5.0.skip_connection.weight", "input_blocks.5.0.skip_connection.bias", "input_blocks.6.0.in_layers.0.weight", "input_blocks.6.0.in_layers.0.bias", "input_blocks.6.0.in_layers.2.weight", "input_blocks.6.0.in_layers.2.bias", "input_blocks.6.0.emb_layers.1.weight", "input_blocks.6.0.emb_layers.1.bias", "input_blocks.6.0.out_layers.0.weight", "input_blocks.6.0.out_layers.0.bias", "input_blocks.6.0.out_layers.3.weight", "input_blocks.6.0.out_layers.3.bias", "input_blocks.7.0.in_layers.0.weight", "input_blocks.7.0.in_layers.0.bias", "input_blocks.7.0.in_layers.2.weight", "input_blocks.7.0.in_layers.2.bias", "input_blocks.7.0.emb_layers.1.weight", "input_blocks.7.0.emb_layers.1.bias", "input_blocks.7.0.out_layers.0.weight", "input_blocks.7.0.out_layers.0.bias", "input_blocks.7.0.out_layers.3.weight", "input_blocks.7.0.out_layers.3.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.9.0.in_layers.0.weight", "input_blocks.9.0.in_layers.0.bias", "input_blocks.9.0.in_layers.2.weight", "input_blocks.9.0.in_layers.2.bias", "input_blocks.9.0.emb_layers.1.weight", "input_blocks.9.0.emb_layers.1.bias", "input_blocks.9.0.out_layers.0.weight", "input_blocks.9.0.out_layers.0.bias", "input_blocks.9.0.out_layers.3.weight", "input_blocks.9.0.out_layers.3.bias", "input_blocks.9.0.skip_connection.weight", "input_blocks.9.0.skip_connection.bias", "input_blocks.10.0.in_layers.0.weight", "input_blocks.10.0.in_layers.0.bias", "input_blocks.10.0.in_layers.2.weight", "input_blocks.10.0.in_layers.2.bias", "input_blocks.10.0.emb_layers.1.weight", "input_blocks.10.0.emb_layers.1.bias", "input_blocks.10.0.out_layers.0.weight", "input_blocks.10.0.out_layers.0.bias", "input_blocks.10.0.out_layers.3.weight", "input_blocks.10.0.out_layers.3.bias", "input_blocks.11.0.in_layers.0.weight", "input_blocks.11.0.in_layers.0.bias", "input_blocks.11.0.in_layers.2.weight", "input_blocks.11.0.in_layers.2.bias", "input_blocks.11.0.emb_layers.1.weight", "input_blocks.11.0.emb_layers.1.bias", "input_blocks.11.0.out_layers.0.weight", "input_blocks.11.0.out_layers.0.bias", "input_blocks.11.0.out_layers.3.weight", "input_blocks.11.0.out_layers.3.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "input_blocks.13.0.in_layers.0.weight", "input_blocks.13.0.in_layers.0.bias", "input_blocks.13.0.in_layers.2.weight", "input_blocks.13.0.in_layers.2.bias", "input_blocks.13.0.emb_layers.1.weight", "input_blocks.13.0.emb_layers.1.bias", "input_blocks.13.0.out_layers.0.weight", "input_blocks.13.0.out_layers.0.bias", "input_blocks.13.0.out_layers.3.weight", "input_blocks.13.0.out_layers.3.bias", "input_blocks.13.0.skip_connection.weight", "input_blocks.13.0.skip_connection.bias", "input_blocks.14.0.in_layers.0.weight", "input_blocks.14.0.in_layers.0.bias", "input_blocks.14.0.in_layers.2.weight", "input_blocks.14.0.in_layers.2.bias", "input_blocks.14.0.emb_layers.1.weight", "input_blocks.14.0.emb_layers.1.bias", "input_blocks.14.0.out_layers.0.weight", "input_blocks.14.0.out_layers.0.bias", "input_blocks.14.0.out_layers.3.weight", "input_blocks.14.0.out_layers.3.bias", "input_blocks.15.0.in_layers.0.weight", "input_blocks.15.0.in_layers.0.bias", "input_blocks.15.0.in_layers.2.weight", "input_blocks.15.0.in_layers.2.bias", "input_blocks.15.0.emb_layers.1.weight", "input_blocks.15.0.emb_layers.1.bias", "input_blocks.15.0.out_layers.0.weight", "input_blocks.15.0.out_layers.0.bias", "input_blocks.15.0.out_layers.3.weight", "input_blocks.15.0.out_layers.3.bias", "middle_block.0.in_layers.0.weight", "middle_block.0.in_layers.0.bias", "middle_block.0.in_layers.2.weight", "middle_block.0.in_layers.2.bias", "middle_block.0.emb_layers.1.weight", "middle_block.0.emb_layers.1.bias", "middle_block.0.out_layers.0.weight", "middle_block.0.out_layers.0.bias", "middle_block.0.out_layers.3.weight", "middle_block.0.out_layers.3.bias", "middle_block.1.norm.weight", "middle_block.1.norm.bias", "middle_block.1.qkv.weight", "middle_block.1.qkv.bias", "middle_block.1.proj_out.weight", "middle_block.1.proj_out.bias", "middle_block.2.in_layers.0.weight", "middle_block.2.in_layers.0.bias", "middle_block.2.in_layers.2.weight", "middle_block.2.in_layers.2.bias", "middle_block.2.emb_layers.1.weight", "middle_block.2.emb_layers.1.bias", "middle_block.2.out_layers.0.weight", "middle_block.2.out_layers.0.bias", "middle_block.2.out_layers.3.weight", "middle_block.2.out_layers.3.bias", "output_blocks.0.0.in_layers.0.weight", "output_blocks.0.0.in_layers.0.bias", "output_blocks.0.0.in_layers.2.weight", "output_blocks.0.0.in_layers.2.bias", "output_blocks.0.0.emb_layers.1.weight", "output_blocks.0.0.emb_layers.1.bias", "output_blocks.0.0.out_layers.0.weight", "output_blocks.0.0.out_layers.0.bias", "output_blocks.0.0.out_layers.3.weight", "output_blocks.0.0.out_layers.3.bias", "output_blocks.0.0.skip_connection.weight", "output_blocks.0.0.skip_connection.bias", "output_blocks.1.0.in_layers.0.weight", "output_blocks.1.0.in_layers.0.bias", "output_blocks.1.0.in_layers.2.weight", "output_blocks.1.0.in_layers.2.bias", "output_blocks.1.0.emb_layers.1.weight", "output_blocks.1.0.emb_layers.1.bias", "output_blocks.1.0.out_layers.0.weight", "output_blocks.1.0.out_layers.0.bias", "output_blocks.1.0.out_layers.3.weight", "output_blocks.1.0.out_layers.3.bias", "output_blocks.1.0.skip_connection.weight", "output_blocks.1.0.skip_connection.bias", "output_blocks.2.0.in_layers.0.weight", "output_blocks.2.0.in_layers.0.bias", "output_blocks.2.0.in_layers.2.weight", "output_blocks.2.0.in_layers.2.bias", "output_blocks.2.0.emb_layers.1.weight", "output_blocks.2.0.emb_layers.1.bias", "output_blocks.2.0.out_layers.0.weight", "output_blocks.2.0.out_layers.0.bias", "output_blocks.2.0.out_layers.3.weight", "output_blocks.2.0.out_layers.3.bias", "output_blocks.2.0.skip_connection.weight", "output_blocks.2.0.skip_connection.bias", "output_blocks.3.0.in_layers.0.weight", "output_blocks.3.0.in_layers.0.bias", "output_blocks.3.0.in_layers.2.weight", "output_blocks.3.0.in_layers.2.bias", "output_blocks.3.0.emb_layers.1.weight", "output_blocks.3.0.emb_layers.1.bias", "output_blocks.3.0.out_layers.0.weight", "output_blocks.3.0.out_layers.0.bias", "output_blocks.3.0.out_layers.3.weight", "output_blocks.3.0.out_layers.3.bias", "output_blocks.3.0.skip_connection.weight", "output_blocks.3.0.skip_connection.bias", "output_blocks.3.1.conv.weight", "output_blocks.3.1.conv.bias", "output_blocks.4.0.in_layers.0.weight", "output_blocks.4.0.in_layers.0.bias", "output_blocks.4.0.in_layers.2.weight", "output_blocks.4.0.in_layers.2.bias", "output_blocks.4.0.emb_layers.1.weight", "output_blocks.4.0.emb_layers.1.bias", "output_blocks.4.0.out_layers.0.weight", "output_blocks.4.0.out_layers.0.bias", "output_blocks.4.0.out_layers.3.weight", "output_blocks.4.0.out_layers.3.bias", "output_blocks.4.0.skip_connection.weight", "output_blocks.4.0.skip_connection.bias", "output_blocks.5.0.in_layers.0.weight", "output_blocks.5.0.in_layers.0.bias", "output_blocks.5.0.in_layers.2.weight", "output_blocks.5.0.in_layers.2.bias", "output_blocks.5.0.emb_layers.1.weight", "output_blocks.5.0.emb_layers.1.bias", "output_blocks.5.0.out_layers.0.weight", "output_blocks.5.0.out_layers.0.bias", "output_blocks.5.0.out_layers.3.weight", "output_blocks.5.0.out_layers.3.bias", "output_blocks.5.0.skip_connection.weight", "output_blocks.5.0.skip_connection.bias", "output_blocks.6.0.in_layers.0.weight", "output_blocks.6.0.in_layers.0.bias", "output_blocks.6.0.in_layers.2.weight", "output_blocks.6.0.in_layers.2.bias", "output_blocks.6.0.emb_layers.1.weight", "output_blocks.6.0.emb_layers.1.bias", "output_blocks.6.0.out_layers.0.weight", "output_blocks.6.0.out_layers.0.bias", "output_blocks.6.0.out_layers.3.weight", "output_blocks.6.0.out_layers.3.bias", "output_blocks.6.0.skip_connection.weight", "output_blocks.6.0.skip_connection.bias", "output_blocks.7.0.in_layers.0.weight", "output_blocks.7.0.in_layers.0.bias", "output_blocks.7.0.in_layers.2.weight", "output_blocks.7.0.in_layers.2.bias", "output_blocks.7.0.emb_layers.1.weight", "output_blocks.7.0.emb_layers.1.bias", "output_blocks.7.0.out_layers.0.weight", "output_blocks.7.0.out_layers.0.bias", "output_blocks.7.0.out_layers.3.weight", "output_blocks.7.0.out_layers.3.bias", "output_blocks.7.0.skip_connection.weight", "output_blocks.7.0.skip_connection.bias", "output_blocks.7.1.conv.weight", "output_blocks.7.1.conv.bias", "output_blocks.8.0.in_layers.0.weight", "output_blocks.8.0.in_layers.0.bias", "output_blocks.8.0.in_layers.2.weight", "output_blocks.8.0.in_layers.2.bias", "output_blocks.8.0.emb_layers.1.weight", "output_blocks.8.0.emb_layers.1.bias", "output_blocks.8.0.out_layers.0.weight", "output_blocks.8.0.out_layers.0.bias", "output_blocks.8.0.out_layers.3.weight", "output_blocks.8.0.out_layers.3.bias", "output_blocks.8.0.skip_connection.weight", "output_blocks.8.0.skip_connection.bias", "output_blocks.9.0.in_layers.0.weight", "output_blocks.9.0.in_layers.0.bias", "output_blocks.9.0.in_layers.2.weight", "output_blocks.9.0.in_layers.2.bias", "output_blocks.9.0.emb_layers.1.weight", "output_blocks.9.0.emb_layers.1.bias", "output_blocks.9.0.out_layers.0.weight", "output_blocks.9.0.out_layers.0.bias", "output_blocks.9.0.out_layers.3.weight", "output_blocks.9.0.out_layers.3.bias", "output_blocks.9.0.skip_connection.weight", "output_blocks.9.0.skip_connection.bias", "output_blocks.10.0.in_layers.0.weight", "output_blocks.10.0.in_layers.0.bias", "output_blocks.10.0.in_layers.2.weight", "output_blocks.10.0.in_layers.2.bias", "output_blocks.10.0.emb_layers.1.weight", "output_blocks.10.0.emb_layers.1.bias", "output_blocks.10.0.out_layers.0.weight", "output_blocks.10.0.out_layers.0.bias", "output_blocks.10.0.out_layers.3.weight", "output_blocks.10.0.out_layers.3.bias", "output_blocks.10.0.skip_connection.weight", "output_blocks.10.0.skip_connection.bias", "output_blocks.11.0.in_layers.0.weight", "output_blocks.11.0.in_layers.0.bias", "output_blocks.11.0.in_layers.2.weight", "output_blocks.11.0.in_layers.2.bias", "output_blocks.11.0.emb_layers.1.weight", "output_blocks.11.0.emb_layers.1.bias", "output_blocks.11.0.out_layers.0.weight", "output_blocks.11.0.out_layers.0.bias", "output_blocks.11.0.out_layers.3.weight", "output_blocks.11.0.out_layers.3.bias", "output_blocks.11.0.skip_connection.weight", "output_blocks.11.0.skip_connection.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias", "output_blocks.12.0.in_layers.0.weight", "output_blocks.12.0.in_layers.0.bias", "output_blocks.12.0.in_layers.2.weight", "output_blocks.12.0.in_layers.2.bias", "output_blocks.12.0.emb_layers.1.weight", "output_blocks.12.0.emb_layers.1.bias", "output_blocks.12.0.out_layers.0.weight", "output_blocks.12.0.out_layers.0.bias", "output_blocks.12.0.out_layers.3.weight", "output_blocks.12.0.out_layers.3.bias", "output_blocks.12.0.skip_connection.weight", "output_blocks.12.0.skip_connection.bias", "output_blocks.13.0.in_layers.0.weight", "output_blocks.13.0.in_layers.0.bias", "output_blocks.13.0.in_layers.2.weight", "output_blocks.13.0.in_layers.2.bias", "output_blocks.13.0.emb_layers.1.weight", "output_blocks.13.0.emb_layers.1.bias", "output_blocks.13.0.out_layers.0.weight", "output_blocks.13.0.out_layers.0.bias", "output_blocks.13.0.out_layers.3.weight", "output_blocks.13.0.out_layers.3.bias", "output_blocks.13.0.skip_connection.weight", "output_blocks.13.0.skip_connection.bias", "output_blocks.14.0.in_layers.0.weight", "output_blocks.14.0.in_layers.0.bias", "output_blocks.14.0.in_layers.2.weight", "output_blocks.14.0.in_layers.2.bias", "output_blocks.14.0.emb_layers.1.weight", "output_blocks.14.0.emb_layers.1.bias", "output_blocks.14.0.out_layers.0.weight", "output_blocks.14.0.out_layers.0.bias", "output_blocks.14.0.out_layers.3.weight", "output_blocks.14.0.out_layers.3.bias", "output_blocks.14.0.skip_connection.weight", "output_blocks.14.0.skip_connection.bias", "output_blocks.15.0.in_layers.0.weight", "output_blocks.15.0.in_layers.0.bias", "output_blocks.15.0.in_layers.2.weight", "output_blocks.15.0.in_layers.2.bias", "output_blocks.15.0.emb_layers.1.weight", "output_blocks.15.0.emb_layers.1.bias", "output_blocks.15.0.out_layers.0.weight", "output_blocks.15.0.out_layers.0.bias", "output_blocks.15.0.out_layers.3.weight", "output_blocks.15.0.out_layers.3.bias", "output_blocks.15.0.skip_connection.weight", "output_blocks.15.0.skip_connection.bias", "out.0.weight", "out.0.bias", "out.2.weight", "out.2.bias", "pen_state_out.0.weight", "pen_state_out.0.bias".
Unexpected key(s) in state_dict: "state", "param_groups".

The sampling code does not stop

The part for the sampling has:
while len(all_images) * args.batch_size < args.num_samples:
However, all_images is net updated which makes the code run infinitely.

Some questions about model train and sample

num_res_blocks=4,num_heads=8,my loss is around 0.08 and is no longer decreasing,and I'm wondering if I can stop training,but FID is 30,I used Apple in QuickDraw dataset,What can I do to lower my FID?

| grad_norm | 0.114 |
| loss | 0.0757 |
| loss_q0 | 0.227 |
| loss_q1 | 0.0749 |
| loss_q2 | 0.0142 |
| loss_q3 | 0.00181 |
| mse | 0.0747 |
| mse_q0 | 0.226 |
| mse_q1 | 0.074 |
| mse_q2 | 0.0132 |
| mse_q3 | 0.000843 |
| pen_state | 0.0965 |
| pen_state_q0 | 0.0964 |
| pen_state_q1 | 0.0966 |
| pen_state_q2 | 0.0966 |
| pen_state_q3 | 0.0965 |

evaluator

The evaluator needs some instructions on how to use it.

I'm testing the evaluator but the read_activations function does not accept the reference dataset.
python evaluations/evaluator.py datasets/airplane.npz datasets/airplane.npz

The above command throws the below error:
ValueError: missing arr_0 in npz file

Can you also advise on the right format for samples? Its not very clear from the code.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.