새소식

Computer Vision

[SAM] 코드 리뷰

  • -

sam 코드에 대해 분석해보려고 한다.

Github : https://github.com/facebookresearch/segment-anything/tree/main

Image Encoder - Image embedding 생성

  • 실제 이미지를 image embedding으로 변환
x image [bs, 3, 1024, 1024]

 

[1] patch map 생성

x = self.patch_embed(x)             # [b, h, w, c]

 

  • PatchEmbed() : image -> patch map
def forward(self, x: torch.Tensor) -> torch.Tensor :
        x = self.proj(x) # Cov2d(3, 768, (16,16), (16,16))
        # B C H W -> B H W C
        x = x.permute(0, 2, 3, 1)
        return x
  • 16 x 16 size의 patch로 이루어진 patch map 생성

[2] ViT encoder 통과

if self.pos_embed is not None:      #* positional encoding 진행
        x = x + self.pos_embed

for blk in self.blocks:
    x = blk(x)

x = self.neck(x.permute(0, 3, 1, 2))    # [b, c, h, w]
  • self.blocks : ViT encoder 부분

Prompt Encoder

  • prompts를 embedding으로 매핑
  • prompt의 역할 : 이미지에서 segmentation 할 대상 지정
forward : 실제 prompt embedding 진행
points point coordinates(x,y) & labels(0:background, 1:foreground) [bs, nP, 2] & [bs, nP]
boxes boxes [bs, nB, 4]
masks masks  
1. Sparse embedding 구하기 - points, boxes
# 빈 embedding 정의
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())         # [bs,  0,  256]

# (1) points embedding화
if points is not None:
		coords, labels = points
		point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))              # [bs,  (nPoints + 1),  256]
		sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)

# (2) boxes embedding화
if boxes is not None:
		box_embeddings = self._embed_boxes(boxes)
		sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)               # [bs,  (nPoints + nBoxes + 1), 256]​

  - points embedding화 : self._embed_points 호출
  - boxes embedding화 : self._embed_boxes 호출

2. Dense embedding 구하기 - mask
if masks is not None:
            dense_embeddings = self._embed_masks(masks)
        else:
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )

 - mask embedding화 : self._embed_masks 호출

 

Sparse prompt : 점(Points), 박스(Boxes), 텍스트(Text)

  • Points embedding
def _embed_points(  # 해당점의 위치 + 피사체와 배경을 구별하게 학습된 임베딩
        self,
        points: torch.Tensor,
        labels: torch.Tensor,
        pad: bool,      # boxes가 없으면 pad=True, boxes가 있으면 pad=False
    ) -> torch.Tensor:
        """Embeds point prompts."""
        points = points + 0.5  # Shift to center of pixel

				# boxes가 없으면 -> pad point 하나 추가(coord:[0,0], label:-1)
        if pad:     
            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)  # [bs, 1, 2]
            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)     # [bs, 1]
            points = torch.cat([points, padding_point], dim=1)
            labels = torch.cat([labels, padding_label], dim=1)
        
				# coord에 맞춰 PE mapping
				point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)  # [bs, nP+ ?(pad or not), embed_dim] - embedding
        # label에 맞춰 embedding 더해주기
				point_embedding[labels == -1] = 0.0                                                 # pad 
        point_embedding[labels == -1] += self.not_a_point_embed.weight
        point_embedding[labels == 0] += self.point_embeddings[0].weight                     # 0(background) : pos point embedding mapping 더함
        point_embedding[labels == 1] += self.point_embeddings[1].weight                     # 1(foreground) : neg point embedding mapping 더함
        return point_embedding
더보기
self.pe_layer.forward_with_coords - coords에 맞춰 PE mapping
def forward_with_coords(
        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
    ) -> torch.Tensor:
        """Positionally encode points that are not normalized to [0,1]."""
        coords = coords_input.clone()
        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
        return self._pe_encoding(coords.to(torch.float))  # B x N x C​
  • RETURN : [bs, nP+ $\alpha$ , embed_dim]
  • Boxes embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
        """Embeds box prompts."""
        boxes = boxes + 0.5  # Shift to center of pixel

        coords = boxes.reshape(-1, 2, 2)
        corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
        return corner_embedding                                           # [bs*bB(1), 2, 256]
  • RETURN : [bs * nB(1), 2, embed_dim]

