[实现]U-Net Implementation in PyTorch
/ 5 min read
Updated:Table of Contents
UNet是用于图像分割的经典网络,
模型结构
部分代码
class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None) -> None: super().__init__() """(convolution => [BN] => ReLU) * 2""" if not mid_channels: mid_channels = out_channels
self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), )
def forward(self, x): return self.double_conv(x)
class Down(nn.Module): """Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels) -> None: super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) )
def forward(self, x): return self.maxpool_conv(x)
class Up(nn.Module): """Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True) -> None: super().__init__()
# if bilinear,use the normal convolutions to reduce the number of channenls
if bilinear: self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d( in_channels, in_channels // 2, kernel_size=2, stride=2 ) self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] # (填充X1至X2相同大小) x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x)
class OutConv(nn.Module): def __init__(self, in_channels, out_channels) -> None: super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x): return self.conv(x)模型代码
class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=False): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64 // factor, bilinear) self.outc = OutConv(64, n_classes)
def forward(self, x): # if bilinear = True, then factor = 2 x1 = self.inc(x) # x1 = (64,h,w), (64,h,w) x2 = self.down1(x1) # x2 = (128, h, w), if factor = 1 , output = (128, h, w) x3 = self.down2(x2) # x3 = (256, h, w), if factor = 1 , output = (256, h, w) x4 = self.down3(x3) # x4 = (512, h, w), if factor = 1 , output = (512, h, w) x5 = self.down4(x4) # x5 = (512, h, w), if factor = 1 , output = (1024, h, w) x = self.up1(x5, x4) # x = (256, h, w), if factor = 1 , input c1 = 1024 ,c2 = 512,output = (512, h, w) x = self.up2(x, x3) # x = (128, h, w), if factor = 1 , input c1 = 512, c2 = 256, output = (256, h, w) x = self.up3(x, x2) # x = (64, h, w), if factor = 1 , input c1 = 256, c2 = 128, output = (128, h, w) x = self.up4(x, x1) # x = (32, h, w), if factor = 1 , input c1 = 128, c2 = 64, output = (64, h, w) logits = self.outc(x) # x = (n_classes) return logits关键解释
F.pad
文中代码为
F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
主要作用是对齐特征图的大小,方便跨越连接时做cat,x1即操作的tensor,后面的参数是填0的位置,主要是将特征图填充到中间部分,顺序是从最后一个维度开始的,对于二维图像,顺序就是左,右,上,下。
ConvTranspose2d
转置卷积又叫反卷积,类似于双线性插值,都可以将图像放大二倍

输出结果
img = 3pd = 0k = 2s = 2
def result_size_fn(img, k, s, pd): # 先对原图进行填充,若有步长则拉大原图像素间距,进行滑动,缩减填充 # 可见,通过调整步长s进行图像缩放, # (卷积核预填充) + (步长填充) - (滑动窗口) - (padding填充) return (img + (k - 1) * 2) + (img - 1) * (s - 1) - (k - 1) - pd
Tconv = nn.ConvTranspose2d( in_channels=1, out_channels=1, kernel_size=k, stride=s, padding=pd)x = torch.rand(1, 1, img, img)img, Tconv(x).shape[-1], result_size_fn(img, k, s, pd)参考链接
- 一文搞懂反卷积和转置卷积-极市开发者社区 (cvmart.net)
- 卷积神经网络CNN——图像卷积与反卷积(后卷积,转置卷积) | 电子创新网 Imgtec 社区 (eetrend.com)
- 转置卷积(Transpose Convolution)
- 直接理解转置卷积(Transposed convolution)的各种情况
代码对是否使用双线性插值的不同处理
如果使用双线性插值,则factor为2,不使用则为1
当使用双线性插值时,跨越连接不会缩减输入x的维度,和跨越连接的x’,进行cat后的维度就是下一维度所需要的
当使用反卷积时,反卷积缩减输入x一半的特征图,再与跨越连接x’做cat,这里应该是为了贴近论文UNet的实现