index = 7 plt.imshow(train_data[index].reshape(28, 28)) print ("y = " + str(np.squeeze(train_labels[index])))