AIHGF

基于Autoencoder和Keras的连衣裙分割[译]
原文:Dress Segmentation with Autoencoder in Keras - 2019.06...
扫描右侧二维码阅读全文
04
2019/09

基于Autoencoder和Keras的连衣裙分割[译]

原文:Dress Segmentation with Autoencoder in Keras - 2019.06.01

作者:Marco Cerliani

时尚行业是AI一个具有广阔应用场景的领域,有很多有意思的研究价值空间.

这里,将会开发一个图片中连衣裙分割的系统,其针对输入的原始图片,如网络下载或者手机所拍摄的,提取图片中的连衣裙主体. 分割的一个难点在于,原始图片中可能存在较多的噪声. 这里尝试开发一种在预处理阶段来技巧性的处理该问题.

1. 数据集

自己构建的服装数据集.

首先从网络上收集一些图片,其包含了很多不同场景下、穿着不同类型连衣裙的女性;然后创建masks:这是每个目标分割任务中都是必要的,尤其是想要训练针对目标任务的模型. 最后得到如下数据集:

在该任务中,引文需要将图片分为背景、皮肤和连衣裙. 不过,对于连衣裙提取任务来说,背景和皮肤是噪声的主要来源. 例如:

最后一步处理,将三种维度(背景、皮肤和连衣裙)的二值图像合并为一张图片上. 合并后的图片解码了原始图片中我们所感兴趣的相关特征. 因为目的是,从图片中分割背景、皮肤和连衣裙.

对数据集中的每张图片进行相同的处理,即构建完成.

2. 模型

创建分割模型的过程非常简单.

对于待创建的模型而言,其输入为一张原始图像,并输出一个三维 mask,即能将原始图片分割为背景、皮肤和连衣裙. 根据分割后的三部分,只需提取其中连衣裙通道,即可分割出连衣裙,实现任务目的.

这里采用 UNet 模型,该深度卷积 Autoencoder 往往被用于类似的分割任务. 且其在 Keras 很容易实现.

模型训练前,需要对所有原始图片进行 RGB 均值标准化处理.

3. 结果和预测

模型预测时,当图片噪声比较高时(如,背景或者皮肤模糊),模型就开始震荡. 对此,可以简单的通过增加训练图片的数据量来解决. 但这里还是采用技巧性的方法来处理该问题.

采用了 OpenCV 的 GrubCut 算法. 该算法主要是采用高斯混合模型(Gaussian Mixture Model) 来分割前景和背景物体. 在这里,其有助于辅助定位图片中的人物.

例如:

