Giter VIP home page Giter VIP logo

wise-ft's People

Contributors

gabrielilharco avatar mitchellnw avatar mmatena avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

wise-ft's Issues

OSError (undefined symbol) running wise_ft.py

Hi, I'm getting an OSError when trying to run the interpolation

I want to interpolate ViT-L-14-336px.pt with my fine-tuned.pt model but can't solve this issue, any ideas?

I ran the code below to create the env (no errors or warnings):

conda env create
conda activate wiseft

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

And the code to interpolate:

python wise_ft.py 

--load=/home/user/.cache/clip/ViT-L-14-336px.pt,/home/user/model_checkpoint/ft_01_6ep_lr2e6.pt      
--results-db=results.jsonl      
--save=models/wiseft      
--data-location=~/data     
--alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

error:

Traceback (most recent call last):
  File "wise_ft.py", line 5, in <module>
    import torch
  File "/home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/__init__.py", line 189, in <module>
    _load_global_deps()
  File "/home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/__init__.py", line 142, in _load_global_deps
    ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
  File "/home/user/miniconda3/envs/wiseft/lib/python3.6/ctypes/__init__.py", line 348, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /home/user/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/lib/../../../../libcublas.so.11: undefined symbol: free_gemm_select, version libcublasLt.so.11

Replicating few-shot results

In Table 7 of the paper, there are results showing Wise-FT with a linear classifier and the ViT/B-16 backbone can get 73% accuracy on a 16-shot imagenet dataset. It was mentioned that the learning rate was 10e-5 and it was trained for 10 epochs, but even with this information, I still cannot replicate the result shown in the paper. I was wondering if I could be provided with an exact command, or additional hyperparameters (e.g. batch size, number of warmup steps, etc.) so that this result can be replicated?

Where to find pre-trained model weights

Hi newbie here, I am trying to fine tune this model of yours which was uploaded to huggingface: https://huggingface.co/laion/CLIP-ViT-L-14-laion2B-s32B-b82K

I want to fine-tune it on my custom dataset.

Looking from the example below, the "checkpoint" to load are of .pt. May I ask where can I find these checkpoints for the pre-trained model specified in the link?

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Side question: why do I need to pass the finetuned.pt checkpoints for fine tuning? Won't I be missing the fine-tune weights before I start fine-tuning on my custom dataset?

Question about Table. 2 in the paper

image

How can we get the results in above figure. Do we need to design text prompts for each task and use them to init the classification head?

I try to add the classification head with random init weights, but get poor results for WiSE-FT.

Baseline curve for effective robustness for WILDS

Hi,

Thanks for the amazing work! I wonder if you could share some information on fitting the baseline curve, i.e., the list of standard ImageNet models' ID vs. OOD performance and its coefficients (w and b) on FMoW and iWildCam. Thanks in advance!

-K

Possibility to save ensemble as a full model?

If I'm understanding the method correctly, its a mix between model A and model B, at some ratio C

Is it possible, instead of having to add code to ensemble the mix of weights in every downstream application, to ensemble and save the model, premixed?

How to find out exactly which labels is used to caculate logits?

Hi, there are some questions when running your finetune code that i'm facing
In file wiseft/src/model/finetune.py, in line 83, logits is caculated by: logits = model(inputs)
inputs = batch[input_key].cuda(). I choose fine end to end, so input_key is image
When I print out the shape of logit in logits, it's a 1000 dimenson tensor, so that mean my image is compared with 1000 labels?
Howerver, I got no idea how to find out exactly which label is used to caculate logits
I track back to file wise/src/model/modeling.py, in line 72, so my inputs is run through a image_encoder, the output (logits) is caculated by calling classification_head. Howerver I still don't know which labels are used in this process
def forward(self, inputs):
if self.process_images:
inputs = self.image_encoder(inputs)
outputs = self.classification_head(inputs)
return outputs

Custom Dataset Class Usage

Hello, I am planning to finetune a classifier for my dataset and have created a class for it:

import os
import PIL
import torch
import numpy as np
import torchvision
from torchvision import transforms

# define class names
classnames = ['real', 'fake']
# Define the labels and their corresponding integer values
label_dict = {name: i for i, name in enumerate(classnames)}

class ImageFolderDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform

        # Initialize the lists to store the image paths and labels
        self.image_paths = []
        self.labels = []

        # Loop over the subfolders and their contents
        for label_name in classnames:
            label_path = os.path.join(self.folder_path, label_name)
            for filename in os.listdir(label_path):
                # Create the full path to the image file
                image_path = os.path.join(label_path, filename)
                # Add the image path and label to their respective lists
                self.image_paths.append(image_path)
                self.labels.append(label_dict[label_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image from disk
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        if self.transform is not None:
            image = self.transform(image)

        # Retrieve the label for this image
        label = self.labels[idx]

        return image, label
    
class ForenSynths:
    def __init__(self, preprocess,
                 location=os.path.expanduser('~/ForenSynths/biggan'),
                 batch_size=128,
                 num_workers=16,
                 classnames=None):

        ################# training #################
        self.train_dataset = ImageFolderDataset(root=location, transform=preprocess)

        self.train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
        )
        
        ################# testing #################
        self.test_dataset = ImageFolderDataset(root=location, transform=preprocess)

        self.test_dataloader = torch.utils.data.DataLoader(
            self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
        )

        self.classnames = classnames
from src.templates.utils import append_proper_article

forensynths_template = [
    lambda c: f"a {c} photo.",
    lambda c: f"this is {c}.",
    lambda c: f"a {c} image is shown.",
    lambda c: f"a {c} image is displayed.",
    lambda c: f"The image presented is a {c} image.",
    lambda c: f"The image presented is {c}.",
    lambda c: f"The depicted image is {c}.",
    lambda c: f"A picture is showcased, which can be described as {c}.",
]

Based on the instructions, I should run the command as this:

python src/wise_ft.py   \
    --train-dataset=ForenSynths\
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=32 \
    --cache-dir=cache  \
    --model=RN50  \
    --eval-datasets=ForenSynths  \
    --classnames= ['real', 'fake'] \
    --template=forensynths_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ForenSynths\
    --data-location=~/ForenSynths/biggan\
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

About fine-tuning

Hi, good work here. I am following the steps trying to get the clip fine-tuned. So I downloaded two datasets that were used in your example and simplified the script to like this:

python src/wise_ft.py
--train-dataset=ImageNetR
--epochs=10
--lr=0.00003
--batch-size=32
--cache-dir=cache
--model=ViT-B/32
--eval-datasets=ImageNetR,ImageNetA
--template=openai_imagenet_template
--results-db=results.jsonl
--save=models/wiseft/ViTB32
--data-location=~/data
--alpha 0 0.5 0.9

And then I got the following error. I have checked the code, and I found there is no such method as train_loader. Is that because there are some updates from the code? Or? Can you please give me some hints? Thanks.

Traceback (most recent call last):
File "/Users/happymind/local_dev/wise-ft/src/wise_ft.py", line 104, in
wise_ft(args)
File "/Users/happymind/local_dev/wise-ft/src/wise_ft.py", line 61, in wise_ft
finetuned_checkpoint = finetune(args)
^^^^^^^^^^^^^^
File "/Users/happymind/local_dev/wise-ft/src/models/finetune.py", line 50, in finetune
num_batches = len(dataset.train_loader)
^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ImageNetR' object has no attribute 'train_loader'. Did you mean: 'test_loader'?

Zero Shot Classification on my own Dataset

Hello,

I am trying to fine tune CLIP on my own dataset for Zero Shot Classification.
My question is - is there a way to load a CSV containing all the file paths and their corresponding labels? OR a Folder which contains all the images in subfolders?

Does fine-tune only tweak image encoder?

First of all, thanks for sharing the codebase.
I briefly went through the codes and it seems like you only fine-tune the image encoder part, is that right? If yes, I'm curious have you tried tweaking both image and text encoders?

Finetuning configs for more models

Hi, dear authors.
In this code you have provided an example for finetuning ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

By runing it, I can get the final WISE-FT results at \alpha=0.5 below:

ImageNet Top-1 accuracy: 0.7554
ImageNetR Top-1 accuracy: 0.7145
ImageNetA Top-1 accuracy: 0.3452
ImageNetSketch Top-1 accuracy: 0.4696
  • Is the result correctly aligned with your results? Since I cannot find official results for ViT-B/32 in paper, I just want to ensure that I run the code correctly.
  • What hyper-parameter config for other models, such as ViT-L, ViT-B, etc?

Training Parameters

Hello

Can you please tell what do Data(t) and Batch(t) mean when training from scratch using ViT-B/32:

image

Fine-tune on your own dataset

Hi,

I was wondering where to get started if I want to use this to finetune clip on my own dataset (a dataset of sketch-text pairs)?

Poor performance on ResNet.

Although good performace obtained by fine tuning ViT model, I found the poor performance on the ResNet models. Thus, How to fine tune the CLIP model by using pre-trained ResNet models? Thanks.

zero-shot model

Hi,

I would like to use the WiSE-FT method to other tasks or pretrained models (e.g., bert, gpt). In this context, the so-called zero-shot model is actually the orignial model without fine-tuning, right? and the zero-model parameters actually means the directly-loaded pretrained parameters?

Thank you!

ModuleNotFoundError when running wise-ft.py on google colab

Hi, I'm getting ModuleNotFoundError when running wise-ft on google colab. I tried many solution i found in the internet but none of them working.
After cloning your repo, I run this code

!python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

And got this error
Traceback (most recent call last): File "/content/wise-ft/src/wise_ft.py", line 7, in <module> from src.models.eval import evaluate ModuleNotFoundError: No module named 'src'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.