kaiyangzhou / dassl.pytorch Goto Github PK
View Code? Open in Web Editor NEWA PyTorch toolbox for domain generalization, domain adaptation and semi-supervised learning.
License: MIT License
A PyTorch toolbox for domain generalization, domain adaptation and semi-supervised learning.
License: MIT License
RuntimeError: CUDA out of memory. Tried to allocate 50.00 MiB (GPU 0; 11.91 GiB total capacity; 11.04 GiB already allocated; 43.62 MiB free; 11.10 GiB reserved in total by PyTorch)
I used 2 12GB TiTan XP run the Code. I wonder if there's any problem with my own code or is the problem of my devices.
Thanks.
Hi there, when installing this repository with pip install -r requirements.txt
(where the requirements.txt file contains git+https://github.com/KaiyangZhou/Dassl.pytorch.git
), the import numpy as np
in setup.py throws a ModuleNotFoundError, because numpy is not yet installed at that moment. Is is it possible to remove the import numpy as np
statement and def numpy_include(): ...
from setup.py such that this repository can be installed automatically with pip?
Hi, @KaiyangZhou. It is a very good toolbox. Just out of curiosity, will there be an implementation about CTAugment in the future.
Hello, I am having trouble training FixMatch with custom datasets.
I got this error:
input_x2 = batch_x["img2"] KeyError: 'img2
.
As part of the config/trainer files, I set K_TRANSFORMS to 2. When DatasetWrapper.getitem() is called, it returns one image.
Hi there!
In the Evaluation on Heterogeneous DG part of your paper, you evaluate the approach on the cross-dataset person re-identification (re-ID) task.
Can you share your training parameter about this task, including how to set the optimizer, network and data-preprocessing?
Thank you very much!
I want to train on my custom dataset using MME. Can you guide me in that?
Dassl.pytorch/dassl/evaluation/evaluator.py
Lines 42 to 46 in 08acfb3
I think we should also reset the self._y_true and self._y_pred in the reset function.
Otherwise we'll get wrong confusion matrix.
Hi @KaiyangZhou, thanks for implementing this excellent framework. I was wondering if there is any way to add the functionality for gradient accumulation.
There is only fixmatch fold in 'Dassl.pytorch/configs/trainers/ssl/'
when i choose 'fixmatch/cifar10.yaml' as config-file and run this:
CUDA_VISIBLE_DEVICES=2 python tools/train.py --root 'ssl/data/' --trainer MixMatch --dataset-config-file configs/datasets/ssl/cifar10.yaml --config-file configs/trainers/ssl/fixmatch/cifar10.yaml --output-dir output/mixmatch
and get this: assert cfg.DATALOADER.K_TRANSFORMS > 1
(because cfg.DATALOADER.K_TRANSFORMS==1 in default)
so could you give a mixmatch config-file? Thanks a lot :)
In /tools/train.py, line 89, "trainer.train()".
For example, trainer= ‘DDAIG’,
Where is the definition of "trainer.train()"? Thank you very much.
Maybe I am confused about the code in "dassl/engine/bulid.py"。
Hi,
Thanks for your code and this is a great work.
I have read the paper and as described in the section 3, the loss function for domain-specific expert learning is a cross-entropy loss.
I guess the implementation of the loss of domain adaptive ensemble learning is as following:
https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/engine/dg/daeldg.py#L109
Is this a standard cross-entropy loss function? Why not use nn.CrossEntropyLoss()?
And why calculate the mean when calculating the cross entropy of each data sample?
(-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
hi,kaiyang,When I was installing dassl using the installation steps you provided, after completing all the steps and running the code of clipadapter, the following error was thrown:
Traceback (most recent call last):
File "/home1/pan-internship-6/.conda/envs/dassl_coop/lib/python3.8/site-packages/sklearn/__check_build/init.py", line 48, in
from ._check_build import check_build # noqa
ImportError: dlopen: cannot load any more object with static TLS
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "train.py", line 10, in
from dassl.engine import build_trainer
File "/home1/pan-internship-6/Dassl.pytorch/dassl/engine/init.py", line 2, in
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip
File "/home1/pan-internship-6/Dassl.pytorch/dassl/engine/trainer.py", line 19, in
from dassl.evaluation import build_evaluator
File "/home1/pan-internship-6/Dassl.pytorch/dassl/evaluation/init.py", line 3, in
from .evaluator import EvaluatorBase, Classification
File "/home1/pan-internship-6/Dassl.pytorch/dassl/evaluation/evaluator.py", line 1, in
from sklearn.metrics import f1_score,confusion_matrix
File "/home1/pan-internship-6/.conda/envs/dassl_coop/lib/python3.8/site-packages/sklearn/init.py", line 81, in
from . import __check_build # noqa: F401
File "/home1/pan-internship-6/.conda/envs/dassl_coop/lib/python3.8/site-packages/sklearn/__check_build/init.py", line 50, in
raise_build_error(e)
File "/home1/pan-internship-6/.conda/envs/dassl_coop/lib/python3.8/site-packages/sklearn/__check_build/init.py", line 31, in raise_build_error
raise ImportError(
ImportError: dlopen: cannot load any more object with static TLS
How can I fix this error?
Hi, thanks for sharing your code.
I want to use your DAEL model, but can ADEL support the scene that multi-source with different categories (such as category-shift problem solved in "Deep Cocktail Network: Multi-source Unsupervised Domain Adaptation with Category Shift").
Thank you very much!
Hello! I want to use your DAEL's weak & strong transform, while I haven't find the weak transform & strong transform in your code.
Dear Kaiyang
Really appreciate the open-source domain generalization framework. It is really amazing.
I'm currently working on replicating the results and extending my work on the current framework. Do you mind sharing the parameters for each baseline to replicate the results mentioned in your AAAI-2020 (DDAIG)? When I tried to replicate the baseline results of CrossGrad and DomainMix, the results were much worse than the paper. I guess it may be caused by the parameter tunning because I'm currently using the default settings of the framework (I didn't change any single line of the framework). In the config files, there are only configs about DDAIG, DAELDG, and Vanilla. (All these three are very good).
On the other hand, could you please post an instruction about how to run MixStyle and EFDM by the framework?
Thank you very much!
Thanks for your code.
I have tried to reproduce the Vanilla model (Resnet-18) on Office-Home dataset. I got 47.2% on Clipart, which is far away from the proposed results (49.4%) in your paper.
Can you share your training parameter about this task, including how to set the optimizer, network and data-preprocessing?
Loading evaluator: Classification
Traceback (most recent call last):
File "tools/train.py", line 191, in
main(args)
File "tools/train.py", line 110, in main
trainer.load_model(args.model_dir, epoch=args.load_epoch)
File "d:\coop-main\dassl.pytorch-master\dassl\engine\trainer.py", line 199, in load_model
f"Load {checkpoint} to {name} (epoch={epoch}, val_result={val_result:.1f})"
TypeError: unsupported format string passed to NoneType.format
terminal input:python tools/train.py --root datasets/da/ --trainer SourceOnly --dataset-config-file configs/datasets/da/visda17.yaml --config-file configs/trainers/da/source_only/visda17.yaml --output-dir output/office31_test --source-domains real --target-domains real --eval-only --model-dir output/office31 --load-epoch 2
when i used .pth to test dataset, which occured this matter,help me plz
I use 4 gpus to train DDAIG, but the error happens "CUDA OUT OF MEMORY".
It seems the code only supports single-gpu training.
Hello Kaiyang,
thank you for sharing the codes.
Is there any guide to run benchmarks such as MME, MCD, et al.
Many thanks for your reply.
I received an email saying the current code cannot reproduce the results of DDAIG on PACS. I haven't run DDAIG using Dassl so I'm not sure if there is an issue.
I've attached the original log files which contain the information on versions of libraries, the environmental setting, and the exact parameters used in the paper. Hope this could help. Please check this google drive link. As DDAIG was done in early 2019, at that time I was using torch=0.4.1
and numpy=1.14.5
. Not sure if this will cause an issue. If there is really an issue with reproduction, it's also possible that there was sth wrong when I transferred DDAIG's code to this public Dassl repo (I'll double check this).
Please note that DDAIG was named ddap in the log files. Some parameters' names are different from Dassl's, this is because the original code was a baby-version of Dassl. But they should be easy to understand.
I'll find time and resources to run DDAIG using this code (pls bear with me).
You are the author of L2A-OT, which is a impressive work and strong benchmark. But why was it not implemented in this repo?
Hello, kaiyang. I'm trying to use your Dassl in some medical tasks. Due to the limit of my hardware, I have to train the network from two distruibuted GPUs...I'm using Horovod to finish the task, but I noticed that:"A drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU training", is that true while using Horovod ? I'm looking forward to have your answer, because it takes time to prepare the environment.
Excellent work!
I can not find the download link of 'SYN' in Digits-DG. The MNIST, MNIST-M, and SVHN can be downloaded easily.
Impressive work. Can' t wait to get the code. so when?!
Hi,
I tried ADDA on miniDomainNet, but the acc on the target test set decrease from 0.29 to 0.01. Previously I tried it on office31 and it works fine. Do you know what is wrong?
I used the following scripts:
python tools/train.py
--root $DATA \
--trainer SourceOnly \
--source-domains real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidomainnet
python train.py --trainer ADDA --source-domains real --target-domain sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/adda_minidomainnet \
--init-weights output/source_only_minidomainnet/model/model.pth.tar-60
We have improved the implementation of MixStyle to make it more flexible.
Recall that MixStyle has two versions: random mixing and cross-domain mixing. The former randomly shuffles the batch dimension while the latter mixes the 1st half in a batch with the 2nd half.
After merging MixStyle2
to MixStyle
, the two versions are now managed by a new variable called self.mix
, which takes as input either random
or crossdomain
that correspond to the two versions respectively. This variable can be set during initialization, e.g., self.mixstyle = MixStyle(mix='random')
. It can also be changed on-the-fly. For instance, say you wanna apply random mixing at current step, simply do model.apply(random_mixstyle)
, or model.apply(crossdomain_mixstyle)
if you prefer the cross-domain mixing manner.
We have also added new context managers to manage mixstyle in the forward pass. Say your model has MixStyle layers which were initially activated and you would like to deactivate them at a certain time, you can do
# print(MixStyle._activated): True
with run_without_mixstyle(model):
# print(MixStyle._activated): False
output = model(input)
# print(MixStyle._activated): True
Otherwise if you want to use MixStyle layers which were initially deactivated, you can do
# print(MixStyle._activated): False
with run_with_mixstyle(model):
# print(MixStyle._activated): True
output = model(input)
# print(MixStyle._activated): False
You can also change self.mix
while using run_with_mixstyle
, e.g.
# print(MixStyle._activated): False
# print(MixStyle.mix): random
with run_with_mixstyle(model, mix='crossdomain'):
# print(MixStyle._activated): True
# print(MixStyle.mix): crossdomain
output = model(input)
# print(MixStyle._activated): False
# print(MixStyle.mix): crossdomain
But note that the change in self.mix
during run_with_mixstyle
is permanent unless you manually use model.apply(random_mixstyle)
or model.apply(crossdomain_mixstyle)
to modify the variable.
Hello,
I replicated the dataset folder structure and tried running DAEL on digit5. However, it cannot find the mnist files, and seems to be searching for the wrong files. Upon closer inspection the load_mnist function in /data/datasets/da/digit5.py seems to expect different files from those provided in the links in DATASETS.md. Is it possible that this file is outdated with respect to the readme?
Edit: I had not run the dataset creation script. My apologies
Hi, thanks for sharing your code.
I tried it out trainer=SelfEnsembling, source_domain=mnist and traget_domain=mnist_m and I was expecting to get around 95% accuracy on the test subset, target_domain.
But I wasn't able to get more than 65% Accuracy.
Can you please have a look if I am missing any important parameters? I run it like this:
python tools/train.py \
--backbone resnet18 \
--root "datasets" \
--trainer SelfEnsembling \
--source-domains "mnist" \
--target-domains "mnist_m" \
--output-dir "$job_dir" \
--dataset-config-file "configs/datasets/da/digit5.yaml" \
DATALOADER.K_TRANSFORMS 2 \
DATALOADER.TRAIN_X.BATCH_SIZE 128 \
DATALOADER.NUM_WORKERS 10 \
TRAINER.SE.EMA_ALPHA 0.999 \
OPTIM.LR 0.0003 \
OPTIM.MAX_EPOCH 200 \
I also did a small hyper-parameter sweep and tried different LR=(3e-3 3e-4 3e-5) and EMA_ALPHA=(0.99 0.999 0.9999), but I didn't find a combination with a better score.
Someone suggested we should add the dataset of Terra Incognita, which is a wildlife animal classification dataset used in the DomainBed paper. I had gone through the images of Terra Incognita (the four locations chosen by DomainBed) and found that the objects of interest, i.e., animals, are often small in scale in comparison to the whole image, partially visible, and not centered. I feel using this dataset for evaluating image classifiers won't help track the progress as the quality isn't good enough.
I'm very interested in your implementation of weak augmentation & strong augmentation for unlabeled data, while I can't find the dataloader and transform module in your code.
Hi Kaiyang
Can you please point me to how can I implement a two-view dataloader for training using the SimCLR loss. Basically, I want the train dataloader to return two views (augmentations) of the same image each time it is called.
Thanks!
Some trainer config files are missing
Thank you very much if you can share these files, especially mean teacher.
Following the steps in README, there still are some errors to run the demo.
The path of dataset maybe have errors, but I have checked carefully, and the "file not found" errors still exist.
You can see the files can be found in "/home/wyk/dataset/office31/amazon/images/ruler"
How to run a demo on dassl.pytorch? Looking forward to your reply!
I notice that in #12 (comment)
The DDAIG (AAAI 2019) method uses different hyperparameters on different domains, which violates the original setting of domain generalization (The test domain is unseen.).
Fine tuning the hyperparameter for each domain will indirectly use the unseen target to improve the performance.
The real improvement of DDAIG method is unclear. And it is an unfair comparison with other SOTA results!
Hi, thanks for sharing the nice codes.
I have some trouble to get accuracy of M3SDA on DomainNet.
With below command, I got 0.55% of accuracy (error:99.45%).
What`s wrong with this ?
python tools/train.py --root /database --trainer M3SDA --source -domains clipart --target-domains infograph --dataset-config-file configs/datasets/da/domainnet.yaml --config-file configs/trainers /da/m3sda/domainnet.yaml --output-dir output/M3SDA_CI_DOMAINNET
sorry to disturb you.
in my domain generalization task, i find a mistake in https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/data/datasets/dg/vlcs.py#L40
because in the test data, it should inlcude all images in target domain, so it should be written as follows:
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "full")
test += self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
This code is a very good job, but in the process of my study, I found some problems, so I hope the author can help me to solve them.
Hi,
Is there any plan to introduce the semi-supervised domain adaptation data loaders in the code? There are UDA and SSL loaders, but for SSDA, we might need different target data loaders during the training, which, as far as I know, cannot be directly used from the codebase.
If the implementation of ssda loaders will help, I can make a PR as I have worked on it
At first, thanks for your work, it's a nice project in DA, DG filed. I'm using your adabn.py to deploy AdaBN method on my model, I read the code, found it's pretty simple and I think maybe there is somthing missing?
From my understanding, AdaBN method is to use the mean and var of the target domain to replace the ones from the training
stage on BN layers, but I just found reset running stats function in the adabn.py.
I'm looking forwar to your reply, thanks!
Hi there,
Thanks for your code. There is an issue when testing on the saved model.
I run the testing soon after training and the testing result is good. However, when I loaded the saved model to run the testing again, the result is very low.
How can I get the consistent testing results?
Hi everyone,
It seems that PAWS is state of the art in semi-supervised learning https://github.com/facebookresearch/suncet https://arxiv.org/abs/2104.13963 It would be interesting to add it to the framework.
I created the PR to expose the idea.
Thanks!.
Hi Kaiyang,
Thanks for this great repo! I am interested in using dassl to build several baselines for my own dataset, which is a set of tabular (vector) data from various domains. Do you have suggestions/tips/recommendations for adding our own dataset?
Thanks a lot!
Thank you for your great work!
I have a problem would like to ask.
Is it possible to give two different parts in one model with different learning rates?
Dassl.pytorch/dassl/engine/dg/daeldg.py
Lines 124 to 131 in 5e83fdc
Hi, I think there is something wrong with the loss in daeldg.py.
Why the 'loss' is assigned to 0? This will make the supervised loss be invalid.
Here is the error I got.
KeyError: 'Object name "CLIP_Adapter" does not exist in "TRAINER" registry'
The new trainer file is here:
https://github.com/gaopengcuhk/CLIP-Adapter/blob/main/clip_adapter.py
So does the Registry automatically detect the trainer if I used @TRAINER_REGISTRY.register() ?
Thanks for sharing the code.
However, even if I use your latest config file for PACS, I still cannot reproduce the results in the paper.
I repeat the experiments for 5-6 times in each domain. The results are shown in the following table:
art | sketch | photo | cartoon |
---|---|---|---|
76.91 | 72.23 | ||
79.35 | 74.39 | 93.05 | 73.17 |
79.64 | 74.13 | 94.91 | 74.32 |
82.71 | 71.61 | 92.81 | 73.12 |
80.03 | 75.41 | 94.01 | 72.14 |
82.37 | 75.1 | 94.01 | 72.1 |
Hope to get your reply.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.