chinhsuanwu / coatnet-pytorch Goto Github PK
View Code? Open in Web Editor NEWA PyTorch implementation of "CoAtNet: Marrying Convolution and Attention for All Data Sizes"
Home Page: https://arxiv.org/abs/2106.04803
License: MIT License
A PyTorch implementation of "CoAtNet: Marrying Convolution and Attention for All Data Sizes"
Home Page: https://arxiv.org/abs/2106.04803
License: MIT License
do u have pretrained weights? tks
Hey, I tried with your implementation, and I found the calculated #param is a little bit different from the paper, and I am curious about the reason, could you please help me out?
Take coatnet_0 for example, the calculated result is 17789624 ( 17789624 / 2^20 = 16.97), and the reported #param of the paper is 25M
Thanks in advance
very good work!
Hello, first off, really appreciate your work! Now, how can I use coatnet-6 and coatnet-7? Is it adding a sequential of a 'C' block followed by a 'T' block at s3?
Hi,
I tried to train CoAtNet_0 with tiny image net from cs231n (200 classes). Seems the model does not converge.
Could it be that the implementation is not 100% correct? For example, the positional embedding indexing part.
I went through the code and I think other components should be correct.
Except for the pos embedding indexing, I'm not good enough to comprehend it. Do you have a reference for the implementation of the positional embedding indexing part?
Are there any pre-trained models for CoAtNet? (e.g: ImageNet [1k, 21k], COCO, ...)
Hi, I can't found the code about stochastic depth in your implementation.
And I add the stochastic depth code and train a CoAtNet-Tiny on ImageNet 1k, but got 79.27%@top1.
Have you reproduce the results reported by the paper?
I've found the MBConv to have some computational inconsistencies. The following corrected code works, where I've changed the stride of the projection operation (self.proj
) and moved it out of the if downsample
statement. Further, the squeeze and excite block has been appropriately initialized (I've added my squeeze and excite block too here for completeness). I've also added the channel projection operation on the downsample is false
branch of MBConv forward method:
class SqueezeAndExcite(nn.Module):
def __init__(self, in_channels, expansion=0.25): # keep the reduction fixed
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, int(in_channels * expansion)),
nn.GELU(),
nn.Linear(int(in_channels * expansion), in_channels),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avgpool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class MBConv(nn.Module):
def __init__(self, inp, oup, expansion, downsample):
super().__init__()
self.downsample = downsample
stride = 1 if not downsample else 2
hidden_dim = int(expansion * inp)
if self.downsample:
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.proj = nn.Conv2d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False)
if expansion == 1:
self.conv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride,
padding=1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(oup)
)
else:
self.conv = nn.Sequential(
nn.Conv2d(inp, hidden_dim, kernel_size=1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
padding=1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
SqueezeAndExcite(hidden_dim, expansion=0.25),
nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(oup)
)
self.conv = PreNorm(norm=nn.BatchNorm2d, model=self.conv, dimension=inp)
def forward(self, x):
if self.downsample:
return self.proj(self.pool(x)) + self.conv(x)
else:
return self.proj(x) + self.conv(x)
I try to run the CoAtNet wihtout modifying anything,
def coatnet_0(): num_blocks = [1, 1, 1, 1, 1] # L channels = [64, 96, 192, 384, 768] # D return CoAtNet((50, 50), 3, num_blocks, channels, num_classes=3)
img = torch.randn(1, 3, 50, 50)
net = coatnet_0() out = net(img) print(out.shape)
The error is that the dot is computed from downsampled image,
and the relative bias position is calculated from original image, i do not think that atteniton mechanism is miscalcuting the dots and relative position bias.
i think the issue is from how attention is iplement in Transformer Block, somehow the attention in transformer block is not utilizing downsampled image for relative position bias, instead it is calculated based on original image.
how can i solve this issue?
RuntimeError Traceback (most recent call last)
Cell In[15], line 1
----> 1 out = net(img)
2 print(out.shape)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
Cell In[10], line 25, in CoAtNet.forward(self, x)
23 x = self.s1(x)
24 x = self.s2(x)
---> 25 x = self.s3(x)
26 x = self.s4(x)
28 x = self.pool(x).view(-1, x.shape[1])
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
Cell In[9], line 31, in Transformer.forward(self, x)
29 def forward(self, x):
30 if self.downsample:
---> 31 x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
32 else:
33 x = x + self.attn(x)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
Cell In[4], line 8, in PreNorm.forward(self, x, **kwargs)
7 def forward(self, x, **kwargs):
----> 8 return self.fn(self.norm(x), **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
Cell In[8], line 47, in Attention.forward(self, x)
43 relative_bias = self.relative_bias_table.gather(
44 0, self.relative_index.repeat(1, self.heads))
45 relative_bias = rearrange(
46 relative_bias, '(h w) c -> 1 c h w', h=self.ihself.iw, w=self.ihself.iw)
---> 47 dots = dots + relative_bias
49 attn = self.attend(dots)
50 out = torch.matmul(attn, v)
RuntimeError: The size of tensor a (16) must match the size of tensor b (9) at non-singleton dimension 3
Hello.I really aprreciate for your project.
However, The following error occurs when a 512-size image is input at Attention
class.
dots = dots + relative_bias
RuntimeError: The size of tensor a (1024) must match the size of tensor b (196) at non-singleton dimension 3.
Line 155 in d3ef1c3
Why this error is occured? How do I edit your code when I want to resize the image?
Thank you!
Hello,
I'm training the model from scratch on my custom dataset and the convergence is very slow. So, can you share any pretrained networks if possible?
Thank you for your wonderful work,
However, I print the parameters of the net from your code, and compare the parameter from the paper of page 8,
its much more different from the paper
Hello, first off, really appreciate your work! Unfortunately, I'm getting overfitting using a custom dataset even in coatnet_0, is there a workaround?
Please! Thank you veru much!
Thanks for your sharing.
We want to confirm that the relative_coords is learnable parameters or constant in CoatNet?
coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
coords = torch.flatten(torch.stack(coords), 1)
relative_coords = coords[:, :, None] - coords[:, None, :]
relative_coords[0] += self.ih - 1
relative_coords[1] += self.iw - 1
relative_coords[0] *= 2 * self.iw - 1
relative_coords = rearrange(relative_coords, 'c h w -> h w c')
relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
self.register_buffer("relative_index", relative_index)
It seems that the released code did not implement 'stochastic depth' in CoAtNet module, but it was mentioned in the appedix A.2 of paper.
In your code for Attention class, the relative position matrix calculation can be sped up in the following manner:
Common to both:
y, x = torch.meshgrid(torch.arange(ih), torch.arange(iw), indexing='ij')
y_flat, x_flat = y.flatten(), x.flatten()
You implementation: 0.01103353500366211 seconds for a 3X3 matrix
rel_y = y_flat.repeat_interleave(nn).view(nn, nn) - y_flat.repeat(nn).view(nn, nn)
rel_x = x_flat.repeat_interleave(nn).view(nn, nn) - x_flat.repeat(nn).view(nn, nn)
rel_pos = (rel_y + ih - 1) * (2 * iw - 1) + (rel_x + iw - 1) # Unique index calculation
Suggestion: 0.0007836818695068359 seconds for 3X3 matrix
rel_y = y_flat.flip(dims=[0]).repeat(nn, 1) - y_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn)
rel_x = x_flat.flip(dims=[0]).repeat(nn, 1) - x_flat.flip(dims=[0]).view(-1, 1).repeat(1, nn)
rel_pos = (rel_y + ih - 1) * (2 * iw - 1) + (rel_x + iw - 1) # Unique index calculation
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.