Giter VIP home page Giter VIP logo

few_shot_meta_learning's Introduction

Few-shot meta-learning

This repository contains the implementations of many meta-learning algorithms to solve the few-shot learning problem in PyTorch, including:

Python package requirements

  • PyTorch 1.8.1 or above (which introduces a new module called "Lazy", corresponding to the Dense layer in Tensorflow)
  • higher

New updates with functional form of torch module

What does "functional" mean? It is similar to the module torch.nn.functional, where the parameters can be handled explicitly, not implicitly as in PyTorch torch.nn.Sequential(). For example:

# conventional with implicitly-handled parameter
y = net(x) # parameters are handled by PyTorch implicitly

# functional form
y = functional_net(x, params=theta) # theta is the parameter

With the current PyTorch, one needs to manually implement the "functional" form of every component of the model of interest via torch.nn.functional. This is, however, inconvenient when changing network architecture.

Fortunately, Facebook Research has developed higher - a library that can easily convert any "conventional" neural network into its "functional" form to handle parameter explicitly. For example:

# define a network
resnet18 = torchvision.models.resnet18(pretrain=False)

# get its parameters
params = list(resnet18.paramters())

# convert the network to its functional form
f_resnet18 = higher.patch.make_functional(restnet18)

# forward with functional and handling parameter explicitly
y1 = f_resnet18.forward(x=x1, params=params)

# update parameter
new_params = update_parameter(params)

# forward on different data with new paramter
y2 = f_resnet18.forward(x=x2, params=new_params)

Hence, we only need to load or specify the "conventional" model written in PyTorch without manually re-implementing its "functional" form. A few common models are implemented in CommonModels.py.

Although higher provides convenient APIs to track gradients, it does not allow us to use the "first-order" approximate, resulting in more memory and longer training time. I have created a work-around solution to enable the "first-order" approximation, and controlled this by setting --first-order=True when running the code.

Majority of the implementation is based on the abstract base class MLBaseClass.py, and each of the algorithms is written in a separated class. The main program is specified in main.py. PLATIPUS is slightly different since the algorithm mixes between training and validation subset, and hence, implemented in a separated file.

Operation mechanism explanation

The implementation is mainly in the abstract base class MLBaseClass.py with some auxilliary classes and functions in _utils.py. The operation principle of the implementation can be divided into 3 steps:

Step 1: initialize hyper-net and base-net

Recall the nature of the meta-learning as:

θ → w → y ← x,

where θ denotes the parameter of the hyper-net, w is the base-model parameter, and (x, y) is the data.

The implementation is designed to follow this generative process, where the hyper-net will generate the base-net. It can be summarized in the following pseudo-code:

# initialization
base_net = ResNet18() # base-net

# convert conventional functional
f_base_net = torch_to_functional_module(module=base_net)

# make hyper-net from the base-net
hyper_net = hyper_net_cls(base_net=base_net)

# the hyper-net generates the parameter of the base-net
base_net_params = hyper_net.forward()

# make prediction
y = f_base_net(x, params=base_net_params)
  • MAML: the hyper-net is the initialization of the base-net. Hence, the generative process follows identity operator, and hence, hyper_net_cls is defined as the class IdentityNet in _utils.py.
  • ABML and VAMPIRE: the base-net parameter is a sample drawn from a diagonal Gaussian distribution parameterized by the meta-parameter. Hence, the hyper-net is designed to simulate this sampling process. In this case, hyper_net_cls is the class NormalVariationalNet in _utils.py.
  • Prototypical network is different from the above algorithms due to its metric-learning nature. In the implementation, only one network is used as hyper_net, while the base_net is set to None.

Why is it such a complicated implementation? It is to allow us to share the common procedures of many meta-learning algorithms via the abstract base class MLBaseClass. If it is not cleared to you, please open an issue or send me an email. I am happy to discuss to improve the readability of the code further.

Step 2: task adaptation (often known as inner-loop)

There are 2 sub-functions corresponding to MAML-like algorithms and protonet.

adapt_to_episode - applicable for MAML-like algorithms

The idea is simple:

  1. Generate the parameter(s) of the base-net from the hyper-net
  2. Use the generated base-net parameter(s) to calculate loss on training (also known as support) data
  3. Minimize the loss w.r.t. the parameter of the hyper-net
  4. Return the (task-specific) hyper-net (assigned to f_hyper_net) for that particular task

