Supervised Contrastive Learning-Khosla- Paper Review


Paper Title:Supervised Contrastive Learning

Authors: Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, Dilip Krishnan

Venue34th Conference on Neural Information Processing Systems (NeurIPS 2020), Vancouver, Canada.

URLhttps://arxiv.org/pdf/2004.11362v5.pdf

Problem: Contrastive learning banyak digunakan pada self-supervised learning. Pada prakteknya sebagian besar model tidak bekerja lebih baik pada dataset besar seperti ImageNet, terutama dari sisi cross-entropy loss

Contribution

  1. Mengembangkan pendekatan self-supervised batch contrastive menjadi fully-supervised
  2. Mengajukan sebuah loss untuk supervised learning yang terinspirasi dari self-supervised dengan memanfaatkan informasi label. 
  3. Mengajukan ekstensi baru dari contrastive loss function yang menggunakan banyak positive per anchor
  4. Menunjukan bahwa loss dari sistem usulan menghasilkan akurasi top-1 dari beberapa dataset berbeda, dan lebih kuat terhadap korupsi natural
  5. Menunjukan secara analisis bahwa gradient dari loss function usulan mendorodng dari hard positive dan hard negative
  6. Menunjukan secara empiris bahwa loss usulan lebih tidak sensitive daripada cross-entropy pada range hyperparameter

Method/solution

  1. Cluster poin yang masuk pada kelas yang sama, ditarik ke embedding space, dan mendorong cluster sampel dari kelas berbeda
  2. Melakukan analisa 2 versi dari Supervised contrastive (SupCon) loss 
  3. Pendekatan supervised contrastive menggunakan soft-nearest neighbour loss, yang ditingkatkan dengan normalisasi embeddings dan mengganti Euclidean distance dengan inner product. Selain itu ditambahkan dengan augmentasi data, sebuah contrastive head disposable dan 2 tahap training (contrastive diikuti dengan cross-entropy) dan mengganti form dari loss function 
  4. Embedding ternormalisasi dari kelas yang sama, ditarik mendekat dibandingkan embedding dari kelas berbeda
  5. Mendapat input batch data, kemudian dilakukan augmentasi data 2x untuk mendapat 2 kopi dari batch
  6. Kedua kopi tersebut diforward melalui jaringan encoder untuk mendapatkan embedding normalisasi dengan dimensi 2048
  7. Melalui training, representasi ini dipropagasi melalui sebuah jaringan projeksi yang diabaikan pada waktu inference
  8. Supervised contrastive loss dihitung pada output dari projection network
  9. Untuk menggunakan model training untuk klasifikasi, dilatih sebuah klasifikasi linear dengan cross-entropy loss
  10. Komponen utama dari framework usulan adalah: module augmentasi data; encoder network dan projection network
  11. Pada setiap input x, digenerate 2 random augmentasi x’ masing-masing mewaliki view yang berbeda terhadap data dan terdiri dari beberapa subset informasi pada sampel asli
  12. Encoder network melakukan pemetaan x pada representasi vector r. Kedua sampel yang telah diaugmentasi secara terpisah dimasukan pada sebuah encoder yang sama. Yang menghasilkan sepasang vector representasi. R dinormalisasidengan unit hypershere pada Rde
  13. Jaringan projection memetakan r ke vector z. projeksi yang digunakan adalah multi-layer perceptron dengan sebuah hidden layer dengan ukuran 2048 dan vector output dengan ukuran Dp=128 atau hanya sebuah layer linear single dengan ukuran Dp=128
  14. Kemudian dilakukan normalisasi dari output dari jaringan ini pada unit hypershere, menggunakan inner product untuk mengukur distance dari projection space
  15. Loss yang digunakan memiliki properti: generalisasi pada number positive, contrastive power meningkat dengan banyak negative, memiliki kemampuan untuk melakukan hard positive/negative mining
  16. SupCon loss diuki dengan mengukur akurasi klasifikasi pada beberapa benchmark seperti CIFAR-10 dan CIFAR-100 dan ImageNet
  17. Kemudian model ImageNet dibenchmark juga untuk mengetahui common image corruptions, dan mengetahui perubahan performa dengan perubahan hyperparameter dan pengurangan data
  18. Untuk Encoder network, diuji dengan 3 arsitektur encoder yang umum yaitu ResNet-50, ResNet-101 dan ResNet-200
  19. Final Pooling layer menggunakan normalized activation (De-2048)
  20. Implementasi diuji pada 4 modul data augmentasio: autoAugment; randAugment, SimAugment dan Stacked RandAugment
  21. ResNetMelakukan pengujian alternatif memory based. Dengan ImageNet pada memory size 8192 dengan ukuran storage 128-dimensi vector, ukuran batch-size 256 dan SGD optimizer, pada 8 Nvidia V100 GPUs
  22. Melakukan pengujian cross-entropy ResNet-50 baseline dengan ukuran batchsize 12.288
  23. Menggunakan daaset ImageNet-C untuk benchmark pengukuran performa model pada korupsi natural, dibandingkan dengan mCE (Mean Corruption Error) dan Relative Mean Corruption Error Metric
  24. Menguji stabilitas hyperparameter dengan mengubah augmentasi, optimizer dan learning rate satu persatu dan mencari kombinasi yang terbaik.
  25. Perubahan Augmentasi dilakukan dengan RandAugment, AutoAugment, SimAugment, Stacked Rand Augmet; Perubahan optimizer dengan LARS, SGD with momentum dan RMS props
  26. Mengujia learned representation untuk fine-tuning pada 12 natural image dataset.
  27. Training dilakukan dengan 700 epoch pada pretraining untuk ResNet-200 dan 250 epochs untuk model yang lebih kecil
  28. Melatih model dengan batch size sampai 6144. Untuk ResNet-50 diuji sampai batch size 6144 dan ResNet-200 dengan batch-size 4096
  29. Menggunakan temperature = 0,1

