AIHGF

Github项目 - Mask R-CNN 的 Keras 实现之Demo
Github项目 - Mask R-CNN 的 Keras 实现 的 Demo 测试图片.这里的测试环境与 Git...
扫描右侧二维码阅读全文
31
2018/05

Github项目 - Mask R-CNN 的 Keras 实现之Demo

Github项目 - Mask R-CNN 的 Keras 实现 的 Demo 测试图片.

这里的测试环境与 Github 中不太一致, 但测试没出现什么问题.
修改了原来的 demo.ipynb.

环境:
- Ubuntu 14.04
- Python2.7.6
- Tensorflow=1.4.0
- Keras==2.1.3

Demo.py

# --*-- coding: utf-8 --*--

############################################################################
# Mask R-CNN Demo
# 利用 COCO 数据集得到的预训练模型测试
############################################################################

import os
import sys
import random
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2


# 项目根目录
ROOT_DIR = os.path.abspath("../")

# Import Mask RCNN 模块
sys.path.append(ROOT_DIR)
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
# Import COCO config
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/")) 
import coco


# 模型和 logs 日志保存路径
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

# 训练模型的权重文件
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# 如果没有找到 COCO 权重文件, 则自动下载
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

# 测试图片所在路径
# IMAGE_DIR = os.path.join(ROOT_DIR, "images")
IMAGE_DIR = '/home/sh/Pictures/multi'

class InferenceConfig(coco.CocoConfig):
    # batchsize=1, 每次只测试单张图片
    # Batch size = GPU_COUNT * IMAGES_PER_GPU
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

config = InferenceConfig()
config.display()


DEVICE = "/gpu:0"  # /cpu:0 or /gpu:0

# 创建 inference 模式的模型对象
with tf.device(DEVICE):
    model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
                              config=config)

# 加载权重
print("Loading weights ", COCO_MODEL_PATH)
model.load_weights(COCO_MODEL_PATH, by_name=True)


# COCO Class names
class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
               'bus', 'train', 'truck', 'boat', 'traffic light',
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
               'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
               'kite', 'baseball bat', 'baseball glove', 'skateboard',
               'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
               'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
               'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
               'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
               'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
               'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
               'teddy bear', 'hair drier', 'toothbrush']



# 随机加载一张图片
file_names = next(os.walk(IMAGE_DIR))[2]
image = cv2.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))
image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_CUBIC) ##
image = image[:,:,::-1]

# Run detection
results = model.detect([image], verbose=1)

# 可视化结果
r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
                            class_names, r['scores'], title="Predictions")


print('Done')



Last modification:October 9th, 2018 at 09:31 am

Leave a Comment