skip to content
Liu Yang's Blog

PyTorch模型参数计算及model.parameters()方法解析

/ 1 min read

Updated:
Table of Contents

代码

pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params, pytorch_total_train_params

借助parameter手动实现梯度下降

前提是loss使用反向传播之后

for p in rnn.parameters():
p.data.add_(p.grad.data, alpha=-learning_rate)

解释

  1. numel(),“number of elements”,即获得张量元素数量
  2. Parameters()返回的是一个生成器,每个元素是一个Parameter
  3. Parameter是Tensor的子类,当它们被指定为模块属性时,会自动添加到模块的参数列表中,并出现在parameters()中

其独有的属性为data和requires_grad,data是对应的权重Tensor,requires_grad为是否需要保存梯度,梯度数据可从tensor.grad获取。

因为是tensor的子类,所以也可以通过parameter直接进行tensor操作,与操作parameter.data相同