显存一直都不知道怎么算的,不如有空总结一下大概的计算方法,才懂得训练如何设置

以下假设是float32

  • 参数

  • 梯度

    梯度与参数量其实一一对应

  • 优化器

    每个参数要保存动量,方差等等,一个占用4字节

  • Forward 前向传播需要计算存储激活值用于后续反向传播

    其中, 是输出通道数, 是高度, 是宽度