Giter VIP home page Giter VIP logo

aws-cv-task2vec's Introduction

Task2Vec

This is an implementation of the Task2Vec method described in the paper Task2Vec: Task Embedding for Meta-Learning.

Task2Vec provides vectorial representations of learning tasks (datasets) which can be used to reason about the nature of those tasks and their relations. In particular, it provides a fixed-dimensional embedding of the task that is independent of details such as the number of classes and does not require any understanding of the class label semantics. The distance between embeddings matches our intuition about semantic and taxonomic relations between different visual tasks (e.g., tasks based on classifying different types of plants are similar). The resulting vector can be used to represent a dataset in meta-learning applicatins, and allows for example to select the best feature extractor for a task without an expensive brute force search.

Quick start

To compute and embedding using task2vec, you just need to provide a dataset and a probe network, for example:

from task2vec import Task2Vec
from models import get_model
from datasets import get_dataset

dataset = get_dataset('cifar10')
probe_network = get_model('resnet34', pretrained=True, num_classes=10)
embedding =  Task2Vec(probe_network).embed(dataset)

Task2Vec uses the diagonal of the Fisher Information Matrix to compute an embedding of the task. In this implementation we provide two methods, montecarlo and variational. The first is the fastest and is the default, but variational may be more robust in some situations (in particular it is the one used in the paper). You can try it using:

task2vec.embed(dataset, probe_network, method='variational')

Now, let's try computing several embedding and plot the distance matrix between the tasks:

from task2vec import Task2Vec
from models import get_model
import datasets
import task_similarity

dataset_names = ('mnist', 'cifar10', 'cifar100', 'letters', 'kmnist')
dataset_list = [datasets.__dict__[name]('./data')[0] for name in dataset_names] 

embeddings = []
for name, dataset in zip(dataset_names, dataset_list):
    print(f"Embedding {name}")
    probe_network = get_model('resnet34', pretrained=True, num_classes=int(max(dataset.targets)+1)).cuda()
    embeddings.append( Task2Vec(probe_network, max_samples=1000, skip_layers=6).embed(dataset) )
task_similarity.plot_distance_matrix(embeddings, dataset_names)

You can also look at the notebook small_datasets_example.ipynb for a runnable implementation of this code snippet.

Experiments on iNaturalist and CUB

Downloading the data

First, decide where you will store all the data. For example:

export DATA_ROOT=./data

To download CUB-200, from the repository root run:

./scripts/download_cub.sh $DATA_ROOT

To download iNaturalist 2018, from the repository root run:

./scripts/download_inat2018.sh $DATA_ROOT

WARNING: iNaturalist needs ~319Gb for download and extraction. Consider downloading and extracting it manually following the instructions here.

Computing the embedding of all tasks

To compute the embedding on a single task of CUB + iNat2018, run:

python main.py task2vec.method=montecarlo dataset.root=$DATA_ROOT dataset.name=cub_inat2018  dataset.task_id=$TASK_ID -m

This will use the montecarlo Fisher approximation to compute the embedding of the task number $TASK_ID in the CUB + iNAT2018 meta-task. The result is stored in a pickle file inside outputs.

To compute all embeddings at once, we can use Hydra's multi-run mode as follow:

python main.py task2vec.method=montecarlo dataset.root=$DATA_ROOT dataset.name=cub_inat2018  dataset.task_id=`seq -s , 0 50` -m

This will compute the embeddings of the first 50 tasks in the CUB + iNat2018 meta-task. To plot the 50x50 distance matrix between these tasks, first download all the iconic_taxa image files to ./static/iconic_taxa, and then run:

python plot_distance_cub_inat.py --data-root $DATA_ROOT ./multirun/montecarlo

The result should look like the following. Note that task regarding classification of similar life forms (e.g, different types of birds, plants, mammals) cluster together.

task2vec distance matrix

aws-cv-task2vec's People

Contributors

alexachille avatar amazon-auto avatar orchidmajumder avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

aws-cv-task2vec's Issues

what would a task2vec (cosine) distance of -1 mean?

I feel -1 would mean the task representations/embeddings point in opposite directions. But the actual vectors are the same. So are the tasks completely similar up to a single bit of info (the sign) or is it better to interpret this that the two tasks are the same and the direction should be ignored?

In general I'm trying to figure how to interpret the negative sign.

I suppose 0 means perpendicular tasks and 1 means the actual same vector -- even the sign.

Missing txt files

Hi, it seems like there are missing txt files containing taxonomical information. Namely taxonomy.txt as referenced in dataset/cub.py and passeriformes.txt as in datasets.py. Can you point me to them or provide them in the repo? Thank you!

speeding up FIM computation

Is there a way to speed up FIM computation?

