Giter VIP home page Giter VIP logo

jaxpruner's Introduction

JaxPruner: a research library for sparsity research

Jaxpruner logo

Paper: arxiv.org/abs/2304.14082

Introduction

JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: scenic, t5x, dopamine and fedjax.

We believe a sparsity library in Jax has the potential to accelerate sparsity research. This is because:

  • Functional nature of jax makes it easy to modify parameters and masks.
  • Jax is easy debug.
  • Jax libraries and their usage in research is increasing.
  • For further motivation read why Deepmind uses jax here.

There are exciting developments for accelerating sparsity in neural networks (K:N sparsity, CPU-acceleration, activation sparsity) and various libraries aim to enable such acceleration (todo). JaxPruner focuses mainly on accelerating algorithms research for sparsity. We mock sparsity by using binary masks and use dense operations for simulating sparsity. In the longer run, we also plan to provide integration with the jax.experimental.sparse, aim to reduce the memory footprint of our models.

JaxPruner has 3 tenets:

  • Easy Integration: requires minimal changes to use.
  • Research First: provides strong baselines and is easy to modify.
  • Minimal Overhead: runs as fast as (dense) baseline.

Easy Integration

Research in Machine Learning is fast paced. This and the huge variety of Machine Learning applications result in a high number of ever-changing codebases. At the same time,adaptability of new research ideas highly correlates with their ease of use. Therefore, JaxPruner is designed to be easily integrated into existing codebases with minimal changes.JaxPruner uses the popular optax optimization library to achieve this, requiring minimal changes when integrating with existing libraries. State variables (i.e. masks, counters) needed for pruning and sparse training algorithms are stored together with the optimization state, which makes parallelization and checkpointing easy.

tx, params = _existing_code()
pruner = jaxpruner.MagnitudePruning(**config) # Line 1: Create pruner.
tx = pruner.wrap_optax(tx) # Line 2: Wrap optimizer.

Research First

Often research projects require running multiple algorithms and baselines and so they benefit greatly from rapid prototyping. JaxPruner achieves this by committing to a generic API shared among different algorithms, which in return makes it easy to switch between different algorithms. We provide implementations for common baselines and make them easy to modify. A quick overview of such features are discussed in our colabs.

Minimal Overhead

Sparse training and various pruning recipes requires some additional operations like masking. When we implement such basic operations we aim to minimize the overhead introduced (both memory and compute) and be as fast as the dense baseline.

pruner = jaxpruner.MagnitudePruning(is_packed=True) # Line 1: Reduces mask overhead.

Installation

You can install JaxPruner using pip directly from the source.

pip3 install 

Alternatively you can also clone the source and run tests using the run.sh script.

git clone https://github.com/google-research/jaxpruner.git
cd jaxpruner

Following script creates a virtual environment and installs the necessary libraries. Finally, it runs the tests.

bash run.sh

Quickstart

See our Quickstart colab: Quick Start Colab

We also have Deep-Dive and Mnist Pruning colabs.

Baselines

Here we share our initial experiments with baselines implemented.

no_prune random magnitude saliency global_magnitude magnitude_ste static_sparse set rigl
ResNet-50 76.67 70.192 75.532 74.93 75.486 73.542 71.344 74.566 74.752
ViT-B/16 (90ep) 74.044 69.756 72.892 72.802 73.598 74.208 64.61 70.982 71.582
ViT-B/16 (300ep) 74.842 73.428 75.734 75.95 75.652 76.128 70.168 75.616 75.64
Fed. MNIST 86.21 83.53 85.74 85.60 86.01 86.16 83.33 84.20 84.64
t5-Base (C4) 2.58399 3.28813 2.95402 3.52233 5.43968 2.7124 3.17343 3.13115 3.12403
DQN-CNN (MsPacman) 2588.82 1435.29 2123.83 - 2322.21 - 1156.69 1723.3 1535.19

Citation

@inproceedings{jaxpruner,
  title={JaxPruner: A concise library for sparsity research},
  author={Joo Hyung Lee and Wonpyo Park and Nicole Mitchell and Jonathan Pilault and Johan S. Obando-Ceron and Han-Byul Kim and Namhoon Lee and Elias Frantar and Yun Long and Amir Yazdanbakhsh and Shivani Agrawal and Suvinay Subramanian and Xin Wang and Sheng-Chun Kao and Xingyao Zhang and Trevor Gale and Aart J. C. Bik and Woohyun Han and Milen Ferev and Zhonglin Han and Hong-Seok Kim and Yann Dauphin and Karolina Dziugaite and Pablo Samuel Castro and Utku Evci},
  year={2023}
}

Disclaimer

This is not an officially supported Google product.

jaxpruner's People

Contributors

evcu avatar hawkinsp avatar lenscloth avatar tink-expo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

jaxpruner's Issues

Roadmap and experimental.sparse

Hi everyone!

First, I wanted to say thanks for such an easy-to-use library! I've been using Jaxpruner for some weeks now, resulting in highly pruned convolutional models (mostly with unstructured pruners).

I wanted to ask if there was a roadmap for the library and if you had any expected release date for the "integration with the jax.experimental.sparse"

Thanks in advance

TypeError: Subscripted generics cannot be used with class and instance checks

I am using python3.9 and I am getting the following unit test error when installing:

