鴨川η

not δ

ユニグラム言語モデルとベイズ推定

MAP推定もほぼ公式どおりなので,馴染みのないベイズ推定の実装(といっても数えるだけで numpy の関数がやってくれる). ユニグラムと言いつつ例はサイコロであるが,単語に置き換えればよい.

テキストにある例をそのまま使うと,このコードは6面サイコロを1回ふって3の目だけが出た時に相当する.

import sys
from collections import Counter
import numpy as np

fname = sys.argv[1]

beta = 2  # param for dirichlet
docs = []
w_freq = Counter()
w_freq[1] = 0
w_freq[2] = 0
w_freq[3] = 1
w_freq[4] = 0
w_freq[5] = 0
w_freq[6] = 0

total_freq = sum(w_freq.values())


param = {}

#  estimation
for k, v in sorted(w_freq.items()):
    param[k] = v + beta


# 推定結果のディリクレ分布から10回サンプル,ディリクレ分布のパラメータが比較的1に近いので一様分布に近い
print(np.random.dirichlet(([t[1] for t in sorted(param.items())]), 10))


# ベイズ予測 p(v|W)
for k, v in sorted(w_freq.items()):
    print(k, ":", (v + beta) / (total_freq + (beta * len(w_freq))))

実行結果

[[ 0.12417395  0.19154408  0.07719698  0.1738157   0.1531057   0.28016358]
 [ 0.24764191  0.07221084  0.24745174  0.05909202  0.1733986   0.20020489]
 [ 0.23744145  0.11019392  0.27679663  0.14449885  0.1761183   0.05495086]
 [ 0.29045251  0.04017403  0.23073892  0.11682754  0.16461681  0.15719019]
 [ 0.35437087  0.06406654  0.24366861  0.11545934  0.04009513  0.18233951]
 [ 0.49584865  0.07876236  0.0594855   0.12073768  0.12014671  0.12501909]
 [ 0.23841069  0.07891781  0.29979036  0.21096755  0.11178245  0.06013113]
 [ 0.09141996  0.05997774  0.39383342  0.05182101  0.19453655  0.20841133]
 [ 0.09320177  0.17077338  0.29225936  0.13199859  0.14587007  0.16589683]
 [ 0.08586467  0.19022226  0.42977082  0.12957443  0.13003424  0.03453358]]
1 : 0.15384615384615385
2 : 0.15384615384615385
3 : 0.23076923076923078
4 : 0.15384615384615385
5 : 0.15384615384615385
6 : 0.15384615384615385

次はもっと試行が増やしたとき,この例は64回サイコロをふって,3の目が19回,それ以外が9回の場合.

import sys
from collections import Counter
import numpy as np

fname = sys.argv[1]

beta = 2  # param for dirichlet
docs = []
w_freq = Counter()
w_freq[1] = 9
w_freq[2] = 9
w_freq[3] = 19
w_freq[4] = 9
w_freq[5] = 9
w_freq[6] = 9


total_freq = sum(w_freq.values())


param = {}

#  estimation
for k, v in sorted(w_freq.items()):
    param[k] = v + beta


# 推定結果のディリクレ分布から10回サンプル,ディリクレ分布のパラメータが比較的1に近いので一様分布に近い
print(np.random.dirichlet(([t[1] for t in sorted(param.items())]), 10))


# ベイズ予測 p(v|W)
for k, v in sorted(w_freq.items()):
    print(k, ":", (v + beta) / (total_freq + (beta * len(w_freq))))

実行結果

[[ 0.13090258  0.13316758  0.26825103  0.15309491  0.15798866  0.15659523]
 [ 0.10465213  0.15160679  0.34082025  0.09529056  0.16828622  0.13934405]
 [ 0.1257165   0.12082959  0.34599945  0.16500097  0.15000467  0.09244882]
 [ 0.16548145  0.16067654  0.28502537  0.12881519  0.11611117  0.14389027]
 [ 0.13807608  0.13542528  0.32553823  0.11926589  0.09503038  0.18666414]
 [ 0.13570627  0.12121744  0.22131254  0.21681907  0.20421259  0.10073209]
 [ 0.13753213  0.12878648  0.25753269  0.14081603  0.18224729  0.15308539]
 [ 0.0916103   0.07901546  0.30365542  0.15803882  0.16105213  0.20662786]
 [ 0.12105008  0.16091482  0.37429778  0.10832703  0.14001683  0.09539345]
 [ 0.15292372  0.22995453  0.26713616  0.11307034  0.08586252  0.15105274]]
1 : 0.14473684210526316
2 : 0.14473684210526316
3 : 0.27631578947368424
4 : 0.14473684210526316
5 : 0.14473684210526316
6 : 0.14473684210526316

numpyの行列の3列目が3の目がでる確率であるが,どのサンプル(1行)においても他の目よりも高い確率になっている.(1回しか振らなかった時よりも高い) また,予測分布の確率も1回のときより高くなる.