TensorFlow 提供了多种图片数据读取的方法 - TensroFlow - 数据读取[转] 基于 TFRecord Flowers 数据集 fine-tune InceptionV1 模型.

TensorFlow 数据数据读取方法主要有:

  • 直接从磁盘读取 - 在 train_op 运行会话Session 时,采用 feed_dict 传递数据. 但是对于大规模数据集可能不太适用,因为需要足够的 GPU 内存来存储训练数据.
  • 从 CSV 文件读取 - 不适用于图片.
  • 从 TFRecord 文件读取 - 将图片转化为 TensorFlow 可读取格式的 TFRecord 文件,在训练是不用再读取原始图像文件,具有更高的读取效率. 这里主要基于 TFRecord 处理大规模数据集.

虽然 TFRecord 文件的创建不如从 HDF5 格式读取数据的方式(如 Keras 采用的)直接,但这种方式更便于采用数据管道工具(data pipeline tools) 进行图片训练,比如 queue runners,coordinaors 和 supervisors,有益于训练数据流的管理.

TensorFlow 提供了 TF-Slim 封装用于 TFRecord 文件的创建与读取 - slim/datasets.

<h2>1. 创建 TFRecord Flowers 数据集</h2>

Flowers 数据集下载 - Flowers Dataset

解压后的目录结构为:

flowers_photos/
|----daisy
| -------- *.jpg (633 张)
|----dandelion
| -------- *.jpg (898 张)
|----roses
| -------- *.jpg (641 张)
|----sunflowers
| -------- *.jpg (699 张)
|----tulips
| -------- *.jpg (799 张)

TensorFlow 提供了将 Flowers 数据集转换为 TFRecord 的脚本:

"""
用于数据集下载和转换.
"""
from future import absolute_import
from future import division
from future import print_function

import os
import sys
import tarfile

from six.moves import urllib
import tensorflow as tf

LABELS_FILENAME = 'labels.txt'

def int64_feature(values):
  """Returns a TF-Feature of int64s.
  Args:
    values: A scalar or list of values.
  Returns:
    A TF-Feature.
  """
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def bytes_feature(values):
  """Returns a TF-Feature of bytes.
  Args:
    values: A string.
  Returns:
    A TF-Feature.
  """
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def float_feature(values):
  """Returns a TF-Feature of floats.
  Args:
    values: A scalar of list of values.
  Returns:
    A TF-Feature.
  """
  if not isinstance(values, (tuple, list)):
    values = [values]
  return tf.train.Feature(float_list=tf.train.FloatList(value=values))


def image_to_tfexample(image_data, image_format, height, width, class_id):
  return tf.train.Example(features=tf.train.Features(feature={
      'image/encoded': bytes_feature(image_data),
      'image/format': bytes_feature(image_format),
      'image/class/label': int64_feature(class_id),
      'image/height': int64_feature(height),
      'image/width': int64_feature(width),
  }))


def download_and_uncompress_tarball(tarball_url, dataset_dir):
  """Downloads the tarball_url and uncompresses it locally.
  Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
  """
  filename = tarball_url.split('/')[-1]
  filepath = os.path.join(dataset_dir, filename)

  def _progress(count, block_size, total_size):
    sys.stdout.write('r>> Downloading %s %.1f%%' % (
        filename, float(count  block_size) / float(total_size)  100.0))
    sys.stdout.flush()
  filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
  print()
  statinfo = os.stat(filepath)
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dataset_dir)


def write_label_file(labels_to_class_names, dataset_dir,
                     filename=LABELS_FILENAME):
  """
  Writes a file with the list of class names.

  Args:
    labels_to_class_names: A map of (integer) labels to class names.
    dataset_dir: The directory in which the labels file should be written.
    filename: The filename where the class names are written.
  """
  labels_filename = os.path.join(dataset_dir, filename)
  with tf.gfile.Open(labels_filename, 'w') as f:
    for label in labels_to_class_names:
      class_name = labels_to_class_names[label]
      f.write('%d:%sn' % (label, class_name))


def has_labels(dataset_dir, filename=LABELS_FILENAME):
  """Specifies whether or not the dataset directory contains a label map file.
  Args:
    dataset_dir: The directory in which the labels file is found.
    filename: The filename where the class names are written.
  Returns:
    True if the labels file exists and False otherwise.
  """
  return tf.gfile.Exists(os.path.join(dataset_dir, filename))


