Giter VIP home page Giter VIP logo

Comments (13)

alanchiao avatar alanchiao commented on May 16, 2024 1

In general, yes you can.

There are some caveats (e.g. lack of subclassed model support / nesting of models within models like in both examples (tejalal@ and Cospel@). Created #155 in light of this for making subclassed support better.

from model-optimization.

s36srini avatar s36srini commented on May 16, 2024

Firstly,
include_top=False means that you are changing the input, so you'll firstly want to do this:
model = Model(top_model.input, vgg16.output), this will combine the input of top_model along with the sequential layers and connect it to the input of vgg16 (without the input layer), and have the output remain the same as vgg16.

Secondly, by pruning the whole model, you don't get to specify which layers you want to prune; it is only necessary to prune layers that have a high number of trainable parameters. In my code, I only prune pointwise convolutional layers as they contain 76% of the model's parameters.
Here's my code for reference:

mobileNet = tf.keras.applications.MobileNet(weights=None) # Not ImageNet 2012 trained weights

end_step = np.ceil(1.0 * NUM_TRAIN_SAMPLES / FLAGS.batch_size).astype(np.int32) * EPOCHS

pruning_schedule = sparsity.PolynomialDecay(
                        initial_sparsity=0.0, final_sparsity=0.5,
                        begin_step=0, end_step=end_step, frequency=100)

#layer.input_shape[-1]
pruned_model = tf.keras.Sequential()
for layer in mobileNet.layers:
    if(re.match(r"conv_pw_\d+$", layer.name)):
         pruned_model.add(sparsity.prune_low_magnitude(
            layer,
            pruning_schedule,
            block_size=(1,1)
         ))
    else:
        pruned_model.add(layer)
        
opt = tf.train.AdamOptimizer()
pruned_model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

Lastly, I forgot to mention, you will want to use keras from tensorflow.python;
from tensorflow.python import keras -> this is different than import keras; and using keras without this import will lead to headaches.

from model-optimization.

s36srini avatar s36srini commented on May 16, 2024

I also forgot to mention, you want to initialize the sequential model as tf.keras.Sequential() not keras.Sequential()

from model-optimization.

Cospel avatar Cospel commented on May 16, 2024

I'm using tf=2.0.0 library and get same error:

ValueError: Please initialize `Prune` with a supported layer. Layers should either be a `PrunableLayer` instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.training.Model'>

My code looks like this:

        model = tf.keras.Sequential(
            [
                tf.keras.applications.MobileNetV2(weights="imagenet", input_shape=(224, 224, 3), include_top=False),
                tf.keras.layers.GlobalAveragePooling2D(),
                tf.keras.layers.Dense(256, activation="relu", name="descriptor"),
                tf.keras.layers.Dense(2, activation="softmax", name="probs"),
            ]
        )

        model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=tfmot.sparsity.keras.PolynomialDecay(
                initial_sparsity=0.0, final_sparsity=0.5, begin_step=3, end_step=5
            ))

from model-optimization.

Cospel avatar Cospel commented on May 16, 2024

Thank you @alanchiao. Most of the models nowadays are models that are subclassed or nested. It will be very useful if we could prune them.

from model-optimization.

alanchiao avatar alanchiao commented on May 16, 2024

@nutsiepully, @raziel for visibility

from model-optimization.

raziel avatar raziel commented on May 16, 2024

We understand the need. The caveat is that going subclass then basically diminishes the usability of Keras abstractions we are using. Our suggestion, for now, would be to abstract some of the subclass logic into keras layers and then apply the pruning in the same manner as we currently do for the built in layers.

@nutsiepully wdyt? Do we have an example to point folks to?

from model-optimization.

alanchiao avatar alanchiao commented on May 16, 2024

Closing this issue since #155 was created. Will update this thread once #155 is fixed and we'll have almost complete coverage at that point.

from model-optimization.

nutsiepully avatar nutsiepully commented on May 16, 2024

Sorry, I seem to have missed this issue.

For now as @raziel suggested, the best approach is to apply pruning on a per-layer basis. You can choose the layers most important to you and just prune them. For parts of your model that are purely custom, you can use the PrunableLayer abstraction to control them.

from model-optimization.

jiayiliu avatar jiayiliu commented on May 16, 2024

mobileNet = tf.keras.applications.MobileNet(weights=None) # Not ImageNet 2012 trained weights

end_step = np.ceil(1.0 * NUM_TRAIN_SAMPLES / FLAGS.batch_size).astype(np.int32) * EPOCHS

pruning_schedule = sparsity.PolynomialDecay(
initial_sparsity=0.0, final_sparsity=0.5,
begin_step=0, end_step=end_step, frequency=100)

#layer.input_shape[-1]
pruned_model = tf.keras.Sequential()
for layer in mobileNet.layers:
if(re.match(r"conv_pw_\d+$", layer.name)):
pruned_model.add(sparsity.prune_low_magnitude(
layer,
pruning_schedule,
block_size=(1,1)
))
else:
pruned_model.add(layer)

Thank you @s36srini for sharing the code. It works well for MobileNet, but it fails for MobileNetV2. Because we cannot model.add() easily as A merge layer should be called on a list of inputs..

from model-optimization.

sushruta avatar sushruta commented on May 16, 2024

hello, so what's the correct way of getting past the above error -

A merge layer should be called on a list of inputs.

when we use model.add(...)

I get the same error when I try to define the layers I need for pruning efficientnet-B6.


What I do is the following -

  • I load Efficientnet-B6 with weights as Imagenet
  • I freeze the first 150 layers as non-touchable
  • From 151st layer onwards, I set layer.trainable as True and also check if they are one of expand_conv or project_conv and if they are, I set them as targets for pruning in the exact same way as described by @s36srini

It gives me an error pointing at model.add(...) step.

I peeked at the code of Efficientnet and it looks like it's using Functional API. Could maxing Sequential API with Functional API result in errors like these?

from model-optimization.

NonlinearNimesh avatar NonlinearNimesh commented on May 16, 2024

Hi, I have a trained frozen model, is it possible to prune it, Any references will be a great help

Thanks.

from model-optimization.

gnhearx avatar gnhearx commented on May 16, 2024

Hi everyone :)
I have a similar issue with pruning nested models, even if I apply the pruning wrappers per layer inside all the nested Functional API models, they don't prune.

Is this expected behaviour at all for nested models? Because I would think that if any layer in a model has that wrapper, then it will be pruned when the pruning callback is called in the training phase. Unfortunately, this does not happen. Instead everything not nested (that have pruning wrappers) do prune, and anything inside a nested model does not.

I can also confirm that if I create a model with no nested models at all, then everything I set to prune does in fact prune the way it should.

Side note:
My nested model is a pretrained VGG16 from keras and I apply pruning wrappers to each layer within the nested model.

If anyone perhaps have a solution to this or workaround that would seriously be very helpful, thank you.

from model-optimization.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.