U-Net: Convolutional Networks for Biomedical Image Segmentation - 2015

UNet 最早被提出应用到医学图像分析中. 由于其网络的简单易懂,已被广泛应用于很多语义分割场景,如:Github 项目 - Kaggle 车辆边界识别之 UNet.

U-Net 网络结构

Pytorch 网络定义

Github 项目 - Kaggle 车辆边界识别之 UNet 给出了一种 U-Net 的网络结构定义与使用.

这里参考另一种定义 - pytorch-semseg/ptsemseg/models/unet.py,其支持可设定 deconvolution 和 batchnorm.

import torch.nn as nn

class unetConv2(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm):
        super(unetConv2, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_size, out_size, 3, 1, 0),
                nn.BatchNorm2d(out_size),
                nn.ReLU(), )
            self.conv2 = nn.Sequential(
                nn.Conv2d(out_size, out_size, 3, 1, 0),
                nn.BatchNorm2d(out_size),
                nn.ReLU(), )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_size, out_size, 3, 1, 0),
                nn.ReLU() )
            self.conv2 = nn.Sequential(
                nn.Conv2d(out_size, out_size, 3, 1, 0),
                nn.ReLU() )

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class unetUp(nn.Module):
    def __init__(self, in_size, out_size, is_deconv):
        super(unetUp, self).__init__()
        self.conv = unetConv2(in_size, out_size, False)
        if is_deconv:
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, inputs1, inputs2):
        outputs2 = self.up(inputs2)
        offset = outputs2.size()[2] - inputs1.size()[2]
        padding = 2 * [offset // 2, offset // 2]
        outputs1 = F.pad(inputs1, padding)
        return self.conv(torch.cat([outputs1, outputs2], 1))
    

class unet(nn.Module):
    def __init__(self,
                 feature_scale=1,
                 n_classes=2,
                 is_deconv=True,
                 in_channels=3,
                 is_batchnorm=True,
                ):
        super(unet, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]

        # downsampling
        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)

        # upsampling
        self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
        self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
        self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
        self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)

        # final conv (without any concat)
        self.final = nn.Conv2d(filters[0], n_classes, 1)

    def forward(self, inputs):
        conv1 = self.conv1(inputs)
        maxpool1 = self.maxpool1(conv1)

        conv2 = self.conv2(maxpool1)
        maxpool2 = self.maxpool2(conv2)

        conv3 = self.conv3(maxpool2)
        maxpool3 = self.maxpool3(conv3)

        conv4 = self.conv4(maxpool3)
        maxpool4 = self.maxpool4(conv4)

        center = self.center(maxpool4)
        up4 = self.up_concat4(conv4, center)
        up3 = self.up_concat3(conv3, up4)
        up2 = self.up_concat2(conv2, up3)
        up1 = self.up_concat1(conv1, up2)

        final = self.final(up1)

        return final

采用 torchsummary 库可以打印定义的 U-Net 的各层输出.

Python库 - torchsummary 打印 Pytorch 模型

import torch
import torch.nn.functional as F
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = unet().to(device)

summary(model, (3, 572, 572))

