skip to content
Liu Yang's Blog

[实现]ResNet Implementation in PyTorch

/ 2 min read

Updated:
Table of Contents

思想

原始的方法需要学习一个直接的映射,f(x)f(x),但是直接学习的难度比较大,学习对x的修改,更简单。就像做什么任务,照着模板改会更容易。

所以对于正常块只学习f(x)xf(x)-x,减少了网络的工作量

image-20240930005344938

主要代码

class ShortcutProject(nn.Module):
def __init__(self, in_channels=1, out_channels=1, kernel_size=3, stride=1) -> None:
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
kernel_size=kernel_size,
padding=1,
bias=None,
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
return self.bn(self.conv(x))
class ResNetBlock(nn.Module):
def __init__(self, in_channels=1, out_channels=1, stride=1, kernel_size=3) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
kernel_size=kernel_size,
padding=1,
bias=None,
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
stride=1,
kernel_size=3,
padding=1,
bias=None,
)
self.bn2 = nn.BatchNorm2d(out_channels)
if stride != 1 or in_channels != out_channels:
# 对输入X
self.shortcut = ShortcutProject(in_channels, out_channels, stride=stride)
else:
self.shortcut = nn.Identity()
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
shortcut = self.shortcut(x)
x1 = self.relu1(self.bn1(self.conv1(x)))
x1 = self.bn2(self.conv2(x1))
return self.relu2(x1 + shortcut)

定义模型

class ResNet(nn.Module):
def __init__(self):
super().__init__()
self.res1 = ResNetBlock(in_channels=1, out_channels=128, kernel_size=3)
self.res2 = ResNetBlock(in_channels=128, out_channels=128, kernel_size=3)
self.res3 = ResNetBlock(in_channels=128, out_channels=128, kernel_size=3)
# mlp替换为1x1卷积
self.con1x1 = nn.Conv2d(in_channels=128, out_channels=10, kernel_size=1)
self.linear = nn.Linear(28 * 28 * 10, 10)
def forward(self, x):
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.con1x1(x)
x = x.view(x.shape[0], -1)
return self.linear(x)

解释

  1. 有padding不需要扩充,1x1卷积控制通道数,最后一层MLP用于分类
  2. 先Conv -> Bn -> ReLU,ShortcutProject不能有ReLU,只能有前两项,后面加了之后再过ReLU层
  3. ReLU中的inplace参数表示原地操作,省时省空间,类似于tensor.zero_()这种,原地变0
  4. x只考虑单个样本即可,无需考虑整个batch