[slim.learning.py]()

TF-Slim 模型训练代码. 包含了模型训练个不同函数,如,

[1] - 梯度控制(manipulation gradientes)

[2] - train_op 创建,计算 loss 和应用梯度的 operation.

[3] - 训练 loop 函数(training loop function).

The training loop allows the user to pass in the train_op and runs the optimization according to user-specified arguments.

1. 模型训练简单流程

# 加载数据/创建模型
images, labels = LoadData(...)
predictions = MyModel(images)

# 定义损失函数loss
slim.losses.log_loss(predictions, labels)
total_loss = slim.losses.get_total_loss()

# 定义优化器optimizer
optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)

# 创建train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)

# 运行训练.
slim.learning.train(train_op, my_log_dir)

1.1 创建 train_op

模型训练时,TF-Slim 的训练循环函数需要定义 train_op.

train_op 的作用:

[1] - 计算 loss.

[2] - 应用梯度,更新权重.

[3] - 返回 loss 值.

slim.learning.create_train 则用于创建 train_op,如:

# 创建 train_op 和梯度裁剪(clip the gradient norms)
train_op = slim.learning.create_train_op(
      total_loss,
      optimizer,
      clip_gradient_norm=4)

# 创建 train_op 和缩放梯度值
# 根据提供的从变量名(variable name 或 variable) 到
# 缩放系数(scaling coefficient) 的映射来进行缩放梯度值.
# scale the gradients by providing a map from variable
# name (or variable) to a scaling coefficient:
gradient_multipliers = {'conv0/weights': 1.2,
                        'fc8/weights': 3.4,
                       }
train_op = slim.learning.create_train_op(
      total_loss,
      optimizer,
      gradient_multipliers=gradient_multipliers)

注: 梯度裁剪:

tf.clip_by_norm 梯度裁剪

梯度裁剪的目的是避免梯度爆炸. 其通过控制梯度的最大范数(norm) 来实现.

tf.clip_by_global_norm.

tf.clip_by_global_norm(
    t_list,
    clip_norm,
    use_norm=None,
    name=None
)

1.2 训练过程其它(非梯度)更新

很多网络中,会利用到 BatchNorm 等模块,其在训练过程中,需要进行一系列的非梯度更新(non-gradient updates).

slim.learning.create_train_op 还支持传递与梯度更新一起的其它 update_ops 列表.

train_op = slim.learning.create_train_op(total_loss, optimizer, update_ops)

slim.learning.create_train_op 默认包含了所有的更新的 ops,其是 tf.GraphLeys.UPDATE_OPS collection 的一部分.

此外,TF-Slim 的 slim.batch_norm 函数还在该 collection 添加了 moving meanmoving variance 更新. 故,如果采用到了 slim.batch_norm 函数则不需任何额外的计算 moving meanmoving variance 更新的处理.

不过,也可以针对 tf.GraphKeys.UPDATE_OPS collection,覆盖重写默认的 update ops 或者新增 update ops.

# 强制 TF-Slim 不采用任何 update_ops:
train_op = slim.learning.create_train_op(
     total_loss,
     optimizer,
     update_ops=[])

# 替换 update ops 集:
train_op = slim.learning.create_train_op(
     total_loss,
     optimizer,
     update_ops=my_other_update_ops)

# 新增 update ops 到默认的 updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update0)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1)

train_op = slim.learning.create_train_op(
     total_loss,
     optimizer)

# 等价形式:
train_op = slim.learning.create_train_op(
     total_loss,
     optimizer,
     update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))

1.3 从断点文件初始化模型

模型训练时,往往需要从与训练的断点模型文件 warm-start 训练.

...
# 创建 train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)

# 创建初始化赋值 op
checkpoint_path = '/path/to/old_model_checkpoint'
variables_to_restore = slim.get_model_variables()
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
    checkpoint_path, variables_to_restore)

# 创建初始化赋值函数
def InitAssignFn(sess):
    sess.run(init_assign_op, init_feed_dict)

# 运行训练
slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)

1.4 从内存变量初始化模型变量

在模型训练中,也可能需要从任意源(如文本文档,matlab 文件等)的值来初始化模型权重. 虽然采用原始 TensorFlow 是技术可行的,但其仍需要权重值是以图(graph) 的形式保存的. 这对于大型模型而言,很可能是很大的文件. TF-Slim 提供了一种无需将初始模型权重值保存为图(graph)的初始化赋值方法,其采用了 placeholders(占位符)feed dictionary

