Caffe 在图像分类模型的训练时, 效率起见, 未直接从图片列表读取图片, 训练数据往往是采用 LMDB 或 HDF5 格式.

LMDB格式的优点:

  • 基于文件映射IO(memory-mapped),数据速率更好
  • 对大规模数据集更有效.

HDF5的特点:

  • 易于读取
  • 类似于mat数据,但数据压缩性能更强
  • 需要全部读进内存里,故HDF5文件大小不能超过内存,可以分成多个HDF5文件,将HDF5子文件路径写入txt中.
  • I/O速率不如LMDB.

1. LMDB创建

import lmdb
import caffe

lmdb_file = '/path/to/data_lmdb'
N = 1000 
# 准备 data 和 labels
X = np.zeros((N, 3, 224, 224), dtype=np.uint8) # data
y = np.zeros(N, dtype=np.int64) # labels

env = lmdb.open(lmdb_file, map_size=int(1e12))
txn = env.begin(write=True)

for i in range(N):
    datum = caffe.proto.caffe_pb2.Datum()

    datum.channels = X.shape[1]
    datum.height = X.shape[2]
    datum.width = X.shape[3]
    datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9
    datum.label = int(y[i])
    # 以上五行也可以直接: datum = caffe.io.array_to_datum(data, label)
    str_id = '{:08}'.format(i)
    txn.put(str_id, datum.SerializeToString())
    
    # in Python3
    # txn.put(str_id.encode('ascii'), datum.SerializeToString())

2. LMDB读取

import numpy as np
import lmdb
import caffe

env = lmdb.open('data_lmdb', readonly=True)
txn = env.begin()
lmdb_cursor = txn.cursor()
datum = caffe.proto.caffe_pb2.Datum()

for key, value in lmdb_cursor:
    print '{},{}'.format(key, value)
    datum.ParseFromString(value)

    flat_data = np.fromstring(datum.data, dtype=np.uint8)
    data = flat_data.reshape(datum.channels, datum.height, datum.width)
    # 或 data = caffe.io.datum_to_array(datum)
    labels = datum.label

3. HDF5创建和读取

W1:

import h5py 
import numpy as np  

# 创建HDF5文件  
imgsData = np.zeros((10,3,224,224)) # Images
labels = range(10)                 # Labels
f = h5py.File('HDF5_FILE.h5','w')  # 创建一个h5文件 
f['datas'] = imgsData                # 写入Images数据 
f['labels'] = labels               # 写入Labels数据 
f.close()                          #  

# 读取HDF5文件  
f = h5py.File('HDF5_FILE.h5','r')   # 打开h5文件  
f_keys = f.keys()                   
imgsData = f['datas'][:] 
labels = f['labels'][:] 
f.close()  

W2:

import h5py
datas = np.random.rand(100, 1000, 1000).astype('float32')
labels = np.random.rand(1, 1000, 1000).astype('float32')
 
# Create a new file
f = h5py.File('data.h5', 'w')
f.create_dataset('datas', data=datas)
f.create_dataset('labels', data=labels)
f.close()
 
# Load hdf5 dataset
f = h5py.File('data.h5', 'r')
X = f['datas']
Y = f['labels']
f.close()

4. LMDB 数据集创建

"""
a modified version of CRNN torch repository 
https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py 
"""

import fire
import os
import lmdb
import cv2
import numpy as np

def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True

def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k, v)

def createDataset(inputPath, gtFile, outputPath, checkValid=True):
    """
    Create LMDB dataset for training and evaluation.
    ARGS:
        inputPath  : input folder path where starts imagePath
        outputPath : LMDB output path
        gtFile     : list of image path and label
        checkValid : if true, check the validity of every image
    """
    os.makedirs(outputPath, exist_ok=True)
    env = lmdb.open(outputPath, map_size=1099511627776)
    cache = {}
    cnt = 1

    with open(gtFile, 'r', encoding='utf-8') as data:
        datalist = data.readlines()

    nSamples = len(datalist)
    for i in range(nSamples):
        imagePath, label = datalist[i].strip('\n').split('\t')
        imagePath = os.path.join(inputPath, imagePath)

        # # only use alphanumeric data
        # if re.search('[^a-zA-Z0-9]', label):
        #     continue

        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        if checkValid:
            try:
                if not checkImageIsValid(imageBin):
                    print('%s is not a valid image' % imagePath)
                    continue
            except:
                print('error occured', i)
                with open(outputPath + '/error_image_log.txt', 'a') as log:
                    log.write('%s-th image data occured error\n' % str(i))
                continue

        imageKey = 'image-%09d'.encode() % cnt
        labelKey = 'label-%09d'.encode() % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()

        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt-1
    cache['num-samples'.encode()] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


if __name__ == '__main__':
    fire.Fire(createDataset)

5. LMDB 数据集读取

import lmdb
import numpy as np
import cv2

lmdb_file = "/path/to/lmdb"
lmdb_env = lmdb.open(lmdb_file)
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()

for key, value in lmdb_cursor:
    print('[INFO]', key)
    img = cv2.imdecode(np.fromstring(value, np.uint8), 3);
    cv2.imshow("demo", img)
    cv2.waitKey(0)
Last modification:June 3rd, 2021 at 09:41 am