Caffe 在图像分类模型的训练时, 效率起见, 未直接从图片列表读取图片, 训练数据往往是采用 LMDB 或 HDF5 格式.
LMDB格式的优点:
- 基于文件映射IO(memory-mapped),数据速率更好
 - 对大规模数据集更有效.
 
HDF5的特点:
- 易于读取
 - 类似于mat数据,但数据压缩性能更强
 - 需要全部读进内存里,故HDF5文件大小不能超过内存,可以分成多个HDF5文件,将HDF5子文件路径写入txt中.
 - I/O速率不如LMDB.
 
1. LMDB创建
import lmdb
import caffe
lmdb_file = '/path/to/data_lmdb'
N = 1000 
# 准备 data 和 labels
X = np.zeros((N, 3, 224, 224), dtype=np.uint8) # data
y = np.zeros(N, dtype=np.int64) # labels
env = lmdb.open(lmdb_file, map_size=int(1e12))
txn = env.begin(write=True)
for i in range(N):
    datum = caffe.proto.caffe_pb2.Datum()
    datum.channels = X.shape[1]
    datum.height = X.shape[2]
    datum.width = X.shape[3]
    datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9
    datum.label = int(y[i])
    # 以上五行也可以直接: datum = caffe.io.array_to_datum(data, label)
    str_id = '{:08}'.format(i)
    txn.put(str_id, datum.SerializeToString())
    
    # in Python3
    # txn.put(str_id.encode('ascii'), datum.SerializeToString())2. LMDB读取
import numpy as np
import lmdb
import caffe
env = lmdb.open('data_lmdb', readonly=True)
txn = env.begin()
lmdb_cursor = txn.cursor()
datum = caffe.proto.caffe_pb2.Datum()
for key, value in lmdb_cursor:
    print '{},{}'.format(key, value)
    datum.ParseFromString(value)
    flat_data = np.fromstring(datum.data, dtype=np.uint8)
    data = flat_data.reshape(datum.channels, datum.height, datum.width)
    # 或 data = caffe.io.datum_to_array(datum)
    labels = datum.label3. HDF5创建和读取
W1:
import h5py 
import numpy as np  
# 创建HDF5文件  
imgsData = np.zeros((10,3,224,224)) # Images
labels = range(10)                 # Labels
f = h5py.File('HDF5_FILE.h5','w')  # 创建一个h5文件 
f['datas'] = imgsData                # 写入Images数据 
f['labels'] = labels               # 写入Labels数据 
f.close()                          #  
# 读取HDF5文件  
f = h5py.File('HDF5_FILE.h5','r')   # 打开h5文件  
f_keys = f.keys()                   
imgsData = f['datas'][:] 
labels = f['labels'][:] 
f.close()  W2:
import h5py
datas = np.random.rand(100, 1000, 1000).astype('float32')
labels = np.random.rand(1, 1000, 1000).astype('float32')
 
# Create a new file
f = h5py.File('data.h5', 'w')
f.create_dataset('datas', data=datas)
f.create_dataset('labels', data=labels)
f.close()
 
# Load hdf5 dataset
f = h5py.File('data.h5', 'r')
X = f['datas']
Y = f['labels']
f.close()4. LMDB 数据集创建
"""
a modified version of CRNN torch repository 
https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py 
"""
import fire
import os
import lmdb
import cv2
import numpy as np
def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True
def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k, v)
def createDataset(inputPath, gtFile, outputPath, checkValid=True):
    """
    Create LMDB dataset for training and evaluation.
    ARGS:
        inputPath  : input folder path where starts imagePath
        outputPath : LMDB output path
        gtFile     : list of image path and label
        checkValid : if true, check the validity of every image
    """
    os.makedirs(outputPath, exist_ok=True)
    env = lmdb.open(outputPath, map_size=1099511627776)
    cache = {}
    cnt = 1
    with open(gtFile, 'r', encoding='utf-8') as data:
        datalist = data.readlines()
    nSamples = len(datalist)
    for i in range(nSamples):
        imagePath, label = datalist[i].strip('\n').split('\t')
        imagePath = os.path.join(inputPath, imagePath)
        # # only use alphanumeric data
        # if re.search('[^a-zA-Z0-9]', label):
        #     continue
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        if checkValid:
            try:
                if not checkImageIsValid(imageBin):
                    print('%s is not a valid image' % imagePath)
                    continue
            except:
                print('error occured', i)
                with open(outputPath + '/error_image_log.txt', 'a') as log:
                    log.write('%s-th image data occured error\n' % str(i))
                continue
        imageKey = 'image-%09d'.encode() % cnt
        labelKey = 'label-%09d'.encode() % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt-1
    cache['num-samples'.encode()] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)
if __name__ == '__main__':
    fire.Fire(createDataset)5. LMDB 数据集读取
import lmdb
import numpy as np
import cv2
lmdb_file = "/path/to/lmdb"
lmdb_env = lmdb.open(lmdb_file)
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
    print('[INFO]', key)
    img = cv2.imdecode(np.fromstring(value, np.uint8), 3);
    cv2.imshow("demo", img)
    cv2.waitKey(0)
One comment
[...]后端前端人工智能DevOps移动端测试程序人生 Search 人工智能Python主要数据读写方式 2019年10月20日 Leave a Commentpython常用的训练数据的格式读写汇总HDF5HDF5的特点:易于读取类似于mat数据,但数据压缩性能更强需要全部读进内存里,故HDF5文件大小不能超过内存,可以分成多个HDF5文件,将HDF5子文件路径写入txt中.I/O速率不如L[...]