Giter VIP home page Giter VIP logo

gfnet's People

Contributors

raoyongming avatar wl-zhao avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gfnet's Issues

Code issue

您好,图中的红框中,您给权重乘上0.002这是什么作用?
image
还有一个小问题,为何你不直接对维数为N这个轴做一维傅里叶变换,然后再乘上可学习的权重?

parameters

class GlobalFilter(nn.Module):
def init(self, dim, h=14, w=8):
super().init()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)

def forward(self, x):
    B, H, W, C = x.shape
    x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
    weight = torch.view_as_complex(self.complex_weight)
    x = x * weight
    x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
    return x

Thank you very much for your work. I have some questions. What's meaning of "h=14, w=8", "s=(H, W), dim=(1, 2)".

About visualization

Thank you very much for your work, but when I realized the visualization of the frequency domain filter, the image is different from the one you provided. And I would like to ask how to realize the visualization of the spatial domain

Question about block design

Hello, thanks for your great work!

In your figure and code, there is no skip connection after the global filter layer.

This is different from original transformer implementation,
which has 2 skip connections in a single block (each for self-attention layer and FFN layer)

For example, original transformer uses blocks like

x = x + SA(x)
x = x + FFN(x)

But, global filter network uses below block

x_ = Global_Filter(x)
x = x + FFN(x_)

Is there any reason for adopting the current block architecture?

Thanks,

Flexible input size

Hi, I came across your work and thought it was a very interesting concept. Currently, the network takes in fixed input sizes. But is there a way for there to be flexible input sizes? I realize the main constraint here is the following line where the complex weight is defined during initialization time:
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
Is there a way to modify this line so that we can have inputs of different sizes?
Thanks

model parallel training

why I can only replicate the whole model on different gpu, not the model parallel that distributed part of the model to the gpu?

I follow the instruction on the webpage

GFNet containing imaginary number

class GlobalFilter(nn.Module):
def init(self, dim, h=14, w=8):
super().init()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
self.w = w
self.h = h

def forward(self, x, spatial_size=None):
    B, N, C = x.shape
    if spatial_size is None:
        a = b = int(math.sqrt(N))
    else:
        a, b = spatial_size

    x = x.view(B, a, b, C)

    x = x.to(torch.float32)

    x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
    weight = torch.view_as_complex(self.complex_weight)
    x = x * weight
    **x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')**

    x = x.reshape(B, N, C)

    return x

class Block(nn.Module):

def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
    super().__init__()
    self.norm1 = norm_layer(dim)
    self.filter = GlobalFilter(dim, h=h, w=w)
    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    self.norm2 = norm_layer(dim)
    mlp_hidden_dim = int(dim * mlp_ratio)
    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

def forward(self, x):
    **x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))**
    return x

I have a question regarding GFNet.
When you inverse FFT and layer norm in class Block, i get the error saying the feature map consist imaginary numbers.

How did you come about this?
i get an error in x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))

Thank you in advance,

Question about adversarial robustness

Hi, Rao
Thank you for your great work!
When measuring GFNet's advertising robustness through FGSM and PGD, can I know specific conditions and hyperparameters?
It would be even better if I could get the code you used!

多卡训练

请问能多卡训练吗?我没有看到相关的代码---

Visualization Code

Hello sir:
Thank you for your excellent work!
Could you please provide visualization code?
Thank you!

Question about Complexity (FLOPs)

Hi, interesting work!
I wonder how the Complexity (FLOPs) for global filter in Table.1 is calculated.
Since the conjugate symmetric for real signals, we have:

case1: consider the conjugate symmetric.
RFFT: HWD/2 * log2(HW)
Global Filter: HWD/2
IRFFT: HWD/2 * log2(HW)

Thus, the total Complexity (FLOPs) for global filter is: HWD * log2(HW) + HWD/2. Is it right?

case2: not consider the conjugate symmetric.
RFFT: HWD * log2(HW)
Global Filter: HWD
IRFFT: HWD * log2(HW)

Thus, the total Complexity (FLOPs) for global filter is: 2HWD * log2(HW) + HWD. Is it right?