def read_label_file(dataset_dir, filename=LABELS_FILENAME):
  """Reads the labels file and returns a mapping from ID to class name.
  Args:
    dataset_dir: The directory in which the labels file is found.
    filename: The filename where the class names are written.
  Returns:
    A map from a label (integer) to class name.
  """
  labels_filename = os.path.join(dataset_dir, filename)
  with tf.gfile.Open(labels_filename, 'rb') as f:
    lines = f.read().decode()
  lines = lines.split('n')
  lines = filter(None, lines)

  labels_to_class_names = {}
  for line in lines:
    index = line.index(':')
    labels_to_class_names[int(line[:index])] = line[index+1:]
  return labels_to_class_names
"""
Flowers 数据集下载和转化为TFRecords 格式(TF-Example protos).

Flowers 数据集的下载,解压,读取数据,创建两个 TFRecord 数据集:训练数据集和测试数据集.
每个数据集是由 TF-Example protocol buffers 构成,每个 TF-Example protocol buffer 包含一张图片和对应的标签.

该脚本大概需要耗时一分钟.
"""

from future import absolute_import
from future import division
from future import print_function

import math
import os
import random
import sys

import tensorflow as tf

from datasets import dataset_utils

# Flowers 数据集的 URL.
_DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

# 验证数据集的图片数.
_NUM_VALIDATION = 350

# Seed for repeatability.
_RANDOM_SEED = 0

# The number of shards per dataset split.
_NUM_SHARDS = 5


class ImageReader(object):
  """
  用于 TensorFlow 图片编码的辅助类
  """

  def __init__(self):
    # 初始化解码decode RGB JPEG 格式数据的函数.
    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

  def read_image_dims(self, sess, image_data):
    image = self.decode_jpeg(sess, image_data)
    return image.shape[0], image.shape[1]

  def decode_jpeg(self, sess, image_data):
    image = sess.run(self._decode_jpeg,
                     feed_dict={self._decode_jpeg_data: image_data})
    assert len(image.shape) == 3
    assert image.shape[2] == 3
    return image


def _get_filenames_and_classes(dataset_dir):
  """
  返回文件名和类别名列表.

  Args:
    dataset_dir: 包含多个图片子路径的路径.
    class names. 每个图片子路径包含 PNG 或 JPG 编码的图片.
  Returns:
    图片文件列表,相对于 dataset_dir;
    图片子路经列表,表示类比名字.
  """
  flower_root = os.path.join(dataset_dir, 'flower_photos')
  directories = []
  class_names = []
  for filename in os.listdir(flower_root):
    path = os.path.join(flower_root, filename)
    if os.path.isdir(path):
      directories.append(path)
      class_names.append(filename)

  photo_filenames = []
  for directory in directories:
    for filename in os.listdir(directory):
      path = os.path.join(directory, filename)
      photo_filenames.append(path)

  return photo_filenames, sorted(class_names)


def _get_dataset_filename(dataset_dir, split_name, shard_id):
  output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % (
      split_name, shard_id, _NUM_SHARDS)
  return os.path.join(dataset_dir, output_filename)


def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
  """
  将给定文件名转换为 TFRecord 格式数据集.

  Args:
    split_name: 数据集的名字,train 或 validation.
    filenames: png 或 jpg 图片的绝对路径列表.
    class_names_to_ids: 类别名字(字符串strings) 到类别 ids(整数integers ) 映射的字典.
    dataset_dir: 转换后的 TFRecord 数据集所保存的路径.
  """
  assert split_name in ['train', 'validation']

  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))

  with tf.Graph().as_default():
    image_reader = ImageReader()

    with tf.Session('') as sess:

      for shard_id in range(_NUM_SHARDS):
        output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)

        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
          for i in range(start_ndx, end_ndx):
            sys.stdout.write('r>> Converting image %d/%d shard %d' % (
                i+1, len(filenames), shard_id))
            sys.stdout.flush()

            # 读取文件名数据:
            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
            height, width = image_reader.read_image_dims(sess, image_data)

            class_name = os.path.basename(os.path.dirname(filenames[i]))
            class_id = class_names_to_ids[class_name]

            example = dataset_utils.image_to_tfexample(
                image_data, b'jpg', height, width, class_id)
            tfrecord_writer.write(example.SerializeToString())

  sys.stdout.write('n')
  sys.stdout.flush()


