Giter VIP home page Giter VIP logo

clem.pytorch's Introduction

Continual Learning methods using Episodic Memory

This project provides simple PyTorch-based APIs for continual machine learning methods that use episodic memory. Currently, this supports following continual learning algorithms:

Prerequisites

  • Python 3.6
  • PyTorch
  • quadprog

Usage

  • All the supported continual learning methods are encapsulated in a class, each supporting the following APIs:

    • <learner>.prepare() - sets the optimizer; need to be called prior to training on a task
    • <learner>.run() - optimize on a single batch; where the continual learning algorithm is actually run
    • <learner>.remember() - add more data to a FIFO memory buffer; input data must be a PyTorch Dataset
  • Sample:

     from learners import GEM, AGEM, ER
    
     memory_capacity = 10240
     task_memory_size = 2048
     memory_sample_size = 64
    
     # instantiate learner
     learner = AGEM(model, criterion, device=device,
     	       memory_capacity=memory_capacity, memory_sample_sz=memory_sample_size)
    
     # assign optimizer to learner
     learner.prepare(optimizer=torch.optim.Adam, lr=learning_rate)
    
     model.train()
     for ep in tqdm(range(num_epochs)):
         for inputs, labels in train_loader:
     	# optimize on a single batch
     	learner.run(inputs, labels)
    
     # save data
     learner.remember(train_data, min_save_sz=task_memory_size)
    

Experimentation

To test the APIs and to see how the implemented continual learning methods help solve the catastrophic forgetting problem, we test each method against a dataset susceptible to such problem. In particular, we use the MNIST dataset, split the training set into 5 sets of equal size, with each having a different class distribution (we'll discuss this further later). We treat each split of the training set as a single learning task.

The target for each learning method is to progressively get higher accuracy on MNIST dataset as it trains successively on each of the 5 tasks. We use the accuracy on the final task as a measure of the method's capability to learn. For comparability, we use a common test set across all methods on which we report the accuracy values. We also measure the algorithm's performance in terms of execution duration.

Apart from the accuracy of the continual learning algorithms, we also measure the accuracy of "offline"/non-continual training to serve as the "gold standard" for learning. We also measure the final accuracy in a continual learning setting where no special algorithms are used; hence, we call it as "Naive Continual" learning.

All throughout the experiment, a neural network with a single hidden layer is used, with hand-picked hyperparameter settings. The whole experiment can be run in test.ipynb.

Note that this was not meant to be an exhaustive evalution of continual learning methods. Thus, the results shall be taken with a grain of salt. :)

Offline/Non-continual Baseline: 95.80%

For a continual learning setup, we simulate two scenarios:

Case 1: Skewed Splits

In this test, we split the data such that each split or task is comprised dominantly of 2 classes, and only few of the other 8 classes. In particular, each task shall consist 90% of all the training samples of 2 classes, while getting only 2.5% of the remaining classes. This simulates the scenario where there is a defined set of classes, but the influx of data is uneven among the classes, resulting to unbalanced datasets for each learning task.

Method Accuracy Duration (s)
Naive Continual 84.63% 8.89
GEM 95.42% 42.27
A-GEM 89.26% 15.64
ER 93.88% 14.51

Case 2: Class Splits

In contrast to the previous test, in this we use 100% of 2 classes for each task. This also means that each task shall consist only of 2 classes. This simulates an incremental class learning problem, where new classes are added in new tasks.

Method Accuracy Duration (s)
Naive Continual 19.38% 9.46
GEM 93.85% 42.50
A-GEM 55.36% 15.58
ER 86.96% 13.99

License

This project is licensed under the MIT License - see the LICENSE file for details

clem.pytorch's People

Contributors

cjbayron avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.