[实现]ResNet Implementation in PyTorch
/ 2 min read
Updated:Table of Contents
思想
原始的方法需要学习一个直接的映射,,但是直接学习的难度比较大,学习对x的修改,更简单。就像做什么任务,照着模板改会更容易。
所以对于正常块只学习,减少了网络的工作量
主要代码
class ShortcutProject(nn.Module): def __init__(self, in_channels=1, out_channels=1, kernel_size=3, stride=1) -> None: super().__init__() self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, padding=1, bias=None, ) self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x): return self.bn(self.conv(x))
class ResNetBlock(nn.Module): def __init__(self, in_channels=1, out_channels=1, stride=1, kernel_size=3) -> None: super().__init__() self.conv1 = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, stride=stride, kernel_size=kernel_size, padding=1, bias=None, ) self.bn1 = nn.BatchNorm2d(out_channels) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d( in_channels=out_channels, out_channels=out_channels, stride=1, kernel_size=3, padding=1, bias=None, ) self.bn2 = nn.BatchNorm2d(out_channels) if stride != 1 or in_channels != out_channels: # 对输入X self.shortcut = ShortcutProject(in_channels, out_channels, stride=stride) else: self.shortcut = nn.Identity() self.relu2 = nn.ReLU(inplace=True)
def forward(self, x): shortcut = self.shortcut(x) x1 = self.relu1(self.bn1(self.conv1(x))) x1 = self.bn2(self.conv2(x1)) return self.relu2(x1 + shortcut)定义模型
class ResNet(nn.Module): def __init__(self): super().__init__() self.res1 = ResNetBlock(in_channels=1, out_channels=128, kernel_size=3) self.res2 = ResNetBlock(in_channels=128, out_channels=128, kernel_size=3) self.res3 = ResNetBlock(in_channels=128, out_channels=128, kernel_size=3) # mlp替换为1x1卷积 self.con1x1 = nn.Conv2d(in_channels=128, out_channels=10, kernel_size=1) self.linear = nn.Linear(28 * 28 * 10, 10)
def forward(self, x): x = self.res1(x) x = self.res2(x) x = self.res3(x) x = self.con1x1(x) x = x.view(x.shape[0], -1) return self.linear(x)解释
- 有padding不需要扩充,1x1卷积控制通道数,最后一层MLP用于分类
- 先Conv -> Bn -> ReLU,ShortcutProject不能有ReLU,只能有前两项,后面加了之后再过ReLU层
- ReLU中的inplace参数表示原地操作,省时省空间,类似于tensor.zero_()这种,原地变0
- x只考虑单个样本即可,无需考虑整个batch