python(scikit-learn)で決定木を試したソースコード
x,y,class
-2.121580967,-0.365665506,0
-1.797266776,-1.618523073,0
-0.717394571,-0.738177485,0
-0.830662087,-1.058791442,0
-1.145322845,-0.750618964,0
-1.193923462,-0.606689289,0
-1.510185511,-0.071761198,0
-1.204388261,-0.894366562,0
-1.089685904,-0.957121487,0
0.120336282,-0.822792142,0
-0.613491473,-0.574858212,0
-0.42222762,-1.156157105,0
-1.46093916,-1.569185406,0
-0.979112372,-1.907546145,0
-0.817423506,-1.262125852,0
-0.697717895,-1.693726266,0
-0.90938065,-1.014607998,0
-0.684790594,-0.715484122,0
-1.001101395,-0.677211991,0
-0.773787445,-0.781490351,0
0.679649577,2.009710795,0
1.087990067,1.128682442,0
-0.471220285,0.658977628,0
0.75666077,1.253800712,0
1.107170215,0.942305548,0
0.547427043,1.25471797,0
0.172337119,1.321508052,0
0.75441344,0.577555831,0
1.193582186,1.536259738,0
2.010374491,0.793770601,0
0.712591176,1.993967333,0
1.575928337,0.465069191,0
0.312669319,2.340031348,0
0.474045066,0.882761763,0
1.627177872,0.911410143,0
0.581137661,1.335835314,0
0.513237718,1.339515202,0
2.08397179,0.525077456,0
1.462142301,0.515768791,0
-0.808576845,1.678407316,1
-1.113784281,0.673034693,1
-1.755298852,1.081698325,1
-0.901312814,0.374982319,1
0.468540221,1.0636286,1
-0.657399227,1.340507018,1
-1.444319158,0.84839978,1
0.165544458,1.204151487,1
-1.886536914,0.865973372,1
-1.738291657,1.325458792,1
-1.187196553,0.958926432,1
-1.481173228,0.748988483,1
-1.230384608,1.538903521,1
-1.338228883,0.406475801,1
-1.501449356,0.691336663,1
-0.271862827,0.775127338,1
-0.5807241,1.011103064,1
-1.075546799,1.192832784,1
-1.290304416,1.306105006,1
-1.047113498,0.573244887,1
-1.829345211,1.526520998,1
1.655654668,-0.919084924,1
0.354419891,-1.514486921,1
1.294325506,-1.025373154,1
1.059651738,-0.329919638,1
0.776688583,-0.981650811,1
0.852488155,-1.050404224,1
0.709031861,-0.754367705,1
0.286452122,-0.675478837,1
1.375405012,-2.322473849,1
0.802597412,-1.04669331,1
1.086726869,-0.457528884,1
1.176494782,-1.678981604,1
1.919258297,-0.608970286,1
0.416852043,-1.265871275,1
0.849027825,-0.682913596,1
0.838646307,-0.948194085,1
0.61533752,-0.138612196,1
1.938780666,-1.05737308,1
1.226448812,-0.534711982,1
1.059234006,-0.374870802,1
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
from sklearn import tree
# 教師データをロード
df = pd.read_csv('xor_simple.csv');
data_array = df[['x', 'y']].values
class_array = df['class'].values
# 学習(決定木)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(data_array, class_array)
#学習後に、2つのデータを与えてそれらを分類。
#与えられた教師データの特徴から考えると
# x=2.0, y=1.0 であれば、クラス「0」に分類されるはず。
# x=1.0, y= -0.5であれば、クラス「1」に分類されるはず。
result = clf.predict([[2., 1.], [1., -0.5]])
print result
### 決定境界の可視化
import matplotlib.pyplot as plt
# Parameters for plot
n_classes = 2
plot_colors = "br"
plot_step = 0.05
#グラフ描画時の説明変数 x、yの最大値&最小値を算出。
#グラフ描画のメッシュを定義
x_min, x_max = data_array[:, 0].min() - 1, data_array[:, 0].max() + 1
y_min, y_max = data_array[:, 1].min() - 1, data_array[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
np.arange(y_min, y_max, plot_step))
#各メッシュ上での決定木による分類を計算
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
#決定木による分類を等高線フィールドプロットでプロット
cs = plt.contour(xx, yy, Z, cmap=plt.cm.Paired)
plt.xlabel('x')
plt.ylabel('y')
plt.axis("tight")
#教師データも重ねてプロット
for i, color in zip(range(n_classes), plot_colors):
idx = np.where(class_array == i)
plt.scatter(data_array[idx, 0], data_array[idx, 1], c=color, label=['a','b'],
cmap=plt.cm.Paired)
plt.axis("tight")
plt.show()
これは、下記のサイトを参考にpython(scikit-learn)で決定木を試したソースコード
http://data-hacker.blogspot.jp/2014/05/pythonscikit-learn.html