def _clean_up_temporary_files(dataset_dir):
  """
  删除创建数据集时产生的临时文件.

  Args:
    dataset_dir: 临时文件的路径.
  """
  filename = _DATA_URL.split('/')[-1]
  filepath = os.path.join(dataset_dir, filename)
  tf.gfile.Remove(filepath)

  tmp_dir = os.path.join(dataset_dir, 'flower_photos')
  tf.gfile.DeleteRecursively(tmp_dir)


def _dataset_exists(dataset_dir):
  for split_name in ['train', 'validation']:
    for shard_id in range(_NUM_SHARDS):
      output_filename = _get_dataset_filename(
          dataset_dir, split_name, shard_id)
      if not tf.gfile.Exists(output_filename):
        return False
  return True


def run(dataset_dir):
  """
  运行数据集下载和转换.

  Args:
    dataset_dir: 数据集所在的路径.
  """
  if not tf.gfile.Exists(dataset_dir):
    tf.gfile.MakeDirs(dataset_dir)

  if _dataset_exists(dataset_dir):
    print('Dataset files already exist. Exiting without re-creating them.')
    return

  # 如果已经下载解压过 Flowers 数据集,可以跳过此步.
  # dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)

  photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
  class_names_to_ids = dict(zip(class_names, range(len(class_names))))

  # 数据集分为:train 和 test:
  random.seed(_RANDOM_SEED)
  random.shuffle(photo_filenames)
  training_filenames = photo_filenames[_NUM_VALIDATION:]
  validation_filenames = photo_filenames[:_NUM_VALIDATION]

  # 首先, 分别转换 training 和 validation 数据集.
  _convert_dataset('train', training_filenames, class_names_to_ids, dataset_dir)
  _convert_dataset('validation', validation_filenames, class_names_to_ids, dataset_dir)

  # 最后, 写入标签label 文件:
  labels_to_class_names = dict(zip(range(len(class_names)), class_names))
  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

 # 会自动删除 flower_photos.tgz 和 flower_photo 文件夹.
    _clean_up_temporary_files(dataset_dir)
  print('nFinished converting the Flowers dataset!')


if name == '__main__':
    dataset_dir = '/path/to/flower_photos/'
    run(dataset_dir)
    print('Done.')

输出类似于如下:

<h2>2. Fine-tune InceptionV1 模型</h2>

TensorFlow - TF-Slim 使用总览 有相关介绍.
#!/usr/bin/python
# -- coding: utf-8 --
"""
Provides data for the flowers dataset.
"""

from future import absolute_import
from future import division
from future import print_function


import os
import tensorflow as tf
slim = tf.contrib.slim

import dataset_utils


_FILE_PATTERN = 'flowers_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}

_NUM_CLASSES = 5

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and 4',
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  """
  获取数据集元组,以读取 flowers 数据.
  Gets a dataset tuple with instructions for reading flowers.

  Args:
    split_name: A train/validation split name.
    dataset_dir: 数据集路径.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

  Returns:
    A Dataset namedtuple.

  Raises:
    ValueError: if split_name is not a valid train/validation split.
  """
  if split_name not in SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  if dataset_utils.has_labels(dataset_dir):
    labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)
  • train.py
#!/usr/bin/python
# -- coding: utf-8 --
import os

import flowers
from nets import inception
from preprocessing import inception_preprocessing

import tensorflow as tf
import tensorflow.contrib.slim as slim

image_size = inception.inception_v1.default_image_size


flowers_data_dir = '/path/to/flower/tfrecords'
checkpoints_dir = '/path/to/flower/checkpoints'
train_dir = '/path/to/flower/outputs'


