Tensorflow - 语义分割 Deeplab API 之 Demo

DeepLab: Deep Labelling for Semantic Image Segmentation 语义分割的目标是对输入图片的每个像素分配特定的类别标签, 如 person, cat 等等.

Tensorflow 语义分割 DeepLab API

Tensorflow DeepLab ModelZoo

语义分割 DeepLab 的系列论文:

[1] - DeepLabv1- Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs-ICLR2015
采用 atrous conv 显示地控制 CNN 计算的 feature maps 的分别率.

[2] - DeepLabv2 - DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs-TPAMI2017
采用 atrous spatial pyramid pooling (ASPP) 在多种尺度分割目标物体,ASPP 采用多种采样率(sampling rates) 和有效的 fields-of-views 来捕捉多尺度信息.

[3] - DeepLabv3 - Rethinking Atrous Convolution for Semantic Image Segmentation-2017
在 image-level 特征扩展了 ASPP 模块,以捕获较长范围内的信息. 采用 batch normalization 来加速讯息.
在训练和评测过程中, 采用 atrous conv 以不同输出步长来提取输出特征,. 训练时 output_stride=16 训练BN,测试时 output_stride=8 得到较高的表现效果.

[4] - DeepLabv3++ - Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation-2018
扩展 DeepLabV3,加入了简单有效的解码模块decoder module 来精调分割结果,尤其时沿着物体边缘的分割效果. 而且,在编码-解码encoder-deconder 结构中,通过 atrous conv 可以任意的控制编码特征的分辨率,以平衡精度和运行时间. TensorFlow 语义分割 API 相关的论文 - 当前实现采用的

Backbone 网络:

[1] - MobileNetv2 - Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation-CVPR2018
用于移动设备的快速网络结构

[2] - Xception: Deep Learning with Depthwise Separable Convolutions-CVPR2017
用于服务器端部署的强大网络结构

[3] - Deformable Convolutional Networks -- COCO Detection and Segmentation Challenge 2017 Entry

TensorFlow DeepLab API 提供了 IoU 精度评测和分割结果的可视化.
代码中以 PASCAL VOC 2012Cityscapes benchmarks 为例.

1. DeepLab API 安装

1.1 依赖项

  • Numpy
  • Pillow 1.0
  • tf Slim (位于 "tensorflow/models/research/" 路径)
  • Jupyter notebook
  • Matplotlib
  • Tensorflow
# For CPU
pip install tensorflow
# For GPU
pip install tensorflow-gpu

sudo apt-get install python-pil python-numpy
sudo pip install jupyter
sudo pip install matplotlib

1.2. 环境变量设置

添加 tensorflow/models/research/slim 路径到 PYTHONPATH:

# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

1.3. 测试 API 安装

在目录 tensorflow/models/research/ 中运行:

#From tensorflow/models/research/
python deeplab/model_test.py #快速测试

输出:

....
2018-06-06 15:58:06.598998: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1053] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 1821 MB memory) -> physical GPU (device: 0, name: GeForce GTX 980 Ti, pci bus id: 0000:01:00.0, compute capability: 5.2)
....
----------------------------------------------------------------------
Ran 5 tests in 15.748s

OK

也可以运行在  PASCAL VOC 2012 数据集上的完整代码.

# From tensorflow/models/research/deeplab
sh local_test.sh

2. Tensorflow ModelZoo

Tensorflow DeepLab ModelZoo

Tensorflow 提供了在 PASCAL VOC 2012, Cityscapes 和 ADE20K 数据集上的预训练模型.

2.1. DeepLab models trained on PASCAL VOC 2012

2.1.1. 模型描述

