Giter VIP home page Giter VIP logo

tvt's Introduction

updates (03/10/2022)

  1. Add the environment requirements to reproduce the results.

  2. Add the attention visualization code. An example is as follows where att_visual.txt contains image pathes:

python3 visualize.py --dataset office --name dw --num_classes 31 --image_path att_visual.txt --img_size 256

More details can be found in Attention Map Visualization

updates (03/08/2022)

Add the source-only code. An example on Office-31 dataset is as follows, where dslr is the source domain, webcam is the target domain:

python3 train.py --train_batch_size 64 --dataset office --name dw_source_only --train_list data/office/dslr_list.txt --test_list data/office/webcam_list.txt --num_classes 31 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz --num_steps 5000 --img_size 256

Environment (Python 3.8.12)

# Install Anaconda (https://docs.anaconda.com/anaconda/install/linux/)
wget https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh
bash Anaconda3-2021.11-Linux-x86_64.sh

# Install required packages
conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -c pytorch
pip install tqdm==4.50.2
pip install tensorboard==2.8.0
# apex 0.1
conda install -c conda-forge nvidia-apex
pip install scipy==1.5.2
pip install ml-collections==0.1.0
pip install scikit-learn==0.23.2

Pretrained ViT

Download the following models and put them in checkpoint/

TVT with ViT-B_16 (ImageNet-21K) performs a little bit better than TVT with ViT-B_16 (ImageNet):

Datasets:

  • Download data and replace the current data/

  • Download images from Office-31, Office-Home, VisDA-2017 and put them under data/. For example, images of Office-31 should be located at data/office/domain_adaptation_images/

Training:

All commands can be found in script.txt. An example:

python3 main.py --train_batch_size 64 --dataset office --name wa \
--source_list data/office/webcam_list.txt --target_list data/office/amazon_list.txt \
--test_list data/office/amazon_list.txt --num_classes 31 --model_type ViT-B_16 \
--pretrained_dir checkpoint/ViT-B_16.npz --num_steps 5000 --img_size 256 \
--beta 0.1 --gamma 0.01 --use_im --theta 0.1

Attention Map Visualization:

python3 visualize.py --dataset office --name wa --num_classes 31 --image_path att_visual.txt --img_size 256

The code will automatically use the best model in wa to visualize the attention maps of images in att_visual.txt. att_visual.txt contains image pathes you want to visualize, for example:

/data/office/domain_adaptation_images/dslr/images/calculator/frame_0001.jpg 5
/data/office/domain_adaptation_images/dslr/images/calculator/frame_0002.jpg 5
/data/office/domain_adaptation_images/dslr/images/calculator/frame_0003.jpg 5
/data/office/domain_adaptation_images/dslr/images/calculator/frame_0004.jpg 5
/data/office/domain_adaptation_images/dslr/images/calculator/frame_0005.jpg 5

Citation:

@article{yang2021tvt,
  title={TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation},
  author={Yang, Jinyu and Liu, Jingjing and Xu, Ning and Huang, Junzhou},
  journal={arXiv preprint arXiv:2108.05988},
  year={2021}
}

Our code is largely borrowed from CDAN and ViT-pytorch

tvt's People

Contributors

viyjy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

tvt's Issues

Question about loss function.

Hi.

Thank you for sharing the code.

I have the question about the total loss function using in this code.

image

As i know, most codes of GAN backward the generator and discriminator loss separately. However, you backward them in once by combining them into one loss. (As i understood, loss_ad_local and loss_ad_global are same with discriminator and generator loss).

Does backwarding the combined loss function have the same effect as backwarding each loss functions?

Thank you.

training time

your work is perfect, what i want to ask is how long does it take to train the model?

Some questions about TAM and DCM

Hi, thank you very much for publishing the code I've found your paper really interesting and was trying to figure if this can be adapted to my own model (I'm using basically the same configuration as yours).
The paper highlights two main components DCM and TAMand i was wondering if this two modules can be extracted from the code.

  1. Looking at the code I get that DCM is just a classification head which is performing a classification task on the domain, as DANN (I also see GRL). Am I correct?

  2. I am having a harder life finding what is called TAM. Is it something that is built inside the blocks? Because I see that the main change is that the adversarial net is used inside each block to generate a loss_ad . Is this what you call patch-level discriminator inside the paper? Is there need to add something else to create the so called TAM

Thank you very much again for sharing the code

The results about vanilla ViT with adversarial adaptation.

Hi, thanks for releasing the open-source project of transformer-based UDA! May I ask you some questions?

