Giter VIP home page Giter VIP logo

keras-group-normalization's Introduction

Group Normalization in Keras

A Keras implementation of Group Normalization by Yuxin Wu and Kaiming He.

Useful for fine-tuning of large models on smaller batch sizes than in research setting (where batch size is very large due to multiple GPUs). Similar to Batch Renormalization, but performs significantly better on ImageNet.

Group Normalization

The above image is from the paper. It describes the differences between the 4 types of normalization techniques generally used.

As can be seen, GN is independent of batchsize, which is crucial for fine-tuning large models which cannot be retrained with small batch sizes due to Batch Normalization's dependence on large batchsizes to compute the statistics of each batch and update its moving average perameters properly.

Usage

Dropin replacement for BatchNormalization layers from Keras. The important parameter that is different from BatchNormalization is called groups. This must be appropriately set, and requires certain constraints such as :

  1. Needs to an integer by which the number of channels is divisible.
  2. 1 <= G <= #channels, where #channels is the number of channels in the incomming layer.
from group_norm import GroupNormalization

ip = Input(shape=(...))
x = GroupNormalization(groups=32, axis=-1)
...

keras-group-normalization's People

Contributors

titu1994 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

keras-group-normalization's Issues

About axis parameter

image

Hello titu 1994.

I use group normalization code for my project.

If i use GN after a layer whose output shape is (batch_size, C, D, H, W), is it okay to set axis = 1 because the format is channel-first?

Working with dynamic input shape

Hi,

When trying to apply the GroupNormalization Layer to fully convolutionnal networks with a dynamic input shape (as there are not mathematical constraints for a fixed input shape and keras usually allows most of its layers to have a dynamic input shape) i have found the Layer unable to work. This may be due to the K.reshape call that only allows one dimension (the batch size) to be None,

Would you have a solution in mind for GroupNormalization to work w/ dynamic input shapes ? (Is it even possible to implement it this way, like BN ?)

Regards,


Here is the code I'm trying to launch:

i = Input(shape=(None,None,3))
c = Conv2D(64, (3,3), padding="same")(i)
o = GroupNormalization(axis=3)(c)
m = Model(inputs=i,outputs=o)

And the error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    516     try:
--> 517       str_values = [compat.as_bytes(x) for x in proto_values]
    518     except TypeError:

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in <listcomp>(.0)
    516     try:
--> 517       str_values = [compat.as_bytes(x) for x in proto_values]
    518     except TypeError:

/opt/conda/lib/python3.6/site-packages/tensorflow/python/util/compat.py in as_bytes(bytes_or_text, encoding)
     66     raise TypeError('Expected binary or unicode string, got %r' %
---> 67                     (bytes_or_text,))
     68 

TypeError: Expected binary or unicode string, got -1

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-5-91a4d6819a5e> in <module>()
      1 i = Input(shape=(None,None,3))
      2 c = Conv2D(64, (3,3), padding="same")(i)
----> 3 o = GroupNormalization(axis=3)(c)
      4 m = Model(inputs=i,outputs=o)

/opt/conda/lib/python3.6/site-packages/keras/engine/topology.py in __call__(self, inputs, **kwargs)
    617 
    618             # Actually call the layer, collecting output(s), mask(s), and shape(s).
--> 619             output = self.call(inputs, **kwargs)
    620             output_mask = self.compute_mask(inputs, previous_mask)
    621 

/home/code/keras_shipdetection/layers/group_normalization.py in call(self, inputs, **kwargs)
    135         needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
    136 
--> 137         inputs = K.reshape(inputs, group_shape)
    138 
    139         mean = K.mean(inputs, axis=group_reduction_axes[2:], keepdims=True)

/opt/conda/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in reshape(x, shape)
   1896         A tensor.
   1897     """
-> 1898     return tf.reshape(x, shape)
   1899 
   1900 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py in reshape(tensor, shape, name)
   6111   if _ctx is None or not _ctx._eager_context.is_eager:
   6112     _, _, _op = _op_def_lib._apply_op_helper(
-> 6113         "Reshape", tensor=tensor, shape=shape, name=name)
   6114     _result = _op.outputs[:]
   6115     _inputs_flat = _op.inputs

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    511           except TypeError as err:
    512             if dtype is None:
--> 513               raise err
    514             else:
    515               raise TypeError(

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    508                 dtype=dtype,
    509                 as_ref=input_arg.is_ref,
--> 510                 preferred_dtype=default_dtype)
    511           except TypeError as err:
    512             if dtype is None:

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx)
   1102 
   1103     if ret is None:
-> 1104       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1105 
   1106     if ret is NotImplemented:

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    233                                          as_ref=False):
    234   _ = as_ref
--> 235   return constant(v, dtype=dtype, name=name)
    236 
    237 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name, verify_shape)
    212   tensor_value.tensor.CopyFrom(
    213       tensor_util.make_tensor_proto(
--> 214           value, dtype=dtype, shape=shape, verify_shape=verify_shape))
    215   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    216   const_tensor = g.create_op(

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    519       raise TypeError("Failed to convert object of type %s to Tensor. "
    520                       "Contents: %s. Consider casting elements to a "
--> 521                       "supported type." % (type(values), values))
    522     tensor_proto.string_val.extend(str_values)
    523     return tensor_proto

TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [-1, 32, None, None, 2]. Consider casting elements to a supported type.

I am running tensorflow 1.8.0 and keras 2.1.6

TypeError

TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [-1, 32, None, None, 2]. Consider casting elements to a supported type.

inputs = K.reshape(inputs, group_shape)

Permute dimensions before K.reshape

Hi, @titu1994. Thanks for your hard work.
I have one question after reading your code. When use tf as backend and set axis=-1 and if the input shape is [n, h, w, c], the current implementation would reshape the inputs to [n, g, h, w, c//g]. Do we need to permute it to [n, c, h, w] first, then reshape it to [n, g, c//g, h, w]?

where is moving_average update ?

I found that there is no moving_average_mean/var, which is necessary in BatchNorm step. the group normalization without moving_average updat , it is oK?

Is figure of Batch normalization correct?

original figure of BN
As far as I am concerned, the representation of batch normalization is not correct in the original paper. I post the issue here for discussion.
I think the batch normalization should be like the following figure.
BN

The key point is how to calculate mean and std.
With feature maps' shape as (batch_size, channel_number, width, height),
mean = X.mean(axis=(0, 2, 3), keepdims=True)
or
mean = X.mean(axis=(0, 1), keepdims=True)

Which one is correct?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.