Traceback (most recent call last):
  File "/home/jpilault/jaxpruner/jaxpruner/sparsity_distributions_test.py", line 63, in testUniformSparsityMapGeneratorWithCustomMap
    result = sparsity_distributions.uniform(
  File "/home/jpilault/jaxpruner/jaxpruner/sparsity_distributions.py", line 72, in uniform
    if isinstance(params, chex.Array):
  File "/usr/lib/python3.9/typing.py", line 703, in __instancecheck__
    return self.__subclasscheck__(type(obj))
  File "/usr/lib/python3.9/typing.py", line 706, in __subclasscheck__
    raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks

add layer_norm does not work (optimization)

I am trying to add layer_normalization trick to stabilize the training process. However, I found it did not train (the total performance does not increase, without learning signals) after the operation of pruner.wrap_optax(). Do you have any suggestions about how to fix it?

here is some source code:

class MLP(nn.Module):
    hidden_dims: Sequence[int] # (512, 512)
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: int = False
    layer_normalization: bool = False
    dropout_rate: Optional[float] = None
    
    @nn.compact
    def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray:
        for i, size in enumerate(self.hidden_dims):
            x = nn.Dense(size, kernel_init=default_init())(x)
            if self.layer_normalization:
                x = nn.LayerNorm()(x) # add layer normalization 
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                x = self.activations(x)
                if self.dropout_rate is not None:
                    x = nn.Dropout(rate=self.dropout_rate)(
                        x, deterministic=not training)
        return x

Here is the performance (Fig. 1)obtained with nn.LayerNorm(). Through controlled experiments, the algorithm can achieve high performance without nn.LayerNorm(). The source code is from jaxrl
layer_norm
Fig. 1: With layer normalization, there is no learning signal and no performance improvement.
no_layer_norm
Fig. 2: Without layer normalization, the performance of the model significantly increases and then deteriorates during subsequent training. This decline in performance can be attributed to the high sparsity level, approximately 95%.

Basical config inquires

Hi,

When I create a new sparsity_pruner with some basic config as follows:

config.sparsity_config.algorithm = 'rigl'

config.sparsity_config.update_freq = 10
config.sparsity_config.update_end_step = 1000
config.sparsity_config.update_start_step = 1
config.sparsity_config.sparsity = 0.95
config.sparsity_config.dist_type = 'erk'

My question is where is the update_freq config? Is there any inner iteration to count the training steps for pruning? Is there any source code for this, I did not find them.

From the best of my knowledge, after pruner.wrap_optax operator for specific optimizator, every call for optax.apply_updates will increase the inner step plus 1 in the pruner. Is this true?

'mock sparsity'

I'm new to sparsity. So I'm a little confused by this sentence in the readme :

"We mock sparsity by using binary masks and use dense operations for simulating sparsity. In the longer run, we also plan to provide integration with the jax.experimental.sparse, aim to reduce the memory footprint of our models."

What is meant by 'we mock sparsity'?

wrapped optimizer with params

Following the quick_start.ipynb document in #Modification #1, I create a wrapped optimizer with params info from the model, where it occurs a bug and I do not know how to fix it.

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/site-packages/jaxpruner/base_updater.py", line 225, in init_fn
    sparse_state = self.init_state(params)
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/site-packages/jaxpruner/base_updater.py", line 167, in init_state
    target_sparsities = self.sparsity_distribution_fn(params)
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/site-packages/jaxpruner/sparsity_distributions.py", line 72, in uniform
    if isinstance(params, chex.Array):
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/typing.py", line 720, in __instancecheck__
    return self.__subclasscheck__(type(obj))
  File "/home/li_jiang/.conda/envs/jl-jax/lib/python3.9/typing.py", line 723, in __subclasscheck__
    raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks

Source code

# first create a pruner by some default settings
actor_optimizer = optax.adam(learning_rate=actor_lr)
actor_optimizer = pruner.wrap_optax(actor_optimizer)
actor = Model.create(actor_def,
                             inputs=[actor_key, observations],
                             tx=actor_optimizer)
@classmethod
    def create(cls,
               model_def: nn.Module,
               inputs: Sequence[jnp.ndarray],
               tx: Optional[optax.GradientTransformation] = None) -> 'Model':
        variables = model_def.init(*inputs)

        _, params = variables.pop('params')
       ###########
       # where the bug is
        if tx is not None:
            opt_state = tx.init(params)
        else:
            opt_state = None
        ###############

        return cls(step=1,
                   apply_fn=model_def.apply,
                   params=params,
                   tx=tx,
                   opt_state=opt_state)

    def __call__(self, *args, **kwargs):
        return self.apply_fn({'params': self.params}, *args, **kwargs)

    def apply_gradient(
            self,
            loss_fn: Optional[Callable[[Params], Any]] = None,
            grads: Optional[Any] = None,
            has_aux: bool = True) -> Union[Tuple['Model', Any], 'Model']:
        assert (loss_fn is not None or grads is not None,
                'Either a loss function or grads must be specified.')
        if grads is None:
            grad_fn = jax.grad(loss_fn, has_aux=has_aux)
            if has_aux:
                grads, aux = grad_fn(self.params)
            else:
                grads = grad_fn(self.params)
        else:
            assert (has_aux,
                    'When grads are provided, expects no aux outputs.')

        updates, new_opt_state = self.tx.update(grads, self.opt_state,
                                                self.params)
        new_params = optax.apply_updates(self.params, updates)

        new_model = self.replace(step=self.step + 1,
                                 params=new_params,
                                 opt_state=new_opt_state)
        if has_aux:
            return new_model, aux
        else:
            return new_model

It is a source code for jaxrl in jaxrl/agents/sac/sac_learner.py. BTW, since I am really new to JAX and sparse NN, would you mind providing some guidance (example code provided by ) about constructing the SAC pruning example? It will be highly appreciated.

Verson

jax = 0.4.11
python = 3.9.1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.