aidiary
10/28/2015 - 11:50 AM

CNNのフィルタを可視化

CNNのフィルタを可視化

%matplotlib inline
import numpy as np
import cPickle
import matplotlib.pyplot as plt
model = cPickle.load(open("cifar10.pkl", "rb"))

n1, n2, h, w = model.conv1.W.shape
print n1, n2, h, w
fig = plt.figure()
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
pos = 0
for i in range(4):
    for j in range(8):
        ax = fig.add_subplot(4, 8, pos + 1, xticks=[], yticks=[])
        img = model.conv1.W[pos].transpose(1, 2, 0)
        img -= img.min()
        img /= img.max()
        print img
        ax.imshow(img)
        pos += 1