マルチクラスSCW評価メモ
昨年のICML2012で、オンライン分類器であるSoft Confidence-Weighted Learningが提案された("Exact Soft Confidence-Weighted Learning")。この際に提案されたのは基本的な二値分類だったが、今までのオンライン分類器と同じように、マルチクラスへの拡張も可能である。詳しくは以下の記事を参照していただきたい。
そんなマルチクラスSCWをJubatusに実装してみた。コードは"y-tag/jubatus"のscwブランチにある。このブランチは0.4.2のリリース直後をベースとしている。
ただ単純な分類器の追加だけなら話は簡単なのだが、分類器の設定を保持するclassfier_configクラスは設定値を1種類しか保持できない一方で、SCWは2種類のハイパーパラメータを必要とする。そのため、configクラスを拡張する必要があり、それはコード全体が多少なりとも複雑にする。ここで気になるのは、コードを追加することによって増加する複雑さに、SCWの分類精度は見合うのかどうかということである。そのため、簡単ではあるが、マルチクラスSCWと既存の分類器の比較実験を行うことにした。
評価は"Exact Soft Confidence-Weighted Learning"の5章の実験設定を参考に行った。まず各データセットの学習データを用いて5-foldのcross validationを行い、各学習器のハイパーパラメータを決定した。PA1, PA2, AROW, NHERDのハイパーパラメータは{2^-4, 2^-3, ..., 2^3, 2^4}の範囲で変化させ、一番高いAccuracyを示したものを選択した。CWは元々の\etaの範囲{0.5, 0.55, ..., 0.9, 0.95}に対応した、{0.0, 0.125661, 0.253347, 0.385320, 0.524401, 0.674490, 0.841621, 1.036433, 1.281552, 1.644854}の範囲でハイパーパラメータを変化させた。SCW1とSCW2のハイパーパラメータも同様に、上記の二つの範囲内で探索を行った。ハイパーパラメータの決定後、学習データの順序をランダムにシャッフルして学習を行い、学習時の学習データでの誤分類率と、学習後のテストデータでの誤分類率を計測した。順序のランダムシャッフルは20回行い、平均と標準偏差を結果として得た。
データセットはLIBSVM DATAからマルチクラスのデータとしてnews20, usps, letterを対象とし、2クラスのデータセットとしては、元論文でも用いられていたijcnn1, w7aを選択した。
以下がその結果である。
news20 | mistake rate in train | mistake rate in test | hyper parameters |
---|---|---|---|
pecetpron | 0.31393(0.00195) | 0.23138(0.00577) | - |
PA | 0.23127(0.00271) | 0.16853(0.00270) | - |
PA1 | 0.22574(0.00214) | 0.16142(0.00199) | 0.5 |
PA2 | 0.22827(0.00177) | 0.16489(0.00242) | 0.0625(1) |
CW | 0.21639(0.00208) | 0.14931(0.00197) | 1.644854 |
AROW | 0.22226(0.00191) | 0.15759(0.00216) | 4 |
NHERD | 0.21542(0.00275) | 0.14986(0.00191) | 0.5 |
SCW1 | 0.22248(0.00216) | 0.15662(0.00252) | 0.841621, 0.5 |
SCW2 | 0.22104(0.00190) | 0.15523(0.00215) | 1.644854, 0.5 |
usps | mistake rate in train | mistake rate in test | hyper parameters |
---|---|---|---|
pecetpron | 0.15021(0.00201) | 0.13956(0.02299) | - |
PA | 0.11767(0.00258) | 0.11796(0.01202) | - |
PA1 | 0.11767(0.00258) | 0.11796(0.01202) | 0.0625 |
PA2 | 0.11811(0.00264) | 0.11921(0.01631) | 0.0625(1) |
CW | 0.09960(0.00205) | 0.10750(0.01127) | 0.125661 |
AROW | 0.11154(0.00271) | 0.11131(0.01160) | 0.0625 |
NHERD | 0.13674(0.00562) | 0.12972(0.01099) | 0.125 |
SCW1 | 0.09674(0.00181) | 0.10085(0.00655) | 1.281552, 0.0625 |
SCW2 | 0.10296(0.00193) | 0.10523(0.00967) | 0.253347, 0.125 |
letter | mistake rate in train | mistake rate in test | hyper parameters |
---|---|---|---|
pecetpron | 0.48895(0.00243) | 0.41516(0.02527) | - |
PA | 0.52939(0.00397) | 0.48201(0.03677) | - |
PA1 | 0.41159(0.00217) | 0.32184(0.01077) | 0.0625 |
PA2 | 0.51170(0.00362) | 0.45725(0.02690) | 0.0625(1) |
CW | 0.43403(0.00338) | 0.40431(0.02596) | 0.253347 |
AROW | 0.33970(0.00228) | 0.27292(0.00410) | 16 |
NHERD | 0.34401(0.00502) | 0.30952(0.00437) | 0.25 |
SCW1 | 0.31376(0.00320) | 0.24707(0.00385) | 0.841621, 1 |
SCW2 | 0.32867(0.00262) | 0.26821(0.00366) | 0.841621, 2 |
ijcnn1 | mistake rate in train | mistake rate in test | hyper parameters |
---|---|---|---|
pecetpron | 0.10590(0.00076) | 0.11283(0.03072) | - |
PA | 0.10280(0.00101) | 0.10257(0.02067) | - |
PA1 | 0.08116(0.00061) | 0.08332(0.01250) | 0.25 |
PA2 | 0.09769(0.00090) | 0.09832(0.01657) | 0.0625(1) |
CW | 0.10263(0.00074) | 0.10835(0.02168) | 0.125661 |
AROW | 0.07712(0.00051) | 0.08104(0.00061) | 8 |
NHERD | 0.08129(0.00401) | 0.08608(0.00540) | 8 |
SCW1 | 0.07363(0.00044) | 0.07164(0.00069) | 0.674490, 0.0625 |
SCW2 | 0.07347(0.00049) | 0.07615(0.00086) | 0.674490, 0.0625 |
w7a | mistake rate in train | mistake rate in test | hyper parameters |
---|---|---|---|
pecetpron | 0.11713(0.00073) | 0.11404(0.00885) | - |
PA | 0.10866(0.00066) | 0.10887(0.00365) | - |
PA1 | 0.10404(0.00043) | 0.10369(0.00071) | 0.0625 |
PA2 | 0.10791(0.00069) | 0.10786(0.00324) | 0.0625(1) |
CW | 0.10786(0.00109) | 0.11033(0.00457) | 1.644854 |
AROW | 0.10256(0.00042) | 0.10226(0.00064) | 0.0625 |
NHERD | 0.12052(0.00708) | 0.11465(0.00739) | 0.5 |
SCW1 | 0.10252(0.00030) | 0.10244(0.00017) | 1.644854, 0.0625 |
SCW2 | 0.10371(0.00062) | 0.10262(0.00137) | 1.644854, 0.0625 |
結果を見ると、SCW1, SCW2の精度はなかなか良さそうである。しかしながら、今回の実験設定ではこの二つのハイパーパラメータは他の分類器と異なり2種類あり、そのため探索にも多くの時間がかかったことは注記しておく。ちなみに、PA2のハイパーパラメータは設定ファイルから読み込んでおらず、常に初期値の1のままになっていたことに実験後、気づいた(#302)。そのため、PA2の結果は参考値としてとらえていただきたい。また、CWやSCWで用いた\etaが0.5に対応するハイパーパラメータの設定も意味がなかった気がする。
そんなわけで、ハイパーパラメータのチューニングをそれなりにやった場合にはSCWは良い結果を示すことは確認できた。しかし既に書いたように、他の分類器も同じくらいのコストをかけてチューニングを行って比較しないとフェアではない。ということで、実験設定がいけてなかったので、結論は先延ばしということにしたい。
実験に用いたスクリプトはGistに貼り付けておいた。勢いで書いて、きちんと確認していないのでいろいろと間違っているかもしれない。
- 前準備
- cross validation(ハイパーパラメータ探索)
- 評価
- 結果まとめ