Which is right?

Flops

How is Flops calculated?

GFNet models pretrained on ImageNet cannot be loaded

@raoyongming Thanks for providing the ImageNet weights. I was trying to load the weights that you have shared to do some analysis for benchmarking, however, when I am loading it for gfnet-ti or gfnet-xs, I am getting the following error:

model.load_state_dict(torch.load(weightpath))

RuntimeError: Error(s) in loading state_dict for GFNet:
        Missing key(s) in state_dict: "pos_embed", "patch_embed.proj.weight", "patch_embed.proj.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.filter.complex_weight", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.filter.complex_weight", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.filter.complex_weight", "blocks.2.norm2.weight", "blocks.2.norm2.bias", "blocks.2.mlp.fc1.weight", "blocks.2.mlp.fc1.bias", "blocks.2.mlp.fc2.weight", "blocks.2.mlp.fc2.bias", "blocks.3.norm1.weight", "blocks.3.norm1.bias", "blocks.3.filter.complex_weight", "blocks.3.norm2.weight", "blocks.3.norm2.bias", "blocks.3.mlp.fc1.weight", "blocks.3.mlp.fc1.bias", "blocks.3.mlp.fc2.weight", "blocks.3.mlp.fc2.bias", "norm.weight", "norm.bias", "head.weight", "head.bias". 
        Unexpected key(s) in state_dict: "model".

Are these weights for the same architecture in code? how can I resolve this?
Thanks

The fp32fft option

Hello, thanks for your nice work!

I wonder what does the option fp32fft do. In my experiments the input and output to the fft function are already torch.float32, so I'm not sure why there is an option for converting to fp32. Thanks in advance

image size for ADE20K

Hi, Yongming

What is image size did you use for training and validation on ADE20K?

I noticed that PVT used 512x512 for training and a different scale for testing. However, as the parameters of Global Filter are related to the image size, how do you deal with the scale change?

Thanks in advance.

Question about 3D configuration

Hi, Thank you for your excellent work. I have some questions concerning 3D configuration of GF-Net.
I extend your model to a 3D version by introducing 3D FFT and IFFT to conduct global filter learning, and test on the 3D data classification (Point Cloud / Volumetric data). However, the over-fitting problem occurs (Traning dataset 97% acc. Testing dataset 75% acc.). Could you provide some advice on how to train such a model. (I have tried Dropout with different ratios). Thank you~

About size h and w

Thank you for your excellent work.

I'm curious why w is set to w = h // 2 + 1, or just the experiment proves that it is better to set w = h // 2 + 1 in this way

Predicting with the model

Hello, i am new to pytorch and wanted to know how exactly the trained model could be used for prediction with images?
thank you

About 1D FFT AND 2D FFT

您好,感谢您的工作。我有两个问题想咨询。
1.我在有关序列任务上尝试了你的模型,我的序列长度为256,tensor为(B,256,512)输入到你的模型,我这里忽略了输入的意义,直接将序列256看作是2维信号 (即 H W 为16) ,强行让模型学习,但是奇怪的是结果却很好。我想问一下如何直接使用1D FFT对序列进行操作,代码如何编写?
2.在进行2D FFT,输入的信号必须像文中的图片那样 H W是一样的吗?如何可以不一样,代码需要做那些修改?
打扰了,期待您的回复!再次感谢!

FLOPs Concern

Hello, thanks for your nice work!

I test the FLOPs of GFNet using the fvcore library as following:

    from fvcore.nn import flop_count
    model_mode = model.training
    model.eval()
    fake_input = torch.rand(1, 3, 224, 224)
    flops_dict, *_ = flop_count(model, fake_input)
    count = sum(flops_dict.values())
    model.train(model_mode)
    print("fvcore FLOPs: {:.3f} G".format(count))

For gfnet-h-b model, I got the result: 8.547 G. It is slightly higher than what you mentioned in the paper, i.e., 8.4G.

My concerns are:

  1. How do you get the FLOPs value?

  2. From, the fvcore log, I noticed that:

Unsupported operator aten::fft_irfft2 encountered 36 time(s)

