Giter VIP home page Giter VIP logo

Comments (7)

jfrery avatar jfrery commented on July 29, 2024

Hi @malickKhurram,

Yes we don't support GlobalAveragePool for now. This are going to support this soon.

For now, a workaround would be to change the adaptive average pooling in your network to a simple average pooling.

For example, for the resnet18 we have

nn.AdaptiveAvgPool2d((1, 1))

this can be changed to

nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

Of course this change depends on your model architecture and data input shape. The adaptive pooling basically computes the kernel size and stride automatically given the desired output size (here 1,1). So you need to find what value of kernel_size and stride give you the desired value at this specific point in your network and hardcode these values in a standard average pool.

This should make the compilation pass.

from concrete-ml.

malickKhurram avatar malickKhurram commented on July 29, 2024

Hi @jfrery
Thank you for your quick response.

  • May I know how soon GlobalAveragePool will be incorporated into concrete ml
  • I am using a pre trained model densenet121 for covid detection. Infact I am using following code and dataset for this learning task.
    https://www.kaggle.com/code/arunrk7/covid-19-detection-pytorch-tutorial
  • Can you please guide me on the steps to compile this pretrained model to concrete ml. Or you have any similar example converted to concrete ml.

from concrete-ml.

jfrery avatar jfrery commented on July 29, 2024

Hi @malickKhurram,

I can give you some hints on how to workaround the GlobalAveragePooling.

You will need change the densenet model file manually. For this you need to copy that file locally https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py such that you can import the pre-trained densenet from that file instead of torch hub.

Then you will need to change that line:

out = F.adaptive_avg_pool2d(out, (1, 1))

to

out = F.avg_pool2d(out, kernel_size=(N, M))

Here, you need to find out what are N and M. Best way to do this is to get the original model and print the dimensions of the out variable before applying the pooling. If that value is (1024, 5, 5) then N = M = 5.

Let me know if you can find your way with this. I will make this issue a feature request about GlobalAveragePooling such that we can track it.

from concrete-ml.

malickKhurram avatar malickKhurram commented on July 29, 2024

Hi @jfrery
Thank you for your quick response. It helped me to solve above problem. I made changes in local file.
But now I am facing following issue when compiling the onnx model.

1

Can you please guide on this.

Regards

from concrete-ml.

andrei-stoian-zama avatar andrei-stoian-zama commented on July 29, 2024

Thanks for the bug report.

It's hard to tell where the error comes from. The line you show uses numpy functions so it should return numpy.float instead of python float. Could you print the values and types of stats.rmax, stats.rmin, options.n_bits, self.offset just before that line?

Alternatively could you give code that reproduces the issue ?

from concrete-ml.

andrei-stoian-zama avatar andrei-stoian-zama commented on July 29, 2024

Let's continue discussion in #522

Keeping this issue open until CML supports GlobalAveragePooling

from concrete-ml.

andrei-stoian-zama avatar andrei-stoian-zama commented on July 29, 2024

As a reminder, for an image classification model, please see https://github.com/zama-ai/concrete-ml/tree/main/use_case_examples/cifar/cifar_brevitas_finetuning

from concrete-ml.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.