Few-shot segmentation method, using ResNet as encoder and UNetV3 as decoder. UNetV3 is from Pytorch-UNet-2 ResNet is from torchvision.models
Method we use is similar to Integrative Few-Shot Learning for Classification and Segmentation
- Feed query image, support image and support mask to the model.
- Get layer 1 (L1), layer 2 (L2), layer 3 (L3) and layer 4 (L4) feature maps from ResNet for query(Q) and support images(S)
- Check the cosine similarities of pixels from support and query image feature maps.(L1,L2,L3,L4) a. Example: Take one pixel location from L1 of Q and calculate cosine similarity between this pixel and the pixels from L1 of S. Repeat the process for all pixels in Q.
- The acquired cosine similarity matrix C is in this size: (Batch Size, Height*Width, Height, Width). And we have 4 cosine similarity matrices in total for L1, L2, L3, and L4.
- Concatenate C1, C2, C3, C4. (Cc)
- Flatten the support image mask and multiply it with Cc. This is to eliminate the effect of not desired (the pixels which are not in the mask area) pixels.
- Put Cc in 2D 1x1 convolution to reduce the Height*Width number of feature maps to predetermined number (N) of feature maps. (Output: B, N, H, W)
- Feed this feature map to UNET.
- Get the generated mask and calculate BCELoss with generated mask and expected mask of query image.