Would perhaps not going through the entire loop in here be ok? e.g. for the extreme case just doing 1 loop

        for k in range(epochs):
            logging.info(f"\tepoch {k + 1}/{epochs}")
            for i, (data, target) in enumerate(tqdm(data_loader, leave=False, desc="Computing Fisher")):
                data = data.to(device)
                output = self.model(data, start_from=self.skip_layers)
                # The gradients used to compute the FIM needs to be for y sampled from
                # the model distribution y ~ p_w(y|x), not for y from the dataset
                if self.bernoulli:
                    target = torch.bernoulli(F.sigmoid(output)).detach()
                else:
                    target = torch.multinomial(F.softmax(output, dim=-1), 1).detach().view(-1)
                loss = self.loss_fn(output, target)
                self.model.zero_grad()
                loss.backward()
                for p in self.model.parameters():
                    if p.grad is not None:
                        p.grad2_acc += p.grad.data ** 2
                        p.grad_counter += 1
                break  # for debugging faster, otherwise FIM is really slow
            break  # for debugging faster, otherwise FIM is really slow

or are there other better ideas?

citation quote on the readme would help cite repo/project

e.g. from other project

@incollection{NIPS2017_7188,
title = {SVCCA: Singular Vector Canonical Correlation Analysis for Deep Learning Dynamics and Interpretability},
author = {Raghu, Maithra and Gilmer, Justin and Yosinski, Jason and Sohl-Dickstein, Jascha},
booktitle = {Advances in Neural Information Processing Systems 30},
editor = {I. Guyon and U. V. Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett},
pages = {6076--6085},
year = {2017},
publisher = {Curran Associates, Inc.},
url = {http://papers.nips.cc/paper/7188-svcca-singular-vector-canonical-correlation-analysis-for-deep-learning-dynamics-and-interpretability.pdf}
}

task2vec complexity using normalized/standardized values

I saw the task2vec complexity using l1. But the l1 complexities might not be comparable across data sets. How would you suggest to adjust this metric s.t. its directly comparable accross data sets?

My suggestion is to divide perhaps by the std l1 complexity for that data set e.g.:

standardized_complexity = avg_complexity(task2vecs_list, l1) / unbiased_std_complexity(task2vecs_list, l1)

thoughts? This does assume normality and n>=30. Histograms showing normality might be useful.


related: https://stats.stackexchange.com/questions/604296/how-does-one-create-comparable-metrics-when-the-original-metrics-are-not-compara

Failed to Get Trivial Embedding

When I try to extract a trivial embedding with the provided function task_similarity.get_trivial_embedding_from(e) where e is an embedding, I get an error message TypeError: 'Embedding' object is not subscriptable. I notice that in task2vec.py, Embedding is a class and does not have keys such as 'layers', which is why I cannot extract the trivial embedding.

My question is:

  1. Is my input of an 'Embedding' object to the function correct?
  2. Is there a way to work around this?

Thanks.

cannot compute distance matrix for only mnist

I was trying to compute the distance matrix for only mnist but I get the error:

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Embedding mnist
Traceback (most recent call last):
  File "/Users/brando/anaconda3/envs/metalearning/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/brando/anaconda3/envs/metalearning/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'mnist.<locals>.<lambda>'

is your code supposed to run out of the box?

code:

# %%
from pathlib import Path

import torch

from task2vec import Task2Vec
from models import get_model
import datasets
import task_similarity

# %%

# dataset_names = ('stl10', 'mnist', 'cifar10', 'cifar100', 'letters', 'kmnist')
dataset_names = ('mnist', )
# Change `root` with the directory you want to use to download the datasets
dataset_list = [datasets.__dict__[name](root=Path('~/data').expanduser())[0] for name in dataset_names]

# %%

device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")

embeddings = []
for name, dataset in zip(dataset_names, dataset_list):
    print(f"Embedding {name}")
    probe_network = get_model('resnet34', pretrained=True, num_classes=int(max(dataset.targets) + 1)).to(device)
    embeddings.append(Task2Vec(probe_network, max_samples=1000, skip_layers=6).embed(dataset)).to(device)
    # embeddings.append(Task2Vec(probe_network, max_samples=100, skip_layers=6).embed(dataset))

# %%

task_similarity.plot_distance_matrix(embeddings, dataset_names)

Computing Norm of Embedding

In your paper, you mention that the L1 norm of the task embedding should correlate with the task complexity.

Do you have sample code of how to compute the norm of the embedding? I can't find it in the repo.

I tried the following approach but can't reproduce the results that were reported.

probe_network = get_model('resnet18', pretrained=True, num_classes=no_classes).cuda()
task2vec = Task2Vec(probe_network, max_samples=5_000, skip_layers=6)
t2v_embed = task2vec.embed(dset)
l1_norm = np.linalg.norm(t2v_embed.hessian, ord=1)

Thank you very much!

Missing taxonomical txt files?

Hi, it seems like there are missing txt files containing taxonomical information. Namely taxonomy.txt as referenced in dataset/cub.py and passeriformes.txt as in datasets.py. I don't believe that these are part of the standard CUB release. Can you point me to them or provide them in the repo? Thank you!

Missing Files

Hi, Do you still have the file for taxonomy.txt and passerifromes.txt? Thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.