Giter VIP home page Giter VIP logo

diad's People

Contributors

lewandofskee avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

diad's Issues

How to change the epoch and other params?

Dears,
Thank you for your interesting work.
You say that you set the epoch of train on 1000(paper page 6),if I want to change the epoch or other params,How can I do that? Should I refer to the implemention of ControlNet?

Could you please give a comment on this?

Thank you!

out of memory

Using an NVIDIA GeForce RTX 3090 with batch size set to 2 and train. py unable to run the model, modify the code to train=pl. Trainer (gpus=[2,3,4], precision=32, callbacks=[logger, ckpt_callback_val_loss], accumulated_gradnbatches=4, check_val_everyn_epoch=25), or burst memory. May I ask where the problem is?

公式6的小疑惑

您好,论文中公式6中有c_i的出现,我想知道这篇论文中的c_i是什么呢?就是SG网络向去噪网络输出的那些吗?

A question about finetune autoencoder?

Hi,
When I finetune the ae on my own dataset, I found that the total number of steps would be divided by 8 because of the config setting here.

I want to know how many steps you used to finetune, because I found that it took 5w steps to add another loss LPIPSWithDiscriminator in this config.

为什么选用 RD4AD 作为 baseline

RD4AD 是单类AD算法,为什么您选用它作为多类的 baseline,而没有选择别的单类方法作为 baseline?
是不是别的单类方法也尝试过,效果不好所有没有写出来?

CNNs模块

image

hello,这个CNNs的模块我并没有在代码中找到,请问可否为我说明一下呢?非常感谢

ModuleNotFoundError: No module named 'taming.modules.losses'?

第一个问题,执行时,报了这个导入错误,发现源代码中taming文件下没有这几个文件,求解答。
第二个问题,在第三步finetune Autoencoders时,下载了预训练的模型kl-f8.zip,并移动到./models/autoencoders.ckpt,请问是将kl-f8.zip解压后并重新命名为autoencoders.ckpt移动到‘./models’下吗?

代码报错

2734968604

在运行python finetune_autoencoder.py,报错如上所示。
您能否提一些建议。

文章的创新点疑问

文章的创新点就是增加了语义指导网络Semantic-Guided Network,该网络的特征加入到unet的decoder三层特征中,这和增加算力提升准确率有什么区别呢,loss也是均方差loss,没有修改。没有什么创新点吧

how

How long does it take to fine tune VAE on VISA?I am now trying to fine-tune vae for multi-class setting. however, it takes two hours for a epochs on a6000. is it normal?

Checkpoints

Thanks for your nice work.
Could you please release the checkpoints to reproduce the results?

我在运行python build_model.py时出现了这个错误,似乎是因为openai/clip-vit-large-patch14无法下载导致的,可以帮助我解决吗

