skip to content
Liu Yang's Blog

[实现]CNN Implementation in PyTorch

/ 1 min read

没什么好说的,就是Flatten后,第一个Linear的维度可以通过形参优化一下

class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1
)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(
in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1
)
self.relu2 = nn.ReLU()
self.mlp = nn.Sequential(
nn.Flatten(),
nn.Linear(3 * 28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.relu2(self.conv2(x))
return self.mlp(x)