How to Efficiently Adapt Large Segmentation Model(SAM) to Medical Images(2023)
요즘 SAM을 활용한 모델들이 많이 나오는데, medical 영역에서도 많이 시도되고 있는 것 같다. 오늘은 그 중 하나를 리뷰해보려 한다.
paper : https://arxiv.org/abs/2306.13731
github : https://github.com/xhu248/AutoSAM/tree/main
GitHub - xhu248/AutoSAM: finetuning SAM with non-promptable decoder on medical images
finetuning SAM with non-promptable decoder on medical images - GitHub - xhu248/AutoSAM: finetuning SAM with non-promptable decoder on medical images
github.com
### Overview
Segment Anything Model(SAM)은 natural 이미지에서의 generality를 입증했지만, medical에서는 실패함.
실패 이유
1. Large difference in appearance : modality(CT, MRI, ..)와 term of color, contrast 등에서 natural 이미지와 차이가 존재함.
2. Blurred boundaries of target objects : 탐지할 객체들이 굉장~~히 희미끼끼끼 하다(사람눈으로도 구별하기 힘든 이미지가 많음)
효율적으로 medical 영역에 활용할 방법이 있지 않을까??
SAM의 image encoder를 freeze시키고, 가벼운 prediction head를 만들어 finetune 시키자!
1. SAM의 image encoder는 freeze (SAM의 크기의 대부분은 이 encoder에서 오기 때문에)
2. prompt encoder를 지우고, prompt-free prediction head를 설계
<prompt 삭제 이유>
1. multi-class에 prompt를 제공하는 것은 time-consuming
2. segmentation 결과가 prompt quality에 너무 의존적임(medical에서는 정확한 prompt를 얻는것은 힘듬)
- auxiliary embeddings에서 prompt token은 제거 => prompt-free
- predicion head는 3가지 타입을 설계
1. ViT prediction head
- input으로 image embeddings + auxilary decoder(prompt token 제외한 mask token, iou token) 사용
- multi-class segmentation을 위해, class 개수만큼 input embedding을 복사 : input x (num of class)
> 각 연산은 병렬로 처리되므로, 연산량 문제 무시 가능
2. CNN prediction head
- UNet 구조를 따라, k개의 stage로 구성 (각 stage는 conv layer와 transposed conv layer로 구성)
- 실제 실험을 통해, 4개의 stage로 구성했을때 가장 놀은 성능을 보여주었다.
3. Linear prediction head(단순한 classificatio head)
- SAM enocder에서 나온 feature representation이 semantic information을 잘 담고 있는지 확인하기 위해 설계
Future work
1. 더 많은 medical dataset에서 generalization을 입증
2. 한정된 학습 이미지 개수에서의 SAM 성능 향상(label-efficient adaptation)
3. 더 복잡한 prediction head 구조 구상