Jubatusでオンラインランク学習

もう二か月以上前のことだが、とあるところで、Jubatusでオンライン分類ができるならペアワイズのランク学習もできそうだという話をした。いろいろあって時間がかかってしまったが、実装と簡単な評価をしたのでまとめておく。

以下の評価で用いた実装は"y-tag/jubatus"のrankingブランチにある。現時点では、このブランチは0.4.3のリリース直後をベースとしている。

今回の実装の参考文献として、2009年のNIPSで発表された以下の論文が挙げられる。

また、以下の資料も非常に参考になる。

データセットとして、LETORのOHSUMED, MQ2007, MQ2008を選択し、OHSUMEDはQueryLevelNormを、MQ2007とMQ2008はSupervised rankingのものを用いた。標準的な評価と同じように、既に5分割されているデータを用いて評価を行った。

評価では、今回実装したrankingと、比較のためにクライアント側でペアを作成してclassifierサーバを用いて学習するものを対象とした。classifierを用いる学習法は2種類行った。一つ目はランダムにクエリを選択し、そのクエリに紐付いたペアをすべて一度に学習するもの。二つ目はランダムにクエリを選択するのは同じだが、更にそのクエリに紐付いたペアの中から一つのみ選択することを繰り返すものである。これ以降では、一つ目をclassifierQueryと呼び、二つめの選択方法は先述した論文でStochastic Pairwise Descent(SPD)という名前がつけられていたため、classifierSPDと呼ぶ。今回実装したrankingとclassifierQueryはペアの差分をとる場所がサーバかクライアントかの違いがあるが、それ以外はほぼ同じものだと思って構わない。三つの手法すべてが、サーバ側ではJubatusのオンライン分類器を用いている。今回は分類器としてPA, CW, AROW, NHERDを選択し、AROW, NHERDのハイパーパラメータ(設定ファイルのregularization_weight)は1、CWでは1.036433を用いた。設定ファイルは以前用いたgenerate_conf.plで生成したものを用いた。事前の実験で、今回実装したrankingでは、データを1パスで学習しただけでは収束していないようだったので、クエリ数の10倍の回数、ランダムにクエリを選択して学習した。一方でclassifierQueryとclassifierSPDは、元論文と同じように10,000のペアを選択するように学習を行った(この数は全体の有効なペア数と比較して少ない)。そのため、先述したようにrankingとclassifierQueryはロジックとしてはほぼ同じものなので、学習するサンプル数の違いによる影響が分かる。その他の詳細な実験設定等は、実験スクリプトを参照していただきたい。

なお、実験中、classifierを用いてtrainをクライアント側から呼び出す際に、例外が投げられることが多々あった。とりあえず何度か試すと問題なくなるようだったので、外部のスクリプトでリトライをかけることで対処した。そのため、classifierのプログラムはどこかおかしいようであることを注記しておく。

評価スクリプトであるEval-Score-3.0.plとEval-Score-4.0.plは、「downhill simplexでNDCG最適化」と「LETOR4.0 Downloads」を元に、それぞれ該当行を以下のように修正して用いた。

