はじめに
こんにちは、中村です。バイセルでは商品の査定を簡略化・自動化するためのAIを開発しています。 先日、「学習に使えるクラスラベルと運用時に予測すべきクラスラベルの粒度が異なる」という問題に直面しました。 本記事ではこの問題への対処を検証実験を交えて議論しています。
背景
粗いラベルを学習したモデルに細かい分類問題を解かせる状況を考えます。ここで言う「粗い」・「細かい」とはクラスラベルの粒度に対する表現です。鳥の仲間を例に挙げると、「ペンギン」や「フクロウ」といった大きなカテゴリを粗いラベルとしたとき、細かいラベルは「ケープペンギン」や「フンボルトペンギン」など下位の分類に該当します。
細かいラベルでの分類が目標であるときの最も単純なアプローチは、最初から細かいラベルを学習させることです。しかし、実際には学習に必要な教師ラベルを十分に用意できない場合が存在します。一般に、細かいラベルは粗いラベルよりも種類が豊富であり、かつ区別に専門的な知識を要求するため、収集コストが高くなる傾向にあります。 一方で、粗いラベルであれば比較的簡単に集めることができます。もし粗いラベルを学習したモデルでも細かい分類問題が解けるのであれば、ラベルの収集コストを節約できます。
バイセルの商品マスタであるPromasには細かい粒度で分類が登録されています。例えばお酒カテゴリに絞ると、銘柄や等級などを区別するために5,000種類以上のラベルが存在します。粗いラベルだけで細かい分類モデルを構築する技術は、このように細分化されたマスタに対しても安価に分類モデルを提供できるという点で、大きな意義を持っているように見えます。
基本的なアプローチ
整理すると、十分な量の粗いラベルと少量の細かいラベルを使って細かい分類を達成することが目標となります。
粗いラベルを愚直に学習させたモデルでも特徴抽出は可能です。しかし、 粗いラベルの分類だけで得られた表現は、より細かい分類に必要な微細なパターンを区別できない可能性があります。先述の例で言えば、フクロウとペンギンを区別する特徴ではケープペンギンとフンボルトペンギンを区別できません。 ターゲットタスクに有効な特徴量を抽出するための策のひとつは自己教師あり学習です。
- Weakly Supervised Representation Learning with Coarse Labels [Xu+, 2021]
- Grafit: Learning fine-grained image representations with coarse labels [Touvron+, 2021]
- Fine-grained Angular Contrastive Learning with Coarse Labels [Bukchin+, 2021]
上記にそれぞれ細かい違いはあれど、基本的なアーキテクチャはどれも同じです。学習は粗いラベルの分類と自己教師あり学習のマルチタスクになっています。粗いラベルの分類だけでは細かい差異を区別する特徴量が抽出できないと考え、インスタンスを見分けるタスク(=自己教師あり学習)を追加することで特徴量の解像度を補っています。
推論は、学習済みモデルが出力する特徴量を使ってkNN分類器を構築します。クエリの近傍に存在するサンプルについて、それらに紐づくラベルを集計してクエリのラベルを予測します。kNN分類器の構築には細かいラベルを用いるので、この分類器は細かいラベルを予測します。
後の説明のため、ここではkNN分類器を構成するサンプル集合をギャラリーと呼ぶことにします。
損失関数の再考
分類タスクの学習には交差エントロピー損失が広く用いられますが、推論がkNN分類であることを考えると、学習時の損失関数には別の選択肢があります。ここではWu+, 2018で提案されたkNN損失を考えます。
kNN損失
番目の学習画像と粗いラベルを
とします。
はニューラルネットワークで
は温度パラメータです。kNN損失
は以下で定義されます。自身と同じラベルを持つサンプルとのコサイン類似度が大きいほど小さな値を取ります。
kNN損失はembedding空間での距離に従って誤差を評価します。ターゲットタスクが距離ベースの推論であるため、距離ベースの損失関数の採用は自然に見えます。後の実験でkNN損失と交差エントロピー損失を比較します(RQ1)。
ギャラリーの構築
今回扱うデータはロングテールに分布する性質を持っています。そこで、ギャラリー構築の方法を変えながら、ロングテールに対するkNN分類の挙動を調査します。
最も単純な方法はギャラリー構築用のデータをすべて用いる場合です。このとき、元のデータはロングテールであるため、ギャラリーを構成する各クラスのサイズは大きくばらつくことになります。 直感的には、クラスサイズが偏ったギャラリーを構成すると大きなクラスに有利な予測となることが予想されます。
すべてのクラスのサイズが等しくなるよう、ギャラリー構築用のデータをサンプリングして用いる方法も考えられます。この方法は、Few-shot認識*1の文脈においてはAll-way/K-shotタスクと呼ばれています。先の場合とは逆に、小さいクラスに比重を置いたと見做せますが、本当に小さいクラスの分類に貢献するのかは疑問です。
この機会に、ギャラリーを変えたときのkNN分類器の挙動の変化やそのスケール感を調査します(RQ2)。
実験
データセット
実際の査定に用いられたお酒の画像を対象にデータセットを構築しました。お酒の容器や箱を正面から撮影した画像に対して、粗いラベルもしくは細かいラベルが付与されています。細かいラベルは容器のデザインや容量を区別することから、ターゲットタスクはFine-grained分類タスクでもあります。
学習データとテストデータの構成は以下のとおりです。
画像枚数 | ラベル種類数 | ラベル粒度 | |
---|---|---|---|
学習データ | 54,556 | 334 | 粗い |
テストデータ | 2,866 | 674 | 細かい |
データセットはどちらもロングテールになっています。
実装
バックボーンネットワークはResNet18とします。学習は分類タスクと自己教師あり学習のマルチタスクです。それぞれに用いる損失関数は表のとおりです。
手法名 | 分類誤差 | 対照誤差 |
---|---|---|
KNN | kNN損失 | NPairs損失 |
CE | 交差エントロピー損失 | NPairs損失 |
また、以下の2通りのギャラリー構築を試します。
ギャラリー名 | 各クラスのサンプル数 |
---|---|
ギャラリーA | 同数 |
ギャラリーB | テストデータの分布に従う |
ギャラリーAは、テストデータから各クラス2サンプルずつランダムに選択してギャラリーを構成します。ギャラリーBは、テストデータの各クラスを2分割して片方をギャラリーサンプルとします。選択しなかったサンプルから評価用のクエリを選択します。
評価指標
Top−5 Accuracyを報告します。また、クラスの大きさに対する性能の変化を調査するため、テストデータに含まれる各クラスをサンプル数に応じて以下の3グループに分けます.それぞれのグループについてもTop-5 Accuracyを報告します。
- サンプル数が5以下: Few-shotグループ
- サンプル数が6以上10以下: Medium-shotグループ
- サンプル数が11以上: Many-shotグループ
結果
各条件でのTop-5 Accuracyを表に示します。
ギャラリー | 手法 | All | Many | Medium | Few |
---|---|---|---|---|---|
ギャラリーA | KNN | 0.912 | 0.913 | 0.907 | 0.911 |
ギャラリーA | CE | 0.784 | 0.750 | 0.837 | 0.854 |
ギャラリーB | KNN | 0.919 | 0.966 | 0.930 | 0.821 |
ギャラリーB | CE | 0.878 | 0.947 | 0.885 | 0.741 |
RQ1: 交差エントロピー損失とkNN損失の比較
kNN損失は交差エントロピー損失よりも優れたスコアを記録しています。ラベルの予測確率でなくembeddingが下流タスクの入力になるとき、kNN損失は有効な選択肢となりそうです。
細かいラベルの分類問題は入力の微細な違いを見分けるFine-grainedな認識問題になりがちであり、クラス内分散の存在を考慮する必要があります。単一プロトタイプを仮定する交差エントロピー損失*2はクラス内の多様性を考慮できず、この問題との相性が悪かったと考えられます。対して、kNN損失はプロトタイプを仮定しないため、柔軟なアラインメントを実現した可能性があります。
もうひとつ重要な観察として、kNN損失はロングテール分布への耐性が確認されました。ギャラリーBでの実験において、Many-shotグループのスコアにはほとんど差がありません。一方でFew-shotグループのスコアはKNNが5ポイント以上大きく、小さなクラスに対する分類精度が全体のスコアを底上げしています。
RQ2: ギャラリーの構築方法
ギャラリーAのFew-shotグループのスコアはギャラリーBよりも10ポイント程度上回っています。 各クラスのサンプル数を同数とすることで、マイナーなクラスに対して有利に働くギャラリーを構築できることがわかりました。マイナークラスの分類に下駄を履かせたいのであれば、メジャーなクラスを減らすギャラリー構築は検討の余地があります。
各クラスのサンプル数を変えただけではありますが、モデルの挙動は大きく変化することを観測しました。 高品質なモデルを提供するに当たって、運用時の要件に即したギャラリーの設計は無視できない観点です。
まとめ
本記事では学習時と推論時でクラスラベルの粒度が異なる分類問題を紹介し、通常の分類タスクと自己教師あり学習のマルチタスク学習を使った解決を試みました。また、推論がkNN分類問題となることから、学習時の損失関数の選択やkNNギャラリーの設計について議論し、実験で効果を検証しました。
今回紹介した課題は最近でも議論が続いているようです。例えばNi+, 2022では、クラスラベルの階層構造を利用して細かいラベルの分布を明示的にモデル化しています。Fotakis+, 2023では、十分な粗いラベルがあれば効率的に学習が可能であることを理論的枠組から述べています。 総じて問題設定も実践に近く、査定の自動化を目指す上で大きく関連するトピックであると感じました。
最後に、弊社ではAI技術の社会実装をリードするMLエンジニアや研究者を随時募集中です。気になる方は是非ご検討ください。