Giter VIP home page Giter VIP logo

crest's Introduction

CREST: Coresets for Stochastic Gradient Descent

This is the official repository for the paper Towards Sustainable Learning: Coresets for Data-efficient Deep Learning (ICML 2023) by Yu Yang, Hao Kang, and Baharan Mirzasoleiman.

Abstract

To improve the efficiency and sustainability of learning deep models, we propose CREST, the first scalable framework with rigorous theoretical guarantees to identify the most valuable examples for training non-convex models, particularly deep networks. To guarantee convergence to a stationary point of a non-convex function, CREST models the non-convex loss as a series of quadratic functions and extracts a coreset for each quadratic sub-region. In addition, to ensure faster convergence of stochastic gradient methods such as (mini-batch) SGD, CREST iteratively extracts multiple mini-batch coresets from larger random subsets of training data, to ensure nearly-unbiased gradients with small variances. Finally, to further improve scalability and efficiency, CREST identifies and excludes the examples that are learned from the coreset selection pipeline. Our extensive experiments on several deep networks trained on vision and NLP datasets, including CIFAR-10, CIFAR-100, TinyImageNet, and SNLI, confirm that CREST speeds up training deep networks on very large datasets, by 1.7x to 2.5x with minimum loss in the performance. By analyzing the learning difficulty of the subsets selected by CREST, we show that deep models benefit the most by learning from subsets of increasing difficulty levels.

Installation

This code is tested with Python 3.8.8 and PyTorch 1.9.1 with CUDA 11.5.

To install the required packages, run

pip install -r requirements.txt

Usage

python crest_train.py

--dataset: The dataset to use. (default: cifar10)

  • cifar10: CIFAR-10 dataset
  • cifar100: CIFAR-100 dataset
  • tinyimagenet: TinyImageNet dataset

โš ๏ธ The TinyImageNet dataset is not included in this repository. Please download the dataset from here.

--data_dir: The directory to store the dataset. (default: ./data)

--arch: The model architecture to use. (default: resnet20)

  • resnet20: ResNet-20 model for CIFAR-10
  • resnet18: ResNet-18 model for CIFAR-100
  • resnet50: ResNet-50 model for TinyImageNet

--seed: The random seed to use. (default: 0)

--selection_method: The data selection method to use. (default: random)

--train_frac: The fractrion of training steps to use compared to full training. (default: 0.1)

Hyperparameter Tuning

To achieve the most cost-efficient training, the following hyperparameters need be tuned for different datasets and models.

  • --check_thresh_factor: The fraction of the training loss to use as the threshold for coreset selection. (default: 0.1)

Adding New Datasets

To add a new dataset, you need to add the dataset loading code in datasets/dataset.py. Then, you need to add the dataset name to the choices of --dataset argument in utils/arguments.py.

Adding New Models

To add a new model, you need to create a new file in models/ folder, which contains the model class. For example, models/resnet.py contains the class ResNet20, which is the ResNet-20 model. Then, you need to add the model name to the choices of --arch argument in utils/arguments.py.

Adding New Data Selection Methods

To add a new data selection method, you need to create a new file in trainers/ folder, which contains a subclass of SubsetTrainer class defined in trainers/subset_trainer.py. For example, trainers/random_trainer.py contains the class RandomTrainer, which is the trainer for random selection. Then, you need to add the method name to the choices of --selection_method argument in utils/arguments.py.

Citation

If you find this repository useful, please cite our paper:

@inproceedings{yang2023towards,
  title={Towards Sustainable Learning: Coresets for Data-efficient Deep Learning},
  author={Yang, Yu and Kang, Hao and Mirzasoleiman, Baharan},
  booktitle={In Proceedings of the 40th International Conference on Machine Learning},
  year={2023}
}

Acknowledgement

The code is based on Craig and AdaHessian.

Disclaimer and Contact

The current version of the code is refactored for better readability and extensibility, but it is still under the process of testing and optimization.

If you have any questions or suggestions, please contact Yu Yang ([email protected]).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.