leaplabthu / agent-attention Goto Github PK
View Code? Open in Web Editor NEWOfficial repository of Agent Attention
Official repository of Agent Attention
Thank you for your excellent work! I find that the original window size in Swin-T is 7, whereas in agent_swin, it is 56. I am curious about your design choices regarding the window size and stage attention types in agent-swin-T/S/B. Are there any guiding principles behind these decisions?
你好,作者,如果我现在不仅有h,w ,还有一个d ,作者有没有考虑过这方面,请您联系我[email protected]
When I try to run agentsd
with stable diffusion v2.1, it generates black images. Sample below:
I have added agentsd
folder to the root and following lines of code
import agentsd
if i == 0:
agentsd.remove_patch(self.model)
agentsd.apply_patch(self.model, sx=4, sy=4, ratio=0.4, agent_ratio=0.95, attn_precision="fp32")
elif i == 20:
agentsd.remove_patch(self.model)
agentsd.apply_patch(self.model, sx=2, sy=2, ratio=0.4, agent_ratio=0.5, attn_precision="fp32")
to the ldm/models/diffusion/ddim.py
file after L152 as per the instructions. Without this, it works fine.
Thank you for your excellent work, I would like to ask where the A in your paper is in the code?
Is it the code below, but I didn't find the code about A?
class AgentAttention of agent_transformer/models/agent_swin.py
hello, I greatly appreciate your awesome work.
It seems that self.attn1
in https://github.com/LeapLabTHU/Agent-Attention/blob/master/agentsd/patch.py#L220 is replaced as AgentAttention, whose forward function accepts forward(self, x, agent=None, context=None, mask=None)
as parameters. However, in https://github.com/LeapLabTHU/Agent-Attention/blob/master/agentsd/patch.py#L220, encoder_hidden_states
and attention_mask
are passed to self.attn1
, which causes the problem forward() got an unexpected keyword argument 'encoder_hidden_states'
.
do you have any solution? thanks a lot
the total model's input size is (1,4,128,128,128),1 represent batchsize, 4 represent channel ,128 represent h,w,d respectively,i read your Appendix A about Agent Bias,you say Each position offset is composed of three parameters, ssuch as B1 = (B′1c + B′1r + B′1b),include column bias B1c ∈ Rn×1×w, row bias B1r ∈ Rn×h×1 and block bias B1b ∈ Rn×h0×w0,but now have d this dimension,how can i modify the agent bias , should i add a new parameter ? like B1c B1r B1b and B1d,can you give me some addvice ,waiting for your reply
Thank you for your inspiring work. I'm eager to apply ideas like agent attention to train prevailing auto-regressive models like GPT. However, using pooling on Q to get the A matrix will cause information leakage when training auto-regressive models using teacher forcing. I haven't found related discussions in your paper. Is there any straightforward extension or variation of agent attention to adapt it to auto-regressive models?
sys.platform: linux
Python: 3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0]
CUDA available: True
GPU 0,1,2,3: NVIDIA TITAN X (Pascal)
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 10.1, V10.1.24
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.12.1
PyTorch compiling details: PyTorch built with:
2023-12-31 21:14:43,557 - mmdet - INFO - Distributed training: False
2023-12-31 21:14:44,477 - mmdet - INFO - Config:
model = dict(
type='RetinaNet',
backbone=dict(
type='AgentPVT',
img_size=224,
patch_size=4,
in_chans=3,
num_classes=6,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
depths=[3, 4, 18, 3],
sr_ratios=[8, 4, 2, 1],
agent_sr_ratios='1111',
num_stages=4,
agent_num=[9, 16, 49, 49],
downstream_agent_shapes=[(12, 12), (16, 16), (28, 28), (28, 28)],
kernel_size=3,
attn_type='AAAA',
scale=-0.5,
init_cfg=dict(type='Pretrained', checkpoint=None)),
neck=dict(
type='FPN',
in_channels=[64, 128, 320, 512],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=6,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0.0, 0.0, 0.0, 0.0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
dataset_type = 'CocoDataset'
data_root = '/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
train_dataloader=dict(
samples_per_gpu=2, workers_per_gpu=10, pin_memory=True),
train=dict(
type='CocoDataset',
ann_file=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/annotations/instances_train2017.json',
img_prefix=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/train2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]),
val=dict(
type='CocoDataset',
ann_file=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/annotations/instances_val2017.json',
img_prefix=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/val2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]),
test=dict(
type='CocoDataset',
ann_file=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/annotations/instances_val2017.json',
img_prefix=
'/home/z/zky/Cosistentteacher/ConsistentTeacher-main/data/val2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]))
evaluation = dict(interval=1, metric='bbox')
optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
opencv_num_threads = 0
mp_start_method = 'fork'
auto_scale_lr = dict(enable=False, base_batch_size=16)
pretrained = None
lr = 0.0001
work_dir = './work_dirs/agent_pvt_m_rtn_1x_12-16-28-28'
auto_resume = False
gpu_ids = [0]
2023-12-31 21:14:44,477 - mmdet - INFO - Set random seed to 2100000934, deterministic: False
Traceback (most recent call last):
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 69, in build_from_cfg
return obj_cls(**args)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/backbones/agent_pvt.py", line 300, in init
for j in range(depths[i])])
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/backbones/agent_pvt.py", line 300, in
for j in range(depths[i])])
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/backbones/agent_pvt.py", line 217, in init
agent_num=agent_num, downstream_agent_shape=downstream_agent_shape, kernel_size=kernel_size, scale=scale)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/backbones/agent_pvt.py", line 129, in init
print('Agent Attention sr{} v{} n{} k{} scale{} reso{}'.format(sr_ratio, agent_num, kernel_size, scale, window_size))
IndexError: tuple index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 69, in build_from_cfg
return obj_cls(**args)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/detectors/retinanet.py", line 19, in init
test_cfg, pretrained, init_cfg)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/detectors/single_stage.py", line 32, in init
self.backbone = build_backbone(backbone)
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/builder.py", line 20, in build_backbone
return BACKBONES.build(cfg)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 237, in build
return self.build_func(*args, **kwargs, registry=self)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/cnn/builder.py", line 27, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 72, in build_from_cfg
raise type(e)(f'{obj_cls.name}: {e}')
IndexError: AgentPVT: tuple index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/z/zky/Agent-Attention-master/downstream/detection/tools/debug_train.py", line 244, in
main()
File "/home/z/zky/Agent-Attention-master/downstream/detection/tools/debug_train.py", line 215, in main
test_cfg=cfg.get('test_cfg'))
File "/home/z/zky/Agent-Attention-master/downstream/detection/mmdet/models/builder.py", line 59, in build_detector
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 237, in build
return self.build_func(*args, **kwargs, registry=self)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/cnn/builder.py", line 27, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "/home/z/anaconda3/envs/agent_detection/lib/python3.7/site-packages/mmcv/utils/registry.py", line 72, in build_from_cfg
raise type(e)(f'{obj_cls.name}: {e}')
IndexError: RetinaNet: AgentPVT: tuple index out of range
Process finished with exit code 1
Hi, I would like to use your model into DDPM or DDIM method. Is it possible?
Could you tell me which code or file is really important to add to them?
Thank you for your outstanding work! I wanted to inquire if you are familiar with the concept of anchored stripe attention discussed in the paper titled "Efficient and Explicit Modelling of Image Hierarchies for Image Restoration." It appears that there are striking similarities between these two attention mechanisms. Could you elucidate the key distinctions between them?
Thanks for your excellent work. I noticed that AgentSD performs better than ToMeSD and Stable Diffusion v1.5. Could you provide codes for calculating FID, Time (s/im), and Memory (GB/im)?
作者您好,非常感谢您能分享如此有意思的工作。
我在复现您工作时在agent_pvt.py中发现了个小问题:如果self.sr_ratio>1,则在134行后,放缩过后的k和v的维度应该是q的self.sr_ratio平方分之1,而在144-146行中,qkv采用了相同的reshape维度,这里kv的reshape操作可能会有问题。我注意到您使用sr_ratio=sr_ratios[i] if attn_type[i] == 'B' else int(agent_sr_ratios[i])在agentattn下强制sr_ratio=1,那么当sr_ratio不等1时该如何处理呢?还是说agentattn不支持kv的放缩?
感谢您的回复
Hi, I would like to ask if my network is built using the tensorflow framework for stable diffusion, this is the code for it https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_ diffusion/stable_diffusion.py, I want to add your agent attention to this network, but it seems that since the example you gave is also a stable diffusion built using pytorch, it reports an error when I add it to my network. ?
Hello, I encountered several questions while reading your code.
1.What does 'agent_num mean'? I did not find a clear definition in the paper.
2.When defining bias, why is the dimension of bias related to agent_num? If I modify agent_num, it will result in dimension mismatch. However, I noticed that your paper includes comparisons with different agent_num.
I hope to receive your response.
Managed to eek out only 1 768x1024 picture before it starts demanding 9GB RAM on saving the picture.
Thanks for the information.
I have additional question.
What is that meaning of remove and apply patch? also sx, sy, ratio?
agentsd is your model?
agentsd.remove_patch(self.model)
agentsd.apply_patch(model, sx=4, sy=4, ratio=0.4, agent_ratio=0.95)
actually, I would like to apply your agent attention module to ddim from guided diffusion model.
thanks,
jungmin
Hello,
I've been exploring your work on Cross Transformer, and I'm intrigued by the potential of integrating Agent Attention into this architecture. Agent Attention, as a method to balance computational efficiency and representation power, seems like it could complement the Cross Transformer's design quite well.
I'm particularly interested in understanding how Agent Attention might be incorporated into the Cross Transformer framework. Specifically, my questions are:
Any insights or suggestions you could provide would be greatly appreciated. I believe such an integration could offer a promising direction for further research and application.
Thank you for your time and for the impactful work you've shared with the community.
Best regards,
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms --precision full
stable-diffusion-v1-4
Error:
File "scripts/txt2img.py", line 353, in
main()
File "scripts/txt2img.py", line 303, in main
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
File "/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/code/stable-diffusion-main/ldm/models/diffusion/plms.py", line 97, in sample
samples, intermediates = self.plms_sampling(conditioning, size,
File "/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/code/stable-diffusion-main/ldm/models/diffusion/plms.py", line 159, in plms_sampling
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
File "/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/code/stable-diffusion-main/ldm/models/diffusion/plms.py", line 225, in p_sample_plms
e_t = get_model_output(x, t)
File "/code/stable-diffusion-main/ldm/models/diffusion/plms.py", line 192, in get_model_output
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
File "/code/stable-diffusion-main/ldm/models/diffusion/ddpm.py", line 987, in apply_model
x_recon = self.model(x_noisy, t, **cond)
x_recon = self.model(x_noisy, t, **cond)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/models/diffusion/ddpm.py", line 1410, in forward
out = self.diffusion_model(x, t, context=cc)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1128, in _call_impl
result = forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/modules/diffusionmodules/openaimodel.py", line 731, in forward
h = module(h, emb, context)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/modules/diffusionmodules/openaimodel.py", line 85, in forward
x = layer(x, context)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/modules/attention.py", line 258, in forward
x = block(x, context=context)
File "/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/code/stable-diffusion-main/ldm/modules/attention.py", line 209, in forward
return checkpoint(self.forward, (x, context), self.parameters(), self.checkpoint)
File "/code/stable-diffusion-main/ldm/modules/diffusionmodules/util.py", line 114, in checkpoint
return CheckpointFunction.apply(func, len(inputs), *args)
File "/code/stable-diffusion-main/ldm/modules/diffusionmodules/util.py", line 127, in forward
output_tensors = ctx.run_function(*ctx.input_tensors)
File "/code/stable-diffusion-main/ldm/models/diffusion/agentsd/patch.py", line 66, in forward
feature, agent = m_a(y)
File "/code/stable-diffusion-main/ldm/models/diffusion/agentsd/merge.py", line 118, in merge
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
TypeError: scatter_reduce(): argument 'reduce' (position 3) must be str, not Tensor
Hi, May I ask you for your code part to connect guided diffusion? Because I would like to match your code to guided diffusion, not stable diffusion. I think it will be possible, right? Can you give me some help to match q, k, v, agent to guided diffusion?
thanks,
感谢您出色的工作,我还有一个问题,打扰您一下:我将swin transformer的注意力换成了您的AgentAttention:
class AgentAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
shift_size=0, agent_num=49, **kwargs):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.shift_size = shift_size
self.agent_num = agent_num
self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)
self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0], 1))
self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1]))
self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))
self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))
trunc_normal_(self.an_bias, std=.02)
trunc_normal_(self.na_bias, std=.02)
trunc_normal_(self.ah_bias, std=.02)
trunc_normal_(self.aw_bias, std=.02)
trunc_normal_(self.ha_bias, std=.02)
trunc_normal_(self.wa_bias, std=.02)
pool_size = int(agent_num ** 0.5)
self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
b, n, c = x.shape
h = int(n ** 0.5)
w = int(n ** 0.5)
num_heads = self.num_heads
head_dim = c // num_heads
qkv = self.qkv(x).reshape(b, n, 3, c).permute(2, 0, 1, 3)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# q, k, v: b, n, c
agent_tokens = self.pool(q.reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)
q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)
position_bias1 = nn.functional.interpolate(self.an_bias, size=self.window_size, mode='bilinear')
position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
position_bias = position_bias1 + position_bias2
agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)
agent_attn = self.attn_drop(agent_attn)
agent_v = agent_attn @ v
agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')
agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)
agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)
agent_bias = agent_bias1 + agent_bias2
q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)
q_attn = self.attn_drop(q_attn)
x = q_attn @ agent_v
x = x.transpose(1, 2).reshape(b, n, c)
v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)
x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
我使用了预训练权重:/home/class1/work/modify/G/checkpoints/swin_tiny_patch4_window7_224.pth
https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
数据集使用了coco格式的数据,替换成您的AgentAttention后,发生了如下错误:
python-BaseException
Traceback (most recent call last):
File "/home/class1/.pycharm_helpers/pydev/pydevd.py", line 1491, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/home/class1/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/class1/work/modify/G/tools/train.py", line 276, in <module>
main()
File "/home/class1/work/modify/G/tools/train.py", line 239, in main
model.init_weights()
File "/home/class1/.conda/envs/mm100/lib/python3.7/site-packages/mmcv/runner/base_module.py", line 117, in init_weights
m.init_weights()
File "/home/class1/work/modify/G/mmdet/models/backbones/swin_test.py", line 1296, in init_weights
table_current = self.state_dict()[table_key]
KeyError: 'stages.0.blocks.0.attn.w_msa.relative_position_bias_table'
您知道如何解决吗?谢谢您!
/home/algorithms/Agent-Attention/downstream/mmcv/mmcv/cnn/bricks/transformer.py:33: UserWarning: Fail to import MultiScaleDeformableAttention
from mmcv.ops.multi_scale_deform_attn
, You should install mmcv-full
if you need this module.
warnings.warn('Fail to import MultiScaleDeformableAttention
from '
Traceback (most recent call last):
File "tools/test.py", line 17, in
from mmseg.apis import multi_gpu_test, single_gpu_test
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/apis/init.py", line 2, in
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/apis/inference.py", line 9, in
from mmseg.models import build_segmentor
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/init.py", line 2, in
from .backbones import * # noqa: F401,F403
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/backbones/init.py", line 7, in
from .fast_scnn import FastSCNN
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/backbones/fast_scnn.py", line 7, in
from mmseg.models.decode_heads.psp_head import PPM
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/decode_heads/init.py", line 2, in
from .ann_head import ANNHead
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/decode_heads/ann_head.py", line 8, in
from .decode_head import BaseDecodeHead
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/decode_heads/decode_head.py", line 12, in
from ..losses import accuracy
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/losses/init.py", line 6, in
from .focal_loss import FocalLoss
File "/home/algorithms/Agent-Attention/downstream/segmentation/mmseg/models/losses/focal_loss.py", line 6, in
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
File "/home/algorithms/Agent-Attention/downstream/mmcv/mmcv/ops/init.py", line 2, in
from .active_rotated_filter import active_rotated_filter
File "/home/algorithms/Agent-Attention/downstream/mmcv/mmcv/ops/active_rotated_filter.py", line 10, in
ext_module = ext_loader.load_ext(
File "/home/algorithms/Agent-Attention/downstream/mmcv/mmcv/utils/ext_loader.py", line 13, in load_ext
ext = importlib.import_module('mmcv.' + name)
File "/root/miniconda3/envs/agent_segmentation/lib/python3.8/importlib/init.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
ModuleNotFoundError: No module named 'mmcv._ext'
作者你好:
今天有幸读了这篇文章,感觉非常的潜力。同时我目前在进行图像融合方面的研究,在此有几个问题希望作者可以解惑。
1.agent_tokens = self.pool(q[:, 1:, :].reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)在此行代码中,对Q做了切片处理,但是切片之后reshape的话是不是就没办法变成(b, h, w, c)了呢。
2.在Agent-feature加了bias,但是position_bias = torch.cat([self.ac_bias.repeat(b, 1, 1, 1), position_bias], dim=-1)这里的偏移是cat在一起的,这样的话不就多出一个维度了。
3.如果不加偏移的话,请问结果会有很大的出入吗
4.bias是依据什么来划分的呢。
期待您的回复
May I ask how this can be applied to a Siamese tracking network where the input images are of different sizes and serve as a backbone for weight sharing? I noticed that the agent_num and window parameters are related to the input image size, how can I set them to apply to different input image sizes at the same time?
Is there a pre-training weight that can be downloaded
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.