def cut(img):    
    img = cv2.resize(img,(224,224))
    mask = np.zeros(img.shape[:2],np.uint8)
    bgdModel = np.zeros((1,65),np.float64)
    fgdModel = np.zeros((1,65),np.float64)
    height, width = img.shape[:2]    
    rect = (50,10,width-100,height-20)
    #
    cv2.grabCut(img,mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
    mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
    img2 = img*mask2[:,:,np.newaxis]
    img2[mask2 == 0] = (255, 255, 255)
    
    final = np.ones(img.shape,np.uint8)*0 + img2
    
    return mask, final

然后,采用 UNet 进行预测(Input - GrubCut+UNet - Final Dress):

4. Keras 实现

4.1. UNet 训练

#!--*-- coding: utf-8 --*--
import os
from PIL import Image
import numpy as np
import cv2
from matplotlib import pyplot as plt
import imutils
import pickle
import tqdm

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, UpSampling2D
from tensorflow.keras.layers import BatchNormalization, Activation, ReLU
from tensorflow.keras.layers import Input, Dropout, MaxPooling2D
from tensorflow.keras.layers import concatenate, Conv2DTranspose
from tensorflow.keras import backend as K
from tensorflow.keras import losses
from tensorflow.keras.optimizers import Adam

#[1]数据集
#显示一张图片示例
original = cv2.imread('./original77.jpg')
original = cv2.resize(original,(224,224))
dress = cv2.imread('./dress77.jpg')
dress = cv2.resize(dress,(224,224))
body = cv2.imread('./body77.jpg')
body = cv2.resize(body,(224,224))

plt.figure(figsize=(16,8))
plt.subplot(1,3,1)
plt.title('Original')
plt.imshow(cv2.cvtColor(original, cv2.COLOR_BGRA2RGB))
plt.subplot(1,3,2)
plt.title('Person')
plt.imshow(cv2.cvtColor(body, cv2.COLOR_BGRA2RGB))
plt.subplot(1,3,3)
plt.title('Dress')
plt.imshow(cv2.cvtColor(dress, cv2.COLOR_BGRA2RGB))


#[2]语义分割图片准备
dress = cv2.imread('./dress77.jpg',0)
body = cv2.imread('./body77.jpg',0)

### ENCODE DRESS ###
dress[dress == 255] = 0
dress[dress > 0] = 255
dress = cv2.resize(dress,(224,224))

### ENCODE BODY ###
body[body == 255] = 0
body[body > 0] = 255
body = cv2.resize(body,(224,224))

### ENCODE SKIN ###
skin = body - dress

plt.figure(figsize=(16,8))
plt.subplot(1,3,1)
plt.title('Person/Background')
bg = (255 - body)/255
plt.imshow(bg)
plt.subplot(1,3,2)
plt.title('Skin')
skin = (255 - skin)/255
plt.imshow(skin)
plt.subplot(1,3,3)
plt.title('Dress')
dress = (255 - dress)/255
plt.imshow(dress)

### COMBINE BACKGROUND, SKIN, DRESS ###
gt = np.zeros((224,224,3))
gt[:,:,0] = (1-skin)
gt[:,:,1] = (1-dress)
gt[:,:,2] = bg

plt.figure(figsize=(6,6))
plt.imshow(gt)


### ENCODE BACKGROUND, SKIN, DRESS FOR ALL TRAIN IMAGES ###
images_original = []
images_gt = []

mean = np.zeros((224,224,3))
n_img = 81

for i in tqdm.tqdm(range(1,n_img+1)):
    original = cv2.imread('./data/original/original'+str(i)+'.jpg')
    original = cv2.resize(original,(224,224))
    images_original.append(original)
    mean[:,:,0]=mean[:,:,0]+original[:,:,0]
    mean[:,:,1]=mean[:,:,1]+original[:,:,1]
    mean[:,:,2]=mean[:,:,2]+original[:,:,2]
    
    body = cv2.imread('./data/body/body'+str(i)+'.jpg',0)
    dress = cv2.imread('./data/dress/dress'+str(i)+'.jpg',0)
    
    dress[dress == 255] = 0
    dress[dress > 0] = 255
    dress = cv2.resize(dress,(224,224))

    body[body == 255] = 0
    body[body > 0] = 255
    body = cv2.resize(body,(224,224))

    skin = body - dress
    bg = (255 - body)/255
    skin = (255 - skin)/255
    dress = (255 - dress)/255
    
    gt = np.zeros((224,224,3))
    gt[:,:,0] = (1-skin)
    gt[:,:,1] = (1-dress)
    gt[:,:,2] = bg
   
    images_gt.append(gt)

mean = mean / n_img
mean = mean.astype('int')

#[3]UNet 训练

def custom_activation(x):
    return K.relu(x, alpha=0.0, max_value=1)


def focal_loss(gamma=2., alpha=.25):
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0))
    return focal_loss_fixed

smooth = 1.

