raoyongming / gfnet Goto Github PK
View Code? Open in Web Editor NEW[NeurIPS 2021] [T-PAMI] Global Filter Networks for Image Classification
Home Page: https://gfnet.ivg-research.xyz/
License: MIT License
[NeurIPS 2021] [T-PAMI] Global Filter Networks for Image Classification
Home Page: https://gfnet.ivg-research.xyz/
License: MIT License
class GlobalFilter(nn.Module):
def init(self, dim, h=14, w=8):
super().init()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
def forward(self, x):
B, H, W, C = x.shape
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
return x
Thank you very much for your work. I have some questions. What's meaning of "h=14, w=8", "s=(H, W), dim=(1, 2)".
GFNet is an excellent work. What should I do or think about if I want to use it for input with variable resolution?
Thank you very much for your work, but when I realized the visualization of the frequency domain filter, the image is different from the one you provided. And I would like to ask how to realize the visualization of the spatial domain
Hello, thanks for your great work!
In your figure and code, there is no skip connection after the global filter layer.
This is different from original transformer implementation,
which has 2 skip connections in a single block (each for self-attention layer and FFN layer)
For example, original transformer uses blocks like
x = x + SA(x)
x = x + FFN(x)
But, global filter network uses below block
x_ = Global_Filter(x)
x = x + FFN(x_)
Is there any reason for adopting the current block architecture?
Thanks,
Hi, I came across your work and thought it was a very interesting concept. Currently, the network takes in fixed input sizes. But is there a way for there to be flexible input sizes? I realize the main constraint here is the following line where the complex weight is defined during initialization time:
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
Is there a way to modify this line so that we can have inputs of different sizes?
Thanks
why I can only replicate the whole model on different gpu, not the model parallel that distributed part of the model to the gpu?
I follow the instruction on the webpage
class GlobalFilter(nn.Module):
def init(self, dim, h=14, w=8):
super().init()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
self.w = w
self.h = h
def forward(self, x, spatial_size=None):
B, N, C = x.shape
if spatial_size is None:
a = b = int(math.sqrt(N))
else:
a, b = spatial_size
x = x.view(B, a, b, C)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
**x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')**
x = x.reshape(B, N, C)
return x
class Block(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = GlobalFilter(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
**x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))**
return x
I have a question regarding GFNet.
When you inverse FFT and layer norm in class Block, i get the error saying the feature map consist imaginary numbers.
How did you come about this?
i get an error in x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
Thank you in advance,
Hi, Rao
Thank you for your great work!
When measuring GFNet's advertising robustness through FGSM and PGD, can I know specific conditions and hyperparameters?
It would be even better if I could get the code you used!
请问能多卡训练吗?我没有看到相关的代码---
Hello sir:
Thank you for your excellent work!
Could you please provide visualization code?
Thank you!
Hi, interesting work!
I wonder how the Complexity (FLOPs) for global filter in Table.1 is calculated.
Since the conjugate symmetric for real signals, we have:
case1: consider the conjugate symmetric.
RFFT: HWD/2 * log2(HW)
Global Filter: HWD/2
IRFFT: HWD/2 * log2(HW)
Thus, the total Complexity (FLOPs) for global filter is: HWD * log2(HW) + HWD/2. Is it right?
case2: not consider the conjugate symmetric.
RFFT: HWD * log2(HW)
Global Filter: HWD
IRFFT: HWD * log2(HW)
Thus, the total Complexity (FLOPs) for global filter is: 2HWD * log2(HW) + HWD. Is it right?
Which is right?
Hello,
Thank you for your nice work! I am curious about the settings of the head of GFNet for segmentation. Will you share more details? Thank you!
How is Flops calculated?
In my experiments, I noticed that they have similar FLOPs, but GFNet-H-B trains 2x slower than GFNet-B. Do you have similar observations, and any ideas on why? Thanks in advance!
@raoyongming Thanks for providing the ImageNet weights. I was trying to load the weights that you have shared to do some analysis for benchmarking, however, when I am loading it for gfnet-ti
or gfnet-xs
, I am getting the following error:
model.load_state_dict(torch.load(weightpath))
RuntimeError: Error(s) in loading state_dict for GFNet:
Missing key(s) in state_dict: "pos_embed", "patch_embed.proj.weight", "patch_embed.proj.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.filter.complex_weight", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.filter.complex_weight", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.filter.complex_weight", "blocks.2.norm2.weight", "blocks.2.norm2.bias", "blocks.2.mlp.fc1.weight", "blocks.2.mlp.fc1.bias", "blocks.2.mlp.fc2.weight", "blocks.2.mlp.fc2.bias", "blocks.3.norm1.weight", "blocks.3.norm1.bias", "blocks.3.filter.complex_weight", "blocks.3.norm2.weight", "blocks.3.norm2.bias", "blocks.3.mlp.fc1.weight", "blocks.3.mlp.fc1.bias", "blocks.3.mlp.fc2.weight", "blocks.3.mlp.fc2.bias", "norm.weight", "norm.bias", "head.weight", "head.bias".
Unexpected key(s) in state_dict: "model".
Are these weights for the same architecture in code? how can I resolve this?
Thanks
Hello, thanks for your nice work!
I wonder what does the option fp32fft do. In my experiments the input and output to the fft function are already torch.float32, so I'm not sure why there is an option for converting to fp32. Thanks in advance
Hi, Yongming
What is image size did you use for training and validation on ADE20K?
I noticed that PVT used 512x512
for training and a different scale for testing. However, as the parameters of Global Filter are related to the image size, how do you deal with the scale change?
Thanks in advance.
Hi, Thank you for your excellent work. I have some questions concerning 3D configuration of GF-Net.
I extend your model to a 3D version by introducing 3D FFT and IFFT to conduct global filter learning, and test on the 3D data classification (Point Cloud / Volumetric data). However, the over-fitting problem occurs (Traning dataset 97% acc. Testing dataset 75% acc.). Could you provide some advice on how to train such a model. (I have tried Dropout with different ratios). Thank you~
Thank you for your excellent work.
I'm curious why w is set to w = h // 2 + 1
, or just the experiment proves that it is better to set w = h // 2 + 1
in this way
Hello, i am new to pytorch and wanted to know how exactly the trained model could be used for prediction with images?
thank you
您好,感谢您的工作。我有两个问题想咨询。
1.我在有关序列任务上尝试了你的模型,我的序列长度为256,tensor为(B,256,512)输入到你的模型,我这里忽略了输入的意义,直接将序列256看作是2维信号 (即 H W 为16) ,强行让模型学习,但是奇怪的是结果却很好。我想问一下如何直接使用1D FFT对序列进行操作,代码如何编写?
2.在进行2D FFT,输入的信号必须像文中的图片那样 H W是一样的吗?如何可以不一样,代码需要做那些修改?
打扰了,期待您的回复!再次感谢!
Hello, thanks for your nice work!
I test the FLOPs of GFNet using the fvcore library as following:
from fvcore.nn import flop_count
model_mode = model.training
model.eval()
fake_input = torch.rand(1, 3, 224, 224)
flops_dict, *_ = flop_count(model, fake_input)
count = sum(flops_dict.values())
model.train(model_mode)
print("fvcore FLOPs: {:.3f} G".format(count))
For gfnet-h-b
model, I got the result: 8.547 G
. It is slightly higher than what you mentioned in the paper, i.e., 8.4G
.
My concerns are:
How do you get the FLOPs value?
From, the fvcore
log, I noticed that:
Unsupported operator aten::fft_irfft2 encountered 36 time(s)
I.e., the FLOPs of the fft_irfft2
operator are not taken into account.
I wonder if you consider this operator when calculating the FLOPs?
If not, I think it would be better to consider it because this is the core operator that replaces self-attention.
Please let me know if I missed something.
Thanks.
I read the paper, it was providing straight ideas about how does GFNT tackled the limitation of using Self-Attention with a Vision-Transformer in computations cost and complexity also, i would like to take step further to use it in Segmentation task but i will use Global Filter as Attention mechanism in Unet model similar to Attention Gate
Inputs:
Outputs: filtered output, gate frequencies, weights frequencies
Function: AttentionFilter(
Pseudocode:
Return:
import torch
import torch.nn as nn
import torch.fft
class GlobalFilter(nn.Module):
def init(self, dim, h=14, w=8):
super().init()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
def forward(self, x):
B, H, W, C = x.shape
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
return x
Can you give a simple test for the dimensions' change?
Ask about the reasoning speed of an fft filter:
After using GlobalFilter, I found that this will greatly affect the speed of model reasoning, which is almost half of that before not adding. Do you have this problem? How to solve it?
@raoyongming Thanks for your great contribution.
I have a input data with the shape of (B,C,H,W): (1, 200,8,8)
with 200 gray-scale slices (i.e., channels), and my patch embedding class with patch size (10,1,1) converts the data into B, N, C = x.shape #shape: torch.Size([1, 1280, 96])
. that means my input data after patchifying the data becomes [1,20, 8,8,10]
(10 here is the original pixel values of each patch) and the embedding has the shape of [1, 1280, 96]
, where 96
is the embedding dim. and the data after patchifying looks like something similar to this.
My question is can your GlobalFilter be changed to 3D computation of global filter?
If yes, how these changes can be applied to the code in terms of h
, w_hat
and etc.?
Thanks
您好,我想部署hornet在Nvidia Xavier上,用到这个GFNet,但在转onnx时,遇到torch.tff不支持的情况,有什么好的解决办法么?
Hi! very interesting work!
How is Params calculated? Do you use profile ?
I have noticed that you use a script to calculate memory and flops. Can you share the script? Many thanks goes to the author~
Thanks to share your nice work. I have followed up on your art since the last year and I found out TPAMI paper which is enhanced from the NeurIPS version.
It seems that there are some modifications to your token-mixer design. I want to ask if you have a plan to update the code for the modified version of GFNet?
Actually, I really hope to follow up on your new work :).
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.