AIHGF

TensorFlow - TF-Slim 之 checkpoint 恢复模型
TensorFlow 的 Variables 给出了用于从断点模型恢复模型变量的函数,如 get_variable...
扫描右侧二维码阅读全文
06
2018/07

TensorFlow - TF-Slim 之 checkpoint 恢复模型

TensorFlow 的 Variables 给出了用于从断点模型恢复模型变量的函数,如 get_variables_to_restore 函数.
TensorFlow - TF-Slim 之 variables 函数

D1 - 转自 Tensorflow 部分恢复模型

# 创建变量
v1 = slim.variable(name="v1", ...)
v2 = slim.variable(name="nested/v2", ...)
...

# 获取待恢复的变量列表,有以下四种等价方法:
variables_to_restore = slim.get_variables_by_name("v2")  #1. 根据名字获取变量
variables_to_restore = slim.get_variables_by_suffix("2")  #2. 根据后缀获取变量
variables_to_restore = slim.get_variables(scope="nested")  #3. 根据作用域获取变量 
variables_to_restore = slim.get_variables_to_restore(include=["nested"])  #4. 根据 include 正则表达式获取变量
variables_to_restore = slim.get_variables_to_restore(exclude=["v1"]) #4. 根据 exclude 正则表达式获取变量
# 如 vgg/conv6, vgg 都可以作为exclude的参数传入

# 创建 saver,以恢复变量. 模拟断点文件.
restorer = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
    # 从磁盘恢复变量.
    restorer.restore(sess, "/tmp/model.ckpt")
    print("Model restored.")
    # Do some work with the model

D2 - DeepLab 断点初始化

From train_utils.py.

def get_model_init_fn(train_logdir, tf_initial_checkpoint,
                                    initialize_last_layer, last_layers,
                                    ignore_missing_vars=False):
  """
  该函数用于从断点文件初始化模型变量.

  Args:
    train_logdir: Log directory for training.
    tf_initial_checkpoint: TensorFlow checkpoint for initialization.
    initialize_last_layer: Initialize last layer or not.
    last_layers: Last layers of the model. 
    ignore_missing_vars: Ignore missing variables in the checkpoint.

  Returns:
    Initialization function.
  """
  if tf_initial_checkpoint is None:
    tf.logging.info('Not initializing the model from a checkpoint.')
    return None

  if tf.train.latest_checkpoint(train_logdir):
    tf.logging.info('Ignoring initialization; other checkpoint exists')
    return None

  tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

  # Variables that will not be restored.
  exclude_list = ['global_step']
  if not initialize_last_layer:
    exclude_list.extend(last_layers)

  variables_to_restore = slim.get_variables_to_restore(exclude=exclude_list) ## 

  return slim.assign_from_checkpoint_fn(
      tf_initial_checkpoint,
      variables_to_restore,
      ignore_missing_vars=ignore_missing_vars)

From model.py

_LOGITS_SCOPE_NAME = 'logits'
_MERGED_LOGITS_SCOPE = 'merged_logits'
_IMAGE_POOLING_SCOPE = 'image_pooling'
_ASPP_SCOPE = 'aspp'
_CONCAT_PROJECTION_SCOPE = 'concat_projection'
_DECODER_SCOPE = 'decoder'


def get_extra_layer_scopes(last_layers_contain_logits_only=False):
  """Gets the scopes for extra layers.

  Args:
    last_layers_contain_logits_only: Boolean, True if only consider logits as
    the last layer (i.e., exclude ASPP module, decoder module and so on)

  Returns:
    A list of scopes for extra layers.
  """
  if last_layers_contain_logits_only:
    return [_LOGITS_SCOPE_NAME]
  else:
    return [
        _LOGITS_SCOPE_NAME,
        _IMAGE_POOLING_SCOPE,
        _ASPP_SCOPE,
        _CONCAT_PROJECTION_SCOPE,
        _DECODER_SCOPE,
    ]
Last modification:October 9th, 2018 at 09:31 am

Leave a Comment