SAM+RS:ClassWise-SAM-Adapter: 微调SAM适应SAR域的语义分割

ClassWise-SAM-Adapter (CWSAM)设计用于在星载合成孔径雷达(SAR)图像上对SAM进行土地覆盖分类。提出的CWSAM冻结了SAM的大部分参数,并结合轻量级adpter进行参数高效微调,并设计了一个分类掩码解码器来实现语义分割任务。这种自适应调整方法可以有效地对SAR图像进行土地覆盖分类,平衡精度和计算需求。此外,任务特定输入模块通过基于mlp的层注入SAR图像的低频信息,提高模型性能。通过大量的实验,与传统的最先进的语义分割算法相比,CWSAM以更少的计算资源展示了更高的性能,突出了利用SAM等基础模型在SAR领域中特定下游任务的潜力。源码:https://github.com/xypu98/CWSAM

1)提出的CWSAM引入了一个端到端架构,带有轻量级adpter,可以在大型模型上进行有效的参数微调。该方法将SAM的自然场景域转换为SAR域。利用可视化基础模型,实现了利用SAR影像完成土地覆盖分类任务的可靠性能。

2)设计了CWSAM的分类掩码解码器,为细粒度语义分割下游任务创建从分类属性到原始类别不可知的SAM模型。对于SAR图像,利用分类掩码解码器在像元水平上识别多个地表覆盖类别。

3)针对特定任务的SAR图像低频信息输入模块,通过二维图像快速傅里叶变换提供足够的地表覆盖特征语义信息,增强分割性能。

综合实验表明,CWSAM在对计算成本要求不高的情况下,超越了现有的多种语义分割算法。在满足了SAR图像与自然图像的特征差异的基础上,实现了利用基础模型的优势来提高SAR图像的分割性能。

Adapted Vision Transformer-based image encoder

保持了由Vision Transformer构建的SAM图像编码器的原始架构不变,并在训练过程中冻结了所有原始参数。受AdaptFormer中的AdaptMLP的启发,在Vision Transformer结构的单个块中只插入几个简单高效的adapter。

adapter:图3 每个Transformer块由两个子块组成:多头自注意力层和MLP层,它们由带有初始参数的adapter模块注入用于训练。

adapter都包含统一的结构:一个低阶全连接层Down,一个ReLU激活函数和一个高阶全连接层Up。

Classwise mask decoder and loss function

为了表示每个像素的多个类别,在SAM的原始掩码解码器架构的基础上构建了一个分类掩码解码器,并包含一个类别预测头来完成土地覆盖分类任务。本研究忽略了动态掩码预测头,只得到一个掩码结果,用于计算训练损失。

架构:分类掩码编码器的框架如图4所示。首先,我们探索了两个Two - way Transformer块来提取所有嵌入的特征,包括图像嵌入和提示嵌入。采用了多种注意机制,如提示嵌入的自注意,图像嵌入与提示嵌入的交叉注意两个方向(图像到提示嵌入和提示到图像嵌入)。SAM掩码解码器的上尺度卷积块准备最终输出的二进制掩码。然而,基于二值掩码结果的SAM掩码解码器的原始设计无法在像素级上对语义分割任务进行分类。因此,实现了一个分类轻量级卷积模块来为所有类别生成分类掩码。

 

(左图为本文,右图为SAM原始decoder,左图红色为变化,加了个箭头和模块)

在SAM掩码解码器的输出upscale卷积块中,加入额外的轻量级上尺度卷积模块,并从头开始训练所有随机初始化参数,完成最终的多类掩码预测。upscale卷积模块由反卷积层、RELU激活层和卷积块组成,扩大语义特征,生成N类通道的最终掩码预测。

特征增强模块:此外,考虑到像素级分类的语义分割任务需要从特征图中获取丰富的语义信息,我们在提出的分类掩码解码器中设计了特征增强模块。图像嵌入和上尺度卷积块之间的跳变连接平行于双向Transformer块,将源图像嵌入和双向Transformer输入的上尺度卷积块的输出特征连接在一起,使原始缩放特征映射的维数增加一倍。将预预测掩码与MLP层的token进行点积后,生成预测的多类掩码,并为图像的每个像素分配特定的类别。(对比sam,感觉就是加了个很小的模块,然后多了几个箭头输入image embedding)