Checkpoint nameNetwork backbonePretrained datasetASPPDecoder
mobilenetv2_dm05_coco_voc_trainaugMobileNet-v2 Depth-Multiplier = 0.5MS-COCO VOC 2012 train_aug setN/AN/A
mobilenetv2_dm05_coco_voc_trainvalMobileNet-v2 Depth-Multiplier = 0.5MS-COCO VOC 2012 train_aug + trainval setsN/AN/A
mobilenetv2_coco_voc_trainaugMobileNet-v2MS-COCO VOC 2012 train_aug setN/AN/A
mobilenetv2_coco_voc_trainvalMobileNet-v2MS-COCO VOC 2012 train_aug + trainval setsN/AN/A
xception65_coco_voc_trainaugXception_65MS-COCO VOC 2012 train_aug set[6,12,18] for OS=16 [12,24,36] for OS=8OS = 4
xception65_coco_voc_trainvalXception_65MS-COCO VOC 2012 train_aug + trainval sets[6,12,18] for OS=16 [12,24,36] for OS=8OS = 4

其中,OS 表示输出步长(output stride).

2.1.2. 模型下载

Checkpoint nameEval OSEval scalesLeft-right FlipMultiply-AddsRuntime (sec)PASCAL mIOUFile Size
mobilenetv2_dm05_coco_voc_trainaug16[1.0]No0.88B-70.19% (val)7.6MB
mobilenetv2_dm05_coco_voc_trainval8[1.0]No2.84B-71.83% (test)7.6MB
mobilenetv2_coco_voc_trainaug16 8[1.0] [0.5:0.25:1.75]No Yes2.75B 152.59B0.1 26.975.32% (val) 77.33 (val)23MB
mobilenetv2_coco_voc_trainval8[0.5:0.25:1.75]Yes152.59B26.980.25% (test)23MB
xception65_coco_voc_trainaug16 8[1.0] [0.5:0.25:1.75]No Yes54.17B 3055.35B0.7 223.282.20% (val) 83.58% (val)439MB
xception65_coco_voc_trainval8[0.5:0.25:1.75]Yes3055.35B223.287.80% (test)439MB

2.1.3. 模型压缩包说明

下载的每个 .tar 压缩包里包含如下文件:

[1] - frozen_inference_graph.pb.

All frozen inference graphs use output stride of 8 and a single eval scale of 1.0. No left-right flips are used, and MobileNet-v2 based models do not include the decoder module.

[2] - checkpoint 文件:model.ckpt.data-00000-of-00001, model.ckpt.index.

2.2. DeepLab models trained on Cityscapes

2.2.1. 模型描述

Checkpoint nameNetwork backbonePretrained datasetASPPDecoder
mobilenetv2_coco_cityscapes_trainfineMobileNet-v2MS-COCO Cityscapes train_fine setN/AN/A
xception65_cityscapes_trainfineXception_65ImageNet Cityscapes train_fine set[6, 12, 18] for OS=16 [12, 24, 36] for OS=8OS = 4
xception71_dpc_cityscapes_trainfineXception_71ImageNet MS-COCO Cityscapes train_fine setDense Prediction CellOS = 4
xception71_dpc_cityscapes_trainvalXception_71ImageNet MS-COCO Cityscapes trainval_fine and coarse setDense Prediction CellOS = 4

2.2.2. 模型下载

Checkpoint nameEval OSEval scalesLeft-right FlipMultiply-AddsRuntime (sec)Cityscapes mIOUFile Size
mobilenetv2_coco_cityscapes_trainfine16 8[1.0] [0.75:0.25:1.25]No Yes21.27B 433.24B0.8 51.1270.71% (val) 73.57% (val)23MB
xception65_cityscapes_trainfine16 8[1.0] [0.75:0.25:1.25]No Yes418.64B 8677.92B5.0 422.878.79% (val) 80.42% (val)439MB
xception71_dpc_cityscapes_trainfine16[1.0]No502.07B-80.31% (val)445MB
xception71_dpc_cityscapes_trainval8[0.75:0.25:2]Yes--82.66% (test)446MB

2.3. DeepLab models trained on ADE20K

2.3.1. 模型描述

Checkpoint nameNetwork backbonePretrained datasetASPPDecoderInput size
mobilenetv2_ade20k_trainMobileNet-v2ImageNet ADE20K training setN/AOS = 4257x257
xception65_ade20k_trainXception_65ImageNet ADE20K training set[6, 12, 18] for OS=16 [12, 24, 36] for OS=8OS = 4513x513

2.3.2. 模型下载

