论文: Stacked Hourglass Networks for Human Pose Estimation Demo Code

论文阅读 - Stacked Hourglass Networks for Human Pose Estimation

pose-hg-demo主要包含文件及文件夹内容:

这里基于Docker、python和pose-hg-demo.

<h2>1. 拉取Torch7镜像</h2>

sudo nvidia-docker pull registry.cn-hangzhou.aliyuncs.com/docker_learning_aliyun/torch:v1

<h2>2. 运行 Demo on MPII Human Pose dataset</h2>

下载MPII Human Pose dataset,并将图片放在 images 文件夹.

sudo nvidia-docker run -it --rm -v /path/to/pose-hg-demo-master:/media registry.cn-hangzhou.aliyuncs.com/docker_learning_aliyun/torch:v1

# 进入Torch镜像
root@8f1548fc3b34:~/torch# 
cd /media  # 即主机中的 pose-hg-demo-master
th main.lua predict-test # 得到人体姿态估计结果,并保存在'preds/test.h5'中

利用下面的python脚本可视化人体姿态结果:

#!/usr/bin/env python
import h5py
import scipy.misc as scm
import matplotlib.pyplot as plt

test_images = open('../annot/test_images.txt','r').readlines()
images_path = './images/'

f = h5py.File('./preds/test.h5','r')
preds = f'preds'
f.close()

assert len(test_images) == len(preds)
for i in range(len(test_images)):
    filename = images_path + test_imagesi
    im = scm.imread(filename)
    pose = preds[i]

    plt.axis('off')
    plt.imshow(im)

    for i in range(16):
        if posei > 0 and posei > 0:
            plt.scatter(posei, posei, marker='o', color='r', s=15)
    plt.show()

print 'Done.'

<h2>3. 自定义图片的人体姿态估计</h2>

由于MPII Human Pose Dataset提供了图片中人体scale和center的标注信息,因此可以直接采用pose-hg-demo提供方式处理:

inputImg = crop(img, center, scale, rot, res)

不过,对于一张或多张图片,未知图片中人体scal和center信息时,需要单独处理,这里,处理思路是: 首先检测人体框(这里未给出实现过程),再采用Python对图片与处理,作为网络输入.

  • Python预处理图片的程序
#!/usr/bin/env python
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import scipy

if name == '__main__':
    orig_img_path = '/orig/images/path/'
    new_img_path = '/new/images/path_256/'

    boxsize = 256
    files = os.listdir(orig_img_path)
    for file in files:
        if file[-4:] == '.jpg':
            orig_img_name = orig_img_path + file
            if(os.path.isfile(orig_img_name)):
                img = cv2.imread(orig_img_name)
                height,width = float(img.shape[0]), float(img.shape[1])
                scale = min(boxsize/height, boxsize/width)

                img_resize = cv2.resize(img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LANCZOS4)
                #plt.imshow(img_resize); plt.show()
                h, w = img_resize.shape[0], img_resize.shape[1]

                pad_up = abs(int((boxsize - h) / 2))  # up
                pad_down = abs(boxsize - h - pad_up)  # down
                pad_left = abs(int((boxsize - w) / 2))  # left
                pad_right = abs(boxsize - w - pad_left)  # right

                pad_img = np.lib.pad(img_resize, ((pad_up, pad_down), (pad_left, pad_right), (0, 0)), 'constant',
                                     constant_values=0)

                new_img_name = new_img_path + file
                cv2.imwrite(new_img_name, pad_img)

    print 'Done.'

<h3>3.1 估计单张图片中人体姿态 - demo.lua</h3>

require 'paths'
paths.dofile('util.lua')
paths.dofile('img.lua')

-- Load pre-trained model
m = torch.load('umich-stacked-hourglass.t7')   

-- Set up input image
local im = image.load('image/' .. arg[1])

-- Get network output
local out = m:forward(im:view(1,3,256,256):cuda())
cutorch.synchronize()
local hms = out#out:float()
hms[hms:lt(0)] = 0
--print(hms:size())