193c193
<         if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+)$/)
---
>         if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+)/)
216c216
<         if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+) inc = ([^\s]+) prob = ([^\s]+)$/)
---
>         if ($lnFea =~ m/^(\d+) qid\:([^\s]+).*?\#docid = ([^\s]+) inc = ([^\s]+) prob = ([^\s]+).$/)

結果は以下。比較のためにLETORのBaselineから、OHSUMEDではRankSVMを、MQ2007とMQ2008ではRankSVM-Structを選択した(Baselines on LETOR3.0, LETOR4.0 Baselines)。

OHSUMED               NDCG@1 NDCG@2 NDCG@3 NDCG@4 NDCG@5 NDCG@6 NDCG@7 NDCG@8 NDCG@9 NDCG@10
rankingPA             0.3045 0.2940 0.2709 0.2680 0.2585 0.2521 0.2500 0.2451 0.2441 0.2471
rankingCW             0.5326 0.5040 0.4956 0.4811 0.4740 0.4637 0.4569 0.4540 0.4526 0.4499
rankingAROW           0.5670 0.5193 0.4975 0.4819 0.4630 0.4590 0.4524 0.4503 0.4417 0.4392
rankingNHERD          0.4727 0.4725 0.4581 0.4400 0.4227 0.4229 0.4211 0.4206 0.4180 0.4131
classifierQueryPA     0.1571 0.1470 0.1573 0.1489 0.1496 0.1542 0.1578 0.1583 0.1573 0.1608
classifierQueryCW     0.4377 0.4032 0.3980 0.4052 0.3979 0.3943 0.3900 0.3893 0.3866 0.3878
classifierQueryAROW   0.3227 0.2872 0.2954 0.2989 0.2895 0.2886 0.2849 0.2844 0.2858 0.2880
classifierQueryNHERD  0.3782 0.3607 0.3572 0.3497 0.3420 0.3320 0.3343 0.3334 0.3331 0.3291
classifierSPDPA       0.4431 0.4087 0.3886 0.3865 0.3824 0.3782 0.3737 0.3759 0.3708 0.3735
classifierSPDCW       0.5046 0.4689 0.4414 0.4454 0.4420 0.4445 0.4411 0.4365 0.4322 0.4303
classifierSPDAROW     0.5584 0.5147 0.5045 0.4860 0.4749 0.4622 0.4606 0.4572 0.4530 0.4495
classifierSPDNHERD    0.4799 0.4510 0.4200 0.4123 0.4033 0.3992 0.4019 0.4000 0.3980 0.3961
RankSVM               0.4958 0.4331 0.4207 0.424  0.4164 0.4159 0.4133 0.4072 0.4124 0.414

MQ2007                NDCG@1 NDCG@2 NDCG@3 NDCG@4 NDCG@5 NDCG@6 NDCG@7 NDCG@8 NDCG@9 NDCG@10
rankingPA             0.2644 0.2754 0.2814 0.2862 0.2927 0.2988 0.3067 0.3120 0.3182 0.3236
rankingCW             0.3889 0.3930 0.3959 0.4026 0.4086 0.4153 0.4192 0.4229 0.4289 0.4345
rankingAROW           0.4073 0.4053 0.4041 0.4082 0.4145 0.4206 0.4234 0.4295 0.4358 0.4424
rankingNHERD          0.3358 0.3401 0.3474 0.3549 0.3588 0.3644 0.3695 0.3747 0.3802 0.3857
classifierQueryPA     0.2694 0.2813 0.2912 0.2985 0.3041 0.3088 0.3144 0.3195 0.3248 0.3297
classifierQueryCW     0.3795 0.3849 0.3883 0.3958 0.3979 0.4032 0.4085 0.4134 0.4180 0.4230
classifierQueryAROW   0.3873 0.3857 0.3897 0.3939 0.3994 0.4059 0.4112 0.4163 0.4225 0.4287
classifierQueryNHERD  0.3280 0.3404 0.3461 0.3521 0.3564 0.3621 0.3665 0.3734 0.3808 0.3876
classifierSPDPA       0.2641 0.2776 0.2872 0.2945 0.3033 0.3103 0.3176 0.3244 0.3301 0.3360
classifierSPDCW       0.3970 0.4041 0.4060 0.4104 0.4148 0.4188 0.4232 0.4284 0.4338 0.4392
classifierSPDAROW     0.3976 0.4050 0.4070 0.4082 0.4120 0.4187 0.4230 0.4300 0.4359 0.4419
classifierSPDNHERD    0.3297 0.3372 0.3416 0.3495 0.3546 0.3603 0.3656 0.3718 0.3781 0.3853
RankSVM-Struct        0.4096 0.4074 0.4063 0.4084 0.4143 0.4195 0.4252 0.4306 0.4362 0.4439

MQ2008                NDCG@1 NDCG@2 NDCG@3 NDCG@4 NDCG@5 NDCG@6 NDCG@7 NDCG@8 NDCG@9 NDCG@10
rankingPA             0.2970 0.3419 0.3574 0.3700 0.3833 0.4060 0.4212 0.3646 0.1605 0.1639
rankingCW             0.3745 0.3995 0.4240 0.4510 0.4706 0.4839 0.4888 0.4560 0.2212 0.2260
rankingAROW           0.3674 0.4015 0.4222 0.4488 0.4686 0.4818 0.4882 0.4546 0.2226 0.2277
rankingNHERD          0.3478 0.3890 0.4132 0.4353 0.4570 0.4742 0.4814 0.4476 0.2119 0.2167
classifierQueryPA     0.2405 0.2605 0.2831 0.3107 0.3339 0.3529 0.3678 0.3465 0.1492 0.1540
classifierQueryCW     0.3767 0.3933 0.4210 0.4469 0.4674 0.4792 0.4877 0.4545 0.2211 0.2258
classifierQueryAROW   0.3767 0.3998 0.4266 0.4478 0.4694 0.4802 0.4872 0.4553 0.2237 0.2281
classifierQueryNHERD  0.3278 0.3759 0.4017 0.4320 0.4496 0.4636 0.4723 0.4390 0.2140 0.2178
classifierSPDPA       0.2818 0.3176 0.3453 0.3772 0.3976 0.4181 0.4294 0.4008 0.1863 0.1900
classifierSPDCW       0.3656 0.4069 0.4266 0.4499 0.4716 0.4833 0.4937 0.4586 0.2235 0.2275
classifierSPDAROW     0.3694 0.4027 0.4289 0.4526 0.4742 0.4856 0.4925 0.4584 0.2241 0.2294
classifierSPDNHERD    0.2972 0.3403 0.3705 0.3989 0.4236 0.4429 0.4517 0.4189 0.1983 0.2026
RankSVM-Struct        0.3627 0.3985 0.4286 0.4509 0.4695 0.4851 0.4905 0.4564 0.2239 0.2279

OHSUMED               P@1    P@2    P@3    P@4    P@5    P@6    P@7    P@8    P@9    P@10
rankingPA             0.4242 0.4195 0.3711 0.3702 0.3473 0.3367 0.3292 0.3188 0.3169 0.3164
rankingCW             0.6515 0.6327 0.6170 0.5950 0.5762 0.5543 0.5373 0.5268 0.5188 0.5086
rankingAROW           0.6801 0.6513 0.6167 0.5923 0.5590 0.5400 0.5264 0.5186 0.4999 0.4905
rankingNHERD          0.6043 0.5950 0.5854 0.5645 0.5273 0.5198 0.5157 0.5045 0.4928 0.4728
classifierQueryPA     0.2645 0.2597 0.2772 0.2526 0.2437 0.2424 0.2427 0.2383 0.2378 0.2423
classifierQueryCW     0.5571 0.5197 0.5196 0.5431 0.5215 0.5055 0.4952 0.4854 0.4755 0.4696
classifierQueryAROW   0.4606 0.4095 0.4335 0.4264 0.4000 0.3885 0.3735 0.3682 0.3693 0.3655
classifierQueryNHERD  0.4922 0.4872 0.4889 0.4752 0.4541 0.4352 0.4312 0.4293 0.4257 0.4164
classifierSPDPA       0.5381 0.5379 0.5065 0.5053 0.4985 0.4817 0.4698 0.4641 0.4484 0.4528
classifierSPDCW       0.6056 0.5909 0.5577 0.5670 0.5519 0.5449 0.5319 0.5221 0.5115 0.4983
classifierSPDAROW     0.6528 0.6468 0.6266 0.5976 0.5820 0.5481 0.5307 0.5248 0.5095 0.4993
classifierSPDNHERD    0.6061 0.5870 0.5304 0.5115 0.4963 0.4892 0.4925 0.4889 0.4830 0.4783
RankSVM               0.5974 0.5494 0.5427 0.5443 0.5319 0.5253 0.5097 0.4933 0.492  0.4864

MQ2007                P@1    P@2    P@3    P@4    P@5    P@6    P@7    P@8    P@9    P@10
rankingPA             0.3262 0.3212 0.3191 0.3154 0.3149 0.3123 0.3128 0.3094 0.3082 0.3053
rankingCW             0.4528 0.4386 0.4254 0.4189 0.4100 0.4027 0.3932 0.3838 0.3787 0.3739
rankingAROW           0.4735 0.4489 0.4323 0.4220 0.4152 0.4089 0.3993 0.3940 0.3895 0.3846
rankingNHERD          0.3960 0.3854 0.3789 0.3746 0.3655 0.3603 0.3560 0.3512 0.3474 0.3438
classifierQueryPA     0.3292 0.3287 0.3322 0.3317 0.3301 0.3250 0.3226 0.3193 0.3158 0.3132
classifierQueryCW     0.4433 0.4288 0.4141 0.4087 0.3974 0.3894 0.3834 0.3776 0.3707 0.3653
classifierQueryAROW   0.4480 0.4256 0.4155 0.4075 0.4017 0.3945 0.3892 0.3834 0.3784 0.3738
classifierQueryNHERD  0.3931 0.3868 0.3812 0.3750 0.3689 0.3640 0.3581 0.3555 0.3527 0.3494
classifierSPDPA       0.3216 0.3196 0.3220 0.3200 0.3222 0.3231 0.3232 0.3219 0.3204 0.3181
classifierSPDCW       0.4628 0.4475 0.4297 0.4186 0.4114 0.4017 0.3929 0.3878 0.3814 0.3754
classifierSPDAROW     0.4628 0.4475 0.4297 0.4186 0.4114 0.4017 0.3929 0.3878 0.3814 0.3754
classifierSPDNHERD    0.3919 0.3845 0.3747 0.3712 0.3641 0.3591 0.3562 0.3527 0.3489 0.3457
RankSVM-Struct        0.4746 0.4496 0.4315 0.4194 0.4135 0.4048 0.3994 0.3931 0.3868 0.3833

MQ2008                P@1    P@2    P@3    P@4    P@5    P@6    P@7    P@8    P@9    P@10
rankingPA             0.3202 0.3023 0.2926 0.2813 0.2732 0.2666 0.2577 0.2435 0.2296 0.2171
rankingCW             0.4426 0.4069 0.3839 0.3670 0.3449 0.3225 0.2981 0.2792 0.2615 0.2462
rankingAROW           0.4311 0.4120 0.3852 0.3641 0.3411 0.3201 0.2994 0.2793 0.2625 0.2476
rankingNHERD          0.4183 0.3967 0.3711 0.3501 0.3337 0.3154 0.2961 0.2776 0.2571 0.2422
classifierQueryPA     0.2983 0.2849 0.2775 0.2728 0.2632 0.2553 0.2483 0.2380 0.2225 0.2095
classifierQueryCW     0.4464 0.4030 0.3865 0.3654 0.3441 0.3199 0.3008 0.2812 0.2612 0.2465
classifierQueryAROW   0.4438 0.4100 0.3878 0.3642 0.3431 0.3191 0.2988 0.2806 0.2634 0.2480
classifierQueryNHERD  0.4043 0.3864 0.3673 0.3527 0.3339 0.3127 0.2943 0.2739 0.2568 0.2409
classifierSPDPA       0.3481 0.3386 0.3218 0.3179 0.3015 0.2893 0.2748 0.2595 0.2433 0.2284
classifierSPDCW       0.4362 0.4094 0.3818 0.3600 0.3424 0.3195 0.3001 0.2796 0.2620 0.2464
classifierSPDAROW     0.4400 0.4145 0.3869 0.3670 0.3456 0.3223 0.3007 0.2809 0.2620 0.2479
classifierSPDNHERD    0.3736 0.3616 0.3482 0.3361 0.3217 0.3078 0.2890 0.2704 0.2540 0.2393
RankSVM-Struct        0.4273 0.4069 0.3903 0.3696 0.3474 0.3265 0.3021 0.2822 0.2647 0.2491

MAP                     OHSUMED MQ2007  MQ2008
rankingPA               0.3193  0.3710  0.3812
rankingCW               0.4470  0.4569  0.4737
rankingAROW             0.4433  0.4642  0.4687
rankingNHERD            0.4270  0.4169  0.4625
classifierQueryPA       0.2827  0.3773  0.3612
classifierQueryCW       0.4159  0.4479  0.4723
classifierQueryAROW     0.3556  0.4536  0.4715
classifierQueryNHERD    0.3897  0.4187  0.4557
classifierSPDPA         0.4040  0.3781  0.4110
classifierSPDCW         0.4455  0.4611  0.4726
classifierSPDAROW       0.4443  0.4620  0.4770
classifierSPDNHERD      0.4176  0.4212  0.4332
RankSVM/RankSVM-Struct  0.4334  0.4645  0.4696

まず分類器の違いに着目すると、不思議なことに、三つの方法すべてでPAを用いた場合の結果が著しく悪かった。そのため、グラフからはその三つは省いている。評価指標やデータセットによらず、全体的にAROWを用いた場合に良い結果が得られた。AROWと比較するとCWやNHERDを用いた場合はやや劣る結果となった。

三つの手法を比較すると、rankingやclassifierSPDに比べて、classifireQueryがやや劣る結果となった。先述したように、rankingとclassifierQueryはほぼ同じロジックであり、classifierQueryは学習の回数が少ないため、その違いが結果に反映されたのだと考えられる。一方でclassifierSPDはclassifierQueryと同じ学習の回数で良い結果を得られたため、クエリを選択した後にそれに紐付いたペアを選択するSPDの手法の方が収束速度が優れていると言える。一方でrankingの手法でも学習回数を増やせばパフォーマンスが出るようである。

RankSVMやRankSVM-Structと比較すると、OHSUMEDではRankSVMを越える結果も得られ、他の二つのデータセットでもRankSVM-Structに近い結果が得られた。

詳細な結果は以下にある。

今回も実験で使ったスクリプトをGistに貼り付けておいた。