Giter VIP home page Giter VIP logo

coatnet-pytorch's People

Contributors

chinhsuanwu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

coatnet-pytorch's Issues

aboult attention model

屏幕截图 2021-10-25 222552
Hi,
I noticed that the value of self.relative_bias_table is always all 0, then the following:
relative_bias = self.relative_bias_table.gather(
0, self.relative_index.repeat(1, self.heads))
is actually meaningless (it is all 0)?
Thanks!

About the # params

Hey, I tried with your implementation, and I found the calculated #param is a little bit different from the paper, and I am curious about the reason, could you please help me out?

Take coatnet_0 for example, the calculated result is 17789624 ( 17789624 / 2^20 = 16.97), and the reported #param of the paper is 25M

Thanks in advance

About CoAtNet-6 and CoAtNet7

Hello, first off, really appreciate your work! Now, how can I use coatnet-6 and coatnet-7? Is it adding a sequential of a 'C' block followed by a 'T' block at s3?

Models seem not converging.

Hi,

I tried to train CoAtNet_0 with tiny image net from cs231n (200 classes). Seems the model does not converge.

Could it be that the implementation is not 100% correct? For example, the positional embedding indexing part.
I went through the code and I think other components should be correct.

Except for the pos embedding indexing, I'm not good enough to comprehend it. Do you have a reference for the implementation of the positional embedding indexing part?

About the stochastic depth

Hi, I can't found the code about stochastic depth in your implementation.

And I add the stochastic depth code and train a CoAtNet-Tiny on ImageNet 1k, but got 79.27%@top1.

Have you reproduce the results reported by the paper?

Inconsistencies in MBConv, with corrected code provided

I've found the MBConv to have some computational inconsistencies. The following corrected code works, where I've changed the stride of the projection operation (self.proj) and moved it out of the if downsample statement. Further, the squeeze and excite block has been appropriately initialized (I've added my squeeze and excite block too here for completeness). I've also added the channel projection operation on the downsample is false branch of MBConv forward method:

class SqueezeAndExcite(nn.Module):
    def __init__(self, in_channels, expansion=0.25): # keep the reduction fixed
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, int(in_channels * expansion)),
            nn.GELU(),
            nn.Linear(int(in_channels * expansion), in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class MBConv(nn.Module):
    def __init__(self, inp, oup, expansion, downsample):
        super().__init__()
        self.downsample = downsample
        stride = 1 if not downsample else 2
        hidden_dim = int(expansion * inp)

        if self.downsample:
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.proj = nn.Conv2d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, 
                          padding=1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(oup)
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(inp, hidden_dim, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, 
                          padding=1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SqueezeAndExcite(hidden_dim, expansion=0.25),
                nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(oup)
            )

        self.conv = PreNorm(norm=nn.BatchNorm2d, model=self.conv, dimension=inp)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return self.proj(x) + self.conv(x)

dots + relative positon bias having issue with dimension mismatch in Transformer Block!!!

I try to run the CoAtNet wihtout modifying anything,

def coatnet_0(): num_blocks = [1, 1, 1, 1, 1] # L channels = [64, 96, 192, 384, 768] # D return CoAtNet((50, 50), 3, num_blocks, channels, num_classes=3)

img = torch.randn(1, 3, 50, 50)

net = coatnet_0() out = net(img) print(out.shape)

The error is that the dot is computed from downsampled image,
and the relative bias position is calculated from original image, i do not think that atteniton mechanism is miscalcuting the dots and relative position bias.
i think the issue is from how attention is iplement in Transformer Block, somehow the attention in transformer block is not utilizing downsampled image for relative position bias, instead it is calculated based on original image.
how can i solve this issue?


RuntimeError Traceback (most recent call last)
Cell In[15], line 1
----> 1 out = net(img)
2 print(out.shape)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[10], line 25, in CoAtNet.forward(self, x)
23 x = self.s1(x)
24 x = self.s2(x)
---> 25 x = self.s3(x)
26 x = self.s4(x)
28 x = self.pool(x).view(-1, x.shape[1])

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[9], line 31, in Transformer.forward(self, x)
29 def forward(self, x):
30 if self.downsample:
---> 31 x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
32 else:
33 x = x + self.attn(x)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[4], line 8, in PreNorm.forward(self, x, **kwargs)
7 def forward(self, x, **kwargs):
----> 8 return self.fn(self.norm(x), **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[8], line 47, in Attention.forward(self, x)
43 relative_bias = self.relative_bias_table.gather(
44 0, self.relative_index.repeat(1, self.heads))
45 relative_bias = rearrange(
46 relative_bias, '(h w) c -> 1 c h w', h=self.ihself.iw, w=self.ihself.iw)
---> 47 dots = dots + relative_bias
49 attn = self.attend(dots)
50 out = torch.matmul(attn, v)

RuntimeError: The size of tensor a (16) must match the size of tensor b (9) at non-singleton dimension 3

An error occurs when an image of 512 size is given as input.

Hello.I really aprreciate for your project.

However, The following error occurs when a 512-size image is input at Attention class.

dots = dots + relative_bias
RuntimeError: The size of tensor a (1024) must match the size of tensor b (196) at non-singleton dimension 3.

dots = dots + relative_bias

Why this error is occured? How do I edit your code when I want to resize the image?

Thank you!

About the Wi-j

Thanks for your sharing.
We want to confirm that the relative_coords is learnable parameters or constant in CoatNet?

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)

        relative_coords = coords[:, :, None] - coords[:, None, :]
        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

About stochastic depth rate

It seems that the released code did not implement 'stochastic depth' in CoAtNet module, but it was mentioned in the appedix A.2 of paper.

Perhaps a nicer way to estimate the relative position

In your code for Attention class, the relative position matrix calculation can be sped up in the following manner:

Common to both:
y, x = torch.meshgrid(torch.arange(ih), torch.arange(iw), indexing='ij')
y_flat, x_flat = y.flatten(), x.flatten()

You implementation: 0.01103353500366211 seconds for a 3X3 matrix
rel_y = y_flat.repeat_interleave(nn).view(nn, nn) - y_flat.repeat(nn).view(nn, nn)
rel_x = x_flat.repeat_interleave(nn).view(nn, nn) - x_flat.repeat(nn).view(nn, nn)
rel_pos = (rel_y + ih - 1) * (2 * iw - 1) + (rel_x + iw - 1) # Unique index calculation

Suggestion: 0.0007836818695068359 seconds for 3X3 matrix
rel_y = y_flat.flip(dims=[0]).repeat(nn, 1) - y_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn)
rel_x = x_flat.flip(dims=[0]).repeat(nn, 1) - x_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn)
rel_pos = (rel_y + ih - 1) * (2 * iw - 1) + (rel_x + iw - 1) # Unique index calculation

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.