skip to content
Liu Yang's Blog

[实现]U-Net Implementation in PyTorch

/ 5 min read

Updated:
Table of Contents

UNet是用于图像分割的经典网络,

模型结构

image-20240929194644377.png

部分代码

class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None) -> None:
super().__init__()
"""(convolution => [BN] => ReLU) * 2"""
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels) -> None:
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True) -> None:
super().__init__()
# if bilinear,use the normal convolutions to reduce the number of channenls
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(
in_channels, in_channels // 2, kernel_size=2, stride=2
)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
# (填充X1至X2相同大小)
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)

模型代码

class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64 // factor, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
# if bilinear = True, then factor = 2
x1 = self.inc(x) # x1 = (64,h,w), (64,h,w)
x2 = self.down1(x1) # x2 = (128, h, w), if factor = 1 , output = (128, h, w)
x3 = self.down2(x2) # x3 = (256, h, w), if factor = 1 , output = (256, h, w)
x4 = self.down3(x3) # x4 = (512, h, w), if factor = 1 , output = (512, h, w)
x5 = self.down4(x4) # x5 = (512, h, w), if factor = 1 , output = (1024, h, w)
x = self.up1(x5, x4) # x = (256, h, w), if factor = 1 , input c1 = 1024 ,c2 = 512,output = (512, h, w)
x = self.up2(x, x3) # x = (128, h, w), if factor = 1 , input c1 = 512, c2 = 256, output = (256, h, w)
x = self.up3(x, x2) # x = (64, h, w), if factor = 1 , input c1 = 256, c2 = 128, output = (128, h, w)
x = self.up4(x, x1) # x = (32, h, w), if factor = 1 , input c1 = 128, c2 = 64, output = (64, h, w)
logits = self.outc(x) # x = (n_classes)
return logits

关键解释

F.pad

文中代码为

F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

主要作用是对齐特征图的大小,方便跨越连接时做cat,x1即操作的tensor,后面的参数是填0的位置,主要是将特征图填充到中间部分,顺序是从最后一个维度开始的,对于二维图像,顺序就是左,右,上,下。

ConvTranspose2d

转置卷积又叫反卷积,类似于双线性插值,都可以将图像放大二倍 2024-09-30T07:24.png

输出结果

img = 3
pd = 0
k = 2
s = 2
def result_size_fn(img, k, s, pd):
# 先对原图进行填充,若有步长则拉大原图像素间距,进行滑动,缩减填充
# 可见,通过调整步长s进行图像缩放,
# (卷积核预填充) + (步长填充) - (滑动窗口) - (padding填充)
return (img + (k - 1) * 2) + (img - 1) * (s - 1) - (k - 1) - pd
Tconv = nn.ConvTranspose2d(
in_channels=1, out_channels=1, kernel_size=k, stride=s, padding=pd
)
x = torch.rand(1, 1, img, img)
img, Tconv(x).shape[-1], result_size_fn(img, k, s, pd)

参考链接

  1. 一文搞懂反卷积和转置卷积-极市开发者社区 (cvmart.net)
  2. 卷积神经网络CNN——图像卷积与反卷积(后卷积,转置卷积) | 电子创新网 Imgtec 社区 (eetrend.com)
  3. 转置卷积(Transpose Convolution)
  4. 直接理解转置卷积(Transposed convolution)的各种情况

代码对是否使用双线性插值的不同处理

如果使用双线性插值,则factor为2,不使用则为1

当使用双线性插值时,跨越连接不会缩减输入x的维度,和跨越连接的x’,进行cat后的维度就是下一维度所需要的

当使用反卷积时,反卷积缩减输入x一半的特征图,再与跨越连接x’做cat,这里应该是为了贴近论文UNet的实现

其他

论文:https://arxiv.org/pdf/1505.04597