Giter VIP home page Giter VIP logo

san's Introduction

Exploring Self-attention for Image Recognition

by Hengshuang Zhao, Jiaya Jia, and Vladlen Koltun, details are in paper.

Introduction

This repository is build for the proposed self-attention network (SAN), which contains full training and testing code. The implementation of SA module with optimized CUDA kernels are also included.

Usage

  1. Requirement:

    • Hardware: tested with 8 x Quadro RTX 6000 (24G).
    • Software: tested with PyTorch 1.4.0, Python3.7, CUDA 10.1, CuPy 10.1, tensorboardX.
  2. Clone the repository:

    git clone https://github.com/hszhao/SAN.git
  3. Train:

    • Download and prepare the ImageNet dataset (ILSVRC2012) and symlink the path to it as follows (you can alternatively modify the relevant path specified in folder config):

      cd SAN
      mkdir -p dataset
      ln -s /path_to_ILSVRC2012_dataset dataset/ILSVRC2012
      
    • Specify the gpus (usually 8 gpus are adopted) used in config and then do training:

      sh tool/train.sh imagenet san10_pairwise
      
    • If you are using SLURM for nodes manager, uncomment lines in train.sh and then do training:

      sbatch tool/train.sh imagenet san10_pairwise
  4. Test:

    • Download trained SAN models and put them under folder specified in config or modify the specified paths, and then do testing:

      sh tool/test.sh imagenet san10_pairwise
  5. Visualization:

    • tensorboardX incorporated for better visualization regarding curves:

      tensorboard --logdir=exp/imagenet
  6. Other:

    • Resources: GoogleDrive LINK contains shared models.

Performance

Train Parameters: train_gpus(8), batch_size(256), epochs(100), base_lr(0.1), lr_scheduler(cosine), label_smoothing(0.1), momentum(0.9), weight_decay(1e-4).

Overall result:

Method top-1 top-5 Params Flops
ResNet26 73.6 91.7 13.7M 2.4G
SAN10-pair. 74.9 92.1 10.5M 2.2G
SAN10-patch. 77.1 93.5 11.8M 1.9G
ResNet38 76.0 93.0 19.6M 3.2G
SAN15-pair. 76.6 93.1 14.1M 3.0G
SAN15-patch. 78.0 93.9 16.2M 2.6G
ResNet50 76.9 93.5 25.6M 4.1G
SAN19-pair. 76.9 93.4 17.6M 3.8G
SAN19-patch. 78.2 93.9 20.5M 3.3G

Citation

If you find the code or trained models useful, please consider citing:

@inproceedings{zhao2020san,
  title={Exploring Self-attention for Image Recognition},
  author={Zhao, Hengshuang and Jia, Jiaya and Koltun, Vladlen},
  booktitle={CVPR},
  year={2020}
}

san's People

Contributors

gvi-lab avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

san's Issues

Where is the mapping function as your paper said ?

image
image

But in the code

SAN/model/san.py

Lines 39 to 42 in d88b022