I.e., the FLOPs of the fft_irfft2 operator are not taken into account.
I wonder if you consider this operator when calculating the FLOPs?

If not, I think it would be better to consider it because this is the core operator that replaces self-attention.

Please let me know if I missed something.

Thanks.

[Q] : can The Global filter Network use for Segmentation images

I read the paper, it was providing straight ideas about how does GFNT tackled the limitation of using Self-Attention with a Vision-Transformer in computations cost and complexity also, i would like to take step further to use it in Segmentation task but i will use Global Filter as Attention mechanism in Unet model similar to Attention Gate

Algorithm pseudocode

Inputs: $F_g$ - input feature dimension of the global filter, $F_l$ - input feature dimension of the local filter, $F_{int}$ - intermediate feature dimension, $dim$ - spatial dimensions

Outputs: filtered output, gate frequencies, weights frequencies


Function: AttentionFilter($g$, $x$)

  • Input: $g$ - input global feature map, $x$ - input feature map
  • Output: $out$ - filtered output, $G1_{feq}$ - gate frequencies, $X1_{feq}$ - weights frequencies

Pseudocode:

$G1_{feq} \gets \text{GlobalFilter}(g, F_l, F_{int}, dim)$
$X1_{feq} \gets \text{GlobalFilter}(x, F_g, F_{int}, dim)$
$atten \gets \text{Softmax}\left(\frac{G1_{feq} \odot X1_{feq}}{\sqrt{2\pi\sigma^2}}\right)$
$x1 \gets \text{irfft2}\left(atten, s=(H, W), dim=(1, 2), \text{norm}='ortho'\right)$
$out \gets \text{NormLayer(x1 + x)}$
Return: $out$, $G1_{feq}$, $X1_{feq}$

my Question is : based on Learning from Frequencies which means I kept learning of neural Net on Frequencies similar to Complex-Value NN how does your opinion on the algorithm i provide anything i misunderstood or not correct

The dimenson for Global Filter

import torch
import torch.nn as nn
import torch.fft

class GlobalFilter(nn.Module):
def init(self, dim, h=14, w=8):
super().init()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)

def forward(self, x):
    B, H, W, C = x.shape
    x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
    weight = torch.view_as_complex(self.complex_weight)
    x = x * weight
    x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
    return x  

Can you give a simple test for the dimensions' change?

Ask about the reasoning speed of an fft filter:

Ask about the reasoning speed of an fft filter:
After using GlobalFilter, I found that this will greatly affect the speed of model reasoning, which is almost half of that before not adding. Do you have this problem? How to solve it?

GFNet for 3d data: How to change the Global filter

@raoyongming Thanks for your great contribution.
I have a input data with the shape of (B,C,H,W): (1, 200,8,8) with 200 gray-scale slices (i.e., channels), and my patch embedding class with patch size (10,1,1) converts the data into B, N, C = x.shape #shape: torch.Size([1, 1280, 96]) . that means my input data after patchifying the data becomes [1,20, 8,8,10] (10 here is the original pixel values of each patch) and the embedding has the shape of [1, 1280, 96], where 96 is the embedding dim. and the data after patchifying looks like something similar to this.

My question is can your GlobalFilter be changed to 3D computation of global filter?

If yes, how these changes can be applied to the code in terms of h, w_hat and etc.?
Thanks

如何部署这个模型?

您好,我想部署hornet在Nvidia Xavier上,用到这个GFNet,但在转onnx时,遇到torch.tff不支持的情况,有什么好的解决办法么?

Memory and FLOPs concern?

Hi! very interesting work!

How is Params calculated? Do you use profile ?

I have noticed that you use a script to calculate memory and flops. Can you share the script? Many thanks goes to the author~

Do you plan to update the code? (ft. TPAMI version)

Thanks to share your nice work. I have followed up on your art since the last year and I found out TPAMI paper which is enhanced from the NeurIPS version.

It seems that there are some modifications to your token-mixer design. I want to ask if you have a plan to update the code for the modified version of GFNet?
Actually, I really hope to follow up on your new work :).

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.