NP困難な分類問題を代理損失の最小化に帰着させる話
[mathjax] 機械学習の分類問題の中心にある決定境界の決定方法について かなり要領を得た説明を聞いて理解が2段階くらい先に進んだのでまとめてみます。 データが与えられただけの状態から決定境界を決める問題はNP困難ですが 別の問題に帰着させることで解を得る、というのが基本的なアイデアです。 分類の正誤とその度合いを一度に表現できるマージンを定義し、 マージンを使って与えた代理損失を最小にする問題にします。 分類問題を代理損失の最小化に帰着させるのですね。 任意の決定境界を決める問題は線形分類であってもNP困難 2値のラベルA,B付きの2次のデータポイントが与えられたとして、 入力空間(X1-X2)におけるA,Bの分離境界(decision boundary)を求める問題が\"分類\"。 直線で分離境界を書くとして、それを求めるための最も愚直な方法は以下のようなもの。 その分離境界によりデータポイントが正しく分類出来ていれば1をカウントする。 正しく分類出来ていなければ0をカウントする。 全データポイントにおける正答率を求める。 正答率が最大になるような決定境界を求める。 そもそも分離境界は直線でなくても良いのに、あえて直線ですよ、と仮定をしたとしても、 分離境界が完全に自由で、全データに対して正答率を求めないといけない。 上記の問題の計算量は(mathcal{O}(n^3))では済まない。NP困難。 計算できるように改善 分離境界の初期値を決めて、そこから正答率が良くなる方向に少しずつずらしていこうにも、 \"正しく分類されている\"=1,\"分類されていない\" =0 は、少しの変化に影響されない。 正しい=1/正しくない=0、という損失とは別の損失を作って、 その損失を使った別の問題を解くことを、上記を問題を解くことに帰着させる。 決定境界の変化に敏感な損失を作る サンプルサイズが十分大きいとき、1.で作った損失による学習結果が、「正しく分類」「正しくない分類」という損失の学習結果と一致する margin 線形分類において、分離境界(f(x_1,x_2,cdots,x_n)=w_0+w_1x_1+w_2x_2+cdots+w_nx_n)とする。 この多項式と分離の正誤、正誤の度合いは以下のように決まる。 分類の正誤は(f(x_1,x_2,cdots,x_n))の符号が決める。 分類の正誤の度合いは(f(x_1,x_2,cdots,x_n))の絶対値が決める。 (f(x_1,x_2,cdots,x_n))が正の場合、決定境界から近い場所にあるデータポイントは もしかしたら誤って分類してしまったものかもしれない。 決定境界から遠い場所にあるデータポイントは近いものよりは正しく分類しているかもしれない。 同様に(f(x_1,x_2,cdots,x_n))が負の場合、決定境界から近い場所にあるデータポイントは もしかしたら正しい分類かもしれないし、遠いデータポイントはより近いものより間違っている可能性が高い。 この事実を1つの式で表す。 データポイントには出力ラベル(y=pm 1)が付いているものとする。 判別関数を(f(x_1,x_2,cdots,x_n))とする。決定境界は(f(x_1,x_2,cdots,x_n)=0) begin{eqnarray} m = yf(x_1,x_2,cdots,x_n) end{eqnarray} ラベル1を-1と分類した場合(f(x_1,x_2,cdots,x_n)<0)。 同様に-1を1と分類した場合も(f(x_1,x_2,cdots,x_n)<0)。 つまり、誤分類したときにラベルと判別関数の符号が異なり(m0)となる。 ということで、(m)をマージン(margin)と呼ぶ。 サポートベクトル marginが最大になるように各データポイントの中にある決定境界を決めていく。 全てのデータポイントについて距離を計算する必要はなく、決定境界と距離が一番近いデータポイントとの 距離を最大化すれば良いらしい。(それが一番近いかどうかはいずれにせよ距離を求める必要がありそうだけど..) marginが最大になるように決めた決定境界と距離が最も近いデータポイントをサポートベクトルと言うらしい。 マージンを使った損失 最初に戻ると、決定境界の変化に敏感な損失を作ることが目的だった。 マージンが正の方向に大きいほど正しい分類であると言えるし、 マージンが負の方向に大きいほど誤った分類であると言えるけれども、 正しい度合いが高ければ小、誤りの度合いが高ければ大、となる損失を考えることで、 誤った方向に決定境界を修正すれば敏感に値が上昇する損失にすることができる。 (正しい方向に移動しても変わらない。) 横軸にマージン、縦軸に損失を取ったとして、以下のような損失(h(m))を考える。 もちろん、(m = yf(x_1,x_2,cdots,x_n))。 (m=1)より大きいマージンについては損失が0。(m=1)より小さいマージンについて線形に増加する。 (m=1)を境にヒンジの形をしているのでhinge損失という名前が付いてる。 begin{eqnarray} h(m) = max(0,1-m) end{eqnarray}