self.conv_w = nn.Sequential(nn.BatchNorm2d(rel_planes * (pow(kernel_size, 2) + 1)), nn.ReLU(inplace=True),
nn.Conv2d(rel_planes * (pow(kernel_size, 2) + 1), out_planes // share_planes, kernel_size=1, bias=False),
nn.BatchNorm2d(out_planes // share_planes), nn.ReLU(inplace=True),
nn.Conv2d(out_planes // share_planes, pow(kernel_size, 2) * out_planes // share_planes, kernel_size=1))

I think the self.conv_w is the mapping function but it is different from paper

code: BN -> R -> Conv -> BN -> R -> Conv
paper: L -> R -> L

is the better performance by this code than paper format ๏ผŸ

about the running speed of gpu parallel aggregation and subtraction

Hi, @hszhao, Thanks for this great work.

I have tested the aggregation and subtraction scripts in the folder of /lib/sa/functions/, from my setup as follows:

cuda/10.0.130
cupy-cuda100-7.7.0
python 3.7

I find it takes around 10mins to finish. Here is the log:

$python subtraction_refpad.py
test case passed
567.34s

Is it the same level of time you have taken. Cause I think the size of the input blocks is really small [2, 8, 5, 5], it looks pretty weird to take 10mins to finish.

Is there anything else that need to be clarified about my setup, please let me know.

Is there any support for 3DCNN ?

Hi, I am just wondering is there any support for 3DCNN using proposed pairwise and patchwise attention.
In my case, I have multiple input image and each of them produce a 4D tensor C,D,H,W. I think your implementation only support 3D tensor C,H,W, is that correct ?

Clarification on Aggregation

Hello I had a question regarding your code. In your patchwise attention model you are using these cython kernels for aggregation

I was a bit confused on what exactly is it doing can you explain its functionality? If I have for an example
input data: 1x256x1xWH
weights: 1x32x7**2xWH

How does it perform the hadamard product described in equation 4

Also can I generally confirm the following in SAM module and its mapping to equations 4 and 5 in your paper:
1- conv1: phi, conv2:psi, conv3:beta
2- delta is simple concatenation
3- conv_w: gamma

Thanks for your help.

AggregationRefpadBackward failed

Hi Zhao,
when I test image of which height != width, errors occurred.
lib/sa/functions/aggregation_refpad.py
n, c_x, c_w, in_height, in_width = 2, 8, 4, 24, 44

RuntimeError: Function AggregationRefpadBackward returned an invalid gradient at index 0 - got [2, 8, 28, 44] but expected shape compatible with [2, 8, 24, 44]

Looking forward to your response!
LIANG

Robustness to Adversarial Attacks

Hi, could you kindly provide the implementation details of how to obtain the results demonstrated in Table 10 in your paper, i.e., the part on robustness to adversarial attacks? I have tried several different targeted PGD methods, but still I found my results are quite different from yours.

how to load the pretrain model

model = san.san(sa_type=1, layers=[3, 2, 3, 5, 2], kernels=[3, 7, 7, 7, 7], num_classes=1000).cuda() model.load_state_dict(torch.load('./san15_patchwise/model/model_best.pth'))

RuntimeError: Error(s) in loading state_dict for SAN:
Missing key(s) in state_dict: "conv_in.weight", "bn_in.weight", "bn_in.bias", "bn_in.running_mean", "bn_in.running_var", "conv0.weight", "bn0.weight", "bn0.bias", "bn0.running_mean", "bn0.running_var", "layer0.0.bn1.weight", "layer0.0.bn1.bias", "layer0.0.bn1.running_mean", "layer0.0.bn1.running_var", "layer0.0.sam.conv1.weight", "layer0.0.sam.conv1.bias", "layer0.0.sam.conv2.weight", "layer0.0.sam.conv2.bias", "layer0.0.sam.conv3.weight", "layer0.0.sam.conv3.bias", "layer0.0.sam.conv_w.0.weight", "layer0.0.sam.conv_w.0.bias", "layer0.0.sam.conv_w.0.running_mean", "layer0.0.sam.conv_w.0.running_var", "layer0.0.sam.conv_w.2.weight", "layer0.0.sam.conv_w.3.weight", "layer0.0.sam.conv_w.3.bias", "layer0.0.sam.conv_w.3.running_mean", "layer0.0.sam.conv_w.3.running_var", "layer0.0.sam.conv_w.5.weight", "layer0.0.sam.conv_w.5.bias", "layer0.0.bn2.weight", "layer0.0.bn2.bias", "layer0.0.bn2.running_mean", "layer0.0.bn2.running_var", "layer0.0.conv.weight", "layer0.0.conv.bias", "layer0.1.bn1.weight", "layer0.1.bn1.bias", "layer0.1.bn1.running_mean", "layer0.1.bn1.running_var", "layer0.1.sam.conv1.weight", "layer0.1.sam.conv1.bias", "layer0.1.sam.conv2.weight", "layer0.1.sam.conv2.bias", "layer0.1.sam.conv3.weight", "layer0.1.sam.conv3.bias", "layer0.1.sam.conv_w.0.weight", "layer0.1.sam.conv_w.0.bias", "layer0.1.sam.conv_w.0.running_mean", "layer0.1.sam.conv_w.0.running_var", "layer0.1.sam.conv_w.2.weight", "layer0.1.sam.conv_w.3.weight", "layer0.1.sam.conv_w.3.bias", "layer0.1.sam.conv_w.3.running_mean", "layer0.1.sam.conv_w.3.running_var", "layer0.1.sam.conv_w.5.weight", "layer0.1.sam.conv_w.5.bias", "layer0.1.bn2.weight", "layer0.1.bn2.bias", "layer0.1.bn2.running_mean", "layer0.1.bn2.running_var", "layer0.1.conv.weight", "layer0.1.conv.bias", "layer0.2.bn1.weight", "layer0.2.bn1.bias", "layer0.2.bn1.running_mean", "layer0.2.bn1.running_var", "layer0.2.sam.conv1.weight", "layer0.2.sam.conv1.bias", "layer0.2.sam.conv2.weight", "layer0.2.sam.conv2.bias", "layer0.2.sam.conv3.weight", "layer0.2.sam.conv3.bias", "layer0.2.sam.conv_w.0.weight", "layer0.2.sam.conv_w.0.bias", "layer0.2.sam.conv_w.0.running_mean", "layer0.2.sam.conv_w.0.running_var", "layer0.2.sam.conv_w.2.weight", "layer0.2.sam.conv_w.3.weight", "layer0.2.sam.conv_w.3.bias", "layer0.2.sam.conv_w.3.running_mean", "layer0.2.sam.conv_w.3.running_var", "layer0.2.sam.conv_w.5.weight", "layer0.2.sam.conv_w.5.bias", "layer0.2.bn2.weight", "layer0.2.bn2.bias", "layer0.2.bn2.running_mean", "layer0.2.bn2.running_var", "layer0.2.conv.weight", "layer0.2.conv.bias", "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.sam.conv1.weight", "layer1.0.sam.conv1.bias", "layer1.0.sam.conv2.weight", "layer1.0.sam.conv2.bias", "layer1.0.sam.conv3.weight", "layer1.0.sam.conv3.bias", "layer1.0.sam.conv_w.0.weight", "layer1.0.sam.conv_w.0.bias", "layer1.0.sam.conv_w.0.running_mean", "layer1.0.sam.conv_w.0.running_var", "layer1.0.sam.conv_w.2.weight", "layer1.0.sam.conv_w.3.weight", "layer1.0.sam.conv_w.3.bias", "layer1.0.sam.conv_w.3.running_mean", "layer1.0.sam.conv_w.3.running_var", "layer1.0.sam.conv_w.5.weight", "layer1.0.sam.conv_w.5.bias", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv.weight", "layer1.0.conv.bias", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.sam.conv1.weight", "layer1.1.sam.conv1.bias", "layer1.1.sam.conv2.weight", "layer1.1.sam.conv2.bias", "layer1.1.sam.conv3.weight", "layer1.1.sam.conv3.bias", "layer1.1.sam.conv_w.0.weight", "layer1.1.sam.conv_w.0.bias", "layer1.1.sam.conv_w.0.running_mean", "layer1.1.sam.conv_w.0.running_var", "layer1.1.sam.conv_w.2.weight", "layer1.1.sam.conv_w.3.weight", "layer1.1.sam.conv_w.3.bias", "layer1.1.sam.conv_w.3.running_mean", "layer1.1.sam.conv_w.3.running_var", "layer1.1.sam.conv_w.5.weight", "layer1.1.sam.conv_w.5.bias", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv.weight", "layer1.1.conv.bias", "conv2.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.sam.conv1.weight", "layer2.0.sam.conv1.bias", "layer2.0.sam.conv2.weight", "layer2.0.sam.conv2.bias", "layer2.0.sam.conv3.weight", "layer2.0.sam.conv3.bias", "layer2.0.sam.conv_w.0.weight", "layer2.0.sam.conv_w.0.bias", "layer2.0.sam.conv_w.0.running_mean", "layer2.0.sam.conv_w.0.running_var", "layer2.0.sam.conv_w.2.weight", "layer2.0.sam.conv_w.3.weight", "layer2.0.sam.conv_w.3.bias", "layer2.0.sam.conv_w.3.running_mean", "layer2.0.sam.conv_w.3.running_var", "layer2.0.sam.conv_w.5.weight", "layer2.0.sam.conv_w.5.bias", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.conv.weight", "layer2.0.conv.bias", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.sam.conv1.weight", "layer2.1.sam.conv1.bias", "layer2.1.sam.conv2.weight", "layer2.1.sam.conv2.bias", "layer2.1.sam.conv3.weight", "layer2.1.sam.conv3.bias", "layer2.1.sam.conv_w.0.weight", "layer2.1.sam.conv_w.0.bias", "layer2.1.sam.conv_w.0.running_mean", "layer2.1.sam.conv_w.0.running_var", "layer2.1.sam.conv_w.2.weight", "layer2.1.sam.conv_w.3.weight", "layer2.1.sam.conv_w.3.bias", "layer2.1.sam.conv_w.3.running_mean", "layer2.1.sam.conv_w.3.running_var", "layer2.1.sam.conv_w.5.weight", "layer2.1.sam.conv_w.5.bias", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.conv.weight", "layer2.1.conv.bias", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.sam.conv1.weight", "layer2.2.sam.conv1.bias", "layer2.2.sam.conv2.weight", "layer2.2.sam.conv2.bias", "layer2.2.sam.conv3.weight", "layer2.2.sam.conv3.bias", "layer2.2.sam.conv_w.0.weight", "layer2.2.sam.conv_w.0.bias", "layer2.2.sam.conv_w.0.running_mean", "layer2.2.sam.conv_w.0.running_var", "layer2.2.sam.conv_w.2.weight", "layer2.2.sam.conv_w.3.weight", "layer2.2.sam.conv_w.3.bias", "layer2.2.sam.conv_w.3.running_mean", "layer2.2.sam.conv_w.3.running_var", "layer2.2.sam.conv_w.5.weight", "layer2.2.sam.conv_w.5.bias", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.conv.weight", "layer2.2.conv.bias", "conv3.weight", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.sam.conv1.weight", "layer3.0.sam.conv1.bias", "layer3.0.sam.conv2.weight", "layer3.0.sam.conv2.bias", "layer3.0.sam.conv3.weight", "layer3.0.sam.conv3.bias", "layer3.0.sam.conv_w.0.weight", "layer3.0.sam.conv_w.0.bias", "layer3.0.sam.conv_w.0.running_mean", "layer3.0.sam.conv_w.0.running_var", "layer3.0.sam.conv_w.2.weight", "layer3.0.sam.conv_w.3.weight", "layer3.0.sam.conv_w.3.bias", "layer3.0.sam.conv_w.3.running_mean", "layer3.0.sam.conv_w.3.running_var", "layer3.0.sam.conv_w.5.weight", "layer3.0.sam.conv_w.5.bias", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.conv.weight", "layer3.0.conv.bias", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.sam.conv1.weight", "layer3.1.sam.conv1.bias", "layer3.1.sam.conv2.weight", "layer3.1.sam.conv2.bias", "layer3.1.sam.conv3.weight", "layer3.1.sam.conv3.bias", "layer3.1.sam.conv_w.0.weight", "layer3.1.sam.conv_w.0.bias", "layer3.1.sam.conv_w.0.running_mean", "layer3.1.sam.conv_w.0.running_var", "layer3.1.sam.conv_w.2.weight", "layer3.1.sam.conv_w.3.weight", "layer3.1.sam.conv_w.3.bias", "layer3.1.sam.conv_w.3.running_mean", "layer3.1.sam.conv_w.3.running_var", "layer3.1.sam.conv_w.5.weight", "layer3.1.sam.conv_w.5.bias", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.conv.weight", "layer3.1.conv.bias", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.sam.conv1.weight", "layer3.2.sam.conv1.bias", "layer3.2.sam.conv2.weight", "layer3.2.sam.conv2.bias", "layer3.2.sam.conv3.weight", "layer3.2.sam.conv3.bias", "layer3.2.sam.conv_w.0.weight", "layer3.2.sam.conv_w.0.bias", "layer3.2.sam.conv_w.0.running_mean", "layer3.2.sam.conv_w.0.running_var", "layer3.2.sam.conv_w.2.weight", "layer3.2.sam.conv_w.3.weight", "layer3.2.sam.conv_w.3.bias", "layer3.2.sam.conv_w.3.running_mean", "layer3.2.sam.conv_w.3.running_var", "layer3.2.sam.conv_w.5.weight", "layer3.2.sam.conv_w.5.bias", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.conv.weight", "layer3.2.conv.bias", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.sam.conv1.weight", "layer3.3.sam.conv1.bias", "layer3.3.sam.conv2.weight", "layer3.3.sam.conv2.bias", "layer3.3.sam.conv3.weight", "layer3.3.sam.conv3.bias", "layer3.3.sam.conv_w.0.weight", "layer3.3.sam.conv_w.0.bias", "layer3.3.sam.conv_w.0.running_mean", "layer3.3.sam.conv_w.0.running_var", "layer3.3.sam.conv_w.2.weight", "layer3.3.sam.conv_w.3.weight", "layer3.3.sam.conv_w.3.bias", "layer3.3.sam.conv_w.3.running_mean", "layer3.3.sam.conv_w.3.running_var", "layer3.3.sam.conv_w.5.weight", "layer3.3.sam.conv_w.5.bias", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.conv.weight", "layer3.3.conv.bias", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.sam.conv1.weight", "layer3.4.sam.conv1.bias", "layer3.4.sam.conv2.weight", "layer3.4.sam.conv2.bias", "layer3.4.sam.conv3.weight", "layer3.4.sam.conv3.bias", "layer3.4.sam.conv_w.0.weight", "layer3.4.sam.conv_w.0.bias", "layer3.4.sam.conv_w.0.running_mean", "layer3.4.sam.conv_w.0.running_var", "layer3.4.sam.conv_w.2.weight", "layer3.4.sam.conv_w.3.weight", "layer3.4.sam.conv_w.3.bias", "layer3.4.sam.conv_w.3.running_mean", "layer3.4.sam.conv_w.3.running_var", "layer3.4.sam.conv_w.5.weight", "layer3.4.sam.conv_w.5.bias", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.conv.weight", "layer3.4.conv.bias", "conv4.weight", "bn4.weight", "bn4.bias", "bn4.running_mean", "bn4.running_var", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.sam.conv1.weight", "layer4.0.sam.conv1.bias", "layer4.0.sam.conv2.weight", "layer4.0.sam.conv2.bias", "layer4.0.sam.conv3.weight", "layer4.0.sam.conv3.bias", "layer4.0.sam.conv_w.0.weight", "layer4.0.sam.conv_w.0.bias", "layer4.0.sam.conv_w.0.running_mean", "layer4.0.sam.conv_w.0.running_var", "layer4.0.sam.conv_w.2.weight", "layer4.0.sam.conv_w.3.weight", "layer4.0.sam.conv_w.3.bias", "layer4.0.sam.conv_w.3.running_mean", "layer4.0.sam.conv_w.3.running_var", "layer4.0.sam.conv_w.5.weight", "layer4.0.sam.conv_w.5.bias", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.conv.weight", "layer4.0.conv.bias", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.sam.conv1.weight", "layer4.1.sam.conv1.bias", "layer4.1.sam.conv2.weight", "layer4.1.sam.conv2.bias", "layer4.1.sam.conv3.weight", "layer4.1.sam.conv3.bias", "layer4.1.sam.conv_w.0.weight", "layer4.1.sam.conv_w.0.bias", "layer4.1.sam.conv_w.0.running_mean", "layer4.1.sam.conv_w.0.running_var", "layer4.1.sam.conv_w.2.weight", "layer4.1.sam.conv_w.3.weight", "layer4.1.sam.conv_w.3.bias", "layer4.1.sam.conv_w.3.running_mean", "layer4.1.sam.conv_w.3.running_var", "layer4.1.sam.conv_w.5.weight", "layer4.1.sam.conv_w.5.bias", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.conv.weight", "layer4.1.conv.bias", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict: "epoch", "state_dict", "optimizer", "scheduler", "top1_val", "top5_val".

Enquiry regarding load_kernel

Hello Dr. Zhao, I was studying your paper and the code shared on the page but had difficulty understanding the "load_kernel()" and "f(block=...., grid=(...., args=[..., stream=...)" parts in subtraction/subtraction2/aggregate functions. How do they affect the subtraction relation function and aggregation function and will the results differ without them?

Any help would be appreciated. Thank you

What calculations are performed by the Aggregation function?

Thank you for your great work!

When I run the SAM module on PatchWise-Attention (sa_type=1) for verification, the input tensors of the Aggregation function has the following shape.

In the paper, it says that the outputs of the streams are aggregated via a Hadamard product.

Would you like to tell me what operations are performed on the tensors of these different shapes to perform a Hadamard product?

customized subtraction and aggregation implement much slower than the pytorch implementation

enviroment:
pytorch 1.5.1
cuda 10.1
test on small input tensor (2,8,5,5)

when using the test method in lib/sa/functions to test the speed, I found that the corresponding implementation using pytorch api is much faster than your C code in backward propogation (about 50X faster). Although the forward times of them are relatively close, customized api is slightly faster than torch api.

image

So why you choose to implement the operation on your own?

Question on Paper

Hi!!! Dr Zhao. It is a very interesting work. I have one Question. Have you tried your model with more layers ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.