[논문읽기] A Simple Framework for Contrastive Learning of Visual Representations(SimCLR)
2022. 3. 4. 17:23ㆍPapers/Contrastive Learning
A Simple Framework for Contrastive Learning of Visual Representations
Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020, November). A simple framework for contrastive learning of visual representations. In International conference on machine learning (pp. 1597-1607). PMLR.
Main idea
- 본 논문 SimCLR에서 기존 unsupervised contrastive learning 방법들과 가장 큰 차이점은 memory bank를 활용하지 않는 점이 존재한다.
- Memory bank method를 활용했을 경우 memory issue문제가 컸던 점이 존재한다.
- 두개의 서로 다른 Data augmentation을 통하여 positive sample을 생성하며, 두개의 sample에 대해서는 pull하도록 학습하고, 그 외의 sample에 대해서는 push를 통하여 모델 표현 학습을 최적화 시킨다.
- feature extract를 통하여 추출한 feature 값을 그대로 metric smiliarity를 계산하는 것이 아닌 normalization을 통하여 값을 변화시켜줌으로써 보다 높은 성능을 얻어냈다.
- 추후, MoCo 업데이트 버전에서 해당 방식을 채용한 것으로 알고 있음.
- 적절한 temperature parameter τ와 위에 언급한 normalized embeddings이 학습에 도움을 준다.
- Memory bank를 활용하지 않기 때문에 충분한 negative sample을 활용하기 위해서 batch size를 크게 설정하여 학습을 수행하는 차별성이 존재한다.
- 다른 후속 논문에서 무조건 적으로 negative sample이 많아야 contrastive learning 학습에 도움이 많이 되는가에 대한 논문도 있었는데, 해당 논문은 추후 다뤄보도록 하겠습니다.

SimCLR의 전체적인 흐름은 위의 그림으로 모두 설명할 수 있다.
SimCLR은 총 3개의 부분으로 구성되어 있는데, Data augmetation module, feature extract module, 그리고 normalization module로 구성되어 있다. 각각의 module이 하는 역할은 다음과 같다.
- Data augmentation module : positive sample과 negative sample을 생성하기 위하여 활용되는 모듈이다. 선택된 anchor image 기반으로 서로 다른 두개의 data augmentation을 적용하여 ~xi,~xj를 생성한다. 두개의 sample은 positive sample로서 동작하며, 추후 feature extract module과 normalization module을 통과한 metric에 대해서 pull을 통하여 값을 최대화 하도록 학습시킨다.
- Feature extract module : 해당 모듈은 일반적으로 활용되는 ResNet과 같은 모델의 feature extract부분만 활용하는 모듈로, 마지막의 classifier layer를 제거하여 활용한다. 본 논문에서는 총 2개의 backbone model을 feature exttract module로 활용하는데, ResNet-50과 ResNet-200을 채택하여 실험을 진행하였다. 본 ResNet 모델 두개는 최종적으로 classifier layer에 들어가기 전 Global Average Pooling을 통과한 후의 feature shape은 (batch,2048)과 같다.
- Normalization module : 이전 연구들에서는 해당 모듈을 활용하지 않고 feature extract module을 통해 나온 feature space값을 바로 비교하였다. 실험적으로 본 논문에서는 normalization module을 활용하여 차원을 128까지 압축시키면서 값을 normalization을 수행하는 것이 성능적인 면에서 좋다는 결과를 얻어 다음과 같은 모듈을 활용하였다. 실제로 타 후속 논문들도 normalization module을 많이 채택하였다.
Loss fucntion