输出如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 570, 570]           1,792
       BatchNorm2d-2         [-1, 64, 570, 570]             128
              ReLU-3         [-1, 64, 570, 570]               0
            Conv2d-4         [-1, 64, 568, 568]          36,928
       BatchNorm2d-5         [-1, 64, 568, 568]             128
              ReLU-6         [-1, 64, 568, 568]               0
         unetConv2-7         [-1, 64, 568, 568]               0
         MaxPool2d-8         [-1, 64, 284, 284]               0
            Conv2d-9        [-1, 128, 282, 282]          73,856
      BatchNorm2d-10        [-1, 128, 282, 282]             256
             ReLU-11        [-1, 128, 282, 282]               0
           Conv2d-12        [-1, 128, 280, 280]         147,584
      BatchNorm2d-13        [-1, 128, 280, 280]             256
             ReLU-14        [-1, 128, 280, 280]               0
        unetConv2-15        [-1, 128, 280, 280]               0
        MaxPool2d-16        [-1, 128, 140, 140]               0
           Conv2d-17        [-1, 256, 138, 138]         295,168
      BatchNorm2d-18        [-1, 256, 138, 138]             512
             ReLU-19        [-1, 256, 138, 138]               0
           Conv2d-20        [-1, 256, 136, 136]         590,080
      BatchNorm2d-21        [-1, 256, 136, 136]             512
             ReLU-22        [-1, 256, 136, 136]               0
        unetConv2-23        [-1, 256, 136, 136]               0
        MaxPool2d-24          [-1, 256, 68, 68]               0
           Conv2d-25          [-1, 512, 66, 66]       1,180,160
      BatchNorm2d-26          [-1, 512, 66, 66]           1,024
             ReLU-27          [-1, 512, 66, 66]               0
           Conv2d-28          [-1, 512, 64, 64]       2,359,808
      BatchNorm2d-29          [-1, 512, 64, 64]           1,024
             ReLU-30          [-1, 512, 64, 64]               0
        unetConv2-31          [-1, 512, 64, 64]               0
        MaxPool2d-32          [-1, 512, 32, 32]               0
           Conv2d-33         [-1, 1024, 30, 30]       4,719,616
      BatchNorm2d-34         [-1, 1024, 30, 30]           2,048
             ReLU-35         [-1, 1024, 30, 30]               0
           Conv2d-36         [-1, 1024, 28, 28]       9,438,208
      BatchNorm2d-37         [-1, 1024, 28, 28]           2,048
             ReLU-38         [-1, 1024, 28, 28]               0
        unetConv2-39         [-1, 1024, 28, 28]               0
  ConvTranspose2d-40          [-1, 512, 56, 56]       2,097,664
           Conv2d-41          [-1, 512, 54, 54]       4,719,104
             ReLU-42          [-1, 512, 54, 54]               0
           Conv2d-43          [-1, 512, 52, 52]       2,359,808
             ReLU-44          [-1, 512, 52, 52]               0
        unetConv2-45          [-1, 512, 52, 52]               0
           unetUp-46          [-1, 512, 52, 52]               0
  ConvTranspose2d-47        [-1, 256, 104, 104]         524,544
           Conv2d-48        [-1, 256, 102, 102]       1,179,904
             ReLU-49        [-1, 256, 102, 102]               0
           Conv2d-50        [-1, 256, 100, 100]         590,080
             ReLU-51        [-1, 256, 100, 100]               0
        unetConv2-52        [-1, 256, 100, 100]               0
           unetUp-53        [-1, 256, 100, 100]               0
  ConvTranspose2d-54        [-1, 128, 200, 200]         131,200
           Conv2d-55        [-1, 128, 198, 198]         295,040
             ReLU-56        [-1, 128, 198, 198]               0
           Conv2d-57        [-1, 128, 196, 196]         147,584
             ReLU-58        [-1, 128, 196, 196]               0
        unetConv2-59        [-1, 128, 196, 196]               0
           unetUp-60        [-1, 128, 196, 196]               0
  ConvTranspose2d-61         [-1, 64, 392, 392]          32,832
           Conv2d-62         [-1, 64, 390, 390]          73,792
             ReLU-63         [-1, 64, 390, 390]               0
           Conv2d-64         [-1, 64, 388, 388]          36,928
             ReLU-65         [-1, 64, 388, 388]               0
        unetConv2-66         [-1, 64, 388, 388]               0
           unetUp-67         [-1, 64, 388, 388]               0
           Conv2d-68          [-1, 2, 388, 388]             130
================================================================
Total params: 31,039,746
Trainable params: 31,039,746
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.74
Forward/backward pass size (MB): 3136.33
Params size (MB): 118.41
Estimated Total Size (MB): 3258.48
----------------------------------------------------------------
Last modification:October 18th, 2018 at 10:48 am