adapt to task by calculating prototypes - applicable for Prototypical Networks

Calculate and return the prototypes in the embedding space

Step 3: evaluate on validation subset

The task-specific hyper-net, or f_hyper_net in the case of MAML-like algorithms, or the prototypes in the case of prototypical networks, are used to predict the labels of the data in the validation subset.

  • In training, the predicted labels are used to calculate the loss, and the parameter of the hyper-net is updated to minimize that loss.
  • In testing, the predicted labels are used to compute the prediction accuracy.

Note that ABML is slightly different since it also includes the loss made by the task-specific hyper-net on the training subset. In addition, it places prior on the parameter of the hyper-net. This is implemented in the methods loss_extra() and loss_prior, respectively.

Data source

Regression

The DataLoader in PyTorch is modified to generate data for multimodality tasks where each regression is generated from either a sinusoidal or linear function. To run with regression, please specify --datasource SineLine as one of the input arguments.

A Jupyter Notebook (visualize_regression.ipynb) to visualize regression results saved in the meta_learning folder is also added.

Classification

Omniglot and mini-ImageNet are the two datasets considered. They are organized following the torchvision.datasets.ImageFolder.

Dataset
│__alphabet1_character1 (or class1)
|__alphabet2_character2 (or class2)
...
|__alphabetn_characterm (or classz)

You can modify the transformations in main.py to fit your need about image sizes or image normalization.

The implementation replies on torch.utils.data.DataLoader with customized EpisodeSampler.py to generate data for each task. The implementation also support loading multiple datasets by appending --datasource dataset_name --datasource another_dataset_name in the input arguments.

If the original structure of Omniglot (train -> alphabets -> characters) is desired, you might need to append the list of all alphabet names to config['datasource'].

Run

To run, copy and paste the command at the beginning of each algorithm script and change the configurable parameters (if needed).

To test, simply specify which saved model is used via variable resume_epoch and replace --train by --test at the end of the commands found on the top of main.py.

Tensorboard

Tensorboard is also integrated into the implementation. Hence, you can open it and monitor the training on your favourite browser:

tensorboard --logdir=<your destination folder>

Then open the browser and see the training progress at:

http://localhost:6006/

Final note

If you only need to run MAML and feel that my implementation is complicated, torch-meta is a worthy repository to take a look. The difference between torch-meta and mine is to extend the implementation to other algorithms, such as VAMPIRE and ABML.

If you feel this repository useful, please give a ⭐ to motivate my work.

In addition, please consider to give a ⭐ to the higher repository developed by Facebook. Without it, we still suffer from the arduous re-implementation of model "functional" form.

few_shot_meta_learning's People

Contributors

cnguyen10 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

few_shot_meta_learning's Issues

Consultation about the code

Hello, I want to utilize BMAML and PLATIPUS to multi-label sequence(1D) classification, and the code is really helpful! The difficulty I face now is that there are too many files, and I don't know in what order should I modify the code. Could you please give me some advice and help?

First order approximate typo?

if config['first_order']:

The inputs always set to q_params, no matter the first-order is true or false. Is this a typo?

            if config['first_order']:
                all_grads = torch.autograd.grad(
                    outputs=loss,
                    inputs=q_params,
                    retain_graph=config['train_flag']
                )
            else:
                all_grads = torch.autograd.grad(
                    outputs=loss,
                    inputs=q_params,
                    create_graph=config['train_flag']
                )

getting NaN's in ABML at about epoch 14

Thanks for publishing these implementations. I was running ABML on Omniglot for (20/5, 20/1, 5/5, and 5/1) n-way k-shot learning problems. On all four of the above experiments I start getting NaN's for ABML around epoch 12-14. Its weird that they all fail at the same spot. I have looked through the code carefully and I cannot see anything directly which might cause this, but you may have a better idea since you implemented it...

Any ideas where this is coming from? Here are the flags I was using to train. I may have added some flags related to dataloading, but nothing that interferes with the core.

python main.py \
    --datasource=$DATASET \
    --ds-folder $ROOT \
    --run $RUN \
    --ml-algorithm=abml \
    --num-models=2 \
    --minibatch 16 \
    --no-batchnorm \
    --n-way=20 \
    --k-shot=5 \
    --v-shot=$VSHOT \
    --num-epochs=40 \
    --num-episodes-per-epoch 10000 \
    --resume-epoch=0 \
    --train

