Training SwAV – Ishan Misra


Saya lanjutkan bahasan tentang kuliah Ishan Misra tentang Self-supervised. Pada tulisan sebelumnya telah dibahas tentang soft-assignment. Kali ini kita lanjutkan bahasan tentang bagaimana proses training di SwAV (swapping Assignment between views). Swav ini salah satu metode clustering.

Arsitektur SwAV

2 potongan (crop) dari sebuah gambar diteruskan (feed-forward) melalui jaringan f(theta) untuk menghitung embeddingnya, yang biru. Kemudian dihitung optimal transport dengan sinkhorn-knopp dan didapat kode assignmentnya ke masing2 prototipe. Kemudian solver menghitung prediksi. Kita coba melakukan prediksi code no 2 dari embedding no 1. Begitu juga sebaliknya code no 1 dari embedding no 1.

Idenya adalah bila kedua gambar ini related, dan invariance terhadap data augmentation, maka kita bisa melakukan prediksi code no 2 dari gambar no 1. Karena keduanya harusnya berada dalam satu grup yang sama atau cluster yang sama. Gradient kemudian dibackpropagasikan ke embedding dan prototipenya. Sehingga model ini dapat diupdate secara online.

Keuntungan dari SwAV:

  1. Tidak memerlukan explicit negatives sehingga tidak ada contrastive learning
  2. Metoda Optimal transport mencegah trivial solutions.
  3. Lebih cepat konvergen dibandingkan contrastive learning.
    • Space kode membutuhkan konstrain lebih banyak dan embedding tidak dibandingkan secara langsung.
  4. Membutuhkan komputasi yang lebih sedikit.
  5. Membutuhkan jumlah GPU yang lebih sedikit (4-8) .

Prototipe adalah sekumpulan kelas yang berisi gambar secara random, bisa juga disebut bag of embedding. Pada setiap forward pass kita mengambil setiap embedding dari input f theta, kemudian kita menghitung similarity dari setiap prototipe. Misalnya kita punya b embedding dan k prototipe, maka kita menghitung b/k means type. Kemudian kita menggunakan algoritma optimal transport, untuk memastikan kode kita terbagi ke k prototipe secara merata. Kemudian kita bisa melakukan backpropagate.

Metoda ini berbeda dengan algoritma k-means yang dapat menghasilkan trivial solution. Sampai disini dulu, insyaallah akan saya lanjutkan pada tulisan berikutnya.

Kode dan model bisa dilihat di :

https://github.com/facebookresearch/swav

Materi kuliahnya:

https://atcold.github.io/NYU-DLSP21/en/week10/10-1/

slidenya ada disini:

https://drive.google.com/file/d/1BQlWMVesOcioW69RCKWCjp6280Q42W9q/edit

Videonya:


Silahkan tuliskan tanggapan, kritik maupun saran