Computer Vision
[SAM] 코드 리뷰
- -
sam 코드에 대해 분석해보려고 한다.
Github : https://github.com/facebookresearch/segment-anything/tree/main
Image Encoder - Image embedding 생성
- 실제 이미지를 image embedding으로 변환
x | image | [bs, 3, 1024, 1024] |
[1] patch map 생성
x = self.patch_embed(x) # [b, h, w, c]
- PatchEmbed() : image -> patch map
def forward(self, x: torch.Tensor) -> torch.Tensor :
x = self.proj(x) # Cov2d(3, 768, (16,16), (16,16))
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
- 16 x 16 size의 patch로 이루어진 patch map 생성
[2] ViT encoder 통과
if self.pos_embed is not None: #* positional encoding 진행
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2)) # [b, c, h, w]
- self.blocks : ViT encoder 부분
Prompt Encoder
- prompts를 embedding으로 매핑
- prompt의 역할 : 이미지에서 segmentation 할 대상 지정
forward : 실제 prompt embedding 진행
1. Sparse embedding 구하기 - points, boxes
points point coordinates(x,y) & labels(0:background, 1:foreground) [bs, nP, 2] & [bs, nP] boxes boxes [bs, nB, 4] masks masks
# 빈 embedding 정의 sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) # [bs, 0, 256] # (1) points embedding화 if points is not None: coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) # [bs, (nPoints + 1), 256] sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) # (2) boxes embedding화 if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) # [bs, (nPoints + nBoxes + 1), 256]
- points embedding화 : self._embed_points 호출
2. Dense embedding 구하기 - mask
- boxes embedding화 : self._embed_boxes 호출
if masks is not None: dense_embeddings = self._embed_masks(masks) else: dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] )
- mask embedding화 : self._embed_masks 호출
Sparse prompt : 점(Points), 박스(Boxes), 텍스트(Text)
- Points embedding
def _embed_points( # 해당점의 위치 + 피사체와 배경을 구별하게 학습된 임베딩
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool, # boxes가 없으면 pad=True, boxes가 있으면 pad=False
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
# boxes가 없으면 -> pad point 하나 추가(coord:[0,0], label:-1)
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) # [bs, 1, 2]
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) # [bs, 1]
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
# coord에 맞춰 PE mapping
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) # [bs, nP+ ?(pad or not), embed_dim] - embedding
# label에 맞춰 embedding 더해주기
point_embedding[labels == -1] = 0.0 # pad
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight # 0(background) : pos point embedding mapping 더함
point_embedding[labels == 1] += self.point_embeddings[1].weight # 1(foreground) : neg point embedding mapping 더함
return point_embedding
더보기
self.pe_layer.forward_with_coords - coords에 맞춰 PE mapping
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
- RETURN : [bs, nP+ $\alpha$ , embed_dim]
- Boxes embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding # [bs*bB(1), 2, 256]
- RETURN : [bs * nB(1), 2, embed_dim]
Dense prompt : 마스크(Mask)
- down-sampling 진행
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: #* 보편적인 segmentation 메서드 : conv layers를 통해 1/16배로 down-sampling
"""Embeds mask inputs."""
mask_embedding = self.mask_downscaling(masks) # down-sampling 진행
return mask_embedding # [bs, h, w, 1(c)] -> [bs, h/16, w/16, 256(embed_dim)]
더보기
down-sampling layer
Sequential(
(0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
(1): LayerNorm2d()
(2): GELU(approximate='none')
(3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
(4): LayerNorm2d()
(5): GELU(approximate='none')
(6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
)
⇒ [bs, embed_dim, h/16, w/16] 차원의 dense embedding 생성
Mask Decoder : 마스크 생성
MaskDecoder
- init
더보기
- output token 임베딩 생성 - IoU Token & Mask Tokens
self.iou_token = nn.Embedding(1, transformer_dim) # (1, 256) - 256 차원의 embedding 1개
self.num_mask_tokens = num_multimask_outputs + 1 # 3(whole, part, subpart) + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) # (4(mask개수), 256) - 256 차원의 embedding 4개
- Transformer 정의 - decoder에서 사용할 two-way transformer
- MLP 정의
- mask tokens용 MLP : self.output_upscaling
- IoU token용 MLP : self.iou_prediction_head
- forward : mask 생성
더보기
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
# (1) prediction 진행
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# (2) Select the correct mask or masks for output
if multimask_output: # for. ambiguous-aware: 여러 mask 추출(whole, part, subpart)
mask_slice = slice(1, None) # -> [bs, 3(masks), 256, 256]
else: # predicted quality score를 통해 best mask 하나 추출
mask_slice = slice(0, 1) # -> [bs, 1, 256, 256]
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
- masks, IoU scores 예측 - self.predict_masks() 호출
- multi mask
- True) ambiguous-aware 위해 여러 masks(whole, part, subpart)
- False) 가장 높은 IoU score를 가진 mask 하나만!
- predict_masks : masks, IoU scores 예측
[1] Token 준비 : Output Tokens + Sparse prompt embeddings
더보기
- output_tokens = IoU Token + Mask Tokens
- IoU Token : [bs, 256]
- Mask Tokens : [bs, 256] * 4(mask 개수) = [bs, 4, 256]
- Sparse prompt embeddings : [bs, (nP + nB + 1), 256]
- Token = IoU Token + Mask Tokens + sparse embeddings( [bs, 5 + (nSparse), 256])
[2] Transformer
더보기
- Tokens와 image embedding간 연관성 연산 수행
hs, src = self.transformer(src, pos_src, tokens)
⇒ tokens(hs) & image embedding(src) 리턴
[3] masks 생성
더보기
(1) mask embeddings upsampling
Tokens와 image embedding간 연관성 연산 수행
src = src.transpose(1, 2).view(b, c, h, w) # 2D map으로 : [bs, 256, 64, 64]
upscaled_embedding = self.output_upscaling(src) # img embedding upsampling : [bs, 32, 256, 256]
- image embedding을 2D map으로 재구성 → upscaling layer 거침
- [bs, embed_dim, h/16, w/16] → [bs, embed_dim/8, h/4, w/4]
(2) mask token MLP 통과
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # tokens 중 mask tokens([bs, 4(num_masks), 256]) 택
# 각 mask token별 MLP 통과
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) # [bs, 256] x (n_mask_token) => [bs, 32] x (n_mask_token)
hyper_in = torch.stack(hyper_in_list, dim=1) # [bs, 4(n_mask_token), 32(dim)]
- Transformer에서 나온 Tokens 중 mask tokens 추출 - mask_tokens_out
- mask tokens 차례대로 MLP 거침 - self.output_hypternetworks_mlp[i]
(3) mask 예측 - mask token 사용
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # dot product per mask : [bs, 4(n_mask_token), 256*256] => [bs, 4(n_mask_token), 256, 256]
- flattend upscaled_embedding([bs, 32, 256 * 256]) @ mask_tokens_after_mlp([bs, 4, 32])
⇒ mask embeddings per mask tokens([bs, 4, 256 * 256]) - 최종 masks : [bs, 4(num_mask_tokens), 256, 256]
[4] IoU scores 생성
더보기
- IoU scores 역할 : mask의 quality 예측
- 각 mask token에 대한 IoU score 예측
iou_token_out = hs[:, 0, :] # tokens 중 iou token([bs, 256]) 택
iou_pred = self.iou_prediction_head(iou_token_out) # [bs, 4]
- [bs, trans_dim] → [bs, num_mask_tokens]
- iou_prediction_head 구성 - MLP로 구성
TwoWayTransformer - Decoder에서 사용하는 Transformer
[1] Transformer blocks 2번 수행
더보기
# B x C x H x W -> B x (HW) x C == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1) # image embedding - [bs, 256, 64, 64] -> [bs, 64*64, 256]
image_pe = image_pe.flatten(2).permute(0, 2, 1) # image PE - [bs, 256, 64, 64] -> [bs, 64*64, 256]
# Prepare queries
queries = point_embedding # [bs, 7, 256]
keys = image_embedding # [bs, 64*64, 256]
#* Apply transformer blocks and final layernorm - 2개의 trnasformer block 통과
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
- self.layers : TwoWayAttentionBlock
[2] 최종 attention layer 수행 - token to image attn
더보기
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries) # norm(attn(token, img) + token)
- query : tokens, key/value : image embedding
- self.final_attn_token_to_image : Attention
TwoWayAttentionBlock - 실제 cross attention 수행하는 block
[1] self attn. : Token에 대한 self attention 수행
더보기
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries) # attn(tokens, tokens)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
- query/key/value : tokens
[2] token to image attn.
더보기
# Cross attention block, tokens attending to image embedding
q = queries + query_pe #* (2) token to image attn.
k = keys + key_pe # token에 img와의 연관성 부여
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) # q:tokens, k,v:image embedding
queries = queries + attn_out
queries = self.norm2(queries)
- query : tokens, key/value : image embedding
[3] MLP : 각 token에 대해 차원간 업데이트
더보기
# MLP block
mlp_out = self.mlp(queries) #* (3) mlp : 각 token에 대해 차원간 업데이트
queries = queries + mlp_out
queries = self.norm3(queries)
- MLP(self.mlp) 구성
MLPBlock(
(lin1): Linear(in_features=256, out_features=2048, bias=True)
(lin2): Linear(in_features=2048, out_features=256, bias=True)
(act): ReLU()
)
[4] image to token attn.
더보기
# Cross attention block, image embedding attending to tokens
q = queries + query_pe #* (4) image to token attn.
k = keys + key_pe # img에 token과의 연관성 부여
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) # q: image embedding, k,v : token
keys = keys + attn_out
keys = self.norm4(keys)
- query : image embedding, key/value : tokens
'Computer Vision' 카테고리의 다른 글
Convolution Meets LoRA: Parameter Efficient Finetuning for Segment Anything Model (2024.01) (1) | 2024.03.07 |
---|
Contents
소중한 공감 감사합니다