Checkpoint nameEval OSEval scalesLeft-right FlipmIOUPixel-wise AccuracyFile Size
mobilenetv2_ade20k_train16[1.0]No32.04% (val)75.41% (val)24.8MB
xception65_ade20k_train8[0.5:0.25:1.75]Yes45.65% (val)82.52% (val)439MB

2.4. Checkpoints pretrained on ImageNet

模型 checkpoint 文件:model.ckpt.data-00000-of-00001, model.ckpt.index.

Model nameFile Size
xception_41_imagenet288MB
xception_65_imagenet447MB
xception_65_imagenet_coco292MB
xception_71_imagenet474MB
resnet_v1_50_beta_imagenet274MB
resnet_v1_101_beta_imagenet477MB

3. Demo.py

#!--*-- coding:utf-8 --*--

# Deeplab Demo 

import os
import tarfile

from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tempfile
from six.moves import urllib

import tensorflow as tf


class DeepLabModel(object):
    """
    加载 DeepLab 模型;
    推断 Inference.
    """
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 513
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, tarball_path):
        """
        Creates and loads pretrained deeplab model.
        """
        self.graph = tf.Graph()

        graph_def = None
        # Extract frozen graph from tar archive.
        tar_file = tarfile.open(tarball_path)
        for tar_info in tar_file.getmembers():
            if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
                file_handle = tar_file.extractfile(tar_info)
                graph_def = tf.GraphDef.FromString(file_handle.read())
                break

        tar_file.close()

        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')

        self.sess = tf.Session(graph=self.graph)


    def run(self, image):
        """
        Runs inference on a single image.

        Args:
        image: A PIL.Image object, raw input image.

        Returns:
        resized_image: RGB image resized from original input image.
        seg_map: Segmentation map of `resized_image`.
        """
        width, height = image.size
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
        batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
                                      feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
        seg_map = batch_seg_map[0]
        return resized_image, seg_map


def create_pascal_label_colormap():
    """
    Creates a label colormap used in PASCAL VOC segmentation benchmark.

    Returns:
        A Colormap for visualizing segmentation results.
    """
    colormap = np.zeros((256, 3), dtype=int)
    ind = np.arange(256, dtype=int)

    for shift in reversed(range(8)):
        for channel in range(3):
            colormap[:, channel] |= ((ind >> channel) & 1) << shift
        ind >>= 3

    return colormap


def label_to_color_image(label):
    """
    Adds color defined by the dataset colormap to the label.

    Args:
        label: A 2D array with integer type, storing the segmentation label.

    Returns:
        result: A 2D array with floating type. The element of the array
        is the color indexed by the corresponding element in the input label
        to the PASCAL color map.

    Raises:
        ValueError: If label is not of rank 2 or its value is larger than color
        map maximum entry.
    """
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')

    colormap = create_pascal_label_colormap()

    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')

    return colormap[label]


def vis_segmentation(image, seg_map):
    """Visualizes input image, segmentation map and overlay view."""
    plt.figure(figsize=(15, 5))
    grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(image)
    plt.axis('off')
    plt.title('input image')

    plt.subplot(grid_spec[1])
    seg_image = label_to_color_image(seg_map).astype(np.uint8)
    plt.imshow(seg_image)
    plt.axis('off')
    plt.title('segmentation map')

    plt.subplot(grid_spec[2])
    plt.imshow(image)
    plt.imshow(seg_image, alpha=0.7)
    plt.axis('off')
    plt.title('segmentation overlay')

    unique_labels = np.unique(seg_map)
    ax = plt.subplot(grid_spec[3])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0.0)
    plt.grid('off')
    plt.show()


## 
LABEL_NAMES = np.asarray(
    ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
     'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
     'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' ])

FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)


## Tensorflow 提供的模型下载
MODEL_NAME = 'xception_coco_voctrainval'
# ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']