def get_unet(do=0, activation=ReLU):
    inputs = Input((None, None, 3))
    conv1 = Dropout(do)(activation()(Conv2D(32, (3, 3), padding='same')(inputs)))
    conv1 = Dropout(do)(activation()(Conv2D(32, (3, 3), padding='same')(conv1)))
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Dropout(do)(activation()(Conv2D(64, (3, 3), padding='same')(pool1)))
    conv2 = Dropout(do)(activation()(Conv2D(64, (3, 3), padding='same')(conv2)))
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Dropout(do)(activation()(Conv2D(128, (3, 3), padding='same')(pool2)))
    conv3 = Dropout(do)(activation()(Conv2D(128, (3, 3), padding='same')(conv3)))
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Dropout(do)(activation()(Conv2D(256, (3, 3), padding='same')(pool3)))
    conv4 = Dropout(do)(activation()(Conv2D(256, (3, 3), padding='same')(conv4)))
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Dropout(do)(activation()(Conv2D(512, (3, 3), padding='same')(pool4)))
    conv5 = Dropout(do)(activation()(Conv2D(512, (3, 3), padding='same')(conv5)))

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Dropout(do)(activation()(Conv2D(256, (3, 3), padding='same')(up6)))
    conv6 = Dropout(do)(activation()(Conv2D(256, (3, 3), padding='same')(conv6)))

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Dropout(do)(activation()(Conv2D(128, (3, 3), padding='same')(up7)))
    conv7 = Dropout(do)(activation()(Conv2D(128, (3, 3), padding='same')(conv7)))

    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Dropout(do)(activation()(Conv2D(64, (3, 3), padding='same')(up8)))
    conv8 = Dropout(do)(activation()(Conv2D(64, (3, 3), padding='same')(conv8)))

    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
    conv9 = Dropout(do)(activation()(Conv2D(32, (3, 3), padding='same')(up9)))
    conv9 = Dropout(do)(activation()(Conv2D(32, (3, 3), padding='same')(conv9)))

    conv10 = Conv2D(3, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])

    model.compile(optimizer=Adam(lr=1e-3), loss=losses.binary_crossentropy, metrics=['accuracy'])

    return model

#
x_raw = np.asarray(images_original) - mean.reshape(-1,224,224,3) 
x_gt = np.asarray(images_gt).reshape(-1,224,224,3)

model = get_unet()
model.summary()

#
history = model.fit(x_raw, x_gt, epochs=120)

#
model.save('./fashion_unet.h5')

