Semi-supervised Deep Learning by Metric Embedding
1. どんなもの?
少ないラベル付きデータとラベルなしデータを元に距離埋め込み (neighbor embedding) を学習する、半教師あり学習を提案
2. 先行研究と比べてどこがすごいの?
従来のニューラルネットワークの学習では学習データに対してラベルを推定する枠組みであったが、ラベル付きの学習データが少ない場合に容易に過学習を引き起こす。 本研究ではラベル付きの学習データに対して距離埋め込み (neighbor embedding) を推定する枠組みを導入することで、ラベルなしデータも含めて学習を行い精度を向上させた。
3. 技術や手法の”キモ”はどこにある?
- Embedding同士の距離比較による学習 (neighbor embedding)
- 学習データ バッチサイズ分サンプリング
- クラスのラベル付きデータ を各クラス1サンプルずつサンプリング
- 学習データをembeddingした と ラベル付きデータをembeddingした に対してそれぞれ
L2 norm squared
の逆数を計算 - 計算した距離の逆数に対して
softmax
値を計算
softmax
値と学習データの教師ラベルとのcross entropy
を計算する。
- 半教師あり学習への応用
- ラベルなしデータ群 からサンプリングして以下を計算
4. どうやって有効だと検証した?
MNISTおよびCIFAR10に対して先行研究のモデル(EmbedCNN, SWWAE, Ladder network, Conv-CatGAN / Spike-and-Slab Sparse Coding, View-Invariant k-means, Exampler-CNN, Ladder network, Conv-CatGan, Improved GAN)と提案手法の比較を行っている。
MNISTに対しては各クラス100枚ずつにのみ教師ラベルを付与し、CIFAR10に対しては各クラス400枚ずつにのみ教師ラベルを付与し実験を行っている。
学習時にdata agumentationは行わず、テスト時には出力したembeddingに対してk-NNを用いてk={1, 3, 5}の場合の予測結果をaveragingしている。
5. 議論はあるか?
MNISTデータセットに対して、実験で使用したモデルに2次元のembeddingを出力する全結合層を追加し、可視化を行った結果である。色付きの点は教師ラベルありのサンプルであり、グレーの点は教師ラベルなしのサンプルである。 ラベルありデータは1つのクラスタを形成しており、ラベルなしデータは大部分において各クラスタに属するような形で分布していることがわかる。
6. 次に読むべき論文はあるか?
- EmbedCNNについて
- SWWAEについて
- Ladder networkについて
- Conv-CatGANについて
- Spike-and-Slab Sparse Codingについて
- View-Invatiant k-meansについて
- Exampler-CNNについて
- ImprovedGanについて