没什么好说的,就是Flatten后,第一个Linear的维度可以通过形参优化一下
class CNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d( in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1 ) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d( in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1 ) self.relu2 = nn.ReLU() self.mlp = nn.Sequential( nn.Flatten(), nn.Linear(3 * 28 * 28, 512), nn.ReLU(), nn.Linear(512, 10), )
def forward(self, x): x = self.relu1(self.conv1(x)) x = self.relu2(self.conv2(x)) return self.mlp(x)