NaN loss when training with sine

Hi.
Thanks you so much for sharing the code.
i cloned your repo and i tried to run abml.py with sine curve. I get nan loss and the code exits.
Please let me know if i need to change any hyper parameters for this task.
Thanks.

Loss function for implementation of BMAML

Hi. Thank you for uploading the code for all these algorithms.

For the implementation of bayesian maml (classification problem), you are using simple cross-entropy loss. But, bayesian maml uses the chaser loss. Specifically, for each task, I think we have to compute the chaser and leader using SVGD and then update the global parameters using the average of the difference over multiple particles (multiple models) (This might be much tricker).

Did I miss something in this implementation? Correct me if I am wrong or I missed something.

thanks,
Deep

Regression code

Hi, thank you for the wonderful code!

Are there any plans to open the code of regression models?

Thanks!
Jihoon

Platipus loss function potentially doesn't match paper

Hi,

thank you for the great implementation of meta learning algorithms.

We are trying to evaluate the PLATIPUS algorithm and we noted that the paper requires a Negative Log Likelihood loss (Page 3, 3 Preliminaries), yet your adaptation step by default is using an MSE loss.

We were wondering if this is an adaptation of the original paper on your part or if we’re missing a crucial step where the NLL is calculated from the MSE. Later the loss we’re suspecting to be an MSE loss is logged to Tensorboard as a NLL again(Platipus.py, 216). This seems especially important since in Platipus.py, 184 the gradient is calculated on this loss function. To our understanding, this gradient might differ significantly from the gradient intended in the paper since it is calculated on a different loss function.

Is this just a trick in the implementation to use the MSE instead of the NLL or are we missing something in the implementation?

Thanks a lot in advance,
Leon

Some questions about this code.

  1. When calculating the calibration, why it adds some noise to the target in regression?
    Like this,
    outputs = outputs + (self.noise_std2)output_noises*
    And should it be
    outputs = outputs + self.noise_std*output_noises

  2. As we all know, MAML can only calculate one value for a input, then how can it calculate the reliability diagram like fig 2(c) in you paper?

Thank you very much for your kind consideration and I am looking forward to your early reply.

Loss is NaN in PLATIPUS

Hi,

when trying to run Platipus.py with the provided defaults:
python3 main.py --datasource SineLine --ml-algorithm platipus --num-models 4 --first-order --network-architecture FcNet --no-batchnorm --num-ways 1 --k-shot 5 --inner-lr 0.001 --meta-lr 0.001 --num-epochs 100 --resume-epoch 0 --train

the following error occurs:
Platipus.py, line 195: ValueError: Loss is NaN.

Do you have any idea what might cause this? Doesn't seem to be a config issue since both defaults and altered params don't get any other results.

Thanks!
Leon

test in Platius model

Hi Cuong,
Thanks for your great work.
I want to ask about the test function in Platius.py what is the meaning of eps_generator in test function

Question about the initialization of theta0 in abml

If I understand it right, the distribution of theta0 is not a vanilla Normal distribution as it's written in the application details in the original paper and instead it's a distribution of a normal distribution multiplied by a gamma distribution. However in your code I am confused about how is that presents as it seems that you just initialize mean and logsigma of theta0. How did you implement this part? I am struggling implementing a similar algorithm like abml and really really hope you can give some help. thx:)

Models not training

Hi,

I like your repository and the code you have implemented. I am facing a couple of issues:

The dataset - i cannot find the dataset in the format you have asked. I have tried using the dataloader from torchmeta to run with your code but the issue is that most of your code when run does not go beyond 20.00 accuracy. Do you have any advice on what I may do?

Question about the implementation of VAMPIRE

Hi Cuong,

I really appreciate your work for the VAMPIRE algorithm. There are some questions about the implementation in Vampire2.py.

  1. Why is the KL loss in the inner loop implemented as the KL distribution between q(\theta) and the standard Gaussion, instead of KL(q||p)? I do not find the correspondence in the original paper of VAMPIRE.
  2. Why does the global update (i.e., validation() function in the code) also need a KL loss?

Thank you.

Potential Problem of the loss function in ABML

Hi Cuong,

I really appreciate your work, especially this is the only piece of implementation of ABML I could find.

However,
when you calculate the loss for updating the meta-parameters here, it seems that you left the
image behind.

I hope you could check this out. Looking forward to your reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.