I tried to reproduce the Baseline result, which denotes vanilla ViT with adversarial adaptation. Differently, I fixed the bs as 32, the input size as 224 * 224. Accordingly, I got the results on the setting of office31, W@A, D@A.
image
[1] Safe Self-Refinement for transformer-based Domain Adaptation (https://arxiv.org/pdf/2204.07683.pdf)
Obviously, [1] uses the Timm library and gets a very strong baseline. Have you thought about switching to Timm to improve performance? It's hard to compare because everyone implements baseline in different ways and the results vary greatly.:sweat_smile:

Additionally, I had a problem with the training baseline mentioned before.
image
Model performance rises and then deteriorates rapidly. The training parameters are as follows: --beta 0.1 --gamma 0. --theta 0.
Have you ever encountered such a problem? Can you give me some advice?

code can not achieve result in paper

I run some script in ‘script.txt’, but get result far away from result reported in paper.
Here are some result:
Office-home
Pr->Cl:52.37
Cl->Pr:78.64
Cl->Ar:72.39
Office-31
AD:94.98
AW:94.21
DW:76.78
Visda17
Train->Val:80.78

ps: apex didn't work well for some installation error, but i think it have no effect on result.(have effect on GPU memory and accumulation.)
(【fused_weight_gradient_mlp_cuda module not found. gradient accumulation fusion with weight gradient computation disabled.】 I didn't find a resolution for this error)

Errors encountered in running the code,SOS!!!

The error message is as follows:
test setup failed
args = <module 'main' from 'C:\Users\Hongyu\Desktop\TVT-main\main.py'>

        def setup(args):
            # Prepare model
    >       config = CONFIGS[args.model_type]
    E       AttributeError: module 'main' has no attribute 'model_type'
    
    main.py:74: AttributeError

About the feature visualization in your paper

Hello, I'm very happy to read your article. I have some puzzles. If you have time, can you help me to solve them?

How to realize the feature visualization in your article? What format is the data?

About code in modeling.py

Hello!
After reading your paper and codes, I have a question about the following codes in modeling.py, can you please explain it?

    if posi_emb is not None:
        eps=1e-10
        batch_size = key_layer.size(0)
        patch = key_layer
        ad_out, loss_ad = lossZoo.adv_local(patch[:,:,1:], ad_net, is_source)
        entropy = - ad_out * torch.log2(ad_out + eps) - (1.0 - ad_out) * torch.log2(1.0 - ad_out + eps) # 标准熵H
        entropy = torch.cat((torch.ones(batch_size, self.num_attention_heads, 1).to(hidden_states.device).float(), entropy), 2) # [1, T(Kpatch)]
        trans_ability = entropy if self.vis else None   # [B*12*197]
        entropy = entropy.view(batch_size, self.num_attention_heads, 1, -1)
        attention_probs = torch.cat((attention_probs[:,:,0,:].unsqueeze(2) * entropy, attention_probs[:,:,1:,:]), 2)

In the last line why do you multiply entropy only in the 0 demension, not in all demension?

Some problems in the code/modeling.py

Hello, I encountered something I didn't understand in the process of reading the code. If you have time, can you help me point it out?
At line 99 of the code/modeling.py, this is:
if posi_emb is not None:
eps=1e-10
batch_size = key_layer.size(0)
patch = key_layer
ad_out, loss_ad = lossZoo.adv_local(patch[:,:,1:], ad_net, is_source)
entropy = - ad_out * torch.log2(ad_out + eps) - (1.0 - ad_out) * torch.log2(1.0 - ad_out + eps)
entropy = torch.cat((torch.ones(batch_size, self.num_attention_heads, 1).to(hidden_states.device).float(), entropy), 2)
trans_ability = entropy if self.vis else None # [B12197]
entropy = entropy.view(batch_size, self.num_attention_heads, 1, -1)
attention_probs = torch.cat((attention_probs[:,:,0,:].unsqueeze(2) * entropy, attention_probs[:,:,1:,:]), 2)
What I don't understand is, is this place a place to fight against discrimination losses? Which part of Vit did you improve?

No module named 'models.modeling_resnet'

I try to run your code, but get the error as title, and I found there is no file named 'modeling_resnet'(in the code 'import .modeling_resnet import ResNetV2'), are there any file lost?
Thanks.

how to load the .bin file for Attention map visualization

Hi, sorry to bother you. I'm having trouble saving the training model file and loading the model. I can't load the files I trained and saved in the output folder, _checkpoint.bin and _checkpoint_adv.bin. I want to load the trained model to do some visualization experiments, but I am confused about the .bin file. I can load ViT-B_16.npz normally for Attention map visualization. Can you share how you loaded the .bin file for Attention map visualization?

here is the code to save model in your train.py.

`def save_model(args, model, is_adv=False):

model_to_save = model.module if hasattr(model, 'module') else model
if not is_adv:
model_checkpoint = os.path.join(args.output_dir, args.dataset, "%s_checkpoint.bin" % args.name)
else:
model_checkpoint = os.path.join(args.output_dir, args.dataset, "%s_checkpoint_adv.bin" % args.name)
torch.save(model_to_save.state_dict(), model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", os.path.join(args.output_dir, args.dataset))`

I'm a newbie, so I'm very sorry to bother you with this question.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.