mlfoundations / task_vectors Goto Github PK
View Code? Open in Web Editor NEWEditing Models with Task Arithmetic
Editing Models with Task Arithmetic
Hi @gabrielilharco, when I download the ViT-L-14/Cars/finetuned.pt checkpoint from google drive and try to load the checkpoint, I am getting an error that the checkpoint is corrupted. All the other checkpoints work fine. I have tried downloading it multiple times and it is still not working. Would it be possible for you to reupload ViT-L-14/Cars/finetuned.pt checkpoint?
Hi @gabrielilharco, I was trying to run some experiments on the validation dataset. I learned from the other issues that I need to use [dataset]Val
to use the validation split. When trying this the code started to create a head for the validation split of the dataset. Is this the expected behavior? I thought that a single classifier would work for all splits. Am I missing something here?
Hi @gabrielilharco,
thanks for the great work!
I have a question about classification heads used in your experiments and available here. How exactly did you train them? Looking at the code I can see that that you manually construct a zero shot classifier based on embeddings of class names put in various templates. Importantly, you use the pretrained OpenCLIP model to calculate embeddings for classifier, not the model finetuned for a particular task. Am I right about that?
What is the reason for that? To me, the most natural way to obtain the classifier would be to get model.classification_head
after the finetuning of a task specific model (here). This classification head is aligned with the finetuned model while the zeroshot head is aligned with pretrained model therefore the head from finetuning seems more suitable. Did you consider such an approach?
When I try to load this checkpoint with torch.load
, it raises
RuntimeError: Invalid magic number; corrupt file?
The md5sum of download file is 2c3b6a4f39e1b62def9ccf7288099e64
.
Could you update the counterpart code,plz
While trying to load the file using Python’s pickle module, I encountered an _pickle.UnpicklingError, stating that persistent IDs in protocol 0 must be ASCII strings. Here is the exact error message:
_pickle.UnpicklingError: persistent IDs in protocol 0 must be ASCII strings
I attempted to resolve the issue by employing various methods, including utilizing the persistent_load parameter with pickle.Unpickler and trying to load the file in different environments, but unfortunately, all efforts have been in vain.
Request for Assistance:
Given the circumstances, I was hoping you could provide some insights or guidance on the following points:
Creation Environment: Could you share details about the environment in which the file was created, including the Python and PyTorch versions used?
Persistent IDs: Any information or context regarding the persistent IDs encountered in the file would be immensely helpful.
Loading Method: If there is a specific method or procedure to correctly load the file, could you please share it with me?
Additional Details: Any other details or specifications about the file that you think might assist in resolving the issue would be greatly appreciated.
Hi @gabrielilharco, I see that in Appendix D2 you mentioned that you have tried training a multitask checkpoint on the eight vision tasks (SUN397 Cars RESISC45 EuroSAT SVHN GTSRB MNIST DTD) for the learning via addition experiments. I can see that in Appendix D.2 you report the multitask normalized performance to be 99.4. Did you share the raw numbers on each task somewhere? If you can share these checkpoints for ViT-B-32 and ViT-L-14 that would be really helpful.
Thanks in advance,
Prateek
It's unclear to me what should I assign as the value for args.data_location
. The README tells args.data_location = '/path/to/data'
but I'm not sure which folder that means.
Hi, awesome work!
I'm trying to reproduce your results but I cannot find the split definitions you use for DTD, EuroSAT and SUN397. Would you mind pointing me to the right resources to download the versions of these datasets compatible with your code?
Thanks a lot!
Hi @gabrielilharco ,
Thank you for your exciting work! I tried to replicate the result using code from README.md
. It showed an error when running
# Create the task vector
task_vector = TaskVector(pretrained_checkpoint, finetuned_checkpoint)
The error is:
AttributeError: Can't get attribute 'VisualTransformer' on <module 'open_clip.model' from '/srv/home/<user_name>/anaconda3/envs/task-vectors/lib/python3.10/site-packages/open_clip/model.py'>
The error is from loading the model with trained weights pretrained_state_dict = torch.load(pretrained_checkpoint).state_dict()
.
Could you help me with this? Thanks!
Dear authors,
I am trying to reproduce you work based on this repo. Now I encounter a problem. It seems that ImageNet is not downloaded automatically in your repo. So, which ImageNet did you adopt? ILSVRC_2012? And any other changes ought to be applied to the datasets?
Best regards,
Hongduan
I am trying to use task vectors for GPT2 models.
In the TaskVectors class, when the task vector is created, there is a condition which ignores keys in the state dict that have dtype of uint8.
Due to this condition, when I call the apply_to() method of a task vector instance by passing a GPT-2 model checkpoint, I get the following error.
Warning: key transformer.h.0.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.1.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.2.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.3.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.4.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.5.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.6.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.7.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.8.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.9.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.10.attn.bias is present in the pretrained state dict but not in the task vector Warning: key transformer.h.11.attn.bias is present in the pretrained state dict but not in the task vector
Is there a reason why state dict keys with dtype torch.uint8 are ignored.
When that condition is removed, the code to run without any errors.
Please suggest what should be the best thing to do here.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.