-- Get predictions (hm and img refer to the coordinate space)
if hms:size():size() == 3 then 
    hms = hms:view(1, hms:size(1), hms:size(2), hms:size(3)) 
end

---- Get locations of maximum activations
local max, idx = torch.max(hms:view(hms:size(1), hms:size(2), hms:size(3) * hms:size(4)), 3)
local preds = torch.repeatTensor(idx, 1, 1, 2):float()
preds[{{}, {}, 1}]:apply(function(x) return (x - 1) % hms:size(4) + 1 end)
preds[{{}, {}, 2}]:add(-1):div(hms:size(3)):floor():add(.5)

collectgarbage()

-- Save predictions
local predFile = hdf5.open('preds/pred.h5', 'w')
predFile:write('preds', preds)
predFile:write('img', im)
predFile:close()

<h3>3.2 批量图片中人体姿态估计 - demo_multi.lua</h3>

这个需要在 util.lua 中新增函数loadImageNames

function loadImageNames(fileName)
    a = {}
    -- Load in image file names
    a.images = {}
    local namesFile = io.open(fileName)
    local idxs = 1
    for line in namesFile:lines() do
        print(line)
        a.images[idxs] = line
        idxs = idxs + 1
    end
    namesFile:close()
    a.nsamples = idxs-1

    return a
end

demo_multi.lua:

require 'paths'
paths.dofile('util.lua')
paths.dofile('img.lua')

--------------------------------------------------------------------------------
-- Initialization
--------------------------------------------------------------------------------
a =  loadImageNames(arg[1])

m = torch.load('umich-stacked-hourglass.t7')   -- Load pre-trained model

 -- Displays a convenient progress bar
idxs = torch.range(1, a.nsamples)
nsamples = idxs:nElement() 

xlua.progress(0,nsamples)
preds = torch.Tensor(nsamples,16,2)
imgs = torch.Tensor(nsamples,3,256,256)

--------------------------------------------------------------------------------
-- Main loop
--------------------------------------------------------------------------------
for i = 1,nsamples do
    -- Set up input image
    --print(a'images'])
    local im = image.load('image/' .. a'images'])

    -- Get network output
    local out = m:forward(im:view(1,3,256,256):cuda())
    cutorch.synchronize()
    local hms = out#out:float()
    hms[hms:lt(0)] = 0

    -- Get predictions (hm and img refer to the coordinate space)
    if hms:size():size() == 3 then 
    hms = hms:view(1, hms:size(1), hms:size(2), hms:size(3)) 
    end

    ---- Get locations of maximum activations
    local max, idx = torch.max(hms:view(hms:size(1), hms:size(2), hms:size(3) * hms:size(4)), 3)
    local preds_img = torch.repeatTensor(idx, 1, 1, 2):float()
    preds_img[{{}, {}, 1}]:apply(function(x) return (x - 1) % hms:size(4) + 1 end)
    preds_img[{{}, {}, 2}]:add(-1):div(hms:size(3)):floor():add(.5)

    preds[i]:copy(preds_img)
    imgs[i]:copy(im)

    xlua.progress(i,nsamples)

    collectgarbage()
end

-- Save predictions
local predFile = hdf5.open('preds/preds.h5', 'w')
predFile:write('preds', preds)
predFile:write('imgs', imgs)
predFile:close()

<h3>3.3 利用Python可视化结果:</h3>

#!/usr/bin/env python
import h5py
import scipy.misc as scm
import matplotlib.pyplot as plt

f = h5py.File('./preds/preds.h5','r')
imgs = f'imgs'
preds = f'preds'
f.close()

assert len(imgs) == len(preds)
for i in range(len(imgs)):
    pose = preds[i]*4  # 输入图片是 256×256,输出是64×64,4倍处理  
    img = imgs[i].transpose(1,2,0)
    plt.axis('off')
    plt.imshow(img)
    for i in range(16):
        if posei > 0 and posei > 0:
            plt.scatter(posei, posei, marker='o', color='r', s=15)
    plt.show()

print 'Done.'

结果如下:



Last modification:October 9th, 2018 at 09:31 am