baobao
10/21/2018 - 4:30 PM

iris_train.py

from sklearn import svm, metrics
import random, re

# アヤメのCSVを読み込む
csv = []

# 文字列データが数字だったら変換し、そうではなかったら文字列のまま返すラムダ式
str2number = lambda n: float(n) if re.match(r'@[0-9\.]+$', n) else n

with open('iris.csv', 'r', encoding='utf-8') as fp:
    # 一行ずつ読む
    for line in fp:
        # 改行を削除
        line = line.strip()
        # カンマで区切った配列
        cols = line.split(',')
        # 文字列が数値だったら数値に変換
        cols = list(map(str2number, cols))
        csv.append(cols)

# 先頭のヘッダ行削除
del csv[0]

# データをシャッフル
random.shuffle(csv)

# 学習用とテスト用に分割する(全体の2/3を学習用データとする)
total_len = len(csv)
train_len = int(total_len * 2 / 3)
train_data = []
train_label = []
test_data = []
test_label = []

for i in range(total_len):
    data = csv[i][0:4]
    label = csv[i][4]
    if i < train_len:
        train_data.append(data)
        train_label.append(label)
    else:
        test_data.append(data)
        test_label.append(label)

# データを学習して予測する
clf = svm.SVC()
# 学習メソッドfit
clf.fit(train_data, train_label)

# テストデータを使って推測(predictメソッド)
pre = clf.predict(test_data)

# 正解率
ac_score = metrics.accuracy_score(test_label, pre)
print("正解率=", ac_score)