class MaskDecoder(nn.Module):
    def __init__(
            self,
            *,
            transformer_dim: int,
            transformer: nn.Module,
            num_multimask_outputs: int = 4,
            activation: Type[nn.Module] = nn.GELU,
            iou_head_depth: int = 3,
            iou_head_hidden_dim: int = 256,
            num_classes: int = 1
    ) -> None:
        """
        Predicts masks given an image and prompt embeddings, using a
        transformer architecture.

        Arguments:
          transformer_dim (int): the channel dimension of the transformer
          transformer (nn.Module): the transformer used to predict masks
          num_multimask_outputs (int): the number of masks to predict
            when disambiguating masks
          activation (nn.Module): the type of activation to use when
            upscaling masks
          iou_head_depth (int): the depth of the MLP used to predict
            mask quality
          iou_head_hidden_dim (int): the hidden dimension of the MLP
            used to predict mask quality
        """
        super().__init__()
        self.transformer_dim = transformer_dim
        self.transformer = transformer

        self.num_classes = num_classes

        self.num_multimask_outputs = num_multimask_outputs

        self.iou_token = nn.Embedding(1, transformer_dim)  # (1,256)
        self.num_mask_tokens = num_multimask_outputs + 1
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)  # (4,256)

        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
            activation(),
        )
        self.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
                for i in range(self.num_mask_tokens)
            ]
        )

        self.iou_prediction_head = MLP(
            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
        )

        # self.output_upscaling_adaptor = nn.Sequential(
        #     nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
        #     LayerNorm2d(transformer_dim // 4),
        #     activation(),
        #     nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
        #     activation(),
        # )

        self.cls_upscaling = nn.Sequential(
            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 4, kernel_size=2, stride=2),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.Conv2d(transformer_dim // 4, transformer_dim * self.num_classes // 8, kernel_size=7, stride=2,
                      padding=3),
            activation(),
        )
        # self.cls_tokens_embeddings = nn.ModuleList(
        #     [nn.Embedding(4,transformer_dim)  #4-num_cls
        #      for i in range(self.num_mask_tokens)
        #      ]
        # )
        # self.cls_tokens = nn.Embedding(4,transformer_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(
            self,
            image_embeddings: torch.Tensor,
            image_pe: torch.Tensor,
            sparse_prompt_embeddings: torch.Tensor,
            dense_prompt_embeddings: torch.Tensor,
            multimask_output: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict masks given image and prompt embeddings.

        Arguments:
          image_embeddings (torch.Tensor): the embeddings from the image encoder
          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
          multimask_output (bool): Whether to return multiple masks or a single
            mask.

        Returns:
          torch.Tensor: batched predicted masks分批预测掩码
          torch.Tensor: batched predictions of mask quality掩模质量的批量预测
        """
        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # Select the correct mask or masks for output
        if multimask_output:
            mask_slice = slice(1, None)
        else:
            mask_slice = slice(0, 1)
        masks = masks[:, mask_slice, :, :, :]#todo
        iou_pred = iou_pred[:, mask_slice]

        # Prepare output
        return masks, iou_pred

    def predict_masks(
            self,
            image_embeddings: torch.Tensor,  # (1,256,64,64)
            image_pe: torch.Tensor,  # (1,256,64,64)
            sparse_prompt_embeddings: torch.Tensor,  # (1,0,256)
            dense_prompt_embeddings: torch.Tensor,  # (1,256,64,64)
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predicts masks. See 'forward' for more details."""
        # Concatenate output tokens
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)  # (5,256)
        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)  # (1,5,256)
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)  # (1,5,256)

        # Expand per-image data in batch direction to be per-mask
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)  # (1,256,64,64)
        src = src + dense_prompt_embeddings
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)  # (1,256,64,64)
        b, c, h, w = src.shape  # 1,256,64,64
        src_feature = src

        # Run the transformer
        hs, src = self.transformer(src, pos_src, tokens)  # hs (1,5,256) src (1,4096,256)
        iou_token_out = hs[:, 0, :]  # (1,256)
        mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :]  # (1,4,256)

        # Upscale mask embeddings and predict masks using the mask tokens
        src = src.transpose(1, 2).view(b, c, h, w)  # (1,256,64,64)
        upscaled_embedding = self.output_upscaling(src)  # (1,32,256,256)
        upscaled_embedding_src = self.output_upscaling(src_feature)  # (1,32,256,256)
        # upscaled_embedding_adaptor = self.output_upscaling_adaptor(src)  #(1,32,256,256)
        hyper_in_list: List[torch.Tensor] = []  # 4* (1,32)
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
        hyper_in = torch.stack(hyper_in_list, dim=1)  # (1,4,32)
        b, c, h, w = upscaled_embedding.shape  # (1,32,256,256)

        ##### add MLP by pxy 230706 ######

        # scr_mlp_module = MLP(256,256,256,3)

        upscaled_embedding_concat = torch.cat([upscaled_embedding, upscaled_embedding_src], dim=1)

        cls_upscaled_embedding = self.cls_upscaling(upscaled_embedding_concat)

        # cls_upscaled_embedding = self.cls_upscaling(upscaled_embedding)
        masks = (hyper_in @ cls_upscaled_embedding.view(b, c, self.num_classes * h * w)).view(b, self.num_mask_tokens,
                                                                                              -1, h, w)  # (1,4,256,256)这里应该不对

        # masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) #(1,4,256,256)

        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)  # (1,4)

        return masks, iou_pred

任务特定输入的低频SAR特性

虽然通过对图像编码器中adapter的参数进行微调,可以将SAM图像转换到SAR图像域,但自然感知信号与微波信号之间的差异不断影响着模型的性能。SAM的Vision Transformer图像编码器的特征提取仍然缺乏SAR域信息,因为冻结的参数会丢失一些语义信息。因此,为了扩展SAR图像的合理语义信息,用于分割下游任务,我们构建了一个基于mlp的架构,并与VIT图像编码器并行,以保持SAR图像中有意义的低频信息,如地表纹理和依赖信号反射率的像素亮度。

通过二维快速傅里叶变换、低通滤波、快速傅里叶反变换等过程得到输入图像的低频特性…略…

如图5所示,ViT编码器输入图像的patch embedding feature属于低维并加入低频特征。然后,多个MLP块分别为每个Vision Transformer块提取该特征,最后由参数共享MLP块负责实现任务特定输入模块与原始SAM图像编码器之间的各个特征融合。(设计的…感觉像是炼丹)