Dense prompt : 마스크(Mask)

  • down-sampling 진행
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:    #* 보편적인 segmentation 메서드 : conv layers를 통해 1/16배로 down-sampling
        """Embeds mask inputs."""
        mask_embedding = self.mask_downscaling(masks)           # down-sampling 진행
        return mask_embedding                                   # [bs, h, w, 1(c)] -> [bs, h/16, w/16, 256(embed_dim)]
더보기

down-sampling layer

Sequential(
  (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
  (1): LayerNorm2d()
  (2): GELU(approximate='none')
  (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
  (4): LayerNorm2d()
  (5): GELU(approximate='none')
  (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
)

⇒ [bs, embed_dim, h/16, w/16] 차원의 dense embedding 생성

 

 


Mask Decoder : 마스크 생성

MaskDecoder

  • init
더보기
  • output token 임베딩 생성 - IoU Token & Mask Tokens
self.iou_token = nn.Embedding(1, transformer_dim)                               # (1, 256)                       - 256 차원의 embedding 1개
self.num_mask_tokens = num_multimask_outputs + 1                                # 3(whole, part, subpart) + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)          # (4(mask개수), 256)             -  256 차원의 embedding 4개
  • Transformer 정의 - decoder에서 사용할 two-way transformer
  • MLP 정의
    • mask tokens용 MLP :  self.output_upscaling
    • IoU token용 MLP : self.iou_prediction_head
  • forward : mask 생성
더보기
def forward(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
				# (1) prediction 진행
        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # (2) Select the correct mask or masks for output
        if multimask_output:    # for. ambiguous-aware: 여러 mask 추출(whole, part, subpart)
            mask_slice = slice(1, None)     # -> [bs, 3(masks), 256, 256]
        else:                   # predicted quality score를 통해 best mask 하나 추출
            mask_slice = slice(0, 1)        # -> [bs, 1, 256, 256]
        masks = masks[:, mask_slice, :, :]
        iou_pred = iou_pred[:, mask_slice]

        # Prepare output
        return masks, iou_pred
  1. masksIoU scores 예측 - self.predict_masks() 호출
  2. multi mask
    • True) ambiguous-aware 위해 여러 masks(whole, part, subpart)
    • False) 가장 높은 IoU score를 가진 mask 하나만! 
  • predict_masks : masks, IoU scores 예측

[1] Token 준비 : Output Tokens + Sparse prompt embeddings

더보기
  1. output_tokens = IoU Token + Mask Tokens
    • IoU Token : [bs, 256]
    • Mask Tokens : [bs, 256] * 4(mask 개수) = [bs, 4, 256]
  2. Sparse prompt embeddings : [bs, (nP + nB + 1), 256]
    • Token = IoU Token + Mask Tokens + sparse embeddings( [bs, 5 + (nSparse), 256])
    1.  

[2] Transformer

더보기
  • Tokens와 image embedding간 연관성 연산 수행
hs, src = self.transformer(src, pos_src, tokens)

 ⇒ tokens(hs) & image embedding(src) 리턴

[3] masks 생성

더보기

(1) mask embeddings upsampling

Tokens와 image embedding간 연관성 연산 수행

src = src.transpose(1, 2).view(b, c, h, w)        # 2D map으로                : [bs, 256, 64, 64]
upscaled_embedding = self.output_upscaling(src)   # img embedding upsampling  : [bs, 32, 256, 256]
  • image embedding을 2D map으로 재구성 → upscaling layer 거침
  • [bs, embed_dim, h/16, w/16] → [bs, embed_dim/8, h/4, w/4]

(2) mask token MLP 통과

mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]                                      # tokens 중 mask tokens([bs, 4(num_masks), 256]) 택

# 각 mask token별 MLP 통과 
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
		hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))           # [bs, 256] x (n_mask_token) => [bs, 32] x (n_mask_token)
