Jubatusでオンラインランク学習
もう二か月以上前のことだが、とあるところで、Jubatusでオンライン分類ができるならペアワイズのランク学習もできそうだという話をした。いろいろあって時間がかかってしまったが、実装と簡単な評価をしたのでまとめておく。
以下の評価で用いた実装は"y-tag/jubatus"のrankingブランチにある。現時点では、このブランチは0.4.3のリリース直後をベースとしている。
今回の実装の参考文献として、2009年のNIPSで発表された以下の論文が挙げられる。
また、以下の資料も非常に参考になる。
データセットとして、LETORのOHSUMED, MQ2007, MQ2008を選択し、OHSUMEDはQueryLevelNormを、MQ2007とMQ2008はSupervised rankingのものを用いた。標準的な評価と同じように、既に5分割されているデータを用いて評価を行った。
評価では、今回実装したrankingと、比較のためにクライアント側でペアを作成してclassifierサーバを用いて学習するものを対象とした。classifierを用いる学習法は2種類行った。一つ目はランダムにクエリを選択し、そのクエリに紐付いたペアをすべて一度に学習するもの。二つ目はランダムにクエリを選択するのは同じだが、更にそのクエリに紐付いたペアの中から一つのみ選択することを繰り返すものである。これ以降では、一つ目をclassifierQueryと呼び、二つめの選択方法は先述した論文でStochastic Pairwise Descent(SPD)という名前がつけられていたため、classifierSPDと呼ぶ。今回実装したrankingとclassifierQueryはペアの差分をとる場所がサーバかクライアントかの違いがあるが、それ以外はほぼ同じものだと思って構わない。三つの手法すべてが、サーバ側ではJubatusのオンライン分類器を用いている。今回は分類器としてPA, CW, AROW, NHERDを選択し、AROW, NHERDのハイパーパラメータ(設定ファイルのregularization_weight)は1、CWでは1.036433を用いた。設定ファイルは以前用いたgenerate_conf.plで生成したものを用いた。事前の実験で、今回実装したrankingでは、データを1パスで学習しただけでは収束していないようだったので、クエリ数の10倍の回数、ランダムにクエリを選択して学習した。一方でclassifierQueryとclassifierSPDは、元論文と同じように10,000のペアを選択するように学習を行った(この数は全体の有効なペア数と比較して少ない)。そのため、先述したようにrankingとclassifierQueryはロジックとしてはほぼ同じものなので、学習するサンプル数の違いによる影響が分かる。その他の詳細な実験設定等は、実験スクリプトを参照していただきたい。
なお、実験中、classifierを用いてtrainをクライアント側から呼び出す際に、例外が投げられることが多々あった。とりあえず何度か試すと問題なくなるようだったので、外部のスクリプトでリトライをかけることで対処した。そのため、classifierのプログラムはどこかおかしいようであることを注記しておく。
評価スクリプトであるEval-Score-3.0.plとEval-Score-4.0.plは、「downhill simplexでNDCG最適化」と「LETOR4.0 Downloads」を元に、それぞれ該当行を以下のように修正して用いた。
193c193 < if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+)$/) --- > if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+)/)
216c216 < if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+) inc = ([^\s]+) prob = ([^\s]+)$/) --- > if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+) inc = ([^\s]+) prob = ([^\s]+).$/)
結果は以下。比較のためにLETORのBaselineから、OHSUMEDではRankSVMを、MQ2007とMQ2008ではRankSVM-Structを選択した(Baselines on LETOR3.0, LETOR4.0 Baselines)。
OHSUMED NDCG@1 NDCG@2 NDCG@3 NDCG@4 NDCG@5 NDCG@6 NDCG@7 NDCG@8 NDCG@9 NDCG@10 rankingPA 0.3045 0.2940 0.2709 0.2680 0.2585 0.2521 0.2500 0.2451 0.2441 0.2471 rankingCW 0.5326 0.5040 0.4956 0.4811 0.4740 0.4637 0.4569 0.4540 0.4526 0.4499 rankingAROW 0.5670 0.5193 0.4975 0.4819 0.4630 0.4590 0.4524 0.4503 0.4417 0.4392 rankingNHERD 0.4727 0.4725 0.4581 0.4400 0.4227 0.4229 0.4211 0.4206 0.4180 0.4131 classifierQueryPA 0.1571 0.1470 0.1573 0.1489 0.1496 0.1542 0.1578 0.1583 0.1573 0.1608 classifierQueryCW 0.4377 0.4032 0.3980 0.4052 0.3979 0.3943 0.3900 0.3893 0.3866 0.3878 classifierQueryAROW 0.3227 0.2872 0.2954 0.2989 0.2895 0.2886 0.2849 0.2844 0.2858 0.2880 classifierQueryNHERD 0.3782 0.3607 0.3572 0.3497 0.3420 0.3320 0.3343 0.3334 0.3331 0.3291 classifierSPDPA 0.4431 0.4087 0.3886 0.3865 0.3824 0.3782 0.3737 0.3759 0.3708 0.3735 classifierSPDCW 0.5046 0.4689 0.4414 0.4454 0.4420 0.4445 0.4411 0.4365 0.4322 0.4303 classifierSPDAROW 0.5584 0.5147 0.5045 0.4860 0.4749 0.4622 0.4606 0.4572 0.4530 0.4495 classifierSPDNHERD 0.4799 0.4510 0.4200 0.4123 0.4033 0.3992 0.4019 0.4000 0.3980 0.3961 RankSVM 0.4958 0.4331 0.4207 0.424 0.4164 0.4159 0.4133 0.4072 0.4124 0.414
MQ2007 NDCG@1 NDCG@2 NDCG@3 NDCG@4 NDCG@5 NDCG@6 NDCG@7 NDCG@8 NDCG@9 NDCG@10 rankingPA 0.2644 0.2754 0.2814 0.2862 0.2927 0.2988 0.3067 0.3120 0.3182 0.3236 rankingCW 0.3889 0.3930 0.3959 0.4026 0.4086 0.4153 0.4192 0.4229 0.4289 0.4345 rankingAROW 0.4073 0.4053 0.4041 0.4082 0.4145 0.4206 0.4234 0.4295 0.4358 0.4424 rankingNHERD 0.3358 0.3401 0.3474 0.3549 0.3588 0.3644 0.3695 0.3747 0.3802 0.3857 classifierQueryPA 0.2694 0.2813 0.2912 0.2985 0.3041 0.3088 0.3144 0.3195 0.3248 0.3297 classifierQueryCW 0.3795 0.3849 0.3883 0.3958 0.3979 0.4032 0.4085 0.4134 0.4180 0.4230 classifierQueryAROW 0.3873 0.3857 0.3897 0.3939 0.3994 0.4059 0.4112 0.4163 0.4225 0.4287 classifierQueryNHERD 0.3280 0.3404 0.3461 0.3521 0.3564 0.3621 0.3665 0.3734 0.3808 0.3876 classifierSPDPA 0.2641 0.2776 0.2872 0.2945 0.3033 0.3103 0.3176 0.3244 0.3301 0.3360 classifierSPDCW 0.3970 0.4041 0.4060 0.4104 0.4148 0.4188 0.4232 0.4284 0.4338 0.4392 classifierSPDAROW 0.3976 0.4050 0.4070 0.4082 0.4120 0.4187 0.4230 0.4300 0.4359 0.4419 classifierSPDNHERD 0.3297 0.3372 0.3416 0.3495 0.3546 0.3603 0.3656 0.3718 0.3781 0.3853 RankSVM-Struct 0.4096 0.4074 0.4063 0.4084 0.4143 0.4195 0.4252 0.4306 0.4362 0.4439
MQ2008 NDCG@1 NDCG@2 NDCG@3 NDCG@4 NDCG@5 NDCG@6 NDCG@7 NDCG@8 NDCG@9 NDCG@10 rankingPA 0.2970 0.3419 0.3574 0.3700 0.3833 0.4060 0.4212 0.3646 0.1605 0.1639 rankingCW 0.3745 0.3995 0.4240 0.4510 0.4706 0.4839 0.4888 0.4560 0.2212 0.2260 rankingAROW 0.3674 0.4015 0.4222 0.4488 0.4686 0.4818 0.4882 0.4546 0.2226 0.2277 rankingNHERD 0.3478 0.3890 0.4132 0.4353 0.4570 0.4742 0.4814 0.4476 0.2119 0.2167 classifierQueryPA 0.2405 0.2605 0.2831 0.3107 0.3339 0.3529 0.3678 0.3465 0.1492 0.1540 classifierQueryCW 0.3767 0.3933 0.4210 0.4469 0.4674 0.4792 0.4877 0.4545 0.2211 0.2258 classifierQueryAROW 0.3767 0.3998 0.4266 0.4478 0.4694 0.4802 0.4872 0.4553 0.2237 0.2281 classifierQueryNHERD 0.3278 0.3759 0.4017 0.4320 0.4496 0.4636 0.4723 0.4390 0.2140 0.2178 classifierSPDPA 0.2818 0.3176 0.3453 0.3772 0.3976 0.4181 0.4294 0.4008 0.1863 0.1900 classifierSPDCW 0.3656 0.4069 0.4266 0.4499 0.4716 0.4833 0.4937 0.4586 0.2235 0.2275 classifierSPDAROW 0.3694 0.4027 0.4289 0.4526 0.4742 0.4856 0.4925 0.4584 0.2241 0.2294 classifierSPDNHERD 0.2972 0.3403 0.3705 0.3989 0.4236 0.4429 0.4517 0.4189 0.1983 0.2026 RankSVM-Struct 0.3627 0.3985 0.4286 0.4509 0.4695 0.4851 0.4905 0.4564 0.2239 0.2279
OHSUMED P@1 P@2 P@3 P@4 P@5 P@6 P@7 P@8 P@9 P@10 rankingPA 0.4242 0.4195 0.3711 0.3702 0.3473 0.3367 0.3292 0.3188 0.3169 0.3164 rankingCW 0.6515 0.6327 0.6170 0.5950 0.5762 0.5543 0.5373 0.5268 0.5188 0.5086 rankingAROW 0.6801 0.6513 0.6167 0.5923 0.5590 0.5400 0.5264 0.5186 0.4999 0.4905 rankingNHERD 0.6043 0.5950 0.5854 0.5645 0.5273 0.5198 0.5157 0.5045 0.4928 0.4728 classifierQueryPA 0.2645 0.2597 0.2772 0.2526 0.2437 0.2424 0.2427 0.2383 0.2378 0.2423 classifierQueryCW 0.5571 0.5197 0.5196 0.5431 0.5215 0.5055 0.4952 0.4854 0.4755 0.4696 classifierQueryAROW 0.4606 0.4095 0.4335 0.4264 0.4000 0.3885 0.3735 0.3682 0.3693 0.3655 classifierQueryNHERD 0.4922 0.4872 0.4889 0.4752 0.4541 0.4352 0.4312 0.4293 0.4257 0.4164 classifierSPDPA 0.5381 0.5379 0.5065 0.5053 0.4985 0.4817 0.4698 0.4641 0.4484 0.4528 classifierSPDCW 0.6056 0.5909 0.5577 0.5670 0.5519 0.5449 0.5319 0.5221 0.5115 0.4983 classifierSPDAROW 0.6528 0.6468 0.6266 0.5976 0.5820 0.5481 0.5307 0.5248 0.5095 0.4993 classifierSPDNHERD 0.6061 0.5870 0.5304 0.5115 0.4963 0.4892 0.4925 0.4889 0.4830 0.4783 RankSVM 0.5974 0.5494 0.5427 0.5443 0.5319 0.5253 0.5097 0.4933 0.492 0.4864
MQ2007 P@1 P@2 P@3 P@4 P@5 P@6 P@7 P@8 P@9 P@10 rankingPA 0.3262 0.3212 0.3191 0.3154 0.3149 0.3123 0.3128 0.3094 0.3082 0.3053 rankingCW 0.4528 0.4386 0.4254 0.4189 0.4100 0.4027 0.3932 0.3838 0.3787 0.3739 rankingAROW 0.4735 0.4489 0.4323 0.4220 0.4152 0.4089 0.3993 0.3940 0.3895 0.3846 rankingNHERD 0.3960 0.3854 0.3789 0.3746 0.3655 0.3603 0.3560 0.3512 0.3474 0.3438 classifierQueryPA 0.3292 0.3287 0.3322 0.3317 0.3301 0.3250 0.3226 0.3193 0.3158 0.3132 classifierQueryCW 0.4433 0.4288 0.4141 0.4087 0.3974 0.3894 0.3834 0.3776 0.3707 0.3653 classifierQueryAROW 0.4480 0.4256 0.4155 0.4075 0.4017 0.3945 0.3892 0.3834 0.3784 0.3738 classifierQueryNHERD 0.3931 0.3868 0.3812 0.3750 0.3689 0.3640 0.3581 0.3555 0.3527 0.3494 classifierSPDPA 0.3216 0.3196 0.3220 0.3200 0.3222 0.3231 0.3232 0.3219 0.3204 0.3181 classifierSPDCW 0.4628 0.4475 0.4297 0.4186 0.4114 0.4017 0.3929 0.3878 0.3814 0.3754 classifierSPDAROW 0.4628 0.4475 0.4297 0.4186 0.4114 0.4017 0.3929 0.3878 0.3814 0.3754 classifierSPDNHERD 0.3919 0.3845 0.3747 0.3712 0.3641 0.3591 0.3562 0.3527 0.3489 0.3457 RankSVM-Struct 0.4746 0.4496 0.4315 0.4194 0.4135 0.4048 0.3994 0.3931 0.3868 0.3833
MQ2008 P@1 P@2 P@3 P@4 P@5 P@6 P@7 P@8 P@9 P@10 rankingPA 0.3202 0.3023 0.2926 0.2813 0.2732 0.2666 0.2577 0.2435 0.2296 0.2171 rankingCW 0.4426 0.4069 0.3839 0.3670 0.3449 0.3225 0.2981 0.2792 0.2615 0.2462 rankingAROW 0.4311 0.4120 0.3852 0.3641 0.3411 0.3201 0.2994 0.2793 0.2625 0.2476 rankingNHERD 0.4183 0.3967 0.3711 0.3501 0.3337 0.3154 0.2961 0.2776 0.2571 0.2422 classifierQueryPA 0.2983 0.2849 0.2775 0.2728 0.2632 0.2553 0.2483 0.2380 0.2225 0.2095 classifierQueryCW 0.4464 0.4030 0.3865 0.3654 0.3441 0.3199 0.3008 0.2812 0.2612 0.2465 classifierQueryAROW 0.4438 0.4100 0.3878 0.3642 0.3431 0.3191 0.2988 0.2806 0.2634 0.2480 classifierQueryNHERD 0.4043 0.3864 0.3673 0.3527 0.3339 0.3127 0.2943 0.2739 0.2568 0.2409 classifierSPDPA 0.3481 0.3386 0.3218 0.3179 0.3015 0.2893 0.2748 0.2595 0.2433 0.2284 classifierSPDCW 0.4362 0.4094 0.3818 0.3600 0.3424 0.3195 0.3001 0.2796 0.2620 0.2464 classifierSPDAROW 0.4400 0.4145 0.3869 0.3670 0.3456 0.3223 0.3007 0.2809 0.2620 0.2479 classifierSPDNHERD 0.3736 0.3616 0.3482 0.3361 0.3217 0.3078 0.2890 0.2704 0.2540 0.2393 RankSVM-Struct 0.4273 0.4069 0.3903 0.3696 0.3474 0.3265 0.3021 0.2822 0.2647 0.2491
MAP OHSUMED MQ2007 MQ2008 rankingPA 0.3193 0.3710 0.3812 rankingCW 0.4470 0.4569 0.4737 rankingAROW 0.4433 0.4642 0.4687 rankingNHERD 0.4270 0.4169 0.4625 classifierQueryPA 0.2827 0.3773 0.3612 classifierQueryCW 0.4159 0.4479 0.4723 classifierQueryAROW 0.3556 0.4536 0.4715 classifierQueryNHERD 0.3897 0.4187 0.4557 classifierSPDPA 0.4040 0.3781 0.4110 classifierSPDCW 0.4455 0.4611 0.4726 classifierSPDAROW 0.4443 0.4620 0.4770 classifierSPDNHERD 0.4176 0.4212 0.4332 RankSVM/RankSVM-Struct 0.4334 0.4645 0.4696
まず分類器の違いに着目すると、不思議なことに、三つの方法すべてでPAを用いた場合の結果が著しく悪かった。そのため、グラフからはその三つは省いている。評価指標やデータセットによらず、全体的にAROWを用いた場合に良い結果が得られた。AROWと比較するとCWやNHERDを用いた場合はやや劣る結果となった。
三つの手法を比較すると、rankingやclassifierSPDに比べて、classifireQueryがやや劣る結果となった。先述したように、rankingとclassifierQueryはほぼ同じロジックであり、classifierQueryは学習の回数が少ないため、その違いが結果に反映されたのだと考えられる。一方でclassifierSPDはclassifierQueryと同じ学習の回数で良い結果を得られたため、クエリを選択した後にそれに紐付いたペアを選択するSPDの手法の方が収束速度が優れていると言える。一方でrankingの手法でも学習回数を増やせばパフォーマンスが出るようである。
RankSVMやRankSVM-Structと比較すると、OHSUMEDではRankSVMを越える結果も得られ、他の二つのデータセットでもRankSVM-Structに近い結果が得られた。
詳細な結果は以下にある。
今回も実験で使ったスクリプトをGistに貼り付けておいた。
- 前準備
- 評価(classifier1がclassifierQueryに、classifier2がclassfierSPDに対応)
- 結果まとめ