model.summary()输出如下:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_7 (InputLayer)            (None, None, None, 3 0                                            
__________________________________________________________________________________________________
conv2d_114 (Conv2D)             (None, None, None, 3 896         input_7[0][0]                    
__________________________________________________________________________________________________
re_lu_108 (ReLU)                (None, None, None, 3 0           conv2d_114[0][0]                 
__________________________________________________________________________________________________
dropout_108 (Dropout)           (None, None, None, 3 0           re_lu_108[0][0]                  
__________________________________________________________________________________________________
conv2d_115 (Conv2D)             (None, None, None, 3 9248        dropout_108[0][0]                
__________________________________________________________________________________________________
re_lu_109 (ReLU)                (None, None, None, 3 0           conv2d_115[0][0]                 
__________________________________________________________________________________________________
dropout_109 (Dropout)           (None, None, None, 3 0           re_lu_109[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_24 (MaxPooling2D) (None, None, None, 3 0           dropout_109[0][0]                
__________________________________________________________________________________________________
conv2d_116 (Conv2D)             (None, None, None, 6 18496       max_pooling2d_24[0][0]           
__________________________________________________________________________________________________
re_lu_110 (ReLU)                (None, None, None, 6 0           conv2d_116[0][0]                 
__________________________________________________________________________________________________
dropout_110 (Dropout)           (None, None, None, 6 0           re_lu_110[0][0]                  
__________________________________________________________________________________________________
conv2d_117 (Conv2D)             (None, None, None, 6 36928       dropout_110[0][0]                
__________________________________________________________________________________________________
re_lu_111 (ReLU)                (None, None, None, 6 0           conv2d_117[0][0]                 
__________________________________________________________________________________________________
dropout_111 (Dropout)           (None, None, None, 6 0           re_lu_111[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_25 (MaxPooling2D) (None, None, None, 6 0           dropout_111[0][0]                
__________________________________________________________________________________________________
conv2d_118 (Conv2D)             (None, None, None, 1 73856       max_pooling2d_25[0][0]           
__________________________________________________________________________________________________
re_lu_112 (ReLU)                (None, None, None, 1 0           conv2d_118[0][0]                 
__________________________________________________________________________________________________
dropout_112 (Dropout)           (None, None, None, 1 0           re_lu_112[0][0]                  
__________________________________________________________________________________________________
conv2d_119 (Conv2D)             (None, None, None, 1 147584      dropout_112[0][0]                
__________________________________________________________________________________________________
re_lu_113 (ReLU)                (None, None, None, 1 0           conv2d_119[0][0]                 
__________________________________________________________________________________________________
dropout_113 (Dropout)           (None, None, None, 1 0           re_lu_113[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_26 (MaxPooling2D) (None, None, None, 1 0           dropout_113[0][0]                
__________________________________________________________________________________________________
conv2d_120 (Conv2D)             (None, None, None, 2 295168      max_pooling2d_26[0][0]           
__________________________________________________________________________________________________
re_lu_114 (ReLU)                (None, None, None, 2 0           conv2d_120[0][0]                 
__________________________________________________________________________________________________
dropout_114 (Dropout)           (None, None, None, 2 0           re_lu_114[0][0]                  
__________________________________________________________________________________________________
conv2d_121 (Conv2D)             (None, None, None, 2 590080      dropout_114[0][0]                
__________________________________________________________________________________________________
re_lu_115 (ReLU)                (None, None, None, 2 0           conv2d_121[0][0]                 
__________________________________________________________________________________________________
dropout_115 (Dropout)           (None, None, None, 2 0           re_lu_115[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_27 (MaxPooling2D) (None, None, None, 2 0           dropout_115[0][0]                
__________________________________________________________________________________________________
conv2d_122 (Conv2D)             (None, None, None, 5 1180160     max_pooling2d_27[0][0]           
__________________________________________________________________________________________________
re_lu_116 (ReLU)                (None, None, None, 5 0           conv2d_122[0][0]                 
__________________________________________________________________________________________________
dropout_116 (Dropout)           (None, None, None, 5 0           re_lu_116[0][0]                  
__________________________________________________________________________________________________
conv2d_123 (Conv2D)             (None, None, None, 5 2359808     dropout_116[0][0]                
__________________________________________________________________________________________________
re_lu_117 (ReLU)                (None, None, None, 5 0           conv2d_123[0][0]                 
__________________________________________________________________________________________________
dropout_117 (Dropout)           (None, None, None, 5 0           re_lu_117[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_24 (Conv2DTran (None, None, None, 2 524544      dropout_117[0][0]                
__________________________________________________________________________________________________
concatenate_24 (Concatenate)    (None, None, None, 5 0           conv2d_transpose_24[0][0]        
                                                                 dropout_115[0][0]                
__________________________________________________________________________________________________
conv2d_124 (Conv2D)             (None, None, None, 2 1179904     concatenate_24[0][0]             
__________________________________________________________________________________________________
re_lu_118 (ReLU)                (None, None, None, 2 0           conv2d_124[0][0]                 
__________________________________________________________________________________________________
dropout_118 (Dropout)           (None, None, None, 2 0           re_lu_118[0][0]                  
__________________________________________________________________________________________________
conv2d_125 (Conv2D)             (None, None, None, 2 590080      dropout_118[0][0]                
__________________________________________________________________________________________________
re_lu_119 (ReLU)                (None, None, None, 2 0           conv2d_125[0][0]                 
__________________________________________________________________________________________________
dropout_119 (Dropout)           (None, None, None, 2 0           re_lu_119[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_25 (Conv2DTran (None, None, None, 1 131200      dropout_119[0][0]                
__________________________________________________________________________________________________
concatenate_25 (Concatenate)    (None, None, None, 2 0           conv2d_transpose_25[0][0]        
                                                                 dropout_113[0][0]                
__________________________________________________________________________________________________
conv2d_126 (Conv2D)             (None, None, None, 1 295040      concatenate_25[0][0]             
__________________________________________________________________________________________________
re_lu_120 (ReLU)                (None, None, None, 1 0           conv2d_126[0][0]                 
__________________________________________________________________________________________________
dropout_120 (Dropout)           (None, None, None, 1 0           re_lu_120[0][0]                  
__________________________________________________________________________________________________
conv2d_127 (Conv2D)             (None, None, None, 1 147584      dropout_120[0][0]                
__________________________________________________________________________________________________
re_lu_121 (ReLU)                (None, None, None, 1 0           conv2d_127[0][0]                 
__________________________________________________________________________________________________
dropout_121 (Dropout)           (None, None, None, 1 0           re_lu_121[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_26 (Conv2DTran (None, None, None, 6 32832       dropout_121[0][0]                
__________________________________________________________________________________________________
concatenate_26 (Concatenate)    (None, None, None, 1 0           conv2d_transpose_26[0][0]        
                                                                 dropout_111[0][0]                
__________________________________________________________________________________________________
conv2d_128 (Conv2D)             (None, None, None, 6 73792       concatenate_26[0][0]             
__________________________________________________________________________________________________
re_lu_122 (ReLU)                (None, None, None, 6 0           conv2d_128[0][0]                 
__________________________________________________________________________________________________
dropout_122 (Dropout)           (None, None, None, 6 0           re_lu_122[0][0]                  
__________________________________________________________________________________________________
conv2d_129 (Conv2D)             (None, None, None, 6 36928       dropout_122[0][0]                
__________________________________________________________________________________________________
re_lu_123 (ReLU)                (None, None, None, 6 0           conv2d_129[0][0]                 
__________________________________________________________________________________________________
dropout_123 (Dropout)           (None, None, None, 6 0           re_lu_123[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_27 (Conv2DTran (None, None, None, 3 8224        dropout_123[0][0]                
__________________________________________________________________________________________________
concatenate_27 (Concatenate)    (None, None, None, 6 0           conv2d_transpose_27[0][0]        
                                                                 dropout_109[0][0]                
__________________________________________________________________________________________________
conv2d_130 (Conv2D)             (None, None, None, 3 18464       concatenate_27[0][0]             
__________________________________________________________________________________________________
re_lu_124 (ReLU)                (None, None, None, 3 0           conv2d_130[0][0]                 
__________________________________________________________________________________________________
dropout_124 (Dropout)           (None, None, None, 3 0           re_lu_124[0][0]                  
__________________________________________________________________________________________________
conv2d_131 (Conv2D)             (None, None, None, 3 9248        dropout_124[0][0]                
__________________________________________________________________________________________________
re_lu_125 (ReLU)                (None, None, None, 3 0           conv2d_131[0][0]                 
__________________________________________________________________________________________________
dropout_125 (Dropout)           (None, None, None, 3 0           re_lu_125[0][0]                  
__________________________________________________________________________________________________
conv2d_132 (Conv2D)             (None, None, None, 3 99          dropout_125[0][0]                
==================================================================================================
Total params: 7,760,163
Trainable params: 7,760,163
Non-trainable params: 0

model.fit训练日志输出:

Epoch 1/120
81/81 [==============================] - 4s 52ms/sample - loss: 52.0267 - acc: 0.5463
Epoch 2/120
81/81 [==============================] - 3s 32ms/sample - loss: 33.1465 - acc: 0.7927
Epoch 3/120
81/81 [==============================] - 3s 32ms/sample - loss: 21.0042 - acc: 0.7896
Epoch 4/120
81/81 [==============================] - 3s 31ms/sample - loss: 17.7321 - acc: 0.7446
Epoch 5/120
81/81 [==============================] - 3s 31ms/sample - loss: 15.3721 - acc: 0.7675
Epoch 6/120
81/81 [==============================] - 3s 31ms/sample - loss: 14.4174 - acc: 0.8018
Epoch 7/120
81/81 [==============================] - 3s 31ms/sample - loss: 13.9712 - acc: 0.8106
Epoch 8/120
81/81 [==============================] - 3s 32ms/sample - loss: 13.8914 - acc: 0.8119
Epoch 9/120
81/81 [==============================] - 3s 32ms/sample - loss: 13.3527 - acc: 0.8203
Epoch 10/120
81/81 [==============================] - 3s 32ms/sample - loss: 13.1950 - acc: 0.8235
Epoch 11/120
81/81 [==============================] - 3s 32ms/sample - loss: 12.8550 - acc: 0.8192
Epoch 12/120
81/81 [==============================] - 3s 32ms/sample - loss: 12.6445 - acc: 0.8206
Epoch 13/120
81/81 [==============================] - 3s 32ms/sample - loss: 12.7519 - acc: 0.8322
Epoch 14/120
81/81 [==============================] - 3s 32ms/sample - loss: 12.5392 - acc: 0.8347
Epoch 15/120
81/81 [==============================] - 3s 31ms/sample - loss: 12.2611 - acc: 0.8331
Epoch 16/120
81/81 [==============================] - 3s 32ms/sample - loss: 12.2062 - acc: 0.8374
Epoch 17/120
81/81 [==============================] - 3s 32ms/sample - loss: 11.9259 - acc: 0.8380
Epoch 18/120
81/81 [==============================] - 3s 32ms/sample - loss: 12.0169 - acc: 0.8458
Epoch 19/120
81/81 [==============================] - 3s 32ms/sample - loss: 11.8951 - acc: 0.8453
Epoch 20/120
81/81 [==============================] - 3s 32ms/sample - loss: 11.5625 - acc: 0.8442
Epoch 21/120
81/81 [==============================] - 3s 32ms/sample - loss: 11.3960 - acc: 0.8500
Epoch 22/120
81/81 [==============================] - 3s 32ms/sample - loss: 11.6926 - acc: 0.8461
Epoch 23/120
81/81 [==============================] - 3s 32ms/sample - loss: 11.3838 - acc: 0.8522
Epoch 24/120
81/81 [==============================] - 3s 31ms/sample - loss: 11.2607 - acc: 0.8570
Epoch 25/120
81/81 [==============================] - 3s 32ms/sample - loss: 11.0254 - acc: 0.8585
Epoch 26/120
81/81 [==============================] - 3s 32ms/sample - loss: 11.2450 - acc: 0.8543
Epoch 27/120
81/81 [==============================] - 3s 31ms/sample - loss: 10.8368 - acc: 0.8589
...

4.2. 预测

from keras.applications import vgg16, resnet50
from keras.preprocessing.image import load_img,img_to_array
from keras.models import Model, load_model
from keras.applications.imagenet_utils import preprocess_input
import tensorflow

from PIL import Image
import cv2
import requests
from io import BytesIO
import os
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import pickle


#[1]GrubCut 函数
def cut(img):
    img = cv2.resize(img,(224,224))
    mask = np.zeros(img.shape[:2],np.uint8)
    bgdModel = np.zeros((1,65),np.float64)
    fgdModel = np.zeros((1,65),np.float64)
    height, width = img.shape[:2]

    rect = (50,10,width-100,height-20)
    cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
    mask2 = np.where((mask==2)|(mask==0),0,1).astype('uint8')
    img2 = img*mask2[:,:,np.newaxis]
    img2[mask2 == 0] = (255, 255, 255)
    
    final = np.ones(img.shape,np.uint8)*0 + img2
    
    return mask, final


#[2]加载UNet预训练模型
UNET = tensorflow.keras.models.load_model('./fashion_unet.h5')

#[3]GrubCut 处理
plt.figure(figsize=(16,8))
original = cv2.imread('testGrubCut.jpg')
original = cv2.resize(original,(224,224))

plt.subplot(1,3,1)
plt.imshow(cv2.cvtColor(original, cv2.COLOR_BGRA2RGB))
mask, final = cut(original)
plt.subplot(1,3,2)
plt.imshow(mask)
plt.subplot(1,3,3)
plt.imshow(cv2.cvtColor(final, cv2.COLOR_BGRA2RGB))

#[4]加载RGB均值
with open("./mean81.pkl", 'rb') as pickle_file:
    mean = pickle.load(pickle_file)
    
#[5]读取新图片
plt.figure(figsize=(16,8))
img = cv2.imread('test1.jpg')
plt.subplot(1,3,1)
plt.imshow(cv2.cvtColor(cv2.resize(img.copy(),(224,224)), cv2.COLOR_BGRA2RGB))

#[6]GrubCut + UNet
mask_test, test = cut(img)
test = test - mean.reshape(-1,224,224,3)
pred = UNET.predict(test)[0]
plt.subplot(1,3,2)
plt.imshow(pred)

#[7]像素阈值化得到Mask
pred_dress = pred.copy()[:,:,1]
pred_dress[pred_dress>=0.90]=1
pred_dress[pred_dress<0.90]=0
real_dress = (cv2.resize(img.copy(), (224,224))*pred_dress[:,:,np.newaxis]).astype('int')
real_dress[pred_dress == 0] = (255, 255, 255)
plt.subplot(1,3,3)
plt.imshow(cv2.cvtColor(real_dress.astype('uint8'), cv2.COLOR_BGRA2RGB))
Last modification:September 4th, 2019 at 12:52 pm

Leave a Comment