Main result

  1. Pengujian pada ResNet-200 diperoleh akurasi 81,4% (top-1) pada dataset ImageNet. 0,8% lebih baik dari state-of-the-art arsitektur ini
  2. Menghasilkan performa lebih baik pada cross-entropy pada dataset lain dan 2 ResNet Variant.
  3. AutoAugment menghasilkan performa yang terbaik pada ResNet-50 pada SupCon dan cross Entropy dengan akurasi 78,7%
  4. Stacked RandAugment menghasilkan performa terbaik untuk ResNet-200 untuk kedua loss functions
  5. Menghasilkan performa sedikit lebih baik dibandingkan CutMix, yang merupakan state-of-the-art pada strategi data augmentasi
  6. Menghasilkan akurasi 79,1 pada pengujian alternatif memory based dengan ResNet-50. 
  7. Akurasi 77,5% pada pengujian cross-entropy ResNet-50
  8. Pada pengujian penambahan training epoch pada cross-entropy sampai 1400, akurasi turun menjadi 77%
  9. Pada pengujian N-Pair loss dengan batchsize 6144 mendapatkan akurasi 57,4% pada ImageNet
  10. Pada pengujian natural corruption model usulan memiliki nilai mCE lebih rendah pada corruption berbeda, menunjukan robustness.
  11. Model usulan menghasilkan degradasi akurasi yang lebih rendah pada peningkatan korupsi
  12. Hasil pengujian stabilitas hyperparameter menunjukan nilainya konstan top-1 akurasi
  13. Sistem usulan memiliki contrastive loss yang setara dengan cross-entropu dan self-supervised pada transfer learning ketika ditrain pada arsitektur yang sama
  14. Pada ResNet50 fungsi akurasi menunjukan 200 epoch sudah mencukupi
  15. Hasil pengujian menunjukan batch size 2048 sudah mencukupi
  16. Performa terbaik untuk ImageNet menggunakan LARS untuk pre-training dan RMSProp untuk training layer linear
  17. Untuk CIFAR1- dan CIFAR 100, SGD menghasilkan performa terbaik

Limitation:

  1. Tidak melakukan training linear classifier Bersama dengan encoder dan projection network
  2. N-Pair loss masih rendah