def load_batch(dataset, batch_size=32, height=299, width=299, is_training=False):
    """
    加载单个 bacth 的数据.

    Args:
      dataset: 待加载数据.
      batch_size: batch 内图片数量.
      height: 预处理后的每张图片的 height.
      width: 预处理后的每张图片的 width.
      is_training: 当前数据是否处于 training 还是 evaluating.

    Returns:
      images: [batch_size, height, width, 3] 大小的 Tensor, 预处理后的图片样本.
      images_raw: [batch_size, height, width, 3] 大小的 Tensor, 用于可视化的图片样本.
      labels: [batch_size] 大小的 Tensor, 其值范围为 [0,dataset.num_classes].
    """
    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset, common_queue_capacity=32, common_queue_min=8)
    image_raw, label = data_provider.get(['image', 'label'])

    # Inception 的图片预处理.
    image = inception_preprocessing.preprocess_image(image_raw, height, width, is_training=is_training)

    # 预处理图片的可视化.
    image_raw = tf.expand_dims(image_raw, 0)
    image_raw = tf.image.resize_images(image_raw, [height, width])
    image_raw = tf.squeeze(image_raw)

    # Batch 化.
    images, images_raw, labels = tf.train.batch(
        [image, image_raw, label],batch_size=batch_size,
        num_threads=1, capacity=2 * batch_size)

    return images, images_raw, labels


def get_init_fn():
    """
    训练热身函数.
    权重参数初始化.
    """

    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"]  #原输出层
    # finetune 时更改原输出层,初始化权重时,不更新输出层的权重参数
    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    variables_to_restore = []
    for var in slim.get_model_variables():
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                break
        else:
            variables_to_restore.append(var)

    return slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, 'inception_v1.ckpt'),
        variables_to_restore)



with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = flowers.get_split('train', flowers_data_dir)
    images, _, labels = load_batch(dataset, height=image_size, width=image_size)

    # 模型创建,采用默认的arg scope 配置 batch norm 参数.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True)

    # 设定 loss 函数:
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
    slim.losses.softmax_cross_entropy(logits, one_hot_labels)
    total_loss = slim.losses.get_total_loss()

    # 创建 summaries 以可视化训练过程:
    tf.summary.scalar('losses/Total Loss', total_loss)

    # 设定 optimizer,创建 train op:
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    train_op = slim.learning.create_train_op(total_loss, optimizer)

    # 开始训练:
    final_loss = slim.learning.train(train_op,
                                     logdir=train_dir,
                                     log_every_n_steps=10,
                                     init_fn=get_init_fn(),
                                     number_of_steps=3000,
                                     save_summaries_secs=600,
                                     save_interval_secs=1200)

print('Finished training. Last batch loss %f' % final_loss)

  • test.py
#!/usr/bin/python
# -- coding: utf-8 --
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.contrib import slim

from nets import inception
import flowers
from preprocessing import inception_preprocessing

image_size = inception.inception_v1.default_image_size
batch_size = 30
flowers_data_dir = '/path/to/flower/tfrecords'
train_dir = '/path/to/flower/outputs'

with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = flowers.get_split('validation', flowers_data_dir)
    images, images_raw, labels = load_batch(dataset, height=image_size, width=image_size)

    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True)

    probabilities = tf.nn.softmax(logits)

    checkpoint_path = tf.train.latest_checkpoint(train_dir)
    init_fn = slim.assign_from_checkpoint_fn(checkpoint_path,
                                             slim.get_variables_to_restore())

    with tf.Session() as sess:
        with slim.queues.QueueRunners(sess):
            sess.run(tf.initialize_local_variables())
            init_fn(sess)
            np_probabilities, np_images_raw, np_labels = sess.run([probabilities, images_raw, labels])

            for i in range(batch_size):
                image = np_images_raw[i, :, :, :]
                true_label = np_labels[i]
                predicted_label = np.argmax(np_probabilities[i, :])
                predicted_name = dataset.labels_to_names[predicted_label]
                true_name = dataset.labels_to_names[true_label]

                plt.figure()
                plt.imshow(image.astype(np.uint8))
                plt.title('Ground Truth: [%s], Prediction [%s]' % (true_name, predicted_name))
                plt.axis('off')
                plt.show()

print('Done.')

<h2>3. Related</h2>

[1] - tensorflow之从文件中读取数据(适用场景:大规模数据集,亲测有效~)
[2] - tensorflowxun训练自己的数据集之从tfrecords读取数据
[3] - TensorFlow高效读取数据的方法

Last modification:October 9th, 2018 at 09:31 am