aidiary
10/11/2014 - 3:38 AM

『Rによるモンテカルロ法入門の例3.5

『Rによるモンテカルロ法入門の例3.5

import numpy as np
import matplotlib.pyplot as plt
import scipy.integrate
from scipy.stats import norm, expon

# Rによるモンテカルロ法入門の例3.5

a = 4.5

# サンプリング数
N = 500

# 被積分関数
f = norm.pdf
h = lambda x: x > a
y = lambda x: h(x) * f(x)

# scipy.integrateでの積分
I = scipy.integrate.quad(y, -np.inf, np.inf)[0]
print "scipy.integrate:", I

# 通常のモンテカルロ積分の場合
x = norm.rvs(size=N)
I = np.mean(h(x))
print "normal monte carlo integration:", I

# (1) 重点関数として平均が4.5の正規分布を使用した重点サンプリング

g1 = norm(loc=a, scale=1).pdf
x1 = norm(loc=a, scale=1).rvs(size=N)
I = np.mean(f(x1) / g1(x1) * h(x1))
print "importance sampling (norm):", I

# (2) 重点関数として4.5で切り詰められた指数分布を使用した重点サンプリング

def g2(xlist, a):
    """aで切り詰められた指数分布
    TODO: Vectorizationを使った効率的な実装は?"""
    result = []
    for x in xlist:
        if x < a: result.append(0)
        else: result.append(np.exp(-(x - a)))
    return result

# g2は独自定義の関数であるため直接サンプリングできない
# g2はグラフ描画用として使う
# サンプリングはloc=0の指数分布のサンプルにaを加える(右にずらす)ことで代替する
x2 = expon(loc=0, scale=1).rvs(size=N) + a
I = np.mean(f(x2) / g2(x2, a) * h(x2))
print "importance sampling (truncated expon):", I

# グラフ描画
ix = np.arange(-10, 10, 0.01)
plt.plot(ix, f(ix), label="f(x)")
plt.plot(ix, g1(ix), label="g1(x)")
plt.plot(ix, g2(ix, a), label="g2(x)")
plt.plot(ix, expon(loc=a).pdf(ix), label="expon")  # g2と同じ分布になる
plt.legend(loc="best")
plt.show()

# 収束性の評価

plt.subplot(211)
x1 = f(x1) / g1(x1) * h(x1)
estint = np.cumsum(x1) / np.arange(1, N + 1)
esterr = np.sqrt(np.cumsum((x1 - estint) ** 2)) / np.arange(1, N + 1)
plt.plot(estint, color='red', linewidth=2)
plt.plot(estint + 2 * esterr, color='gray')
plt.plot(estint - 2 * esterr, color='gray')
plt.title("convergence (g = norm)")
plt.ylim([0, 8e-6])

plt.subplot(212)
x2 = f(x2) / g2(x2, a) * h(x2)
estint = np.cumsum(x2) / np.arange(1, N + 1)
esterr = np.sqrt(np.cumsum((x2 - estint) ** 2)) / np.arange(1, N + 1)
plt.plot(estint, color='red', linewidth=2)
plt.plot(estint + 2 * esterr, color='gray')
plt.plot(estint - 2 * esterr, color='gray')
plt.title("convergence (g = truncated expon)")
plt.ylim([0, 8e-6])

plt.tight_layout()
plt.show()