Note: 

  1. Contrastive learning telah menjadi state-of-the-art pada unsupervised training pada model deep image
  2. Pendekatan batch contrastive modern melampaui tradisional contrastive loss seperti triplet, max-margin dan N-pair los
  3. Cross-entropy loss adalah fungsi loss paling banyak digunakan pada supervised dari model deep classification
  4. Perkembangan contrastive learning mendorong perkembangan self-supervised learning
  5. Contrastive learning bekerja dengan menarik sebuah anchor dan sebuah sampel positive ke embedding space dan memisahkan anchor dari sampel negative
  6. Karena tidak ada label tersedia, sebuah pasangan positive terdiri dari augmentasi data dari samepl, dan pasangan negatif dibentuk oleh anchor dan secara random memilih sampel dari minibatch
  7. Koneksi dibuat dari contrastive loss dari maximization dari informasi mutual antara view data-data yang berbeda
  8. Kebaruan teknis adalah dengan mempertimbangkan banyak positive dari anchor sebagai tambahan pada banyak negative; berbeda dengan self-supervised contrastive learning yang hanya menggunakan single positive
  9. Positive diambil dari sampel pada kelas yang sama dengan anchor, tidak dari augmentasi data dari anchor. 
  10. Walaupun terlihat sebagai extensi sederhana dari SSL, namun tidak mudah untuk mensetting loss function dengan baik. Ada 2 alternatif yang dipelajari
  11. Loss pada model ini dapat dilihat sebagai sebuah generalisasi dari triplet dan N-Pair los
  12. Triplet hanya menggunakan 1 positif dan 1 negative sampel per anchor
  13. N-Pair menggunakan 1 positive dan banyak negative
  14. Banyak positive dan banyak negative pada setiap anchor menghasilkan performa state-of-the art tanpa perlu mining hard negative, yang susah untuk detuning
  15. Model ini adalah contrastive loss pertama yang menghasilkan performa lebih baik daripada cross entropy pada tugas klasifikasi besar
  16. Metode ini menghasilkan sebuah loss function yang dapat digunakan pada self-supervised atau supervised
  17. SupCon mudah diimplementasi dan stabil untuk di training
  18. Naive extension menghasilkan performa lebih buruk dibandingkan sistem usulan
  19. Cross-entropy loss adalah powerfull loss function untuk train deep network; setiap kelas diassigned sebuah target (biasanya 1-hot) vector. Namun tidak jelas kenapa target label tersebut adalah yang optimal, dan banyak penelitian telah mencoba mengidentifikasi target label vector yang lebih baik
  20. Kekurangan cross entropy loss diantaranya sensitivitas label noisy, adanya adversarial examples dan poor margin
  21. Loss alternativ telah diajukan, tapi yang terbaik adalah mengubah reference label distribution seperti label smoothing, data augmentasi seperti mixup dan cutmix dan knowledge distillation
  22. SSL berbasis model deep learning banyak digunakan pada natural language. 
  23. Pada domain image, pendekatan pixel prediksi digunakan untuk belajar embedding.
  24. Metode ini digunakan untuk memprediksi bagian yang hilang dari sinyal input.
  25. Pendekatan yang lebih efektif adalah mengganti sebuah desne per-pixel predictive loss, dengan sebuah loss di lower dimensional representation space.
  26. State-of-the-art family model untuk SSL menggunakan paradigma yang dibawah istilah contrastive learning
  27. Loss pada penelitian tersebut terinspirasi oleh noise contrastive estimation atau N-pair loss
  28. Loss diterapkan pada layer terakhir dari sebuah deep network
  29. Pada pengujian embedding dari layer sebelumnya digunakan untuk downstream transfer task, fine tuning atau direct retrieval task.
  30. Terkait dengan contrastive learning adalah family dari loss berbasis metric distance learning atau triplets
  31. Loss tersebut banyak digunakan utk supervised, dimana label digunakan untuk memandu pemilihan positive dan negative pairs
  32. Yang membedakan triplet loss dan contrastive loss adalah jumlah pasangan positive dan negative pada setiap data poin
  33. Triplet loss menggunakan 1 positive dan 1 pasangan negative per anchor.
  34. Pada setingan supervised metric, hamper selalu dibutuhkan hard-negative mining untuk performa yang baik
  35. SSL contrastive loss hanya menggunakan 1 positive pair utk setiap anchor sampel, memilih antara co-occurrence atau data augmentation
  36. Perbedaannya adalah banyak negative pair digunakan untuk setiap anchor. Yang dipilih secara random menggunakan weak knowledge seperti patches dari gambar lain atau frame dari video random lainnya. Dengan asumsi bahwa pendekatan ini menghasiklan probability false negative paling rendah
  37. Loss formulation yang dekat dengan usulan adalah entangle representasi pada intermediate layer dengan melakukan maximize loss
  38. Metoda yang paling mirip adalah Compact clustering via label propagation (CCLP) regularizer
  39. CCLP focus pada semi-supervised, pada fully supervised regularizer mengurangi hamper sama dengan loss formulation usulan
  40. Perbedaannya adalah normalisasi yang diusulkan adalah dengan embedding ke unit sphere, tuning parameter temperatur dan augmentasi yang lebih kuat
  41. Deep Neural network tidak robust terhadap data yang out of distribution atau korupsi natural seperti noise, blur dan kompresi JPEG

Silahkan tuliskan tanggapan, kritik maupun saran