# 创建 train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)

# 创建变量名到值的映射(from variable names to values):
var0_initial_value = ReadFromDisk(...)
var1_initial_value = ReadFromDisk(...)

var_names_to_values = {'var0': var0_initial_value,
                       'var1': var1_initial_value,
}
init_assign_op, init_feed_dict = slim.assign_from_values(var_names_to_values)

# 创建初始化赋值函数
def InitAssignFn(sess):
    sess.run(init_assign_op, init_feed_dict)

# 运行训练
slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)

2. tf.learning.train()

采用 TensorFlow supervisor 运行模型循环的函数:

# 返回:训练后的loss函数值.
_USE_DEFAULT = 0
def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          session_wrapper=None,
          trace_every_n_steps=None,
          ignore_live_threads=False)

[1] - train_op - Tensor,如图像输入Tensor,当执行时,进行梯度计算并返回 loss 值.

[2] - logdir - 训练日志保存路径. 如果是 None,则不写入模型断点(checkpoints) 和概要(summaries).

[3] - train_step_fn - 调用函数,用于执行单次梯度计算. 该函数必须有四个参数:current seesiontrain_op Tensorglobal step Tensora dictionary.

[4] - train_step_kwargs - 字典形式,其传递到 train_step_fn. 默认提供的是两个布尔参数(Boolean), should_stopshould_log 两个标量 ops.

[5] - log_every_n_steps - loss 和 global step 相对于 global steps 的日志保存频率.

[6] - graph - 传递到 supervisor 的图(graph). 如果值为 None,则采用默认图(graph).

[7] - master - tensorflow master 的地址.

[8] - is_chief - 指定在复制训练(replica training)时,训练是否以主要副本(primary replica) 运行.

[9] - global_step - 表示 global step 的 Tensor. 如果值是 None,则采用 training_util.get_or_create_global_step(),即:tf.contrib.framework.global_step().

[10] - number_of_steps - 训练所进行的梯度计算的最大数量. 采用 global_step 来衡量:当 global_step 大于 number_of_steps 时,训练停止. 如果值是None,则训练无限进行.

[11] - init_op - 初始化操作. 如果是默认值,则 session 通过调用 tf.global_variables_initializer() 进行初始化.

[12] - init_feed_dict - 执行 init_op 时所采用的 feed dictionary.

[13] - local_init_op - 局部初始化操作. 如果是默认值,则 session 通过调用tf.local_variables_initializer()tf.tables_initializer() 进行初始化.

[14] - init_fn - 调用 init_op 后,待执行的可调用参数. 该调用必须有一个参数,session 才进行初始化.

[15] - ready_op - 检查模型是否准备好的操作. 如果是默认值,则 session 通过调用 tf.report_uninitialized_variables() 检查模型读取.

[16] - summary_op - summary 操作.

[17] - save_summaries_secs - 每隔多少保存一次 summaries.

[18] - summary_writer - 采用的 SummaryWriter. 值如果是 None,则表示不写任何 summaries. 如果未设置(unset),则创建一个 SummaryWriter.

[19] - startup_delay_steps - 开始训练前所等待的迭代次数. 如果采用了 sync_optimizer 则其值必须是0.

[20] - saver - 保存断点(checkpoints)的 Saver. 如果值时 None,则创建和使用默认的.

[21] - save_interval_secs - 每隔多少保存一次模型到 logdir 路径.

[22] - sync_optimizer - tf.train.SyncReplicasOptimizer 实例,或tf.train.SyncReplicasOptimizer 实例列表. 如果提供了参数,则进行同步(synchronous)的梯度更新. 如果值是 None,则进行异步(asynchronous)的梯度更新.

[23] - session_config - 用于配置 Sessiontf.ConfigProto实例. 如果值是 None,则采用默认设置.

[24] - session_wrapper - 一个函数接口,其采用 tf.Session 对象作为唯一参数,并返回与原始对象具有相同方法的 封装 session 对象;或者返回 None. 如果其值不是 None,则训练采用封装的对象.

[25] - trace_every_n_steps - 产生并保存 Chrome trace 格式的 Timeline,并每 trace_every_n_steps 将其添加到 summaries. 如果值是 None,则不产生和保存任何 trace 信息.

[26] - ignore_live_threads - 如果值是 True,则,当停止 supervisor时,忽略在一个 grace period 周期内剩余的线程,而不是抛出 RuntimeError.

ValueError 出现情况:

train_op 为空;

或者,当提供了 sync_optimizer,而 startup_delay_steps 是非零值;