_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
_MODEL_URLS = {'mobilenetv2_coco_voctrainaug': 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
               'mobilenetv2_coco_voctrainval': 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
               'xception_coco_voctrainaug': 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
               'xception_coco_voctrainval': 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', }
# _TARBALL_NAME = 'deeplab_model.tar.gz'

# model_dir = tempfile.mkdtemp()
# tf.gfile.MakeDirs(model_dir)
#
# download_path = os.path.join(model_dir, _TARBALL_NAME)
# print('downloading model, this might take a while...')
# urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path)
# print('download completed! loading DeepLab model...')

model_dir = '/path/to/models_zoo/deeplab'
download_path = os.path.join(model_dir, _MODEL_URLS[MODEL_NAME])
MODEL = DeepLabModel(download_path)
print('model loaded successfully!')


## 
def run_visualization(imagefile):
    """
    DeepLab 语义分割,并可视化结果.
    """
    orignal_im = Image.open(imagefile)
    print('running deeplab on image %s...' % imagefile)
    resized_im, seg_map = MODEL.run(orignal_im)

    vis_segmentation(resized_im, seg_map)

images_dir = '/path/to/images'
images = sorted(os.listdir(images_dir))
for imgfile in images:
    run_visualization(os.path.join(images_dir, imgfile))

print('Done.')

如:

Last modification:May 14th, 2019 at 09:30 am

22 comments

  1. Yuan Linfeng

    大佬,原图边长大于513的时候就出现Invalid argument: padded_shape[1]=xxx is not divisible by block_shape[1]=2(xxx是个随图片大小不同会变化的整数),这个应该怎么办呀。。。。

    1. AIHGF
      @Yuan Linfeng

      这个好像是默认输入要是 513x513 的,增加一步图片预处理到 513 的操作,最后再后处理到原图尺寸.

  2. marc

    请问demo.py如果resize_ratio = 1.0 ,即原图大小的时候会报错,该如何解决?
    报错内容:
    Invalid argument: padded_shape[1]=245 is not divisible by block_shape[1]=2

    1. AIHGF
      @marc

      原图尺寸是多少?

  3. wyz

    您好,请问如何将原图中对应的mask区域分离出来?

    1. AIHGF
      @wyz

      取与操作即可.

      1. wyz
        @AIHGF

        原图与mask图做与操作后,无关区域变成黑色,请问有没有方法可以把这些区域改为白色或者mask区域的主要颜色呢?遍历像素点的话对大图片就会很慢了

        1. AIHGF
          @wyz

          很多分割的开源中都有对不同的分割主体显示不同颜色的可视化 demo,类似于你的需求.

      2. wyz
        @AIHGF

        非常感谢,具体使用我是这样的。
        crop_result = cv2.bitwise_and(oringnal_im, oringnal_im, mask=mask)

        1. 卓一凡
          @wyz

          crop_result = cv2.bitwise_and(oringnal_im, oringnal_im, mask=mask)这个代码当中为什么会有两个 oringnal_im,网上很多人都是写两个参数,而且不一样,而且图片经过deeplab之后是numpy.ndarray格式的,还是单通道,您是怎么把图像抠出来的?

  4. zy

    您好,请问如何只保存中间的分割图?

    1. AIHGF
      @zy

      输入中间分割图是指图片对应的 mask 图吗?还是网络中间层特征?

      1. zy
        @AIHGF

        是指图片对应的 mask 图

        1. AIHGF
          @zy

          seg_map 就是输出的分割 mask,resized_im, seg_map = MODEL.run(orignal_im)

          1. zy
            @AIHGF

            好的,多谢

  5. kam

    您好!想请问一下我在用jupyter运行这个demo时,会报FileNotFoundError: [Errno 2] No such file or directory: '/path/to/models_zoo/deeplab\deeplabv3_pascal_trainval_2018_01_04.tar.gz'
    的错误

    1. AIHGF
      @kam

      是模型没有下载下来吧

    2. kam
      @kam

      有理解错误还想请您指点,谢谢!

    3. kam
      @kam

      请问我是需要把TFModelZoo下载下来放在本地,还是通过url直接使用就可以?

      1. AIHGF
        @kam

        根据给定的 url 可以自动下载的

  6. dennis

    您好,想问一下cityscapes的demo该如何实现,我把路径和对应的类名替换后一直有二进制的报错,希望得到您的指导!谢谢

    1. AIHGF
      @dennis

      二进制的报错?具体错误是什么?

Leave a Comment