Giter VIP home page Giter VIP logo

mokd's Introduction

MOKD: Cross-domain Finetuning for Few-shot Classification via Maximizing Optimized Kernel Dependence

arXiv Static Badge License: MIT Static Badge Static Badge

This repository contains the source codes for reproducing the results of ICML'24 paper:

MOKD: Cross-domain Finetuning for Few-shot Classification via Maximizing Optimized Kernel Dependence.

Author List: Hongduan Tian, Feng Liu, Tongliang Liu, Bo Du, Yiu-ming Cheung, Bo Han.

Introduction

Current works regarding cross-domain few-shot classification mainly focus on adapting a simple transformation head on top of a frozen pretrained backbone (e.g. ResNet-18) by optimizing the nearest centroid classifier loss (a.k.a. NCC-based loss). However, the undesirable phenomenon that there exists high similarity between samples from different classes is observed during the adaptation phase. The high similarity may induce uncertainty and further result in misclassification of data samples.

Heat map of similarity of support data representations on Omniglot

To solve this problem, we propose a bi-level optimization framework maximizing optimized kernel dependence (MOKD) to learn a set of class-specific representations that matches the cluster structures indicated by the label information. Specifically, MOKD first optimizes the kernel used in Hilbert-Schmidt Independence Criterion to obtain the optimized kernel HSIC where the test power is maximized for precise detection of dependence. Then, the optimized kernel HSIC is further optimized to simultaneously maximize the dependence between representations and labels while minimize the dependence among all samples.

Preparation

Dependencies

In our experiments, the main dependences required are the following libraries:

Python 3.6 or greater (Ours: Python 3.8)
yTorch 1.0 or greater (Ours: torch=1.7.1, torchvision=0.8.2)
TensorFlow 1.14 or greater (Ours: TensorFlow=2.10)
tqdm (Ours: 4.64.1)
tabulate (0.8.10)

Dataset

  • Follow Meta-Dataset repository to prepare ILSVRC_2012, Omniglot, Aircraft, CU_Birds, Textures (DTD), Quick Draw, Fungi, VGG_Flower, Traffic_Sign and MSCOCO datasets.

  • Follow CNAPs repository to prepare MNIST, CIFAR-10 and CIFAR-100 datasets.

Backbone Pretraining

In this paper, we follow URL and use ResNet-18 as the frozen backbone in all our experiments. For reproduction, two ways are provided:

Train your own backbone. You can train the ResNet-18 backbone from scratch by yourself. The pretraining mainly contains two phases: domain-specific pretraining and universal backbone distillation.

To train the single domain-specific learning backbones (on 8 seen domains), run:

./scripts/train_resnet18_sdl.sh

Then, distill the model by running:

./scripts/train_resnet18_url.sh

Use the released backbones. URL repository has released both universal backbone and single domain backbone. For simplicity, you can directly use the released model.

The backbones can be downloaded with the above links. To download the pretrained URL model, one can use gdown (installed by pip install gdown) and execute the following command in the root directory of this project:

gdown https://drive.google.com/uc?id=1MvUcvQ8OQtoOk1MIiJmK6_G8p4h8cbY9 && md5sum sdl.zip && unzip sdl.zip -d ./saved_results/ && rm sdl.zip  # Universal backbone
gdown https://drive.google.com/uc?id=1Dv8TX6iQ-BE2NMpfd0sQmH2q4mShmo1A && md5sum url.zip && unzip url.zip -d ./saved_results/ && rm url.zip  # Domain specific backbones

In this way, the backbones are donwnloaded. Please create the ./saved_results directory and place the backbone weights in it.

Evaluate MOKD

To evaluate the MODK, you can run:

./scripts/test_hsic_pa.sh

Specifically, the running command is:

python hsic_loss.py --model.name=url 
                    --model.dir ./url 
                    --data.imgsize=84\
                    --seed=41 \
                    --test_size=600 \
                    --kernel.type=rbf \
                    --epsilon=1e-5 \
                    --test.type=standard \
                    --experiment_name=mokd_seed41

The hyperparameters can be modified for different experiments:

  • model_name: ['sdl', 'url']: sdl means using single domain backbone; url means using universal backbone.
  • model.dir: Path to the backbone weights.
  • seed: The random seed. All our results are the average of seed 41-45.
  • kernel.type ['linear', 'rbf', 'imq']: Select different kernels to run MOKD.
  • test.type ['standard', '5shot', '1shot']: Different task modes. standard means vary-way vary-shot tasks; 5shot means vary-way 5-shot tasks; 1shot means 5-way 1-shot tasks.

Evaluate Pre-classifier Alignment (PA)

To evaluate Pre-classifier Alignment (PA), which is the typical case of URL, run:

./scripts/test_resnet18_pa.sh

Acknowledgement

The repository is built mainly upon these repositories:

[1] Li et al. Universal representation learning from multiple domains for few-shot classification, ICCV 2021.

[2] Triantafillou et al. Meta-dataset: A dataset of datasets for learning to learn from few examples, ICLR 2020.

Citation

@inproceedings{tian2024mokd,
    title={MOKD:Cross-domain Finetuning for Few-shot Classification via Maximizing Optimized Kernel Dependence},
    author={Hongduan Tian and Feng Liu and Tongliang Liu and Bo Du and Yiu-ming Cheung and Bo Han},
    booktitle={International Conference on Machine Learning (ICML)},
    year={2024}
}

mokd's People

Contributors

hongduantian avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

tmlr-group

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.