ktl014
5/15/2017 - 12:35 AM

create_LMDB.py How to create lmdb files

create_LMDB.py How to create lmdb files

"""Description
"""
from __future__ import print_function
import os
import random
import sys
import lmdb
import caffe
import numpy as np
import glob

__author__ = 'PedroMorgado'

GENERAL_IMAGES_PATH = '/data4/plankton_wi17/plankton/plankton_class/images'
DEBUG = False


def write_caffe_lmdb(img_fns, labels, lmdb_fn):
    if os.path.exists(lmdb_fn) and not DEBUG:
        raise ValueError(lmdb_fn + ' already exists.')

    nSmpl = labels.size
    map_size = nSmpl*3*256*256*8*1.5
    env_img = lmdb.open(lmdb_fn, map_size=map_size)
    print('Generating dataset lmdb: '+lmdb_fn)
    for i in range(nSmpl):
        # Write image datum
        datum = caffe.proto.caffe_pb2.Datum()

        img = caffe.io.load_image(img_fns[i])   # Read image
        img = caffe.io.resize_image(img, np.array([256, 256]))     # Resize to 256
        img = (img*255).astype(np.uint8)        # [0,1]->[0,255]
        img = img[:, :, (2, 1, 0)]              # RGB->BGR
        img = np.transpose(img, (2, 0, 1))      # [X,Y,C]->[C,X,Y]
        if DEBUG:
            print(img.max(), img.min(), img.mean(), img.shape)
            exit(0)

        # Prepare Datum
        datum.channels, datum.height, datum.width = img.shape[0], img.shape[1], img.shape[2]
        datum.data = img.tostring()
        datum.label = int(labels[i])

        with env_img.begin(write=True) as txn:
            txn.put('{:08}'.format(i).encode('ascii'), datum.SerializeToString())
        if i % 1000 == 0:
            print('Samples saved:', i, '/', nSmpl)
            sys.stdout.flush()
    return


def get_fns(key, target_lb):
    """
    key - string denoting train, val, or test
    target - integer denoting the class number of the group (class, order, species, etc) to detect
    """
    class_paths_all = glob.glob(os.path.join(GENERAL_IMAGES_PATH,'class*'))
    #print (class_paths_all)
    fns = []
    lbs = []
    # directory hierarchy:
    # images -> class** -> subclass** -> train/val/test
    # class is family; subclass is specimen
    specimen_count_class0 = 0
    specimen_count_class1 = 0
    target_new_lb = 0
    others_new_lb = 1
    num_class0 = 0
    num_class1 = 0
    for class_path in class_paths_all:
        # get class label
        class_name = class_path.split('/')[-1]
        class_label = int(class_name[5:])
        # get all file names in class
        subclass_paths_all = glob.glob(os.path.join(class_path,'subclass*'))
        if class_label == target_lb:
            label = target_new_lb
            num_class0 += 1
            specimen_count_class0 += len(subclass_paths_all)
        else:
            label = others_new_lb
            specimen_count_class1 += len(subclass_paths_all)
            num_class1 += 1
        # print (class_name)
        for subclass_path in subclass_paths_all:
            # print (subclass_path.split('/')[-1])
            class_fns = [os.path.join(subclass_path, key, fn)
                         for fn in os.listdir(os.path.join(subclass_path, key)) if fn.endswith('.png')]
            fns += class_fns
            lbs += [label]*len(class_fns)
    # print stats of files
    print ('# classes:', len(class_paths_all))
    #print ('# specimens:', specimen_count)
    print (key)
    print ('# images:', len(fns))
    print ('# labels:', len(lbs))
    # check if binary separation is successful
    num_class0_fles = lbs.count(0)
    num_class1_fles = lbs.count(1)
    print ('# target class0', num_class0)
    print ('# other class1', num_class1)
    print ('# specimens class0:', specimen_count_class0)
    print ('# specimens class1:', specimen_count_class1)
    print ('# target class0 files', num_class0_fles)
    print ('# target class1 files', num_class1_fles)
    print ('\n')
    # shuffle
    index = range(len(fns))
    random.shuffle(index)
    fns = [fns[i] for i in index]
    lbs = np.array([lbs[i] for i in index])
    # write paths and labels to txt files
    if key != 'train' and key != 'val':
        path_txt = open('Image_paths_labels_'+key+'.txt','w')
        nSmpl = lbs.size
        for i in range(nSmpl):
            path_txt.write(fns[i]+' '+str(lbs[i])+'\n')
        path_txt.close()
    return fns, lbs


def main():

    print(GENERAL_IMAGES_PATH)

    # get file names and lables
    target_lb = 1
    train_fns, train_lbs = get_fns('train', target_lb)
    val_fns, val_lbs = get_fns('val', target_lb)
    test1_fns, test1_lbs = get_fns('test1', target_lb) # same specimen

    # write lmdb files
    write_caffe_lmdb(train_fns, train_lbs, 'train.LMDB')
    write_caffe_lmdb(val_fns, val_lbs, 'val.LMDB')
    write_caffe_lmdb(test1_fns, test1_lbs, 'test1.LMDB')


if __name__ == '__main__':
    main()
    '''
    1. GENERAL_PATH
    2. which files to make lmdb for
    3. path hierarchy
    '''