Giter VIP home page Giter VIP logo

rishit-dagli / gradient-centralization-tensorflow Goto Github PK

View Code? Open in Web Editor NEW
106.0 5.0 21.0 787 KB

Instantly improve your training performance of TensorFlow models with just 2 lines of code!

Home Page: https://pypi.org/project/gradient-centralization-tf/

License: Apache License 2.0

Python 100.00%
tensorflow tensorflow2 machine-learning gradient-centralization python python3 pip python-package deep-learning neural-network

gradient-centralization-tensorflow's Introduction

Gradient Centralization TensorFlow Twitter

PyPI DOI Upload Python Package Flake8 Lint Python Version

Binder Open In Colab

GitHub license PEP8 GitHub stars GitHub forks GitHub watchers

This Python package implements Gradient Centralization in TensorFlow, a simple and effective optimization technique for Deep Neural Networks as suggested by Yong et al. in the paper Gradient Centralization: A New Optimization Technique for Deep Neural Networks. It can both speedup training process and improve the final generalization performance of DNNs.

Installation

Run the following to install:

pip install gradient-centralization-tf

About the Examples

Open In Colab Binder

This notebook shows the the process of using the gradient-centralization-tf Python package to train on the Fashion MNIST dataset availaible from tf.keras.datasets. It further also compares using gctf and performance without using gctf.

Open In Colab Binder

This notebook shows the the process of using the gradient-centralization-tf Python package to train on the Horses vs Humans dataset by Laurence Moroney. It further also compares using gctf and performance without using gctf.

Usage

Create a centralized gradients functions for a specified optimizer.

Arguments:

  • optimizer: a tf.keras.optimizers.Optimizer object. The optimizer you are using.

Example:

>>> opt = tf.keras.optimizers.Adam(learning_rate=0.1)
>>> opt.get_gradients = gctf.centralized_gradients_for_optimizer(opt)
>>> model.compile(optimizer = opt, ...)

Returns:

A tf.keras.optimizers.Optimizer object.

Computes the centralized gradients.

This function is ideally not meant to be used directly unless you are building a custom optimizer, in which case you could point get_gradients to this function. This is a modified version of tf.keras.optimizers.Optimizer.get_gradients.

Arguments:

  • optimizer: a tf.keras.optimizers.Optimizer object. The optimizer you are using.
  • loss: Scalar tensor to minimize.
  • params: List of variables.

Returns:

A gradients tensor.

Pre built updated optimizers implementing GC.

This module is speciially built for testing out GC and in most cases you would be using gctf.centralized_gradients_for_optimizer though this module implements gctf.centralized_gradients_for_optimizer. You can directly use all optimizers with tf.keras.optimizers updated for GC.

Example:

>>> model.compile(optimizer = gctf.optimizers.adam(learning_rate = 0.01), ...)
>>> model.compile(optimizer = gctf.optimizers.rmsprop(learning_rate = 0.01, rho = 0.91), ...)
>>> model.compile(optimizer = gctf.optimizers.sgd(), ...)

Returns:

A tf.keras.optimizers.Optimizer object.

Developing gctf

To install gradient-centralization-tf, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow
# or clone your own fork

pip install -e .[dev]

Want to Contribute ๐Ÿ™‹โ€โ™‚๏ธ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? ๐Ÿ’ฌ

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

gradient-centralization-tensorflow's People

Contributors

github-actions[bot] avatar ialimustufa avatar imgbot[bot] avatar rishit-dagli avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

gradient-centralization-tensorflow's Issues

On windows Tensorflow 2.5 it gives error

On windows 10 with miniconda enviroment tensorflow 2.5 gives error on centralized_gradients.py file.

the solution is change
import keras.backend as K
with
import tensorflow.keras.backend as K

Update pypi classifiers

I am specifically thinking of adding three more categories of pypi classifiers:

  • Development status
  • Intended Audience
  • Topic

Apart from this I also think it would be great to add the Programming Language :: Python :: 3 :: Only to make sure the audience to know that this package is intended for Python 3 only.

Add an "About the examples" section

It would be great to write an "About the example" section which could demonstrate in short what the example notebooks aim to achieve and show.

Custom Optimizer Example

It would be great to include an example showing the use of this package with a custom optimizer built on top of tf.keras.Optimizer class.

The results in the mnist example are wrong/misleading

Describe the bug
The results in your colab ipython notebook are misleading: https://colab.research.google.com/github/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/examples/gctf_mnist.ipynb

In this example, the model is first trained with a normal Adam optimizer:

model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss = 'sparse_categorical_crossentropy',
              metrics = ['accuracy'])

history_no_gctf = model.fit(training_images, training_labels, epochs=5, callbacks = [time_callback_no_gctf])

And afterwards the same model is recompiled with the gctf.optimizers.adam(). However, recompiling a keras model does not reset the weights. This means that in the first fit call the model is trained and then in the second fit call with the new optimizer the same model is used and of course then the results are better.

This can be fixed, by recreating the model for the second run, by just adding these few lines:

import gctf #import gctf

time_callback_gctf = TimeHistory()

# Model architecture
model = tf.keras.models.Sequential([
                                    tf.keras.layers.Flatten(), 
                                    tf.keras.layers.Dense(512, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(256, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(64, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(512, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(256, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(64, activation=tf.nn.relu), 
                                    tf.keras.layers.Dense(10, activation=tf.nn.softmax)])

model.compile(optimizer = gctf.optimizers.adam(),
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

history_gctf = model.fit(training_images, training_labels, epochs=5, callbacks=[time_callback_gctf])

However, then the results are not better than without gctf:

Type                   Execution time    Accuracy      Loss
-------------------  ----------------  ----------  --------
Model without gctf:           24.7659    0.88825   0.305801
Model with gctf               24.7881    0.889567  0.30812

Could you please clarify what happens here. I tried this gctf.optimizers.adam() optimizer in my own research and it didn't change the results at all and now after seeing it doesn't work in the example which was constructed here. Makes me question the results of this paper.

To Reproduce
Execute the colab file given in the repository: https://colab.research.google.com/github/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/examples/gctf_mnist.ipynb

Expected behavior
The right comparison would be if both models start from a random initialization, not that the second model can start with the already pre-trained weights.

Looking forward to a fast a swift explanation.

Best,
Max

Wider dependency requirements

The package as of now to be installed requires tensorflow ~= 2.4.0 and keras ~= 2.4.0. It turns out that this is sometimes problematic for folks who have custom installations of TensorFlow and a winder requirement could be set up.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.