uta-smile / tvt Goto Github PK
View Code? Open in Web Editor NEWCode of TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation, WACV 2023
License: MIT License
Code of TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation, WACV 2023
License: MIT License
Hi, thank you very much for publishing the code I've found your paper really interesting and was trying to figure if this can be adapted to my own model (I'm using basically the same configuration as yours).
The paper highlights two main components DCM and TAMand i was wondering if this two modules can be extracted from the code.
Looking at the code I get that DCM is just a classification head which is performing a classification task on the domain, as DANN (I also see GRL). Am I correct?
I am having a harder life finding what is called TAM. Is it something that is built inside the blocks? Because I see that the main change is that the adversarial net is used inside each block to generate a loss_ad . Is this what you call patch-level discriminator inside the paper? Is there need to add something else to create the so called TAM
Thank you very much again for sharing the code
Hi author, I ported your code to my domain and after training for a while, I noticed that my loss in the source domain kept changing during training, but the accuracy in the target domain surprisingly stayed exactly the same. Have you ever encountered this situation on your side?
Thanks for sharing the code. I'm wondering would you like to share the code for reproducing the source only results?
I have follow your code, in VisDA2017, the accuracy only 70.76
Hello, I'm very happy to read your article. I have some puzzles. If you have time, can you help me to solve them?
How to realize the feature visualization in your article? What format is the data?
Hello!
After reading your paper and codes, I have a question about the following codes in modeling.py, can you please explain it?
if posi_emb is not None:
eps=1e-10
batch_size = key_layer.size(0)
patch = key_layer
ad_out, loss_ad = lossZoo.adv_local(patch[:,:,1:], ad_net, is_source)
entropy = - ad_out * torch.log2(ad_out + eps) - (1.0 - ad_out) * torch.log2(1.0 - ad_out + eps) # 标准熵H
entropy = torch.cat((torch.ones(batch_size, self.num_attention_heads, 1).to(hidden_states.device).float(), entropy), 2) # [1, T(Kpatch)]
trans_ability = entropy if self.vis else None # [B*12*197]
entropy = entropy.view(batch_size, self.num_attention_heads, 1, -1)
attention_probs = torch.cat((attention_probs[:,:,0,:].unsqueeze(2) * entropy, attention_probs[:,:,1:,:]), 2)
In the last line why do you multiply entropy only in the 0 demension, not in all demension?
I run some script in ‘script.txt’, but get result far away from result reported in paper.
Here are some result:
Office-home
Pr->Cl:52.37
Cl->Pr:78.64
Cl->Ar:72.39
Office-31
AD:94.98
AW:94.21
DW:76.78
Visda17
Train->Val:80.78
ps: apex didn't work well for some installation error, but i think it have no effect on result.(have effect on GPU memory and accumulation.)
(【fused_weight_gradient_mlp_cuda
module not found. gradient accumulation fusion with weight gradient computation disabled.】 I didn't find a resolution for this error)
Hi, sorry to bother you. I'm having trouble saving the training model file and loading the model. I can't load the files I trained and saved in the output folder, _checkpoint.bin and _checkpoint_adv.bin. I want to load the trained model to do some visualization experiments, but I am confused about the .bin file. I can load ViT-B_16.npz normally for Attention map visualization. Can you share how you loaded the .bin file for Attention map visualization?
here is the code to save model in your train.py.
`def save_model(args, model, is_adv=False):
model_to_save = model.module if hasattr(model, 'module') else model
if not is_adv:
model_checkpoint = os.path.join(args.output_dir, args.dataset, "%s_checkpoint.bin" % args.name)
else:
model_checkpoint = os.path.join(args.output_dir, args.dataset, "%s_checkpoint_adv.bin" % args.name)
torch.save(model_to_save.state_dict(), model_checkpoint)
logger.info("Saved model checkpoint to [DIR: %s]", os.path.join(args.output_dir, args.dataset))`
I'm a newbie, so I'm very sorry to bother you with this question.
Hi, thanks for releasing the open-source project of transformer-based UDA! May I ask you some questions?
I tried to reproduce the Baseline result, which denotes vanilla ViT with adversarial adaptation. Differently, I fixed the bs as 32, the input size as 224 * 224. Accordingly, I got the results on the setting of office31, W@A, D@A.
[1] Safe Self-Refinement for transformer-based Domain Adaptation (https://arxiv.org/pdf/2204.07683.pdf)
Obviously, [1] uses the Timm library and gets a very strong baseline. Have you thought about switching to Timm to improve performance? It's hard to compare because everyone implements baseline in different ways and the results vary greatly.:sweat_smile:
Additionally, I had a problem with the training baseline mentioned before.
Model performance rises and then deteriorates rapidly. The training parameters are as follows: --beta 0.1 --gamma 0. --theta 0.
Have you ever encountered such a problem? Can you give me some advice?
Hello, I encountered something I didn't understand in the process of reading the code. If you have time, can you help me point it out?
At line 99 of the code/modeling.py, this is:
if posi_emb is not None:
eps=1e-10
batch_size = key_layer.size(0)
patch = key_layer
ad_out, loss_ad = lossZoo.adv_local(patch[:,:,1:], ad_net, is_source)
entropy = - ad_out * torch.log2(ad_out + eps) - (1.0 - ad_out) * torch.log2(1.0 - ad_out + eps)
entropy = torch.cat((torch.ones(batch_size, self.num_attention_heads, 1).to(hidden_states.device).float(), entropy), 2)
trans_ability = entropy if self.vis else None # [B12197]
entropy = entropy.view(batch_size, self.num_attention_heads, 1, -1)
attention_probs = torch.cat((attention_probs[:,:,0,:].unsqueeze(2) * entropy, attention_probs[:,:,1:,:]), 2)
What I don't understand is, is this place a place to fight against discrimination losses? Which part of Vit did you improve?
I try to run your code, but get the error as title, and I found there is no file named 'modeling_resnet'(in the code 'import .modeling_resnet import ResNetV2'), are there any file lost?
Thanks.
The error message is as follows:
test setup failed
args = <module 'main' from 'C:\Users\Hongyu\Desktop\TVT-main\main.py'>
def setup(args):
# Prepare model
> config = CONFIGS[args.model_type]
E AttributeError: module 'main' has no attribute 'model_type'
main.py:74: AttributeError
your work is perfect, what i want to ask is how long does it take to train the model?
Hi.
Thank you for sharing the code.
I have the question about the total loss function using in this code.
As i know, most codes of GAN backward the generator and discriminator loss separately. However, you backward them in once by combining them into one loss. (As i understood, loss_ad_local and loss_ad_global are same with discriminator and generator loss).
Does backwarding the combined loss function have the same effect as backwarding each loss functions?
Thank you.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.