或者,number_of_steps 是负值;

或者 trace_every_n_steps 不是 None,但未提供 logdir.

_USE_DEFAULT = 0


def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          session_wrapper=None,
          trace_every_n_steps=None,
          ignore_live_threads=False):

  if train_op is None:
    raise ValueError('train_op cannot be None.')

  if logdir is None:
    if summary_op != _USE_DEFAULT:
      raise ValueError('Cannot provide summary_op because logdir=None')
    if saver is not None:
      raise ValueError('Cannot provide saver because logdir=None')
    if trace_every_n_steps is not None:
      raise ValueError('Cannot provide trace_every_n_steps because '
                       'logdir=None')

  if isinstance(sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
    sync_optimizer = [sync_optimizer]
  if sync_optimizer is not None and startup_delay_steps > 0:
    raise ValueError(
        'startup_delay_steps must be zero when sync_optimizer is supplied.')

  if number_of_steps is not None and number_of_steps <= 0:
    raise ValueError(
        '`number_of_steps` must be either None or a positive number.')

  graph = graph or ops.get_default_graph()
  with graph.as_default():
    if global_step is None:
      global_step = training_util.get_or_create_global_step()
    saver = saver or tf_saver.Saver()

    if sync_optimizer is not None:
      for opt in sync_optimizer:
        if not isinstance(opt, sync_replicas_optimizer.SyncReplicasOptimizer):
          raise ValueError(
              '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.')

    with ops.name_scope('init_ops'):
      if init_op == _USE_DEFAULT:
        init_op = variables.global_variables_initializer()

      if ready_op == _USE_DEFAULT:
        ready_op = variables.report_uninitialized_variables()

      if local_init_op == _USE_DEFAULT:
        local_init_op = control_flow_ops.group(
            variables.local_variables_initializer(),
            lookup_ops.tables_initializer())

      if sync_optimizer is not None and isinstance(sync_optimizer, list):
        with ops.control_dependencies([local_init_op] if local_init_op is
                                      not None else []):
          if is_chief:
            local_init_op = control_flow_ops.group(
                *[opt.chief_init_op for opt in sync_optimizer])
          else:
            local_init_op = control_flow_ops.group(
                *[opt.local_step_init_op for opt in sync_optimizer])
        ready_for_local_init_op = control_flow_ops.group(
            *[opt.ready_for_local_init_op for opt in sync_optimizer])
      else:
        ready_for_local_init_op = None

    if summary_op == _USE_DEFAULT:
      summary_op = summary.merge_all()

    if summary_writer == _USE_DEFAULT:
      summary_writer = supervisor.Supervisor.USE_DEFAULT

    if is_chief and sync_optimizer is not None:
      # Need to create these BEFORE the supervisor finalizes the graph:
      init_tokens_op = [opt.get_init_tokens_op() for opt in sync_optimizer]
      chief_queue_runner = [
          opt.get_chief_queue_runner() for opt in sync_optimizer]

    if train_step_kwargs == _USE_DEFAULT:
      with ops.name_scope('train_step'):
        train_step_kwargs = {}

        if number_of_steps:
          should_stop_op = math_ops.greater_equal(global_step, number_of_steps)
        else:
          should_stop_op = constant_op.constant(False)
        train_step_kwargs['should_stop'] = should_stop_op
        if log_every_n_steps > 0:
          train_step_kwargs['should_log'] = math_ops.equal(
              math_ops.mod(global_step, log_every_n_steps), 0)
        if is_chief and trace_every_n_steps is not None:
          train_step_kwargs['should_trace'] = math_ops.equal(
              math_ops.mod(global_step, trace_every_n_steps), 0)
          train_step_kwargs['logdir'] = logdir

  sv = supervisor.Supervisor(
      graph=graph,
      is_chief=is_chief,
      logdir=logdir,
      init_op=init_op,
      init_feed_dict=init_feed_dict,
      local_init_op=local_init_op,
      ready_for_local_init_op=ready_for_local_init_op,
      ready_op=ready_op,
      summary_op=summary_op,
      summary_writer=summary_writer,
      global_step=global_step,
      saver=saver,
      save_summaries_secs=save_summaries_secs,
      save_model_secs=save_interval_secs,
      init_fn=init_fn)

  if summary_writer is not None:
    train_step_kwargs['summary_writer'] = sv.summary_writer

  total_loss = None
  should_retry = True
  while should_retry:
    try:
      should_retry = False
      with sv.managed_session(
          master, start_standard_services=False, config=session_config) as sess:
        logging.info('Starting Session.')
        if session_wrapper is not None:
          logging.info(
              'Wrapping session with wrapper function: %s', session_wrapper)
          sess = session_wrapper(sess)
        if is_chief:
          if logdir:
            sv.start_standard_services(sess)
        elif startup_delay_steps > 0:
           # (use sys.maxsize because sys.maxint doesn't exist in Python 3)
          _wait_for_step(sess, global_step,
                         min(startup_delay_steps, number_of_steps or
                             sys.maxsize))
        threads = sv.start_queue_runners(sess)
        logging.info('Starting Queues.')
        if is_chief and sync_optimizer is not None:
          sv.start_queue_runners(sess, chief_queue_runner)
          sess.run(init_tokens_op)
        try:
          while not sv.should_stop():
            # 训练计算
            total_loss, should_stop = train_step_fn(sess, 
                                                    train_op, 
                                                    global_step, 
                                                    train_step_kwargs)
            if should_stop:
              logging.info('Stopping Training.')
              sv.request_stop()
              break
        except errors.OutOfRangeError as e:
          # OutOfRangeError is thrown when epoch limit per
          # tf.train.limit_epochs is reached.
          logging.info('Caught OutOfRangeError. Stopping Training. %s', e)
        if logdir and sv.is_chief:
          logging.info('Finished training! Saving model to disk.')
          sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
          sv.stop(
              threads,
              close_summary_writer=True,
              ignore_live_threads=ignore_live_threads)

    except errors.AbortedError:
      # Always re-run on AbortedError as it indicates a restart of one of the
      # distributed tensorflow servers.
      logging.info('Retrying training!')
      should_retry = True

  return total_loss

2.1 tf.learning.train_step()

def train_step(sess, train_op, global_step, train_step_kwargs):
  """
  函数用于进行一次梯度计算,指定是否停止训练.

  Args:
    sess: 当前 session.
    train_op: 计算梯度的操作`Operation`,并返回 total loss.
    global_step: 表示 global training step 的 Tensor.
    train_step_kwargs: 关键词参数字典.

  Returns:
    total loss 和是否停止训练的布尔值.

  Raises:
    ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
  """
  start_time = time.time()

  trace_run_options = None
  run_metadata = None
  if 'should_trace' in train_step_kwargs:
    if 'logdir' not in train_step_kwargs:
      raise ValueError('logdir must be present in train_step_kwargs when '
                       'should_trace is present')
    if sess.run(train_step_kwargs['should_trace']):
      trace_run_options = config_pb2.RunOptions(
          trace_level=config_pb2.RunOptions.FULL_TRACE)
      run_metadata = config_pb2.RunMetadata()

  total_loss, np_global_step = sess.run([train_op, global_step],
                                        options=trace_run_options,
                                        run_metadata=run_metadata)
  time_elapsed = time.time() - start_time

  if run_metadata is not None:
    tl = timeline.Timeline(run_metadata.step_stats)
    trace = tl.generate_chrome_trace_format()
    trace_filename = os.path.join(train_step_kwargs['logdir'],
                                  'tf_trace-%d.json' % np_global_step)
    logging.info('Writing trace to %s', trace_filename)
    file_io.write_string_to_file(trace_filename, trace)
    if 'summary_writer' in train_step_kwargs:
      train_step_kwargs['summary_writer'].add_run_metadata(run_metadata,
                                                           'run_metadata-%d' %
                                                           np_global_step)

  if 'should_log' in train_step_kwargs:
    if sess.run(train_step_kwargs['should_log']):
      logging.info('global step %d: loss = %.4f (%.3f sec/step)',
                   np_global_step, total_loss, time_elapsed)
    
  if 'should_stop' in train_step_kwargs:
    should_stop = sess.run(train_step_kwargs['should_stop'])
  else:
    should_stop = False

  return total_loss, should_stop

2.2 tf.learning.create_train_op()

函数用于创建梯度计算和返回loss 的 Operation. 如:

total_loss = slim.losses.get_total_loss()
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
#
train_op = slim.learning.create_train_op(total_loss, optimizer)
def clip_gradient_norms(gradients_to_variables, max_norm):
  """
  根据给定值对梯度裁剪.
  
  Args:
    gradients_to_variables: 梯度和变量对(元祖)列表
    max_norm: 最大范数值.

  Returns:
    A list of clipped gradient to variable pairs.
  """
  clipped_grads_and_vars = []
  for grad, var in gradients_to_variables:
    if grad is not None:
      if isinstance(grad, ops.IndexedSlices):
        tmp = clip_ops.clip_by_norm(grad.values, max_norm)
        grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
      else:
        grad = clip_ops.clip_by_norm(grad, max_norm)
    clipped_grads_and_vars.append((grad, var))
  return clipped_grads_and_vars


def multiply_gradients(grads_and_vars, gradient_multipliers):
  """
  乘以指定梯度.

  Args:
    grads_and_vars: 梯度和变量对(元祖)列表.
    gradient_multipliers: A map from either `Variables` or `Variable` op names
      to the coefficient by which the associated gradient should be scaled.

  Returns:
    The updated list of gradient to variable pairs.

  Raises:
    ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
    is empty or None or if `gradient_multipliers` is not a dictionary.
  """
  if not isinstance(grads_and_vars, list):
    raise ValueError('`grads_and_vars` must be a list.')
  if not gradient_multipliers:
    raise ValueError('`gradient_multipliers` is empty.')
  if not isinstance(gradient_multipliers, dict):
    raise ValueError('`gradient_multipliers` must be a dict.')

  multiplied_grads_and_vars = []
  for grad, var in grads_and_vars:
    if var in gradient_multipliers or var.op.name in gradient_multipliers:
      key = var if var in gradient_multipliers else var.op.name
      if grad is None:
        raise ValueError('Requested multiple of `None` gradient.')

      multiplier = gradient_multipliers[key]
      if not isinstance(multiplier, ops.Tensor):
        multiplier = constant_op.constant(multiplier, dtype=grad.dtype)

      if isinstance(grad, ops.IndexedSlices):
        tmp = grad.values * multiplier
        grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
      else:
        grad *= multiplier
    multiplied_grads_and_vars.append((grad, var))
  return multiplied_grads_and_vars

_USE_GLOBAL_STEP = 0

def create_train_op(total_loss,
                    optimizer,
                    global_step=_USE_GLOBAL_STEP,
                    update_ops=None,
                    variables_to_train=None,
                    clip_gradient_norm=0,
                    summarize_gradients=False,
                    gate_gradients=tf_optimizer.Optimizer.GATE_OP,
                    aggregation_method=None,
                    colocate_gradients_with_ops=False,
                    gradient_multipliers=None,
                    check_numerics=True):

  def transform_grads_fn(grads):
    if gradient_multipliers:
      with ops.name_scope('multiply_grads'):
        grads = multiply_gradients(grads, gradient_multipliers)

    # Clip gradients.
    if clip_gradient_norm > 0:
      with ops.name_scope('clip_grads'):
        grads = clip_gradient_norms(grads, clip_gradient_norm)
    return grads

  return training.create_train_op(
      total_loss=total_loss,
      optimizer=optimizer,
      global_step=global_step,
      update_ops=update_ops,
      variables_to_train=variables_to_train,
      transform_grads_fn=transform_grads_fn,
      summarize_gradients=summarize_gradients,
      gate_gradients=gate_gradients,
      aggregation_method=aggregation_method,
      colocate_gradients_with_ops=colocate_gradients_with_ops,
      check_numerics=check_numerics)

[1] - total_loss - 表示 total loss 的 Tensor .
[2] - optimizer - 用于梯度计算的 tf.Optimizer.
[3] - global_step - 表示 global step 变量的 Tensor. 如果值是默认的 _USE_GLOBAL_STEP, 则采用 tf.contrib.framework.global_step().

[4] - update_ops - 执行的更新参数列表. 如果值是 None,则 update ops 设置为 is tf.GraphKeys.UPDATE_OPS collection 中的内容,并显示一条警告warning.

[5] - variables_to_train - 训练的变量参数列表. 如果值是 None,则默认为所有的 tf.trainable_variables().

[6] - clip_gradient_norm - 如果值大于 0,则梯度被剪枝.
[7] - summarize_gradients - 是否添加每个梯度的 summaries.

[8] - gate_gradients - 如何进行梯度计算. 参考 tf.Optimizer.
[9] - aggregation_method - 指定用于组合梯度项的方法. 可选值定义在 AggregationMethod类中.
[10] - colocate_gradients_with_ops - Whether or not to try colocating the gradients with the ops that generated them.
[11] - gradient_multipliers - A dictionary of either Variables or Variable op names to the coefficient by which the associated gradient should be scaled.

[12] - check_numerics - 是否进行 check_numerics.

Last modification:November 19th, 2018 at 10:20 pm