Traceback (most recent call last):
File "/homec/ssli/DiAD/build_model.py", line 27, in
model = create_model(config_path='/homec/ssli/DiAD/models/diad.yaml')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homec/ssli/DiAD/sgn/model.py", line 26, in create_model
model = instantiate_from_config(config.model).cpu()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homec/ssli/DiAD/ldm/util.py", line 79, in instantiate_from_config
return get_obj_from_str(config["target"])(**config.get("params", dict()))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homec/ssli/DiAD/sgn/sgn.py", line 369, in init
super().init(*args, **kwargs)
File "/homec/ssli/DiAD/ldm/models/diffusion/ddpm.py", line 603, in init
self.instantiate_cond_stage(cond_stage_config)
File "/homec/ssli/DiAD/ldm/models/diffusion/ddpm.py", line 670, in instantiate_cond_stage
model = instantiate_from_config(config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homec/ssli/DiAD/ldm/util.py", line 79, in instantiate_from_config
return get_obj_from_str(config["target"])(**config.get("params", dict()))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homec/ssli/DiAD/ldm/modules/encoders/modules.py", line 99, in init
self.tokenizer = CLIPTokenizer.from_pretrained(version)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/homec/ssli/lib/python3.11/site-packages/transformers/tokenization_utils_base.py", line 2073, in from_pretrained
raise EnvironmentError(
OSError: Can't load tokenizer for 'openai/clip-vit-large-patch14'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'openai/clip-vit-large-patch14' is the correct path to a directory containing all relevant files for a CLIPTokenizer tokenizer.

test.py报错

运行python test.py --resume_path ./val_ckpt/epoch=299-step=22799.ckpt,显示
Traceback (most recent call last):
File "test.py", line 49, in
dataset = MVTecDataset('test')
TypeError: init() missing 1 required positional argument: 'root'

The results of ddpm and ldm

I have a question, are your ddpm and ldm experimental results re-trained on these abnormal datasets? I don't see it in this paper,thank you!

Connection to huggingface.co timed out

Traceback (most recent call last):
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 169, in _new_conn
conn = connection.create_connection(
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/util/connection.py", line 96, in create_connection
raise err
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/util/connection.py", line 86, in create_connection
sock.connect(sa)
socket.timeout: timed out

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 699, in urlopen
httplib_response = self._make_request(
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 382, in _make_request
self._validate_conn(conn)
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 1010, in _validate_conn
conn.connect()
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 353, in connect
conn = self._new_conn()
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 174, in _new_conn
raise ConnectTimeoutError(
urllib3.exceptions.ConnectTimeoutError: (<urllib3.connection.HTTPSConnection object at 0x7fd5cfa11430>, 'Connection to huggingface.co timed out. (connect timeout=10)')

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/miniconda3/lib/python3.8/site-packages/requests/adapters.py", line 439, in send
resp = conn.urlopen(
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 755, in urlopen
retries = retries.increment(
File "/root/miniconda3/lib/python3.8/site-packages/urllib3/util/retry.py", line 574, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /timm/resnet50.a1_in1k/resolve/main/pytorch_model.bin (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fd5cfa11430>, 'Connection to huggingface.co timed out. (connect timeout=10)'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/root/miniconda3/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1238, in hf_hub_download
metadata = get_hf_file_metadata(
File "/root/miniconda3/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
return fn(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1631, in get_hf_file_metadata
r = _request_wrapper(
File "/root/miniconda3/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 385, in _request_wrapper
response = _request_wrapper(
File "/root/miniconda3/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 408, in _request_wrapper
response = get_session().request(method=method, url=url, **params)
File "/root/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 542, in request
resp = self.send(prep, **send_kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 655, in send
r = adapter.send(request, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/huggingface_hub/utils/_http.py", line 67, in send
return super().send(request, *args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/requests/adapters.py", line 504, in send
raise ConnectTimeout(e, request=request)
requests.exceptions.ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /timm/resnet50.a1_in1k/resolve/main/pytorch_model.bin (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7fd5cfa11430>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: b3552eef-0753-455a-a7b7-8dd61af412af)')

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "train.py", line 47, in
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
self._call_and_handle_interrupt(
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
self._dispatch()
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
self.training_type_plugin.start_training(self)
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
return self._run_train()
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1304, in _run_train
self._run_sanity_check(self.lightning_module)
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1368, in _run_sanity_check
self._evaluation_loop.run()
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 140, in run
self.on_run_start(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 96, in on_run_start
self._on_evaluation_epoch_start()
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 214, in _on_evaluation_epoch_start
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1483, in call_hook
output = model_fx(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/DiAD-main/ldm/models/diffusion/ddpm.py", line 484, in on_validation_epoch_start
pretrained_model = timm.create_model("resnet50", pretrained=True, features_only=True)
File "/root/miniconda3/lib/python3.8/site-packages/timm/models/_factory.py", line 117, in create_model
model = create_fn(
File "/root/miniconda3/lib/python3.8/site-packages/timm/models/resnet.py", line 1322, in resnet50
return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs))
File "/root/miniconda3/lib/python3.8/site-packages/timm/models/resnet.py", line 584, in _create_resnet
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/timm/models/_builder.py", line 397, in build_model_with_cfg
load_pretrained(
File "/root/miniconda3/lib/python3.8/site-packages/timm/models/_builder.py", line 190, in load_pretrained
state_dict = load_state_dict_from_hf(pretrained_loc)
File "/root/miniconda3/lib/python3.8/site-packages/timm/models/_hub.py", line 188, in load_state_dict_from_hf
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
File "/root/miniconda3/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
return fn(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1371, in hf_hub_download
raise LocalEntryNotFoundError(
huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.
请问这个问题怎么解决

Rec?

I copied your code, but the Rec picture I got is still broken (not as well fixed as in the paper), why? Is it because my data set is only 1000 and I can't feed it?

Controlnet?

why not to compare with Controlnet? The proposed model is based on controlnet, even the core code is from it.

微调自动编码器

在build_model.py文件中。path_input 是不是在finetune_autoencoder.py中得到的autoencoders.ckpt重命名成mvtecad_fs.ckpt

Stable Diffusion v1.5

input_path = './models/v1-5-pruned.ckpt'

Output DiAD model

output_path = './models/diad.ckpt'

Finetuned autoencoders

path_input = './models/mvtecad_fs.ckpt'

Sanity Checking DataLoader

image
image

Why are only two images validated during training? There are only 000.jpg and 001.jpg of class bottler are validated. is it right?

Code open source issues

Hello! I am very interested in your work. May I ask when your relevant code will be open source?

纠错帖

finetune_autoencoder.py 这个文件中,train_dataset = MVTecDataset('train')这一行缺少第二个路径参数,第二个参数是否应该为 'training/MVTec-AD/mvtec_anomaly_detection' ,即这行代码正确格式为train_dataset = MVTecDataset('train', 'training/MVTec-AD/mvtec_anomaly_detection')

batch size错误和nan问题

1、
UserWarning: Trying to infer the batch_size from an ambiguous collection. The batch size we found is 12. To avoid any miscalculations, use self.log(..., batch_size=batch_size).
这个batch size一直在变,我尝试了一些操作,但是不能消除这个warnig
2、
[2024-01-05 19:31:47,424][ eval_helper.py][line: 333][ INFO]
| clsname | pixel | max |
| bottle | 0.66969 | nan |
| mean | 0.66969 | nan |
这里的max为什么是nan,max代表什么意思(我使用的是MVTec-AD数据集)
3、训练完后模型大小是多少,在3060显卡上推理一张图的时间大概是多少呢
我还没训练完,并且训练时有第1,2点的疑惑,不确定是否能训练成功,但是想知道一张图上推理的时间
可以帮忙解答下吗,非常感谢!

用自己的数据集训练

我训练完MVTec-AD数据集后准备了一个自己的数据集(路面),照着MVTec-AD的格式进行更改,写了train.json和test.json,数据集路径也改了,也删了npz_result。但是训练了300轮以后生成的log_image中xxx-sample.jpg都是上个数据集MVTec-AD中的图片

Multi-class in MVTec-AD why use one-class?

Dears,
Thank you for your interesting work.
I see the evaluation result with MVTec-AD dataset. You seem to evaluate with one-class classification (for a single class) in the dataset, not the multi-class (all classes with a single trained model).
Could you please give a comment on this?

Thank you!

输入数据图片似乎没有放到gpu上

Traceback (most recent call last):
File "train.py", line 53, in
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
self._call_and_handle_interrupt(
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
self._dispatch()
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
return self._run_train()
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1304, in _run_train
self._run_sanity_check(self.lightning_module)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1368, in _run_sanity_check
self._evaluation_loop.run()
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 123, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 215, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 219, in validation_step
return self.model.validation_step(*args, **kwargs)
File "/media/ext_disk/xunchangjie/DiAD-main/ldm/models/diffusion/ddpm.py", line 472, in validation_step
input_features = self.pretrained_model(input_img)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/timm/models/_features.py", line 275, in forward
return list(self._collect(x).values())
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/timm/models/_features.py", line 231, in _collect
x = module(x)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/xunchangjie/anaconda3/envs/diad/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
上面是报错信息,我试了很多方法都没办法将Input放到cuda上运行,请问有什么方法嘛?

代码疑问

(1)MVTec-AD数据集在'./training/MVTec-AD/mvtec_anomaly_detection/',但finetune_autoencoder.py 中 data_path='./data/mvtecad',运行过程中会显示对象为空报错。我把它改为data_path='./training/MVTec-AD/mvtec_anomaly_detection/'就可以运行了。
(2)请问代码运行的轮数在哪里修改

Env problem

we noticed in "environment.yaml" the pytorch=1.12.1. While in "xformers==0.0.18" it's required pytorch=2.0.
Which version of pytorch shall I choose?

About the "multi-class setting" mentioned in the paper

Thank you for your selfless sharing!
I saw the data presentation in Table 1 and Table 3 in your paper, which mentioned "multi-class setting". May I ask what it means? Is the MVTEC's 15 types of data mixed together for training and then mixed together for testing? I see that RD4AD seems to have only "OneClass Novelty Detection" and no "multi-class setting".
Looking forward to your reply

RuntimeError: Error(s) in loading state_dict for FeatureListNet: Unexpected key(s) in state_dict: "fc.weight", "fc.bias".

Traceback (most recent call last):
  File "train.py", line 50, in <module>
    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
    self._call_and_handle_interrupt(
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
    self._dispatch()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
    return self._run_train()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1304, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1368, in _run_sanity_check
    self._evaluation_loop.run()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 140, in run
    self.on_run_start(*args, **kwargs)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 96, in on_run_start
    self._on_evaluation_epoch_start()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 214, in _on_evaluation_epoch_start
    self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1483, in call_hook
    output = model_fx(*args, **kwargs)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/DiAD/ldm/models/diffusion/ddpm.py", line 484, in on_validation_epoch_start
    pretrained_model = timm.create_model("resnet50", pretrained=False, features_only=True,checkpoint_path="models/resnet50_a1_0-14fe96d1.pth")
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/timm/models/factory.py", line 74, in create_model
    load_checkpoint(model, checkpoint_path)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/timm/models/helpers.py", line 75, in load_checkpoint
    incompatible_keys = model.load_state_dict(state_dict, strict=strict)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for FeatureListNet:
        Unexpected key(s) in state_dict: "fc.weight", "fc.bias". 

Why does the generated model after I run the build_model.py show RuntimeError: Error(s) when I run the train.py?
Unexpected key(s) in state_dict: "fc.weight", "fc.bias". This mistake. I checked to set strict to False


model.load_state_dict(x,False)

but isn't the original code already False?

论文中的笔误?

您好!
请问论文中公式7是否存在一些笔误?是否将M_SG(E_SG)写成了M_SG(E_SD)?
以及图2中SGDB4下的蓝色Add操作那里,SDEB4的输出也和SGDB4一起加在SDDB4的输出上吗?这部分在论文里似乎没有体现?
诚盼您的解答!感谢!

When I ran test.py, I was prompted that I did not pass the root path, and the path was incorrect after I modified it. Please help to check it.

Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 768 and using 8 heads.
Loaded model config from [models/diad.yaml]
Loaded state_dict from [./val_ckpt/16719.ckpt]
Traceback (most recent call last):
File "test.py", line 53, in
for input in dataloader:
File "/root/miniconda3/envs/diad/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 634, in next
data = self._next_data()
File "/root/miniconda3/envs/diad/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
return self._process_data(data)
File "/root/miniconda3/envs/diad/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
data.reraise()
File "/root/miniconda3/envs/diad/lib/python3.8/site-packages/torch/_utils.py", line 644, in reraise
raise exception
cv2.error: Caught error in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/root/miniconda3/envs/diad/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/root/miniconda3/envs/diad/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/root/miniconda3/envs/diad/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/workspace/mvtecad_dataloader.py", line 65, in getitem
source = cv2.cvtColor(source, 4)
cv2.error: OpenCV(4.5.1) /tmp/pip-req-build-jhawztrk/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.