2クラス分類とベルヌーイ分布【tensorflow】
CPPXのXです。
クラス分類を行う時の誤差関数に交差エントロピー誤差を使用すると思います。
2クラス分類の場合も同様だと思うのですが、その時にベルヌーイ分布が出てきて頭がはてなになったので、何でベルヌーイ分布が出てくるのかまとめておこうと思います。
説明にはtensorflowを使用しています。
確率に関してあまり詳しいことは知らないので、奥まった事は書いていません。
では目次です。
環境
- python 3.6.5
- tensorflow 1.10.0
imports
import tensorflow as tf from tensorflow import distributions as tfd import numpy as np sess = tf.InteractiveSession()
sigmoid_cross_entropy
tensorflowで2クラス分類をシンプルにやるときはtf.losses.sigmoid_cross_entropy
を使って誤差関数を定義すると思います。
label = np.array([0, 1, 0, 1, 1]) # 適当に logits = np.array([-100, 100, -1, 1, 0]) # これも適当に loss = tf.losses.sigmoid_cross_entropy(label, logits) print(loss.eval()) # 誤差値を見てみる # ==> 0.2639341
logitsに対してsigmoidを通した後に交差エントロピーを計算してくれます。
logitsは多くの場合、最後の全結合層の出力ですかね。
交差エントロピー
交差エントロピーの計算式は
\begin{equation} -\sum_{i=0}^n p_i \log q_i \end{equation}
で求まります。
2クラス分類なので、0になる場合()と1になる場合()で考えます。
は実測値、は予測値です。
はの余事象を考えればいいので(も同様)、
\begin{align} p_0 = 1 - p_1\\ q_0 = 1 - q_1 \end{align}
70%(0.7)が起きない確率は30%(1 - 0.7 = 0.3)になりますよね。
なので、2クラス分類の交差エントロピーは
\begin{align} -\sum_{i=0}^n p_i \log q_i &= -(p_0 log q_0 + p_1 log q_1)\\ &= -\left\{(1 - p_1) log (1 - q_1) + p_1 log q_1\right\} \end{align}
tensorflowで書くとこんな感じです。
p1 = label # labelは上で出てきたやつ q1 = tf.sigmoid(logits) # logitsも上で出てきたやつ # 交差エントロピー loss = -( (1 - p1) * tf.log((1 - q1) + 1e-20) + p1 * tf.log(q1 + 1e-20) ) # + 1e-20する事で0がきた時にinfになってしまうのを防ぎます print(loss.eval().mean()) # 交差エントロピーの平均を出しつつ # ==> 0.26393408 # 勘違いしかけたのですが、labelは1になる確率が並んでいるみたいです
そもそも交差エントロピーは確率分布間の距離的なものを計算する式らしいです。
2クラス分類だとこの確率分布がベルヌーイ分布になるので、そこでこいつが出てくるのですね。
ベルヌーイ分布
コインの裏表が出る確率、とかで出てくる確率分布です。
ベルヌーイ分布には確率変数が二つあって(コインの裏, 表など)片方が起きる確率()が定まれば、もう片方が起きる確率()も定まります。
ここで、以下を見てください。
bernoulli = tfd.Bernoulli(logits=logits) p1 = bernoulli.prob(1) # prob(1)は上で言ってるp1(1になる確率)と同じものです print(p1.eval()) # ==> [0. 1. 0.2689414 0.7310586 0.5 ] sig = tf.sigmoid(logits) print(sig.eval()) # ==> [0. 1. 0.26894143 0.7310586 0.5 ] # -- おまけ -- bernoulli = tfd.Bernoulli(probs=0.8) sample = bernoulli.sample(10000) # 定義したベルヌーイ分布から1万個サンプリング ret = sample.eval() print(ret) # 0, 1がいっぱい print(ret.mean(), bernoulli.prob(1).eval()) # 1になってる確率を見てみるとおおよそ0.8くらい
logitsがあった時に、ベルヌーイ分布の何かを通すとsigmoidと同じ計算結果になります。
これなんですが、何ともよくわからず・・・。
調べてみた感じ、ベルヌーイ分布は指数型分布族とかいうやつで、自然パラメータηからベルヌーイ分布を構成しているパラメータ(今回でいうp1)を求められるらしいです。
このをとおいたら、を求める式がsigmoidと同じになるようです。
logitっていうのがそもそもなんですね。
んで、このっていうのがベルヌーイ分布のと一致していて、逆関数がsigmoid。
probはのなのでsigmoidを通した結果と同じになる。
p1 = 0.8 eta = tf.log(p1 / (1 - p1)) print(eta.eval()) # ==> 1.3862944 p1 = tf.sigmoid(eta) print(p1.eval()) # ==> 0.8
自然パラメータって誰だ・・・。
で、ってなんだよ ってなりますよね。
確率分布の確率質量関数を書き換えて
のようになると指数型分布族になります。(上記は分布を決めるパラメータが1個の場合)
このの部分をとしてやるみたいです。
ベルヌーイ分布の確率質量関数は
でした。
xは0か1なので、代入して足してみたら見事1になりますね。
と置いて確率質量関数を式変形してやると
こうなります。
としてやると指数型分布族であることが確認できると思います。
だったので、
という事みたいです。
この辺りを参考にしました。
指数型分布族 | 高校数学の美しい物語
指数型分布族とは?定義と性質をわかりやすく解説 | 全人類がわかる統計学
兎にも角にもこのsigmoidを使って交差エントロピーを計算します。
# p: 正解ラベルのベルヌーイ分布, q: 予測値のベルヌーイ分布 q = tfd.Bernoulli(logits=logits) p1 = label loss = -( (1 - p1) * q.log_prob(0) + # log_probは単純にprobのlogです p1 * q.log_prob(1) ) print(loss.eval().mean()) # sigmoid_cross_entropyと同じ値 # ==> 0.2639341
ここで、pであるlabelの中身について見てみると0か1しか入っていません。
交差エントロピー誤差について考えると、pが0である場所はの結果、pが1である場所はの結果にそれぞれマイナスをつけただけになります。
また、log_probに関してlog_prob([0, 1, 1])とかするとが返ってきます。
# -- おまけ -- bernoulli = tfd.Bernoulli(probs=0.8) p = bernoulli.prob([0, 1, 1]) print(p.eval()) # ==> [0.2 0.8 0.8] log_p = bernoulli.log_prob([0, 1, 1]) print(tf.log(p).eval()) # ==> [-1.609438 -0.22314353 -0.22314353] print(log_p.eval()) # ==> [-1.609438 -0.22314355 -0.22314355]
なので、最終的に交差エントロピーは以下の求めることができます。
q = tfd.Bernoulli(logits=logits) loss = -q.log_prob(label) print(loss.eval().mean()) # sigmoid_cross_entropyと同じ値 # ==> 0.2639341
まとめ
loss = tf.losses.sigmoid_cross_entropy(label, logits)
と
q = tfd.Bernoulli(logits=logits) loss = -q.log_prob(label) loss = tf.reduce_mean(loss)
は一緒みたいです!!
気になって実装を見に行ったらlog_probのなかでsigmoid_cross_entropy_with_logits
が呼ばれていました。
probability/bernoulli.py at master · tensorflow/probability · GitHub
おわり
時々今回のような感じで誤差が求められていて、はて・・・?となっていたものが解消した気がします。
ただ、自然パラメータとか指数型分布族とかちょっと理解が追いつかないです。
今まで雰囲気で機械学習って叫んでいたので、調べれば調べるほど確率が出てきてにっこり笑顔になりました。
高校数学で一番嫌いだった範囲は確率です。