[Paper Review] Remixmatch: Semi-supervised learning with distribution matching and augmentation anchoring

2024. 1. 26. 12:41Papers/Semi-supervised Learning

Remixmatch: Semi-supervised learning with distribution matching and augmentation anchoring

Berthelot, D., Carlini, N., Cubuk, E. D., Kurakin, A., Sohn, K., Zhang, H., & Raffel, C. (2019). Remixmatch: Semi-supervised learning with distribution alignment and augmentation anchoring. arXiv preprint arXiv:1911.09785.

본 논문은 MixMatch에서 발전된 방법으로, 기본적인 토대를 MixMatch의 방법론을 따른다. 하지만 추가적인 방법론 들이 꽤 많이 존재한다. 본 포스팅에서는 그 부분들에 대해서 보다 자세히 다룰 예정이다.

해당 방법은 FixMatch와 같이 Semi-supervised Learning에서 가장 많이 활용되고 있는 backbone algorithm이다. 현재 논문들에서도 benchmark 결과를 볼 때, FixMatch로 했을때와 ReMixMatch로 했을 때 두가지 모두 포함시키기 때문에 이번에 정확하게 이해하고 넘어가면 추후 관련 논문 작업을 할 때 큰 도움이 될 것이라 생각한다.


Overall

위의 두 그림이 본 논문을 가장 잘 나타내는 그림이다. 추가적으로 loss term이 MixMatch에 비해 추가가 되는데 그건 아래에서 자세히 다루도록 하겠따.

Figure 1에서 가장 중요한 부분은 Label Guessing 부분에서의 새로운 방법론이 추가가 된 것이다. 노란색 분포와 초록색 분포 정보가 있는데, 노란색 분포 정보는 실제 정답 라벨의 분포를 의미하며, 초록색 분포는 모델이 예측해왔던 unlabeled sample에 대한 분포의 평균이다. 이를 통해 하고 싶은 것을 간단하게 요약하자면, 실제 labeled training samples의 분포와 predicted unlabeled training samples의 분포를 일치시키고 싶어하는 것이다. 만약 일치하지 않은 경우 예측한 값을 결과론적으로 추후 평균을 취했을 때 일치하는 방향으로 조금 '수정'해주는 것이 Figure 1에서 보여주는 그림이다. 보다 자세한 내용은 아래에서 다루도록 하겠다.

Figure 2는 MixMatch와 가장 근본적으로 다른 부분을 얘기하는데, 본 논문에서부터 'Strong augmentation'개념이 활용되기 시작한다. (FixMatch에서 활용되는 것과 유사)

해당 부분에서 augmentation strategy를 결정하기위해 'AutoAugment'방법을 활용하였으며, 이를 통해 Strong augmentation image의 예측과 weak augmentation image의 예측을 최대한 유사하게 하려고 하는 것이 ReMixMatch에서의 consistency loss이다.


Algorithm

왼쪽의 알고리즘이 기존 'MixMatch' 논문에서 사용된 알고리즘이며, 오른쪽 알고리즘이 제안된 'ReMixMatch'의 알고리즘이다. 실제 MixUp을 하는 상황에서의 차이점이 빨간색 부분이 추가된 것이다.

Pseudo-label (Sharpening)을 만들기 위해서 기존에는 단순하게 augmented된 이미지의 평균을 pseudo-label로 활용하는데, ReMixMatch에서는 pseudo-label을 만드는데에 있어서 'K'번의 평균을 구하는 것이 아닌 weak augmentation 이미지의 예측을 구하고 해당 예측에 대해서 'Distribution Alignment(DA)'를 활용하여 pseudo-label의 미세조정을 진행한다.(데이터의 분포와 모델이 예측한 분포를 기반으로 최대한 두 분포를 유사하게 만들기 위해서)

그 후 해당 pseudo-label은 같은 이미지로부터 만들어진 strong augmentation image들의 pseudo-label로 활용된다. 해당 부분이 Figure 2에서 지향하는 바를 위해 사용되는 것이다.


Distribution Alignment for label guessing

위의 그림이 ReMixMatch에서 제안하는 Distribution Alignment 방법론이다. 기존 MixMatch에서는 K번의 augmentation image의 예측 평균을 통해 Label guessing읋 한것에 반해, ReMixMatch에서는 하나의 weakly augmented image의 예측값에, 학습 데이터 분포와 모델 예측 분포 두가지를 활용하여 refinement를 수행한다. 보다 자세히 다루기 위해 예제를 아래에 첨부하도록 하겠다.

지금 현재 모델의 예측이 0.5라고 칭하고, 라벨링된 학습 데이터의 분포 $p(y)$가 0.5 그리고 라벨링되지 않은 학습 데이터의 모델의 예측 분포의 평균 $\tilde{p}(y)$가 0.5일 경우에는, 두 분포가 동일하기 때문에 현재 분포를 수정할 필요가 없다고 판단하고 예측값 그대로 활용하게 된다.

만약, 두번째 케이스처럼 예측 분포가 더 낮다면, 예측된 확률 값이 낮을 확률이 높다고 판단하고 이를 distribution alignment를 통해 1.25배 상승시킨다. 그렇다면 해당 배치가 끝난 후, 다시 평균을 구할 경우 전반적으로 예측 분포가 증가하는 양상을 보이게 될것이며, 이를 통해 두 분포를 최대한 유사시킬려고 한다.

반대의 경우에는 예측값을 낮추는 상황으로 만든다.

결과론적으로 해당 방법론을 통해서 label guessing을 진행하게 된다.


Loss function

loss term이 MixMatch에 비해 추가된 부분이 존재한다.

각 loss function을 순서대로 정리하면 아래와 같다.

1. labeld sample과 unlabeled sample에 대해서 MixUp을 수행하여 증강한 이미지들에 대한 (pseudo-label이 labeled sample이 더 강하게 들어가 있음) cross-entropy 손실함수

2. unlabeled sample과 labeld sample에 대해서 MixUp을 수행하여 증강한 이미지들에 대한 (pseudo-label이 unlabeled sample이 더 강하게 들어가 있음) cross-entropy 손실함수

3. Mixup없이 Strongly augmented unlabeled sample에 대한 consistency loss 항

4. rotation을 (0,90,180,270)을 예측하여 backbone 성능을 높이고자하는 loss 항

이렇게 4가지가 존재한다. 최종적으로 해당 손실함수를 활용하여 모델 성능을 높인다.


Results & Ablation Study

위의 그림들은 방법론의 결과와 ablation study에서 활용된 테이블이다.

결과론적으로 제안된 방법이 성능이 높은을 Table 1에서 보여준다.

추가적으로 Strong augmentation의 횟수, rotation loss의 유무 (4번 loss term의 유무), mixup이 사용되지 않은 unlabeled sample에 대한 loss (3번 loss term의 유무), distribution alignment의 유무 (label guessing에서), L2 unlabeled loss를 cross-entropy 대신 사용했을 때(mixup이 활용된)-, augmentation의 차이를 보여준다.


본 포스터에서는 ReMixMatch에 대한 간단한 정리를 진행했다. Semi-supervised에서 워낙 많이 활용되고 있기에 꼭 한번 읽어보면 좋을 논문이라 생각된다.