hyper_in = torch.stack(hyper_in_list, dim=1)                                                    # [bs, 4(n_mask_token), 32(dim)]
  1. Transformer에서 나온 Tokens 중 mask tokens 추출 - mask_tokens_out
  2. mask tokens 차례대로 MLP 거침 - self.output_hypternetworks_mlp[i]

(3) mask 예측 - mask token 사용

b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)                     # dot product per mask : [bs, 4(n_mask_token), 256*256] =>  [bs, 4(n_mask_token), 256, 256]
  • flattend upscaled_embedding([bs, 32, 256 * 256])  @ mask_tokens_after_mlp([bs, 4, 32])
        ⇒ mask embeddings per mask tokens([bs, 4, 256 * 256])
  • 최종 masks : [bs, 4(num_mask_tokens), 256, 256]

[4] IoU scores 생성

더보기
  • IoU scores 역할 : mask의 quality 예측
  • 각 mask token에 대한 IoU score 예측
iou_token_out = hs[:, 0, :]                         # tokens 중 iou token([bs, 256]) 택
iou_pred = self.iou_prediction_head(iou_token_out)  # [bs, 4]
  • [bs, trans_dim] → [bs, num_mask_tokens]
  • iou_prediction_head 구성 - MLP로 구성

TwoWayTransformer - Decoder에서 사용하는 Transformer

[1] Transformer blocks 2번 수행

더보기
# B x C x H x W -> B x (HW) x C == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)   # image embedding   - [bs, 256, 64, 64] -> [bs, 64*64, 256]
image_pe = image_pe.flatten(2).permute(0, 2, 1)                 # image PE          - [bs, 256, 64, 64] -> [bs, 64*64, 256]

# Prepare queries
queries = point_embedding   # [bs,  7,      256]
keys = image_embedding      # [bs,  64*64,  256]

#* Apply transformer blocks and final layernorm - 2개의 trnasformer block 통과
for layer in self.layers:
		queries, keys = layer(
    queries=queries,
    keys=keys,
    query_pe=point_embedding,
    key_pe=image_pe,
)
  • self.layers : TwoWayAttentionBlock

[2] 최종 attention layer 수행 - token to image attn

더보기
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)     # norm(attn(token, img) + token)
  • query : tokens, key/value : image embedding
  • self.final_attn_token_to_image : Attention

TwoWayAttentionBlock -  실제 cross attention 수행하는 block

[1] self attn. : Token에 대한 self attention 수행

더보기
if self.skip_first_layer_pe:                                        
		queries = self.self_attn(q=queries, k=queries, v=queries) # attn(tokens, tokens)
else:
		q = queries + query_pe
    attn_out = self.self_attn(q=q, k=q, v=queries)
    queries = queries + attn_out

queries = self.norm1(queries)
  • query/key/value  : tokens

[2] token to image attn.

더보기
# Cross attention block, tokens attending to image embedding        
q = queries + query_pe                                              #* (2) token to image attn.
k = keys + key_pe                                                   # token에 img와의 연관성 부여
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)         # q:tokens, k,v:image embedding
queries = queries + attn_out
queries = self.norm2(queries)
  • query : tokens, key/value : image embedding

[3] MLP : 각 token에 대해 차원간 업데이트

더보기
# MLP block
mlp_out = self.mlp(queries)                                         #* (3) mlp : 각 token에 대해 차원간 업데이트
queries = queries + mlp_out                                         
queries = self.norm3(queries)
  • MLP(self.mlp) 구성
    MLPBlock(
      (lin1): Linear(in_features=256, out_features=2048, bias=True)
      (lin2): Linear(in_features=2048, out_features=256, bias=True)
      (act): ReLU()
    )

[4] image to token attn.

더보기
# Cross attention block, image embedding attending to tokens
q = queries + query_pe                                              #* (4) image to token attn.
k = keys + key_pe                                                   # img에 token과의 연관성 부여
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)      # q: image embedding, k,v : token
keys = keys + attn_out
keys = self.norm4(keys)
  • query : image embedding, key/value : tokens
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.