#coding: utf-8
from autoencoder import *
def unpickle(f):
fo = open(f, 'rb')
d = cPickle.load(fo)
fo.close()
return d
def load_data(dataset):
"""データセットをロードしてGPUの共有変数に格納"""
d = unpickle(dataset)
data_x = np.array(d['data']) / 255.0
def shared_dataset(data_x, borrow=True):
shared_x = theano.shared(np.asarray(data_x, dtype=theano.config.floatX), borrow=borrow)
return shared_x
train_set_x = shared_dataset(data_x)
return train_set_x
if __name__ == "__main__":
learning_rate = 0.01
training_epochs = 100
batch_size = 20
# 学習データのロード
# 今回は評価はしないため訓練データのみ
train_set_x = load_data('../data/cifar10/data_batch_1')
print train_set_x.get_value().shape
# ミニバッチ数
n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size
# ミニバッチのインデックスを表すシンボル
index = T.lscalar()
# ミニバッチの学習データを表すシンボル
x = T.matrix('x')
# モデル構築
rng = np.random.RandomState(123)
theano_rng = RandomStreams(rng.randint(2 ** 30))
autoencoder = Autoencoder(numpy_rng=rng,
theano_rng=theano_rng,
input=x,
n_visible=3072,
n_hidden=100)
# コスト関数と更新式のシンボルを取得
cost, updates = autoencoder.get_cost_updates(learning_rate=learning_rate)
# 訓練用の関数を定義
train_da = theano.function([index],
cost,
updates=updates,
givens={
x: train_set_x[index * batch_size: (index + 1) * batch_size]
})
# モデル訓練
fp = open("cost.txt", "w")
start_time = time.clock()
for epoch in xrange(training_epochs):
c = []
for batch_index in xrange(n_train_batches):
c.append(train_da(batch_index))
print "Training epoch %d, cost %f" % (epoch, np.mean(c))
fp.write("%d\t%f\n" % (epoch, np.mean(c)))
fp.flush()
end_time = time.clock()
training_time = (end_time - start_time)
fp.close()
print "time: %.2fm" % ((training_time / 60.0))
# 学習したモデルの状態を保存
f = open('autoencoder_cifar10.pkl', 'wb')
cPickle.dump(autoencoder.__getstate__(), f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()
# 学習された重みを可視化
import matplotlib.pyplot as plt
W = autoencoder.W.get_value().T
print W.shape
W = (W - np.min(W)) / (np.max(W) - np.min(W))
W *= 255.0
W = W.astype(np.int)
pos = 1
for i in range(100):
plt.subplot(10, 10, pos)
plt.subplots_adjust(wspace=0, hspace=0)
plt.imshow(W[i].reshape(3, 32, 32).transpose(1, 2, 0), interpolation='nearest')
plt.axis('off')
pos += 1
plt.savefig("autoencoder_filters_cifar10.png")