有一个问题想请教一下,就是原始的CenterNet的输出有三个分支,分别是 heatmap (W*H*C),offset (W*H*2)和size (W*H*2),然后你这里加了一个seg_feat,这个分支是怎么加的,能介绍一下吗?能否告知是在代码的哪一处?这里的seg_feat它的size是什么样子的?怎么为每个中心点分配一个mask?难道与offset和size一样,预测一个 W*H*W*H的seg_feat?
def forward(self, seg_feat, conv_weight, mask,ind, target):
mask_loss=0.
batch_size = seg_feat.size(0)
weight = _tranpose_and_gather_feat(conv_weight, ind)
h,w = seg_feat.size(-2),seg_feat.size(-1)
x,y = ind%w,ind/w
x_range = torch.arange(w).float().to(device=seg_feat.device)
y_range = torch.arange(h).float().to(device=seg_feat.device)
y_grid, x_grid = torch.meshgrid([y_range, x_range])
for i in range(batch_size):
num_obj = target[i].size(0)
conv1w,conv1b,conv2w,conv2b,conv3w,conv3b= \
torch.split(weight[i,:num_obj],[(self.feat_channel+2)*self.feat_channel,self.feat_channel,
self.feat_channel**2,self.feat_channel,
self.feat_channel,1],dim=-1)
y_rel_coord = (y_grid[None,None] - y[i,:num_obj].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float())/128.
x_rel_coord = (x_grid[None,None] - x[i,:num_obj].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float())/128.
feat = seg_feat[i][None].repeat([num_obj,1,1,1])
feat = torch.cat([feat,x_rel_coord, y_rel_coord],dim=1).view(1,-1,h,w)
conv1w=conv1w.contiguous().view(-1,self.feat_channel+2,1,1)
conv1b=conv1b.contiguous().flatten()
feat = F.conv2d(feat,conv1w,conv1b,groups=num_obj).relu()
conv2w=conv2w.contiguous().view(-1,self.feat_channel,1,1)
conv2b=conv2b.contiguous().flatten()
feat = F.conv2d(feat,conv2w,conv2b,groups=num_obj).relu()
conv3w=conv3w.contiguous().view(-1,self.feat_channel,1,1)
conv3b=conv3b.contiguous().flatten()
feat = F.conv2d(feat,conv3w,conv3b,groups=num_obj).sigmoid().squeeze()
true_mask = mask[i,:num_obj,None,None].float()
mask_loss+=dice_loss(feat*true_mask,target[i]*true_mask)
return mask_loss/batch_size