def normalize(batch_images):
"""Normilize batch images.
Move data distribution to 0 +- std for each sample.
Shape: [batch, y, x]
"""
batch_images = batch_images - batch_images.mean(axis=(1, 2), keepdims=True)
batch_images /= batch_images.std(axis=(1, 2), keepdims=True)
return batch_images