[논문정리] Collaborative Learning of Semi-Supervised Segmentation and Classification for Medical Images
2021. 5. 17. 14:40ㆍPapers/Segmentation
Collaborative Learning of Semi-Supervised Segmentation and Classification for Medical Images
Zhou, Y., He, X., Huang, L., Liu, L., Zhu, F., Cui, S., & Shao, L. (2019). Collaborative learning of semi-supervised segmentation and classification for medical images. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 2079-2088).
Semi Supervised Learning
Unsupervised Learning 에서 사용되는 Pseudo(가짜) label이 학습에 도움이 되는가??
Labeled data는 의료 이미지와 같은 부분에 있어서 전문 지식이 필요하기 때문에 만드는 과정이 어려우며 특히 Segmentaion과 같은 문제에 있어서 굉장히 시간 소요적이며, 의료적 전문 지식이 필요하다.
위와 같은 이유에 의해서 Unlabled data는 labeld data보다 얻기 쉽기 때문에 위와 같은 Unsuperved Learning을 활용한 상황이 많이 존재함.
단순히 Unsupervised Learning만을 사용할 경우에는 모델의 정확도가 굉장히 낮은 상황이 발생하기 때문에 이에 대한 보완점으로 Semi-supervised Learning이 제안됨.
위의 식을 보면 알 수 있듯이, $L^s$는 labeled data에 대한 loss이며, $L^u$는 pseudo data에 대한 loss이다. 여기서 $\alpha(t)$의 값을 조절함으로써 pseudo label에 대한 영향력을 결정할 수 있다.
4가지 병변에 대한 결과 Pre-trained & Semi-supervised & 정답에 대한 차이
Introduction
질병 등급과 pixel-wise lesion segmentation(병변 분할) 두가지 문제를 다룬다.
최종적인 목표는 두가지 케이스에 대해서 모두 높은 성능을 얻는 것이지만, 두가지 문제는 일반적으로 독립적으로 연구되어왔다.
그러나, 정확한 병변 검출은 질병 등급을 분류하는데 큰 기여를 할 수 있으며, 반대로 분류는 분할 성능에 이점을 가지고 있다.
의료 이미지의 경우 정답을 만드는 과정에 있어서 시간 소요적이면서도 전문가적 지식을 필요로 하기 때문에 굉장히 어려운 작업이다.
따라서 이러한 의료 이미지 데이터를 학습하는 과정을 일반적인 이미지 데이터 학습 방식을 그대로 사용할 경우 부적합하다.
그러나, 비지도학습만을 활용할 경우에는 제한된 성능에 의해 적합하지 않다.
따라서 본 논문에서는 Semi-supervised learning(준지도학습)을 활용하여 제한된 labeled 데이터와 방대한 양의 unlabeld 데이터를 동시에 활용함으로써 분류와 분할 성능을 동시에 올리는 것을 목표로 한다.
해당 논문에서는 당뇨망막병증(Diabetic retinopathy)이미지에 대해서 병변 분할과 질병 등급을 분류하기 위한 협동적인 학습 방법을 제안한다.
질병의 등급은 5 단계로 나눠지며, Normal(정상), mild(경증), moderate(보통), severe non-proliferative(증식하지않는 중증), proliferative(증식하는 중증)으로 나눠질 수 있다.
거대한 양의 pixel-level의 병변 정답을 구하는 것은 어렵기 때문에, 준지도학습 분할 방법이 제안되었으며, 해당 방식을 활용하여 분류와 분할을 동시에 최적화할 수 있었다.
1. 우선 pixel-level lesion anntations(segmentation annotations)데이터를 가지고 분할 모델을 지도학습을 활용하여 pre-train(사전 학습)을 수행한다.
- 사전학습한 모델을 활용하여 image-level annotations(classification annotation)을 pseudo labeling과 classification을 할 수 있도록 돕는 역할을 한다.
2. 그후, 거대한 양의 질병 등급 labels만을 가지는 이미지들을 약한 병변 위치들을 만들기 위해 사전 학습된 분할 모델을 통과시킨다.
- 사전 학습된 segmentation model을 통과시킨 mask의 경우 굉장히 적은 데이터로 지도학습을 통해 만들어진 모델이기 때문에 성능이 좋은 모델은 아직 아니기 때문에 weak mask를 추출한다고 표현한다. 이를 해결하기 위해 아래의 방법을 활용한다.
3. 분할 모델을 통과하여 예측된 masks과 함께 원본 이미지를 병변 주의력 분류 모델을 학습하기 위해 입력으로 넣는다.
- 해당 분류 모델을 추후 더 강력한 pseudo label 을 만들기 위해 사용된다. 또한, 이를 기반으로 분할 모델의 성능을 높이는 역할을 수행할 수 있다.
4. 분류 모델은 질병 등급(분류) 성능과 출력의 병변 주의 맵들은 pseudo masks를 세분화하기 위해 사용되며 이는 분할 모델을 fine-tune(미세조정)하는데 사용된다.
본 논문에서 말한 두 개의 모델이 상호보완적으로 성능 향상에 도움을 주는 모델을 개발함.
Contributions
1. -multi-lesion(다중 병변) mask generator(생성기(모델))은 pixel-wise segmentation을 위해 제안되었다. 때문에 극도로 제한된 학습 데이터(학습 및 테스트 이미지가 총 81장으로 굉장히 제한된 데이터이다.) 때문에, U-shape(UNet)네트워크 기반으로 Xception-module(depth-wise convolution and point-wise convolution)으로 설계하였고 학습을 위해 정답이 있는 segmentation loss와 정답이 없는(pseudo labeling) adversarial loss(적대적 손실)을 통합한 목적 함수를 활용한다.
2. image-level annotations data(only classification annotations data)를 위해, 클래스 정보만을 가진 약한 데이터에 적용시키는 질병 맵들을 자동적으로 예측할 수 있는 병변 주의 모델을 사용한다. 그 예측된 맵들은 이전 분할 모델을 미세조정하는데 사용되며 해당 학습은 전체 데이터를 활용하여 준지도학습 방식을 활용하여 진행한다.
3. 이 병변 분할과 질병 등급 업무들은 end-to-end 모델로 최적화된다. 거대한 양의 클래스-정답 데이터는 분할 성능에 이점이 있다. 반면에 강화된 분할 모델은 분류 정확도를 높일 수 있다.광범위한 dusrn alc tlfgjaemffh DR(당뇨망막병증)에서 효과적인 성능을 보임을 확인했다.
Proposed Methods - Problem Formulation
첫번째, 굉장히 적은 데이터 $X^P$를 입력으로 활영하여 multi-lesion mask generator를 지도학습을 활용하여 학습한다.
사전학습이 완료된 후, 보유하고 있는 많은 양의 $X^I$를 해당 생성기에 통과시킨다.
- 생성기에 통과시킬 경우 weak mask가 생성될 것이다.
Adversarial training loss(적대적 학습 손실) 를 최적화하기 위한 Discriminator(판별기) 두개의 종류의 데이터를 구별하기 위해 설계되었다.
두번째, $X^I$와 초기에 예측된 병번 맵들은 lesion attention model(병변 주의 모델)을 학습하기 위해 적용되며, 이는 오직 질병 등급 정답을 가진 것만 고용하여 활용한다.
병변 주의 등급 모델은 분류 성능을 향상시킨다.
더욱이, 생성된 병변 주의 맵들은 pseudo(가짜) masks로 사용되며 분할 생성기를 거대한 양의 정답이 없는 데이터를 활용하여 준지도 방식을 활용하여 세분화한다.
위의 수식을 보면 알 수 있듯이 $L_{Adv}$에 해당하는 loss는 흔히 GAN에서 활용되는 adversarial loss이며, real image는 정답을 가지는(81개의 데이터)에 대한 이미지이며, fake image는 image-level annotations을 가지는 거대한 양의 이미지들이다. 해당 이미지들은 0에 수렴하도록 학습함으로써 성능을 높인다.
$L_{CE}$는 흔히 우리가 접할 수 있는 cross entropy이며 classification loss에 해당한다.
위의 그림은 병변 주의 질병 등급 모델에 해당하는 구조이며, 우측의 수식에 따라 output이 결정된다.
그림에서는 4개의 multi-lesion이 모두 concat되는것처럼 보일 수 있지만, 실제로는 가장 최상단의 파란색 convolutional layer를 통과한 원본 이미지와 segmentation generator를 통과하여 추출된 4개의 병변과 하나씩 concat하여 $f^{low_{att}}$를 생성한다. 마찬가지로 $\alpha_l$또한 각각의 병변마다 연산되는 과정을 수행한다.
우측 multi-lesion attentive features들은 원본 이미지와 weak multi-lesion마다 합쳐진 데이터이며 해당 데이터를 동일한 가중치를 가지는 classification 모델을 통해 vector값을 만든 후 classification은 4개지 병변에 대해서 concat한 후 Fully connected layer를 통해 예측하게 된다.
해당 결과를 활용하여 Global context vector를 만드는데 활용되며 우리가 segmentation에 활용할 pseudo labeling을 완성하는데 활용한다.
만약 이러한 과정이 없을 경우 segmentation generator와 loss를 구하기 위해 활용될 정답 데이터가 같은 상황이 발생할 수 있기 때문에 해당 정답을 generator에서 나온 값을 활용하며 추가적으로 높은 성능의 classification model을 활용함으로써 보다 강한 label을 만드는 과정을 수행함으로써 성능을 올리는 것을 보여준다.
위에서 활용할 수 있는 방법론을 모두 활용했을 때 가장 성능이 높음을 위의 결과를 보면 알 수 있다.
Conclusion
본 논문은 준비지도학습을 통하여 병변 분할과 질병 등급을 의료 이미지에서 협동적으로 학습하는 방법을 제안한다.
병변 마스크는 분류 모델에 활용되며 등급 정확도를 올리며, 반대로 병변 주의 모델은 클래스 라벨들을 활용하여 분할 결과에 이점이 있다.