titu1994 / batchrenormalization Goto Github PK
View Code? Open in Web Editor NEWBatch Renormalization algorithm implementation in Keras
License: MIT License
Batch Renormalization algorithm implementation in Keras
License: MIT License
Thanks for your implementation of batch renormalization. I saw that your performance of batch renorm and batch norm are similar. Do you check the performance with a simple network as in the paper? The paper shows a simple network with high gain between batch renorm and batch norm.
I want to use the BatchRenorm in Caffe, Can the author give a C++ versioin?
Thanks!
I ran your cifar10_brn.py file and I got this error:
InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1627 try:
-> 1628 c_op = c_api.TF_FinishOperation(op_desc)
1629 except errors.InvalidArgumentError as e:
InvalidArgumentError: Shape must be rank 1 but is rank 0 for 'batch_renormalization_4/Reshape_10' (op: 'Reshape') with input shapes: [1,1,1,16], [].
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
in ()
51
52
---> 53 model = create_wide_residual_network(input_dim=init_shape, nb_classes=10, N=2, k=4)
54
55
/content/drive/vguNuke/Jupyter/BatchRenormalization/wrn_renorm.py in create_wide_residual_network(input_dim, nb_classes, N, k, dropout, verbose)
116 ip = Input(shape=input_dim)
117
--> 118 x = initial_conv(ip)
119 nb_conv = 4
120
/content/drive/vguNuke/Jupyter/BatchRenormalization/wrn_renorm.py in initial_conv(input)
12 channel_axis = 1 if K.image_data_format() == "channels_first" else -1
13
---> 14 x = BatchRenormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_init='uniform')(x)
15 x = Activation('relu')(x)
16 return x
/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py in call(self, inputs, **kwargs)
455 # Actually call the layer,
456 # collecting output(s), mask(s), and shape(s).
--> 457 output = self.call(inputs, **kwargs)
458 output_mask = self.compute_mask(inputs, previous_mask)
459
/content/drive/vguNuke/Jupyter/BatchRenormalization/batch_renorm.py in call(self, x, mask)
195 x, broadcast_running_mean, broadcast_running_std,
196 broadcast_beta, broadcast_gamma,
--> 197 epsilon=self.epsilon)
198
199 # pick the normalized form of x corresponding to the training phase
/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in batch_normalization(x, mean, var, beta, gamma, axis, epsilon)
1906 # so it may have extra axes with 1, it is not needed and should be removed
1907 if ndim(mean) > 1:
-> 1908 mean = tf.reshape(mean, (-1))
1909 if ndim(var) > 1:
1910 var = tf.reshape(var, (-1))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in reshape(tensor, shape, name)
6480 if _ctx is None or not _ctx._eager_context.is_eager:
6481 _, _, _op = _op_def_lib._apply_op_helper(
-> 6482 "Reshape", tensor=tensor, shape=shape, name=name)
6483 _result = _op.outputs[:]
6484 _inputs_flat = _op.inputs
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
785 op = g.create_op(op_type_name, inputs, output_types, name=scope,
786 input_types=input_types, attrs=attr_protos,
--> 787 op_def=op_def)
788 return output_structure, op_def.is_stateful, op
789
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
486 'in a future version' if date is None else ('after %s' % date),
487 instructions)
--> 488 return func(*args, **kwargs)
489 return tf_decorator.make_decorator(func, new_func, 'deprecated',
490 _add_deprecated_arg_notice_to_docstring(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in create_op(failed resolving arguments)
3272 input_types=input_types,
3273 original_op=self._default_original_op,
-> 3274 op_def=op_def)
3275 self._create_op_helper(ret, compute_device=compute_device)
3276 return ret
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in init(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
1790 op_def, inputs, node_def.attr)
1791 self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1792 control_input_ops)
1793
1794 # Initialize self._outputs.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1629 except errors.InvalidArgumentError as e:
1630 # Convert to ValueError for backwards compatibility.
-> 1631 raise ValueError(str(e))
1632
1633 return c_op
ValueError: Shape must be rank 1 but is rank 0 for 'batch_renormalization_4/Reshape_10' (op: 'Reshape') with input shapes: [1,1,1,16], [].
I ran this command inside a conda environment with keras 2.2 installed:
python cifar10_brn.py
Error message:
/root/devansh/FFL/BatchRenormalization/batch_renorm.py:97: UserWarning: This implementation of BatchRenormalization is inconsistent with the original paper and therefore results may not be similar ! For discussion on the inconsistency of this implementation, refer here : keras-team/keras-contrib#17
warnings.warn('This implementation of BatchRenormalization is inconsistent with the '
Traceback (most recent call last):
File "cifar10_brn.py", line 27, in
model = create_wide_residual_network(input_dim=init_shape, nb_classes=10, N=2, k=4)
File "/root/devansh/FFL/BatchRenormalization/wrn_renorm.py", line 118, in create_wide_residual_network
x = initial_conv(ip)
File "/root/devansh/FFL/BatchRenormalization/wrn_renorm.py", line 14, in initial_conv
x = BatchRenormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/keras/engine/base_layer.py", line 457, in call
output = self.call(inputs, **kwargs)
File "/root/devansh/FFL/BatchRenormalization/batch_renorm.py", line 192, in call
r = K.stop_gradient(K.clip(r, 1 / self.r_max, self.r_max))
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 1597, in clip
if max_value is not None and max_value < min_value:
File "/root/anaconda3/envs/ffl_dev/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 653, in bool
raise TypeError("Using a tf.Tensor
as a Python bool
is not allowed. "
TypeError: Using a tf.Tensor
as a Python bool
is not allowed. Use if t is not None:
instead of if t:
to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
Looks like the code was updated to expect K.moments() to be available, but it's not part of the Keras backend module (I'm using Keras 2.0.5). I'm guessing this code is not maintained anymore and that I should use the version in keras-contrib, right?
Hi,
is this work here somehow related to keras-contrib? See here:
https://github.com/keras-team/keras-contrib
There is also a Batch Renorm implementation (which has known bugs): https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/layers/normalization.py#L155
Thanks
Philip
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.